In this HackerEarth The Grass Type problem solution Ash Catch'Em apparently loves only two things in his life: graph theory, and his grass type Poke'mons. And instead of using his grass type Poke'mons to be a better trainer, he just arranges them in a graph, and solve random questions.

Currently Ash is playing with his Bulbasaur. With its help, he has managed to create a tree of N vertices which are rooted at 1. Each of the vertices have labels namely A1, A2, ... , AN -1, AN.

Ash is trying to explain the concept of LCA, known as the Lowest Common Ancestor to his Poke'mons, so that they'll be able to understand how and what kind of evolution will they have to go through in future.

But we all know know that Ash is a little... slow at explaining things. So, he wants you to find the total number of unordered pairs (u,v) such that ALCA(u,v) == Au * Av, where u != v . He thinks that this equation will help him figure out the logic behind the evolution of grass type Poke'mons. Can you help him find the total number of such pairs?


HackerEarth The Grass Type problem solution


HackerEarth The Grass Type problem solution.

#include <iostream>
#include <sstream>
#include <fstream>
#include <string>
#include <iomanip>
#include <vector>
#include <deque>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <algorithm>
#include <functional>
#include <utility>
#include <bitset>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <cassert>
using namespace std;

#define rep(i,n) for(int i=0;i<n;i++)
#define ll long long int
#define pi pair<ll,ll>
#define pii pair<ll,pi>
#define f first
#define mp make_pair
#define mod 1000000007
#define s second
#define pb push_back
ll A[100011];
vector<int>g[100011];
map<ll,ll>M[100011];
ll ans=0;
int cnt=0;
int dfs(int v,int p){
cnt++;
int r = v;
M[v][A[v]]++;
ll cur=0;
rep(i,g[v].size()){
if(g[v][i]!=p){
int x = dfs(g[v][i],v);
if(M[r].size()<M[x].size()) swap(r,x);
for(map<ll,ll>::iterator it=M[x].begin();it!=M[x].end();it++){
if(A[v]%it->f==0){
if(M[r].find(A[v]/it->f)!=M[r].end())
ans+=it->s * M[r][A[v]/it->f];
}
}
for(map<ll,ll>::iterator it=M[x].begin();it!=M[x].end();it++){
M[r][it->f]+=it->s;
}
}
}
return r;
}

int main(){
int N,u,v;
cin >> N;
set<pi>s;
rep(i,N-1){
cin >> u >> v;
s.insert(mp(min(u,v),max(u,v)));
g[u].pb(v);
g[v].pb(u);
}
for(int i=1;i<=N;i++){
cin >> A[i];
}
dfs(1,0);
cout<<ans;
}