# HackerEarth Operations on a tree problem solution

In this HackerEarth Operations on a tree problem solution, You are given an undirected tree G with N nodes. You are also given an array A of N integer elements where A[i] represents the value assigned to node i and an integer K.

You can apply the given operation on the tree at most once:

Select a node x in the tree and consider it as the root of the tree.
Select a node y in the tree and update the value of each node in the subtree of y by taking its XOR with K. That is, for each node u in the subtree of node y, set A[u] = A[u]XORK.
Find the maximum sum of values of nodes that are available in the tree, after the above operation is used optimally.

## HackerEarth Operations on a tree problem solution.

`#include<bits/stdc++.h>using namespace std;#define int long longint ans;void dfs(vector<int> tree[], int v, int p, int dp[][2], vector<int> &A, int K){    dp[v][0] = A[v];    dp[v][1] = (A[v]^K);    for(auto j : tree[v])        if(j != p)            dfs(tree, j, v, dp, A, K), dp[v][0]+=dp[j][0], dp[v][1]+=dp[j][1];}void dfs(vector<int> tree[], int v, int p, int dp[][2]){    ans = max(ans, dp[1][0] - dp[v][0] + dp[v][1]);    for(auto j : tree[v])        if(j != p)        {            ans = max(ans, dp[1][1] - dp[j][1] + dp[j][0]);            dfs(tree, j, v, dp);        }}long long solve (int N, int K, vector<int> A, vector<vector<int> > edges) {    // Write your code here    vector<int> tree[N+1];    int i;    assert(1 <= N && N <= 1e5);    for(i=0;i<N-1;i++)    {        int u = edges[i][0], v = edges[i][1];        assert(1 <= u && u <= N);        assert(1 <= v && v <= N);        assert(u != v);        tree[u].push_back(v);        tree[v].push_back(u);    }    reverse(A.begin(), A.end());    A.push_back(0);    reverse(A.begin(), A.end());    int dp[N+1][2];    ans = 0;    dfs(tree, 1, 0, dp, A, K);    ans = max(dp[1][0], dp[1][1]);    dfs(tree, 1, 0, dp);    return ans;}signed main() {    ios::sync_with_stdio(0);    cin.tie(0);    int T;    cin >> T;    assert(1 <= T && T <= 10);    for(int t_i = 0; t_i < T; t_i++)    {        int N;        cin >> N;        int K;        cin >> K;        assert(0 <= K && K <= 1e9);        vector<int> A(N);        for(int i_A = 0; i_A < N; i_A++)        {            cin >> A[i_A];            assert(0 <= A[i_A] && A[i_A] <= 1e9);        }        vector<vector<int> > edges(N-1, vector<int>(2));        for (int i_edges = 0; i_edges < N-1; i_edges++)        {            for(int j_edges = 0; j_edges < 2; j_edges++)            {                cin >> edges[i_edges][j_edges];            }        }        long long out_;        out_ = solve(N, K, A, edges);        cout << out_;        cout << "\n";    }}`

### Second solution

`#include <bits/stdc++.h>using namespace std;typedef long long ll;const int N = 1e5 + 14, L = 30;int n, k, dp[N][L], sz[N];ll ans;vector<int> g[N];void dfs(int v = 0, int p = -1) {    sz[v] = 1;    for (auto u : g[v])        if (u != p) {            dfs(u, v);            sz[v] += sz[u];            ll cur = 0;            for (int i = 0; i < L; ++i)                if (k >> i & 1) {                    cur += ll(sz[u] - dp[u][i] * 2) * (1 << i);                    dp[v][i] += dp[u][i];                }            ans = max(ans, cur);        }}int main() {    ios::sync_with_stdio(0), cin.tie(0);    int t;    cin >> t;    while (t--) {        ans = 0;        cin >> n >> k;        fill(g, g + n, vector<int>());        ll s = 0;        for (int i = 0; i < n; ++i) {            int x;            cin >> x;            s += x;            for (int j = 0; j < L; ++j)                dp[i][j] = x >> j & 1;        }        for (int i = 0; i < n - 1; ++i) {            int v, u;            cin >> v >> u;            --v;            --u;            g[v].push_back(u);            g[u].push_back(v);        }        dfs();        for (int u = 0; u < n; ++u) {            ll cur = 0;            for (int i = 0; i < L; ++i)                if (k >> i & 1)                    cur += ll(sz[0] - sz[u] - (dp[0][i] - dp[u][i]) * 2) * (1 << i);            ans = max(ans, cur);        }        cout << ans + s << '\n';    }}`