In this tutorial, we are going to solve or make a solution to Kitty's Calculations on a Tree problem. so here we have given a pointer to the head or root node and the values to be inserted into the tree. and we need to insert the values into the appropriate position in the binary search tree and then return the root of the updated binary tree.

Hackerrank Kitty's Calculations on a Tree problem solution


Problem solution in Python programming.

# Enter your code here. Read input from STDIN. Print output to STDOUT
from collections import Counter, defaultdict

MOD = 10**9 + 7

def read_row():
    return (int(x) for x in input().split())

def mul(x, y):
    return (x * y) % MOD

def add(*args):
    return sum(args) % MOD

def sub(x, y):
    return (x - y) % MOD

n, q = read_row()

# Construct adjacency list of the tree
adj_list = defaultdict(list)

for _ in range(n - 1):
    u, v = read_row()
    adj_list[u].append(v)
    adj_list[v].append(u)

# Construct element to set mapping {element: [sets it belongs to]}
elements = {v: set() for v in adj_list}

for set_no in range(q):
    read_row()
    for x in read_row():
        elements[x].add(set_no)

# Do BFS to find parent for each node and order them in reverse depth
root = next(iter(adj_list))
current = [root]
current_depth = 0
order = []
parent = {root: None}
depth = {root: current_depth}

while current:
    current_depth += 1
    order.extend(current)
    nxt = []
    for node in current:
        for neighbor in adj_list[node]:
            if neighbor not in parent:
                parent[neighbor] = node
                depth[neighbor] = current_depth
                nxt.append(neighbor)

    current = nxt

# Process nodes in the order created above
score = Counter()
# {node: {set_a: [depth, sum of nodes, flow]}}
state = {}
for node in reversed(order):
    states = [state[neighbor] for neighbor in adj_list[node] if neighbor != parent[node]]
    largest = {s: [depth[node], node, 0] for s in elements[node]}

    if states:
        max_index = max(range(len(states)), key=lambda x: len(states[x]))
        if len(states[max_index]) > len(largest):
            states[max_index], largest = largest, states[max_index]


    sets = defaultdict(list)
    for cur_state in states:
        for set_no, v in cur_state.items():
            sets[set_no].append(v)

    for set_no, states in sets.items():
        if len(states) == 1 and set_no not in largest:
            largest[set_no] = states[0]
            continue

        if set_no in largest:
            states.append(largest.pop(set_no))

        total_flow = 0
        total_node_sum = 0

        for node_depth, node_sum, node_flow in states:
            flow_delta = mul(node_depth - depth[node], node_sum)
            total_flow = add(total_flow, flow_delta, node_flow)
            total_node_sum += node_sum

        set_score = 0

        for node_depth, node_sum, node_flow in states:
            node_flow = add(mul(node_depth - depth[node], node_sum), node_flow)
            diff = mul(sub(total_flow, node_flow), node_sum)
            set_score = add(set_score, diff)

        score[set_no] = add(score[set_no], set_score)
        largest[set_no] = (depth[node], total_node_sum, total_flow)

    state[node] = largest

print(*(score[i] for i in range(q)), sep='\n')


Problem solution in Java Programming.

import java.io.*;
import java.util.*;

public class Solution {
    
  static final long MOD = 1_000_000_007;

  static int mul(long x, long y, long z) {
    return (int) ((((x * y) % MOD) * z) % MOD);
  }

  static int mul(long x, long y) {
    return (int) ((x * y) % MOD);
  }

  static int sum(long x, long y) {
    return (int) ((x + y) % MOD);
  }

  static int sum(long x, long y, long z) {
    return (int) ((x + y + z) % MOD);
  }

  static int[] nxt;
  static int[] succ;
  static int[] ptr;
  static int[] set;
  static int[] dep;
  static int[] parent;
  static int index = 1;

  static void addEdge(int u, int v) {
    nxt[index] = ptr[u];
    ptr[u] = index;
    parent[v] = u;
    succ[index++] = v;
  }

