HackerEarth Separating Numbers problem solution

In this HackerEarth Separating Numbers problem solution We are given a tree with N nodes. Each node has a color Ci. We are also given N-1 queries and in each query we are told to destroy a previously undestroyed edge. Every time we destroy an edge, we have to report the number of pairs of nodes that get disconnected. Here, two nodes i and j are said to be disconnected  if before the destruction you could reach from i to j using edges not yet destroyed , and if Ci = Cj.

`#include <iostream>#include <vector>#include <algorithm>#include <fstream>#include <queue>#include <deque>#include <iomanip>#include <cmath>#include <set>#include <stack>#include <map>#include <unordered_map>#define FOR(i,n) for(int i=0;i<n;i++)#define FORE(i,a,b) for(int i=a;i<=b;i++)#define ll long long //#define int long long#define ld long double#define vi deque<int>#define pb push_back#define ff first#define ss second#define ii pair<int,int>#define iii pair<int,ii>#define il pair<int,ll>#define pll pair<ll,ll>#define _path pair<ll,pair<ll,int> > #define vv deque//#define endl '\n'//#define mp make_pairusing namespace std;const int MAXN = 3e5+5;const ll INF = 1e18;vi g[MAXN];struct DSU{    int parent[MAXN];    int sz[MAXN];    map<int,ll> mp[MAXN];    DSU(int* arr,int n){        FOR(i,n)parent[i] = i;        FOR(i,n)sz[i] = 1;        FOR(i,n)mp[i][arr[i]] = 1;    }    int find(int a){        if(parent[a] != a)parent[a] = find(parent[a]);        return parent[a];    }    ll merge(int a,int b){        int pa = find(a);        int pb = find(b);        ll sum = 0;        if(sz[pa] < sz[pb])swap(pa,pb);        for(auto e: mp[pb]){            sum += e.ss*mp[pa][e.ff];            mp[pa][e.ff] += e.ss;        }        sz[pa] += sz[pb];        parent[pb] = pa;        mp[pb].clear();        return sum;    }};void solve(){    int n,q;    cin >> n;    q = n-1;    ii edge[n];    FOR(i,n-1){        int a,b;        cin >> a >> b;        a--;b--;        g[a].pb(b);        g[b].pb(a);        edge[i] = {a,b};    }    int colors[n];    FOR(i,n)cin >> colors[i];    DSU dsu(colors,n);    int queries[q];    FOR(i,q){        int x;        cin >> x;        x--;        queries[i] = x;    }    ll ans[q];    for(int i = q-1;i >= 0;i--){        ans[i] = dsu.merge(edge[queries[i]].ff,edge[queries[i]].ss);    }    FOR(i,q){        cout << ans[i] << endl;    }}int main(){    ios_base::sync_with_stdio(0);    cin.tie(0);    cout.tie(0);        int t = 1;    while(t--){        solve();    }    return 0;}`