# Hackerrank Kitty's Calculations on a Tree problem solution

In this tutorial, we are going to solve or make a solution to Kitty's Calculations on a Tree problem. so here we have given a pointer to the head or root node and the values to be inserted into the tree. and we need to insert the values into the appropriate position in the binary search tree and then return the root of the updated binary tree.

## Problem solution in Python programming.

```# Enter your code here. Read input from STDIN. Print output to STDOUT
from collections import Counter, defaultdict

MOD = 10**9 + 7

return (int(x) for x in input().split())

def mul(x, y):
return (x * y) % MOD

return sum(args) % MOD

def sub(x, y):
return (x - y) % MOD

# Construct adjacency list of the tree

for _ in range(n - 1):

# Construct element to set mapping {element: [sets it belongs to]}
elements = {v: set() for v in adj_list}

for set_no in range(q):

# Do BFS to find parent for each node and order them in reverse depth
current = [root]
current_depth = 0
order = []
parent = {root: None}
depth = {root: current_depth}

while current:
current_depth += 1
order.extend(current)
nxt = []
for node in current:
if neighbor not in parent:
parent[neighbor] = node
depth[neighbor] = current_depth
nxt.append(neighbor)

current = nxt

# Process nodes in the order created above
score = Counter()
# {node: {set_a: [depth, sum of nodes, flow]}}
state = {}
for node in reversed(order):
states = [state[neighbor] for neighbor in adj_list[node] if neighbor != parent[node]]
largest = {s: [depth[node], node, 0] for s in elements[node]}

if states:
max_index = max(range(len(states)), key=lambda x: len(states[x]))
if len(states[max_index]) > len(largest):
states[max_index], largest = largest, states[max_index]

sets = defaultdict(list)
for cur_state in states:
for set_no, v in cur_state.items():
sets[set_no].append(v)

for set_no, states in sets.items():
if len(states) == 1 and set_no not in largest:
largest[set_no] = states[0]
continue

if set_no in largest:
states.append(largest.pop(set_no))

total_flow = 0
total_node_sum = 0

for node_depth, node_sum, node_flow in states:
flow_delta = mul(node_depth - depth[node], node_sum)
total_node_sum += node_sum

set_score = 0

for node_depth, node_sum, node_flow in states:
node_flow = add(mul(node_depth - depth[node], node_sum), node_flow)
diff = mul(sub(total_flow, node_flow), node_sum)

largest[set_no] = (depth[node], total_node_sum, total_flow)

state[node] = largest

print(*(score[i] for i in range(q)), sep='\n')```

## Problem solution in Java Programming.