  static void bfsDeep(int source) {
    Queue<Integer> q = new LinkedList<>();
    q.add(source);
    while (!q.isEmpty()) {
      int u = q.poll();
      for (int i = ptr[u]; i > 0; i = nxt[i]) {
        int v = succ[i];
        q.add(v);
        dep[v] = dep[u] + 1;
      }
    }
  }

  static int lowestCommonAncestor(int u, int v) {
    if (dep[u] < dep[v]) {
      int temp = u;
      u = v;
      v = temp;
    }
    while (dep[u] > dep[v]) {
      u = parent[u];
    }

    if (u == v) {
      return u;
    }
    while (parent[u] != parent[v]) {
      u = parent[u];
      v = parent[v];
    }

    return parent[u];
  }
  
  static boolean[] visited;

  static int lowestCommonAncestorVis(int u, int v) {
    if (dep[u] < dep[v]) {
      int temp = u;
      u = v;
      v = temp;
    }
    visited[u] = false;
    visited[v] = false;
    while (dep[u] > dep[v]) {
      u = parent[u];
      visited[u] = false;
    }

    if (u == v) {
      return u;
    }
    while (parent[u] != parent[v]) {
      u = parent[u];
      v = parent[v];
      visited[u] = false;
      visited[v] = false;
    }
    visited[parent[u]] = false;

    return parent[u];
  }

  static boolean[] isSet;

  static class NodeDfs {
    int u;
    int count = 1;
    long parzialInv = 0;
    long sumNode = 0;
    long tot = 0;
    long parz2 = 0;
    NodeDfs parent = null;
    boolean start = true;

    public NodeDfs() {
    }
    
    public void reset(int u, NodeDfs parent) {
      this.u = u;
      this.parent = parent;
      parzialInv = tot = parz2 = sumNode = 0;
      start = true;
      count = 1;
    }
  }

  static int stackIndex = 0;
  static NodeDfs[] nodes;

  static NodeDfs dfs(int u) {
      NodeDfs root = nodes[0];
      root.reset(u, null);
      stackIndex = 1;
      
      while (stackIndex > 0) {
          NodeDfs node = nodes[stackIndex-1];
          if (node.start) {
              visited[node.u] = true;
              
              if (isSet[node.u]) {
                for (int i = ptr[node.u]; i > 0; i = nxt[i]) {
                    if (!visited[succ[i]]) {
                        nodes[stackIndex].reset(succ[i], node);
                        stackIndex++;
                    }
                }
              } else {
                int uu = node.u;
                while(true) {
                    int j = 0;
                    int v = 0;
                    for (int i = ptr[uu]; i > 0; i = nxt[i]) {
                        if (!visited[succ[i]]) {
                            nodes[stackIndex++].reset(v = succ[i], node);
                            j++;
                        }
                    }
                    if (isSet[v] || j != 1) {
                        break;
                    }
                      node.count++;
                    stackIndex--;
                    uu = v;
                }
              }
              
              
              node.start = false;
          } else {
        if (node.count > 1) {
                  node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
                  node.parzialInv = sum(node.parzialInv, mul(node.count-1, node.sumNode));
        } else {
                  node.tot = sum(node.tot, mul(node.sumNode, node.parzialInv), MOD - node.parz2);
        }
                if (isSet[node.u]) {
              node.sumNode += node.u+1;
          }
        if (node.u != u) {
            NodeDfs nodeP = node.parent;
          nodeP.sumNode = sum(nodeP.sumNode, node.sumNode);
          nodeP.parzialInv = sum(nodeP.parzialInv, node.parzialInv, node.sumNode);
          nodeP.parz2 = sum(nodeP.parz2, mul(node.parzialInv + node.sumNode, node.sumNode));
          if (isSet[nodeP.u]) {
              nodeP.tot = sum(nodeP.tot, node.tot, mul(node.sumNode + node.parzialInv, (nodeP.u + 1)));
          } else {
            nodeP.tot = sum(nodeP.tot, node.tot);
          }
        }

        stackIndex--;
      }
      }
      
    return root;
  }

