In this HackerEarth Tree and XOR problem solution For a tree X, rooted at node 1, having values at nodes, and a node i, lets define S(X,i) as the bitwise xor of all values in the subtree of node .

You are given a tree T. Let Ti be the tree remaining after removing all nodes in subtree of i. Define P(i) = max(j belongs to Ti) (S(T,i) xor S(Ti,j)). You have to find sum of P(i) over all nodes i!=1.


HackerEarth Tree and XOR problem solution


HackerEarth Tree and XOR problem solution.

#include<bits/stdc++.h>
using namespace std;

const int MAXN = 2e5 + 55;
int ar[MAXN], S[MAXN], tin[MAXN], tout[MAXN];
vector<int> START[MAXN], END[MAXN], adj[MAXN];
int POWER[55];
int timer = 0;
void DFS(int s, int par = -1) {
tin[s] = ++timer;
START[timer].push_back(s);
S[s] = ar[s];
for(auto it : adj[s]) {
if(it == par) continue;
DFS(it, s);
S[s] ^= S[it];
}
tout[s] = timer;
END[timer].push_back(s);
}
typedef struct node* pnode;
struct node {
pnode L, R;
node() {
L = R = nullptr;
}
};
pnode perL[MAXN], perR[MAXN];
int get(pnode root, int val, int lev) {
if(lev == -1 or root == nullptr) return 0;
int ret = 0;
if(POWER[lev] & val) {
if(root -> L != nullptr) return POWER[lev] + get(root -> L, val, lev - 1);
else {
return get(root -> R, val, lev - 1);
}

}
else {
if(root -> R != nullptr) return POWER[lev] + get(root -> R, val, lev - 1);
else {
return get(root -> L, val, lev - 1);
}
}
}
int getLeft(int idx, int val) {
if(idx < 1) return 0;
return get(perL[idx], val, 29);
}
int n;
int getRight(int idx, int val) {
if(idx > n) return 0;
return get(perR[idx], val, 29);
}
long long answer = 0;
void dfs(int s, int XOR = 0, int par = -1) {
for(auto it : adj[s]) {
if(it == par) continue;
dfs(it, max(XOR, S[s]), s);
}
if(s != 1) {
int ans = XOR;
ans = max(ans, getLeft(tin[s] - 1, S[s]));
ans = max(ans, getRight(tout[s] + 1, S[s]));
answer += 1LL * ans;
}
}
pnode insert(pnode root, pnode other, int val, int lev) {
if(lev == -1) return root;
if(POWER[lev] & val) {
if(other == nullptr) {
root -> R = new node();
root -> R = insert(root -> R, other, val, lev - 1);
}
else {
root -> L = other -> L;
other = other -> R;
root -> R = new node();
root -> R = insert(root -> R, other, val, lev - 1);
}
}
else {
if(other == nullptr) {
root -> L = new node();
root -> L = insert(root -> L, other, val, lev - 1);
}
else {
root -> R = other -> R;
other = other -> L;
root -> L = new node();
root -> L = insert(root -> L, other, val, lev - 1);
}
}
return root;
}
void process() {
perL[0] = new node();
for(int i = 1; i <= n; i++) {
perL[i] = new node();
perL[i] -> L = perL[i - 1] -> L;
perL[i] -> R = perL[i - 1] -> R;
for(int it = 0; it < (int)END[i].size(); it++) {
int node = END[i][it];
if(it == 0) perL[i] = insert(perL[i], perL[i - 1], S[node], 29);
else perL[i] = insert(perL[i], perL[i], S[node], 29);
}
}
perR[n + 1] = new node();
for(int i = n; i >= 1; i--) {
perR[i] = new node();
perR[i] -> L = perR[i + 1] -> L;
perR[i] -> R = perR[i + 1] -> R;
assert((int)START[i].size() <= 1);
for(int it = 0; it < (int)START[i].size(); it++) {
int node = START[i][it];
perR[i] = insert(perR[i], perR[i + 1], S[node], 29);
}
}
}
int main() {
POWER[0] = 1;
for(int i = 1; i < 35; i++) POWER[i] = POWER[i - 1] << 1;
cin >> n;
assert(n >= 1 and n <= 200000);
for(int i = 1; i <= n; i++) {
scanf("%d", &ar[i]);
assert(ar[i] >= 1 and ar[i] <= 1000000000);
}
for(int i = 1; i < n; i++) {
int x, y;
scanf("%d%d", &x, &y);
assert(x != y and x >= 1 and x <= n and y >= 1 and y <= n);
adj[x].push_back(y);
adj[y].push_back(x);
}
DFS(1);
process();
dfs(1);
cout << answer << endl;
return 0;
}

Second solution

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;

#define T pair<int, int>
#define ordered_set tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>

const int N = 200005;

int val[N], tree_size[N], pos[N];
vector<int> con[N];
ordered_set path, global, subtree_size[N];
int cnt = 0, CNT = 0;

int dfs(int s = 1, int p = 0){
tree_size[s] = 1;
for(int v : con[s]) if(v != p){
val[s] ^= dfs(v, s);
tree_size[s] += tree_size[v];
}
global.insert({val[s], ++cnt});
return val[s];
}


int getCount(ordered_set & os, int l, int r){
return os.order_of_key({r + 1, 0}) - os.order_of_key({l, 0});
}

long long ans = 0;

int getMax(int ind, int c_path, int c_subtree, int c_global){
int x = val[ind];
int position = pos[ind];
int ret = 0;
int lo = 0, hi = (1 << 30) - 1;
for(int bit = 29; bit >= 0; bit--){
int mid = (lo + hi) >> 1;
if(x >> bit & 1){
if(c_path * getCount(path, lo, mid) + c_subtree * getCount(subtree_size[position], lo, mid) +
c_global * getCount(global, lo, mid) > 0){
ret += (1 << bit);
hi = mid;
} else{
lo = mid + 1;
}
} else{
if(c_path * getCount(path, mid + 1, hi) + c_subtree *
getCount(subtree_size[position], mid + 1, hi) + c_global * getCount(global, mid + 1, hi) > 0){
lo = mid + 1;
ret += (1 << bit);
} else{
hi = mid;
}
}
}
return ret;
}


void dfs2(int s = 1, int p = 0){
int bigc = 0;
int curr_mx = path.rbegin()->first;
pair<int, int> to_insert = {val[s], ++cnt};
path.insert(to_insert);
for(int v : con[s]) if(v != p){
dfs2(v, s);
if(tree_size[v] > tree_size[bigc]) bigc = v;
}

if(!bigc){
pos[s] = ++CNT;
} else pos[s] = pos[bigc];

for(int v : con[s]) if(v != p && v != bigc){
for(auto ele : subtree_size[pos[v]]) subtree_size[pos[s]].insert(ele);
subtree_size[pos[v]].clear();
}

ans += max(getMax(s, -1, -1, 1), curr_mx);

path.erase(to_insert);
subtree_size[pos[s]].insert({val[s], ++cnt});
}

int main(){
int n;
cin.tie(0); ios_base::sync_with_stdio(0);
cin >> n;
for(int i = 1; i <= n; i++) cin >> val[i];
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
con[u].push_back(v);
con[v].push_back(u);
}
dfs();
dfs2();
cout << ans;
}