```import java.io.*;
import java.util.*;

public class Solution {

static final long MOD = 1_000_000_007;

static int mul(long x, long y, long z) {
return (int) ((((x * y) % MOD) * z) % MOD);
}

static int mul(long x, long y) {
return (int) ((x * y) % MOD);
}

static int sum(long x, long y) {
return (int) ((x + y) % MOD);
}

static int sum(long x, long y, long z) {
return (int) ((x + y + z) % MOD);
}

static int[] nxt;
static int[] succ;
static int[] ptr;
static int[] set;
static int[] dep;
static int[] parent;
static int index = 1;

static void addEdge(int u, int v) {
nxt[index] = ptr[u];
ptr[u] = index;
parent[v] = u;
succ[index++] = v;
}

static void bfsDeep(int source) {
while (!q.isEmpty()) {
int u = q.poll();
for (int i = ptr[u]; i > 0; i = nxt[i]) {
int v = succ[i];
dep[v] = dep[u] + 1;
}
}
}

static int lowestCommonAncestor(int u, int v) {
if (dep[u] < dep[v]) {
int temp = u;
u = v;
v = temp;
}
while (dep[u] > dep[v]) {
u = parent[u];
}

if (u == v) {
return u;
}
while (parent[u] != parent[v]) {
u = parent[u];
v = parent[v];
}

return parent[u];
}

static boolean[] visited;

static int lowestCommonAncestorVis(int u, int v) {
if (dep[u] < dep[v]) {
int temp = u;
u = v;
v = temp;
}
visited[u] = false;
visited[v] = false;
while (dep[u] > dep[v]) {
u = parent[u];
visited[u] = false;
}

if (u == v) {
return u;
}
while (parent[u] != parent[v]) {
u = parent[u];
v = parent[v];
visited[u] = false;
visited[v] = false;
}
visited[parent[u]] = false;

return parent[u];
}

static boolean[] isSet;

static class NodeDfs {
int u;
int count = 1;
long parzialInv = 0;
long sumNode = 0;
long tot = 0;
long parz2 = 0;
NodeDfs parent = null;
boolean start = true;

public NodeDfs() {
}

public void reset(int u, NodeDfs parent) {
this.u = u;
this.parent = parent;
parzialInv = tot = parz2 = sumNode = 0;
start = true;
count = 1;
}
}

static int stackIndex = 0;
static NodeDfs[] nodes;

static NodeDfs dfs(int u) {
NodeDfs root = nodes[0];
root.reset(u, null);
stackIndex = 1;

while (stackIndex > 0) {
NodeDfs node = nodes[stackIndex-1];
if (node.start) {
visited[node.u] = true;

if (isSet[node.u]) {
for (int i = ptr[node.u]; i > 0; i = nxt[i]) {
if (!visited[succ[i]]) {
nodes[stackIndex].reset(succ[i], node);
stackIndex++;
}
}
} else {
int uu = node.u;
while(true) {
int j = 0;
int v = 0;
for (int i = ptr[uu]; i > 0; i = nxt[i]) {
if (!visited[succ[i]]) {
nodes[stackIndex++].reset(v = succ[i], node);
j++;
}
}
if (isSet[v] || j != 1) {
break;
}
node.count++;
stackIndex--;
uu = v;
}
}

node.start = false;
} else {
if (node.count > 1) {
node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
node.parzialInv = sum(node.parzialInv, mul(node.count-1, node.sumNode));
} else {
node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
}
if (isSet[node.u]) {
node.sumNode += node.u+1;
}
if (node.u != u) {
NodeDfs nodeP = node.parent;
nodeP.sumNode = sum(nodeP.sumNode, node.sumNode);
nodeP.parzialInv = sum(nodeP.parzialInv, node.parzialInv, node.sumNode);
nodeP.parz2 = sum(nodeP.parz2, mul(node.parzialInv + node.sumNode, node.sumNode));
if (isSet[nodeP.u]) {
nodeP.tot = sum(nodeP.tot, node.tot, mul(node.sumNode + node.parzialInv, (nodeP.u + 1)));
} else {
nodeP.tot = sum(nodeP.tot, node.tot);
}
}

stackIndex--;
}
}

return root;
}

static final int MAX_SIMPLY = 3;

public static void main(String[] args) throws IOException {
BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

int n = Integer.parseInt(st.nextToken());
int q = Integer.parseInt(st.nextToken());

nxt = new int[2 * n];
succ = new int[2 * n];
ptr = new int[n];
dep = new int[n];
parent = new int[n];
nodes = new NodeDfs[n];
for (int i = 0; i < n; i++) {
nodes[i] = new NodeDfs();
}

for (int i = 0; i < n - 1; i++) {
int u = Integer.parseInt(st.nextToken()) - 1;
int v = Integer.parseInt(st.nextToken()) - 1;
if (u < v) {
} else {
}
}
bfsDeep(0);

visited = new boolean[n];
isSet = new boolean[n];

for (int h = 1; h <= q; h++) {
int k = Integer.parseInt(st.nextToken());
set = new int[k];
if (k >= MAX_SIMPLY) {
Arrays.fill(isSet, false);
}
for (int i = 0; i < k; i++) {
int u = Integer.parseInt(st.nextToken()) - 1;
isSet[u] = true;
set[i] = u;
}

long result = 0;
if (k < MAX_SIMPLY) {
for (int i = 0; i < k - 1; i++) {
int x = set[i];
for (int j = i + 1; j < k; j++) {
int y = set[j];
int z = lowestCommonAncestor(x, y);
int dist = dep[y] + dep[x] - 2 * dep[z];
result = sum(result, mul(x + 1, y + 1, dist));
}
}
} else {
Arrays.fill(visited, true);
Arrays.sort(set);
int x = set[set.length -1];
for (int i = k-2; i >= 0; i--) {
if (visited[set[i]]) {
x = lowestCommonAncestorVis(x, set[i]);
}
}
NodeDfs node = dfs(x);
result = node.tot;
}
bw.write(String.valueOf(result));
bw.newLine();
}

bw.close();
br.close();
}
}```

### Problem solution in C++ programming.

