# HackerEarth Colorful Tree problem solution

In this HackerEarth Colorful Tree problem solution You are given a tree that contains N nodes, where every node i is colored with some color Ci.

The distance of a node V from a node U is defined as the number of edges along the simple path from the node U to the node V. Your task is to answer M queries of the following type:
•  K C: Determine the distance of most distant node of color C from node K. If there is no node of color C in the tree, then print -1.

## HackerEarth Colorful Tree problem solution.

`#include <bits/stdc++.h>using namespace std; #define IOS ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);#define endl "\n"const int N=5e5+5;const int LG=21;int n, k, q, tim=0, dist=0, node;int col[N];int parent[LG][N];int tin[N], tout[N], level[N], vertices[2*N];vector<int> g[N], contains[N], diam[N];vector<pair<int, int> > tree[N];void dfs(int k, int par, int lvl){    tin[k]=++tim;    parent[0][k]=par;    level[k]=lvl;    for(auto it:g[k])    {        if(it==par)            continue;        dfs(it, k, lvl+1);    }    tout[k]=tim;}void precompute(){    for(int i=1;i<LG;i++)        for(int j=1;j<=n;j++)            if(parent[i-1][j])                parent[i][j]=parent[i-1][parent[i-1][j]];}int LCA(int u, int v){    if(level[u]<level[v])        swap(u,v);    int diff=level[u]-level[v];    for(int i=LG-1;i>=0;i--)    {        if((1<<i) & diff)        {            u=parent[i][u];        }    }    if(u==v)        return u;    for(int i=LG-1;i>=0;i--)    {        if(parent[i][u] && parent[i][u]!=parent[i][v])        {            u=parent[i][u];            v=parent[i][v];        }    }    return parent[0][u];}int dist1(int u, int v){    return level[u]+level[v]-2*level[LCA(u, v)];}bool isancestor(int u, int v) //Check if u is an ancestor of v{    return (tin[u]<=tin[v]) && (tout[v]<=tout[u]);}int dfs2(int k, int par, int dis){    //cerr<<k<<endl;    if(dis>dist)    {        dist=dis;        node=k;    }    for(auto it:tree[k])    {        if(it.first==par)            continue;        dfs2(it.first, k, dis+it.second);    }}int work(int color){    sort(vertices+1, vertices+k+1, [](int a, int b)    {        return tin[a]<tin[b];    });    int idx=k;    for(int i=1;i<idx;i++)        vertices[++k]=LCA(vertices[i], vertices[i+1]);    sort(vertices+1, vertices+k+1);    k=unique(vertices+1, vertices+k+1) - vertices - 1;    sort(vertices+1, vertices+k+1, [](int a, int b)    {        return tin[a]<tin[b];    });    stack<int> s;    s.push(vertices[1]);    for(int i=2;i<=k;i++)    {        while(!isancestor(s.top(), vertices[i]))            s.pop();        int u=s.top();        int v=vertices[i];        int w=dist1(u, v);        tree[u].push_back({v, w});        tree[v].push_back({u, w});        s.push(vertices[i]);    }    dist=0;    dfs2(vertices[1], vertices[1], 1);    diam[color].push_back(node);    dfs2(node, node, 1);    diam[color].push_back(node);    for(int i=1;i<=k;i++)        tree[vertices[i]].clear();}int32_t main(){    IOS;    cin>>n>>q;    for(int i=1;i<=n;i++)    {        cin>>col[i];        contains[col[i]].push_back(i);    }    for(int i=1;i<=n-1;i++)    {        int u, v;        cin>>u>>v;        g[u].push_back(v);        g[v].push_back(u);    }    dfs(1, 0, 1);    precompute();    for(int i=1;i<=5e5;i++)    {        if(contains[i].size()<=2)        {            for(auto &it:contains[i])                diam[i].push_back(it);            continue;        }        k=0;        for(auto &it:contains[i])            vertices[++k]=it;        work(i);    }    while(q--)    {        int k, c;        cin>>k>>c;        if(!contains[c].size())            cout<<"-1"<<endl;        else         {            int ans=0;            for(auto &it: diam[c])                ans=max(ans, dist1(k, it));            cout<<ans<<endl;        }    }    return 0;}`

### Second solution