  static final int MAX_SIMPLY = 3;

  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());
    int q = Integer.parseInt(st.nextToken());

    nxt = new int[2 * n];
    succ = new int[2 * n];
    ptr = new int[n];
    dep = new int[n];
    parent = new int[n];
    nodes = new NodeDfs[n];
    for (int i = 0; i < n; i++) {
        nodes[i] = new NodeDfs();
    }
    
    for (int i = 0; i < n - 1; i++) {
      st = new StringTokenizer(br.readLine());
      int u = Integer.parseInt(st.nextToken()) - 1;
      int v = Integer.parseInt(st.nextToken()) - 1;
      if (u < v) {
        addEdge(u, v);
      } else {
        addEdge(v, u);
      }
    }
    bfsDeep(0);

    visited = new boolean[n];
    isSet = new boolean[n];

    for (int h = 1; h <= q; h++) {
      st = new StringTokenizer(br.readLine());
      int k = Integer.parseInt(st.nextToken());
      st = new StringTokenizer(br.readLine());
      set = new int[k];
      if (k >= MAX_SIMPLY) {
        Arrays.fill(isSet, false);
      }
      for (int i = 0; i < k; i++) {
        int u = Integer.parseInt(st.nextToken()) - 1;
        isSet[u] = true;
        set[i] = u;
      }

      long result = 0;
      if (k < MAX_SIMPLY) {
        for (int i = 0; i < k - 1; i++) {
          int x = set[i];
          for (int j = i + 1; j < k; j++) {
            int y = set[j];
            int z = lowestCommonAncestor(x, y);
            int dist = dep[y] + dep[x] - 2 * dep[z];
            result = sum(result, mul(x + 1, y + 1, dist));
          }
        }
      } else {
        Arrays.fill(visited, true);
        Arrays.sort(set);
        int x = set[set.length -1];
        for (int i = k-2; i >= 0; i--) {
            if (visited[set[i]]) {
                x = lowestCommonAncestorVis(x, set[i]);
            }
        }
        NodeDfs node = dfs(x);
        result = node.tot;
      }
      bw.write(String.valueOf(result));
      bw.newLine();
    }

    bw.close();
    br.close();
  }
}


Problem solution in C++ programming.

#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <utility>
using namespace std;

typedef long long LL;
typedef pair<LL,LL> pii;
const int MAX_N = 2e5 + 6;
const int MAX_P = 19;
const LL mod = 1e9 + 7;

vector<int> edg[MAX_N];
int dis[MAX_P][MAX_N];
bool visit[MAX_N];

struct Cen {
    int par;
    int depth;
    pii val_v_av;  //first --> val, second --> minus
    pii val_v;
} cen[MAX_N];

vector<int> v;
int sz[MAX_N];
int mx[MAX_N];

void dfs2(int id) {
    v.push_back(id);
    visit[id]=1;
    sz[id]=1;
    mx[id]=0;
    for (int i:edg[id]) {
        if (!visit[i]) {
            dfs2(i);
            sz[id] += sz[i];
        }
    }
}

#define SZ(x) ((int)(x).size())

int get_cen(int id) {
    v.clear();
    dfs2(id);
    int tot=SZ(v);
    int cen=-1;
    for (int i:v) {
        if (max(mx[i],tot-sz[i]) <= tot/2) {
            cen=i;
        }
        visit[i]=false;
    }
    return cen;
}

void dfs3(int id,int par,int cen_depth,int dist)  {
    dis[cen_depth][id] = dist;
    for (int i:edg[id]) {
        if (!visit[i] && i!=par) {
            dfs3(i,id,cen_depth,dist+1);
        }
    }
}

void dfs(int id,int cen_par,int cen_depth) {
    int ccen=get_cen(id);
    dfs3(ccen,ccen,cen_depth,0);
    cen[ccen]={cen_par,cen_depth,{0,0},{0,0}};
    visit[ccen]=1;
    for (int i:edg[ccen]) {
        if (!visit[i]) dfs(i,ccen,cen_depth+1);
    }
}