```#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <utility>
using namespace std;

typedef long long LL;
typedef pair<LL,LL> pii;
const int MAX_N = 2e5 + 6;
const int MAX_P = 19;
const LL mod = 1e9 + 7;

vector<int> edg[MAX_N];
int dis[MAX_P][MAX_N];
bool visit[MAX_N];

struct Cen {
int par;
int depth;
pii val_v_av;  //first --> val, second --> minus
pii val_v;
} cen[MAX_N];

vector<int> v;
int sz[MAX_N];
int mx[MAX_N];

void dfs2(int id) {
v.push_back(id);
visit[id]=1;
sz[id]=1;
mx[id]=0;
for (int i:edg[id]) {
if (!visit[i]) {
dfs2(i);
sz[id] += sz[i];
}
}
}

#define SZ(x) ((int)(x).size())

int get_cen(int id) {
v.clear();
dfs2(id);
int tot=SZ(v);
int cen=-1;
for (int i:v) {
if (max(mx[i],tot-sz[i]) <= tot/2) {
cen=i;
}
visit[i]=false;
}
return cen;
}

void dfs3(int id,int par,int cen_depth,int dist)  {
dis[cen_depth][id] = dist;
for (int i:edg[id]) {
if (!visit[i] && i!=par) {
dfs3(i,id,cen_depth,dist+1);
}
}
}

void dfs(int id,int cen_par,int cen_depth) {
int ccen=get_cen(id);
dfs3(ccen,ccen,cen_depth,0);
cen[ccen]={cen_par,cen_depth,{0,0},{0,0}};
visit[ccen]=1;
for (int i:edg[ccen]) {
if (!visit[i]) dfs(i,ccen,cen_depth+1);
}
}

pii operator+(const pii &p1,const pii &p2) {
return make_pair(p1.first+p2.first,p1.second+p2.second);
}

pii operator-(const pii &p1,const pii &p2) {
return make_pair(p1.first-p2.first,p1.second-p2.second);
}

pii operator+=(pii &p1,const pii &p2) {
p1 = p1 + p2;
return p1;
}

pii operator-=(pii &p1,const pii &p2) {
p1 = p1 - p2;
return p1;
}

void Pure(pii &p) {
p.first = (p.first%mod + mod) % mod;
p.second = (p.second%mod + mod) % mod;
}

LL p=x;
while (p!=-1) {
cen[p].val_v += {x,0};
cen[p].val_v_av += {x*dis[cen[p].depth][x],0};
if (cen[p].par != -1) {
int par=cen[p].par;
cen[p].val_v -= {0,x};
cen[p].val_v_av -= {0,x*dis[cen[par].depth][x]};
}
Pure(cen[p].val_v);
Pure(cen[p].val_v_av);
p=cen[p].par;
}
}

void dell(LL x) {
LL p=x;
while (p!=-1) {
cen[p].val_v -= {x,0};
cen[p].val_v_av -= {x*dis[cen[p].depth][x],0};
if (cen[p].par != -1) {
int par=cen[p].par;
cen[p].val_v += {0,x};
cen[p].val_v_av += {0,x*dis[cen[par].depth][x]};
}
Pure(cen[p].val_v);
Pure(cen[p].val_v_av);
p=cen[p].par;
}
}

LL query(LL x) {
LL ret=0;
LL v=0;
LL v_av=0;
int p=x;
while (p!=-1) {
v += cen[p].val_v.first;
v_av += cen[p].val_v_av.first;
ret += x*v_av;
ret %= mod;
ret += x*dis[cen[p].depth][x]*v;
ret %= mod;
v = cen[p].val_v.second;
v_av = cen[p].val_v_av.second;
p=cen[p].par;
}
return ret;
}

LL pow(LL a,LL n,LL mod) {
if (n==0) return 1;
else if (n==1) return a;
LL ret=pow(a,n/2,mod);
ret*=ret;
ret%=mod;
if (n&1) {
ret*=a;
ret%=mod;
}
return ret;
}

int main () {
int n,q;
scanf("%d %d",&n,&q);
for (int i=1;n-1>=i;i++) {
int a,b;
scanf("%d %d",&a,&b);
edg[a].push_back(b);
edg[b].push_back(a);
}
dfs(1,-1,0);
while (q--) {
int k;
scanf("%d",&k);
vector<int> v;
while (k--) {
int x;
scanf("%d",&x);
v.push_back(x);
}
LL ans=0;
for (int i:v) {
ans += query(i);
ans%=mod;
}
for (int i:v) dell(i);
printf("%lld\n",(ans*pow(2,mod-2,mod) + mod)%mod);
}
}```