`#ifndef _GLIBCXX_NO_ASSERT#include <cassert>#endif#include <cctype>#include <cerrno>#include <cfloat>#include <ciso646>#include <climits>#include <clocale>#include <cmath>#include <csetjmp>#include <csignal>#include <cstdarg>#include <cstddef>#include <cstdio>#include <cstdlib>#include <cstring>#include <ctime>#if __cplusplus >= 201103L#include <ccomplex>#include <cfenv>#include <cinttypes>#include <cstdbool>#include <cstdint>#include <ctgmath>#include <cwchar>#include <cwctype>#endif// C++#include <algorithm>#include <bitset>#include <complex>#include <deque>#include <exception>#include <fstream>#include <functional>#include <iomanip>#include <ios>#include <iosfwd>#include <iostream>#include <istream>#include <iterator>#include <limits>#include <list>#include <locale>#include <map>#include <memory>#include <new>#include <numeric>#include <ostream>#include <queue>#include <set>#include <sstream>#include <stack>#include <stdexcept>#include <streambuf>#include <string>#include <typeinfo>#include <utility>#include <valarray>#include <vector>#if __cplusplus >= 201103L#include <array>#include <atomic>#include <chrono>#include <condition_variable>#include <forward_list>#include <future>#include <initializer_list>#include <mutex>#include <random>#include <ratio>#include <regex>#include <scoped_allocator>#include <system_error>#include <thread>#include <tuple>#include <typeindex>#include <type_traits>#include <unordered_map>#include <unordered_set>#endif#define ll          long long#define pb          push_back#define mp          make_pair#define pii         pair<int,int>#define vi          vector<int>#define all(a)      (a).begin(),(a).end()#define F           first#define S           second#define sz(x)       (int)x.size()#define hell        1000000007#define endl        '\n'#define rep(i,a,b)  for(int i=a;i<b;i++)using namespace std;string to_string(string s) {    return '"' + s + '"';}string to_string(const char* s) {    return to_string((string) s);}string to_string(bool b) {    return (b ? "true" : "false");}string to_string(char ch) {    return string("'")+ch+string("'");}template <typename A, typename B>string to_string(pair<A, B> p) {    return "(" + to_string(p.first) + ", " + to_string(p.second) + ")";}template <class InputIterator>string to_string (InputIterator first, InputIterator last) {    bool start = true;    string res = "{";    while (first!=last) {        if (!start) {            res += ", ";        }        start = false;        res += to_string(*first);        ++first;    }    res += "}";    return res;}template <typename A>string to_string(A v) {    bool first = true;    string res = "{";    for (const auto &x : v) {        if (!first) {            res += ", ";        }        first = false;        res += to_string(x);    }    res += "}";    return res;}void debug_out() { cerr << endl; }template <typename Head, typename... Tail>void debug_out(Head H, Tail... T) {    cerr << " " << to_string(H);    debug_out(T...);}template <typename A, typename B>istream& operator>>(istream& input,pair<A,B>& x){    input>>x.F>>x.S;    return input;}template <typename A>istream& operator>>(istream& input,vector<A>& x){    for(auto& i:x)        input>>i;    return input;}#ifdef PRINTERS#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)#else#define debug(...) 42#endiflong long readInt(long long l,long long r,char endd){    long long x=0;    int cnt=0;    int fi=-1;    bool is_neg=false;    while(true){        char g=getchar();        if(g=='-'){            assert(fi==-1);            is_neg=true;            continue;        }        if('0'<=g && g<='9'){            x*=10;            x+=g-'0';            if(cnt==0){                fi=g-'0';            }            cnt++;            assert(fi!=0 || cnt==1);            assert(fi!=0 || is_neg==false);            assert(!(cnt>19 || ( cnt==19 && fi>1) ));        } else if(g==endd){            assert(cnt>0);            if(is_neg){                x= -x;            }            assert(l<=x && x<=r);            return x;        } else {            debug(int(g));            assert(false);        }    }}string readString(int l,int r,char endd){    string ret="";    int cnt=0;    while(true){        char g=getchar();        if(g==endd){            break;        }        else if(islower(g)){            cnt++;            ret+=g;        }        else{            assert(false);        }    }    assert(l<=cnt && cnt<=r);    return ret;}long long readIntSp(long long l,long long r){    return readInt(l,r,' ');}long long readIntLn(long long l,long long r){    return readInt(l,r,'\n');}string readStringLn(int l,int r){    return readString(l,r,'\n');}string readStringSp(int l,int r){    return readString(l,r,' ');}vi colnode[500005];int intime[500005];int outtime[500005];int height[500005];vi adj[500005];int dp[20][500005];vector<pii> temptree[500005];vi reqdnodes[500005];int color[500005];void dfs(int u,int p=0){    static int clck = 1;    intime[u] = clck;    colnode[color[u]].emplace_back(u);    height[u] = height[p]+1;    dp[0][u]=p;    clck++;    for(auto i:adj[u]){        if(i!=p)dfs(i,u);    }    outtime[u]=clck;    clck++;}int lca(int u,int v){    if(height[u]>height[v])swap(u,v);    for(int i=19;i>=0;i--){        if(height[v]-(1<<i)>=height[u])v=dp[i][v];    }    if(u==v)return u;    for(int i=19;i>=0;i--){        if(dp[i][u]!=dp[i][v])u=dp[i][u],v=dp[i][v];    }    return dp[0][u];}struct diameter{    pii maxdep1,maxdep2;    pair<int,pii> best_res;    diameter(int u){        maxdep1.S=u;        maxdep2.S=u;        best_res.S={u,u};    }};diameter get_diameter(int u,int p){    diameter res(u);    for(auto i:temptree[u]){        if(i.F==p)continue;        auto new_res = get_diameter(i.F,u);        if(new_res.maxdep1.F+i.S>res.maxdep1.F){            res.maxdep2=res.maxdep1;            res.maxdep1=mp(new_res.maxdep1.F+i.S,new_res.maxdep1.S);        }        else if(new_res.maxdep1.F+i.S>res.maxdep2.F){            res.maxdep2=mp(new_res.maxdep1.F+i.S,new_res.maxdep1.S);        }        res.best_res=max(res.best_res,new_res.best_res);    }    res.best_res=max(res.best_res,mp(res.maxdep1.F+res.maxdep2.F,mp(res.maxdep1.S,res.maxdep2.S)));    return res;}void solve(){    auto comp = [](int a,int b){return intime[a]<intime[b];};    int N,M;    N = readIntSp(1,500000);    M = readIntLn(1,500000);    rep(i,1,N+1){        int col;        if(i==N) col = readIntLn(1,500000);        else col = readIntSp(1,500000);        color[i] = col;    }    rep(i,1,N){        int u,v;        u = readIntSp(1,N);        v = readIntLn(1,N);        adj[u].emplace_back(v);        adj[v].emplace_back(u);    }    dfs(1);    for(int i = 1; i < 20; i++){        for(int j = 1; j <= N; j++){            dp[i][j]=dp[i-1][dp[i-1][j]];        }    }    vector<bool>nodes(N+1);    rep(i,1,500005){        if(colnode[i].empty())continue;        if(sz(colnode[i])==1){            reqdnodes[i].emplace_back(colnode[i].front());            continue;        }        colnode[i].reserve(2*sz(colnode[i]));        for(auto j:colnode[i])nodes[j]=1;        int k = sz(colnode[i]);        rep(j,1,k){            int tmp = lca(colnode[i][j-1],colnode[i][j]);            if(!nodes[tmp]){                nodes[tmp]=1;                colnode[i].emplace_back(tmp);            }        }        for(auto j:colnode[i])nodes[j]=0;        sort(colnode[i].begin()+k,colnode[i].end(),comp);        inplace_merge(colnode[i].begin(),colnode[i].begin()+k,colnode[i].end(),comp);        stack<pii>stk;        stk.emplace(0,INT_MAX);        for(auto j:colnode[i]){            while(outtime[j]>stk.top().S)stk.pop();            temptree[stk.top().F].emplace_back(j,abs(height[j]-height[stk.top().F]));            temptree[j].emplace_back(stk.top().F,abs(height[j]-height[stk.top().F]));            stk.emplace(j,outtime[j]);        }        auto res = get_diameter(temptree[0][0].F,0);        reqdnodes[i].emplace_back(res.best_res.S.F);        reqdnodes[i].emplace_back(res.best_res.S.S);        for(auto j:colnode[i]){            temptree[j].clear();        }        temptree[0].clear();        vi().swap(colnode[i]);    }    rep(i,1,M+1){        int K,C;        K = readIntSp(1,N);        if(i==M) C = readInt(1,500000,EOF);        else C = readIntLn(1,500000);        int ans = -1;        for(auto j:reqdnodes[C]){            int LCA = lca(j,K);            ans = max(ans,height[j]+height[K]-2*height[LCA]);        }        cout << ans << endl;    }}int main(){    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);    int t=1;//  cin>>t;    while(t--){        solve();    }    return 0;}`