# HackerEarth Tree and XOR problem solution

In this HackerEarth Tree and XOR problem solution For a tree X, rooted at node 1, having values at nodes, and a node i, lets define S(X,i) as the bitwise xor of all values in the subtree of node .

You are given a tree T. Let Ti be the tree remaining after removing all nodes in subtree of i. Define P(i) = max(j belongs to Ti) (S(T,i) xor S(Ti,j)). You have to find sum of P(i) over all nodes i!=1.

`#include<bits/stdc++.h>using namespace std;const int MAXN = 2e5 + 55;int ar[MAXN], S[MAXN], tin[MAXN], tout[MAXN];vector<int> START[MAXN], END[MAXN], adj[MAXN];int POWER[55];int timer = 0;void DFS(int s, int par = -1) {  tin[s] = ++timer;  START[timer].push_back(s);  S[s] = ar[s];  for(auto it : adj[s]) {    if(it == par) continue;    DFS(it, s);    S[s] ^= S[it];  }  tout[s] = timer;  END[timer].push_back(s);}typedef struct node* pnode;struct node {  pnode L, R;  node() {    L = R = nullptr;  }};pnode perL[MAXN], perR[MAXN];int get(pnode root, int val, int lev) {  if(lev == -1 or root == nullptr) return 0;  int ret = 0;  if(POWER[lev] & val) {    if(root -> L != nullptr) return POWER[lev] + get(root -> L, val, lev - 1);    else {      return get(root -> R, val, lev - 1);    }  }  else {    if(root -> R != nullptr) return POWER[lev] + get(root -> R, val, lev - 1);    else {      return get(root -> L, val, lev - 1);    }  }}int getLeft(int idx, int val) {  if(idx < 1) return 0;  return get(perL[idx], val, 29);}int n;int getRight(int idx, int val) {  if(idx > n) return 0;  return get(perR[idx], val, 29);}long long answer = 0;void dfs(int s, int XOR = 0, int par = -1) {  for(auto it : adj[s]) {    if(it == par) continue;    dfs(it, max(XOR, S[s]), s);  }  if(s != 1) {    int ans = XOR;    ans = max(ans, getLeft(tin[s] - 1, S[s]));    ans = max(ans, getRight(tout[s] + 1, S[s]));    answer += 1LL * ans;  }}pnode insert(pnode root, pnode other, int val, int lev) {  if(lev == -1) return root;  if(POWER[lev] & val) {    if(other == nullptr) {      root -> R = new node();      root -> R = insert(root -> R, other, val, lev - 1);    }    else {      root -> L = other -> L;      other = other -> R;      root -> R = new node();      root -> R = insert(root -> R, other, val, lev - 1);    }  }  else {    if(other == nullptr) {      root -> L = new node();      root -> L = insert(root -> L, other, val, lev - 1);    }    else {      root -> R = other -> R;      other = other -> L;      root -> L = new node();      root -> L = insert(root -> L, other, val, lev - 1);    }  }  return root;}void process() {  perL[0] = new node();  for(int i = 1; i <= n; i++) {    perL[i] = new node();    perL[i] -> L = perL[i - 1] -> L;    perL[i] -> R = perL[i - 1] -> R;    for(int it = 0; it < (int)END[i].size(); it++) {      int node = END[i][it];      if(it == 0) perL[i] = insert(perL[i], perL[i - 1], S[node], 29);      else perL[i] = insert(perL[i], perL[i], S[node], 29);    }  }  perR[n + 1] = new node();  for(int i = n; i >= 1; i--) {    perR[i] = new node();    perR[i] -> L = perR[i + 1] -> L;    perR[i] -> R = perR[i + 1] -> R;    assert((int)START[i].size() <= 1);    for(int it = 0; it < (int)START[i].size(); it++) {      int node = START[i][it];      perR[i] = insert(perR[i], perR[i + 1], S[node], 29);    }  }}int main() {  POWER[0] = 1;  for(int i = 1; i < 35; i++) POWER[i] = POWER[i - 1] << 1;  cin >> n;  assert(n >= 1 and n <= 200000);  for(int i = 1; i <= n; i++) {    scanf("%d", &ar[i]);    assert(ar[i] >= 1 and ar[i] <= 1000000000);  }  for(int i = 1; i < n; i++) {    int x, y;    scanf("%d%d", &x, &y);    assert(x != y and x >= 1 and x <= n and y >= 1 and y <= n);    adj[x].push_back(y);    adj[y].push_back(x);  }  DFS(1);  process();  dfs(1);  cout << answer << endl;  return 0;}`

### Second solution

`#include <bits/stdc++.h>#include <ext/pb_ds/assoc_container.hpp>#include <ext/pb_ds/tree_policy.hpp>using namespace std;using namespace __gnu_pbds;#define T pair<int, int>#define ordered_set tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>const int N = 200005;int val[N], tree_size[N], pos[N];vector<int> con[N];ordered_set path, global, subtree_size[N];int cnt = 0, CNT = 0;int dfs(int s = 1, int p = 0){    tree_size[s] = 1;    for(int v : con[s]) if(v != p){        val[s] ^= dfs(v, s);        tree_size[s] += tree_size[v];    }    global.insert({val[s], ++cnt});    return val[s];}int getCount(ordered_set & os, int l, int r){    return os.order_of_key({r + 1, 0}) - os.order_of_key({l, 0});}long long ans = 0;int getMax(int ind, int c_path, int c_subtree, int c_global){    int x = val[ind];    int position = pos[ind];    int ret = 0;    int lo = 0, hi = (1 << 30) - 1;    for(int bit = 29; bit >= 0; bit--){        int mid = (lo + hi) >> 1;        if(x >> bit & 1){            if(c_path * getCount(path, lo, mid) + c_subtree * getCount(subtree_size[position], lo, mid) +                 c_global * getCount(global, lo, mid) > 0){                ret += (1 << bit);                hi = mid;            } else{                lo = mid + 1;            }        } else{            if(c_path * getCount(path, mid + 1, hi) + c_subtree *                 getCount(subtree_size[position], mid + 1, hi) + c_global * getCount(global, mid + 1, hi) > 0){                lo = mid + 1;                ret += (1 << bit);            } else{                hi = mid;            }        }    }    return ret;}void dfs2(int s = 1, int p = 0){    int bigc = 0;    int curr_mx = path.rbegin()->first;    pair<int, int> to_insert = {val[s], ++cnt};    path.insert(to_insert);    for(int v : con[s]) if(v != p){        dfs2(v, s);        if(tree_size[v] > tree_size[bigc]) bigc = v;    }    if(!bigc){        pos[s] = ++CNT;    } else pos[s] = pos[bigc];    for(int v : con[s]) if(v != p && v != bigc){        for(auto ele : subtree_size[pos[v]]) subtree_size[pos[s]].insert(ele);        subtree_size[pos[v]].clear();    }        ans += max(getMax(s, -1, -1, 1), curr_mx);    path.erase(to_insert);    subtree_size[pos[s]].insert({val[s], ++cnt});}int main(){    int n;    cin.tie(0); ios_base::sync_with_stdio(0);    cin >> n;    for(int i = 1; i <= n; i++) cin >> val[i];    for(int i = 1; i < n; i++) {        int u, v;        cin >> u >> v;        con[u].push_back(v);        con[v].push_back(u);    }       dfs();    dfs2();    cout << ans;}`