pii operator+(const pii &p1,const pii &p2) {
    return make_pair(p1.first+p2.first,p1.second+p2.second);
}

pii operator-(const pii &p1,const pii &p2) {
    return make_pair(p1.first-p2.first,p1.second-p2.second);
}

pii operator+=(pii &p1,const pii &p2) {
    p1 = p1 + p2;
    return p1;
}

pii operator-=(pii &p1,const pii &p2) {
    p1 = p1 - p2;
    return p1;
}

void Pure(pii &p) {
    p.first = (p.first%mod + mod) % mod;
    p.second = (p.second%mod + mod) % mod;
}

void addd(LL x) {
    LL p=x;
    while (p!=-1) {
        cen[p].val_v += {x,0};
        cen[p].val_v_av += {x*dis[cen[p].depth][x],0};
        if (cen[p].par != -1) {
            int par=cen[p].par;
            cen[p].val_v -= {0,x};
            cen[p].val_v_av -= {0,x*dis[cen[par].depth][x]};
        }
        Pure(cen[p].val_v);
        Pure(cen[p].val_v_av);
        p=cen[p].par;
    }
}

void dell(LL x) {
    LL p=x;
    while (p!=-1) {
        cen[p].val_v -= {x,0};
        cen[p].val_v_av -= {x*dis[cen[p].depth][x],0};
        if (cen[p].par != -1) {
            int par=cen[p].par;
            cen[p].val_v += {0,x};
            cen[p].val_v_av += {0,x*dis[cen[par].depth][x]};
        }
        Pure(cen[p].val_v);
        Pure(cen[p].val_v_av);
        p=cen[p].par;
    }
}

LL query(LL x) {
    LL ret=0;
    LL v=0;
    LL v_av=0;
    int p=x;
    while (p!=-1) {
        v += cen[p].val_v.first;
        v_av += cen[p].val_v_av.first;
        ret += x*v_av;
        ret %= mod;
        ret += x*dis[cen[p].depth][x]*v;
        ret %= mod;
        v = cen[p].val_v.second;
        v_av = cen[p].val_v_av.second;
        p=cen[p].par;
    }
    return ret;
}

LL pow(LL a,LL n,LL mod) {
    if (n==0) return 1;
    else if (n==1) return a;
    LL ret=pow(a,n/2,mod);
    ret*=ret;
    ret%=mod;
    if (n&1) {
        ret*=a;
        ret%=mod;
    }
    return ret;
}

int main () {
    int n,q;
    scanf("%d %d",&n,&q);
    for (int i=1;n-1>=i;i++) {
        int a,b;
        scanf("%d %d",&a,&b);
        edg[a].push_back(b);
        edg[b].push_back(a);
    }
    dfs(1,-1,0);
    while (q--) {
        int k;
        scanf("%d",&k);
        vector<int> v;
        while (k--) {
            int x;
            scanf("%d",&x);
            v.push_back(x);
        }
        for (int i:v) addd(i);
        LL ans=0;
        for (int i:v) {
            ans += query(i);
            ans%=mod;
        }
        for (int i:v) dell(i);
        printf("%lld\n",(ans*pow(2,mod-2,mod) + mod)%mod);
    }
}


Problem solution in C programming.

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define LIMIT 1000000007

typedef struct tree_node_list {
  struct tree_node *node;
  struct tree_node_list *next;
} tree_node_list;

typedef struct tree_node {
  struct tree_node *parent;
  uint32_t num;
  int32_t depth;
} tree_node;

typedef struct aux_info {
  uint32_t simple_sum;
  uint32_t level_sum;
  uint32_t marker;
} aux_info;

static void print_node(tree_node *node) {
  printf("(num: %ld, parent: %ld, depth: %ld) ", node->num,
         node->parent != NULL ? node->parent->num : 0, node->depth);
}