### Problem solution in C programming.

```#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define LIMIT 1000000007

typedef struct tree_node_list {
struct tree_node *node;
struct tree_node_list *next;
} tree_node_list;

typedef struct tree_node {
struct tree_node *parent;
uint32_t num;
int32_t depth;
} tree_node;

typedef struct aux_info {
uint32_t simple_sum;
uint32_t level_sum;
uint32_t marker;
} aux_info;

static void print_node(tree_node *node) {
printf("(num: %ld, parent: %ld, depth: %ld) ", node->num,
node->parent != NULL ? node->parent->num : 0, node->depth);
}

void print_tree(tree_node *nodes, size_t count) {
for (int i = 0; i < count; i++) {
tree_node *node = &nodes[i];
print_node(node);
}
printf("\n");
}

static int order_tree(const void *lhs, const void *rhs) {
tree_node *a = *((tree_node **)lhs);
tree_node *b = *((tree_node **)rhs);
return a->depth - b->depth;
}

if (node->parent == NULL) {
node->depth = 0;
} else if (node->depth == -1) {
node->depth = node->parent->depth + 1;
}
}

int main() {
long num_nodes, num_queries;
scanf("%ld %ld", &num_nodes, &num_queries);
tree_node *nodes = calloc(num_nodes, sizeof(tree_node));
tree_node **order = calloc(num_nodes, sizeof(tree_node *));
aux_info *info = calloc(num_nodes, sizeof(aux_info));
for (long i = 0; i < num_nodes; ++i) {
tree_node *node = &nodes[i];
node->num = i + 1;
node->depth = -1;
order[i] = &nodes[i];
}
for (long i = 0; i < num_nodes - 1; i++) {
long a, b;
scanf("%ld %ld", &a, &b);
tree_node *node_a = &nodes[a - 1];
tree_node *node_b = &nodes[b - 1];
if (node_b->parent == NULL) {
node_b->parent = node_a;
} else if (node_a->parent == NULL) {
node_a->parent = node_b;
} else {
exit(1);
}
}
for (long i = 0; i < num_nodes; ++i) {
}
qsort(order, num_nodes, sizeof(tree_node *), order_tree);
for (long i = 0; i < num_queries; ++i) {
unsigned long k;
scanf("%ld", &k);
for (long j = 0; j < k; j++) {
long node_num;
scanf("%ld", &node_num);
info[node_num - 1].marker = 1;
}

uint64_t total = 0;
for (long j = num_nodes - 1; j >= 0; --j) {
tree_node node = *order[j];
uint64_t node_num = node.num;
uint64_t node_index = node_num - 1;
aux_info node_info = info[node_index];
if (node_info.marker == 0 && node.depth == 0) {
continue;
}
uint64_t node_simple_sum = node_info.simple_sum;
uint64_t node_level_sum = node_info.level_sum;
if (node_info.marker != 0) {
// Add all the combintations made with this node and its children
total = total + node_level_sum * node_num;
if (total > LIMIT) {
total = total % LIMIT;
}
node_simple_sum += node_num;
} else if (node_simple_sum == 0) {
continue;
}
// Increment the level
node_level_sum += node_simple_sum;
tree_node *parent = node.parent;
if (parent != NULL) {
uint64_t parent_index = parent->num - 1;
aux_info parent_info = info[parent_index];
uint64_t parent_simple_sum = parent_info.simple_sum;
uint64_t parent_level_sum = parent_info.level_sum;
// Add the combinations that this subtree makes with all sibling
// subtrees processed so far
total = (total + (parent_simple_sum * node_level_sum) +
(parent_level_sum * node_simple_sum));
if (total > LIMIT) {
total = total % LIMIT;
}
parent_simple_sum = parent_simple_sum + node_simple_sum;
if (parent_simple_sum > LIMIT) {
parent_simple_sum = parent_simple_sum % LIMIT;
}
parent_level_sum = parent_level_sum + node_level_sum;
if (parent_level_sum > LIMIT) {
parent_level_sum = parent_level_sum % LIMIT;
}
info[parent_index].simple_sum = parent_simple_sum;
info[parent_index].level_sum = parent_level_sum;
}
}

memset(info, 0, sizeof(aux_info) * num_nodes);
long ans = total;
printf("%ld\n", ans);
}

return 0;
}```