void print_tree(tree_node *nodes, size_t count) {
  for (int i = 0; i < count; i++) {
    tree_node *node = &nodes[i];
    print_node(node);
  }
  printf("\n");
}

static int order_tree(const void *lhs, const void *rhs) {
  tree_node *a = *((tree_node **)lhs);
  tree_node *b = *((tree_node **)rhs);
  return a->depth - b->depth;
}

static void add_depth(tree_node *node) {
  if (node->parent == NULL) {
    node->depth = 0;
  } else if (node->depth == -1) {
    add_depth(node->parent);
    node->depth = node->parent->depth + 1;
  }
}

int main() {
  long num_nodes, num_queries;
  scanf("%ld %ld", &num_nodes, &num_queries);
  tree_node *nodes = calloc(num_nodes, sizeof(tree_node));
  tree_node **order = calloc(num_nodes, sizeof(tree_node *));
  aux_info *info = calloc(num_nodes, sizeof(aux_info));
  for (long i = 0; i < num_nodes; ++i) {
    tree_node *node = &nodes[i];
    node->num = i + 1;
    node->depth = -1;
    order[i] = &nodes[i];
  }
  for (long i = 0; i < num_nodes - 1; i++) {
    long a, b;
    scanf("%ld %ld", &a, &b);
    tree_node *node_a = &nodes[a - 1];
    tree_node *node_b = &nodes[b - 1];
    if (node_b->parent == NULL) {
      node_b->parent = node_a;
    } else if (node_a->parent == NULL) {
      node_a->parent = node_b;
    } else {
      exit(1);
    }
  }
  for (long i = 0; i < num_nodes; ++i) {
    add_depth(&nodes[i]);
  }
  qsort(order, num_nodes, sizeof(tree_node *), order_tree);
  for (long i = 0; i < num_queries; ++i) {
    unsigned long k;
    scanf("%ld", &k);
    for (long j = 0; j < k; j++) {
      long node_num;
      scanf("%ld", &node_num);
      info[node_num - 1].marker = 1;
    }

    uint64_t total = 0;
    for (long j = num_nodes - 1; j >= 0; --j) {
      tree_node node = *order[j];
      uint64_t node_num = node.num;
      uint64_t node_index = node_num - 1;
      aux_info node_info = info[node_index];
      if (node_info.marker == 0 && node.depth == 0) {
        continue;
      }
      uint64_t node_simple_sum = node_info.simple_sum;
      uint64_t node_level_sum = node_info.level_sum;
      if (node_info.marker != 0) {
        // Add all the combintations made with this node and its children
        total = total + node_level_sum * node_num;
        if (total > LIMIT) {
          total = total % LIMIT;
        }
        node_simple_sum += node_num;
      } else if (node_simple_sum == 0) {
        continue;
      }
      // Increment the level
      node_level_sum += node_simple_sum;
      tree_node *parent = node.parent;
      if (parent != NULL) {
        uint64_t parent_index = parent->num - 1;
        aux_info parent_info = info[parent_index];
        uint64_t parent_simple_sum = parent_info.simple_sum;
        uint64_t parent_level_sum = parent_info.level_sum;
        // Add the combinations that this subtree makes with all sibling
        // subtrees processed so far
        total = (total + (parent_simple_sum * node_level_sum) +
                 (parent_level_sum * node_simple_sum));
        if (total > LIMIT) {
          total = total % LIMIT;
        }
        parent_simple_sum = parent_simple_sum + node_simple_sum;
        if (parent_simple_sum > LIMIT) {
          parent_simple_sum = parent_simple_sum % LIMIT;
        }
        parent_level_sum = parent_level_sum + node_level_sum;
        if (parent_level_sum > LIMIT) {
          parent_level_sum = parent_level_sum % LIMIT;
        }
        info[parent_index].simple_sum = parent_simple_sum;
        info[parent_index].level_sum = parent_level_sum;
      }
    }

    memset(info, 0, sizeof(aux_info) * num_nodes);
    long ans = total;
    printf("%ld\n", ans);
  }

  return 0;
}