HackerRank Tree Splitting problem solution

In this HackerRank Tree Splitting problem solution Given a tree with vertices numbered from 1 to n. You need to process m queries.

Problem solution in Java.

import java.io.*;
import java.util.*;

public class Solution {

  static long x = 1;  
  // Xorshift random number generators
  static long marsagliaXor32() {
    x ^= x << 13;
    x ^= x >> 17;
    return x ^= x << 5;

  static class Node {
    int size = 1;
    long pri = marsagliaXor32();
    Node l = null;
    Node r = null;
    Node p = null;

    Node mconcat() {
      this.size = size(l) + 1 + size(r);
      if (l != null) {
        l.p = this;
      if (r != null) {
        r.p = this;
      return this;

  static int size(Node x) {
    return x != null ? x.size : 0;

  static Node root(Node x) {
    while (x.p != null) {
      x = x.p;
    return x;

  static long orderOf(Node x) {
    long r = size(x.l);
    while (x.p != null) {
      if (x.p.r == x) {
        r += size(x.p.l) + 1;
      x = x.p;
    return r;

  static Node join(Node x, Node y) {
    if (x == null) return y;
    if (y == null) return x;
    if (x.pri < y.pri) {
      x.r = join(x.r, y);
      return x.mconcat();
    } else {
      y.l = join(x, y.l);
      return y.mconcat();

  static long[] dep;
  static List<Integer>[] es;
  static Node[] pre;
  static Node[] post;
  static Node tr = null;

  static class NodeDfs {
    int u;
    int p;
    boolean start = true;

    public NodeDfs(int u, int p) {
      this.u = u;
      this.p = p;
  static void dfs(int u, int p) {
    Deque<NodeDfs> queue = new LinkedList<>();
    queue.add(new NodeDfs(u, p));
    while (!queue.isEmpty()) {
      NodeDfs node = queue.peek();
      if (node.start) {
        pre[node.u] = new Node();
        tr = join(tr, pre[node.u]);
        for (int v: es[node.u]) {
          if (v != node.p) {
            dep[v] = dep[node.u] + 1;
            queue.push(new NodeDfs(v, node.u));
        node.start = false;
      } else {
        post[node.u] = new Node();
        tr = join(tr, post[node.u]);

  static Node[] split(Node x, long k, Node l, Node r) {
    if (x == null) {
      l = r = null;
    } else {
      long c = size(x.l) + 1;
      if (k < c) {
        Node[] res = split(x.l, k, l, x.l);
        l = res[0];
        x.l = res[1];
        r = x;
      } else {
        Node[] res = split(x.r, k - c, x.r, r);
        x.r = res[0];
        r =  res[1];
        l = x;
      x.p = null;
    return new Node[] {l , r};
  static void cut(int u, int v) {
    if (dep[v] < dep[u]) {
      int t = v;
      v = u;
      u = t;
    long il = orderOf(pre[v]);
    long ir = orderOf(post[v])+1;
    Node y = root(pre[v]);
    Node z = null;
    Node[] res = split(y, ir, y, z);
    y = res[0];
    z = res[1];
    Node x = null;
    res = split(y, il, x, y);
    x = res[0];
    join(x, z);

  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());

    dep = new long[n];
    es = new List[n];
    pre = new Node[n];
    post = new Node[n];

    for (int i = 0; i < n; i++) {
      es[i] = new ArrayList<>();
    for (int i = 0; i < n - 1; i++) {
      st = new StringTokenizer(br.readLine());
      int u = Integer.parseInt(st.nextToken())-1;
      int v = Integer.parseInt(st.nextToken())-1;
    dfs(0, -1);

    st = new StringTokenizer(br.readLine());
    int queriesCount = Integer.parseInt(st.nextToken());

    int result = 0;
    for (int i = 0; i < queriesCount; i++) {
      st = new StringTokenizer(br.readLine());
      int u = Integer.parseInt(st.nextToken());
      u = (result ^ u) - 1;
      result = size(root(pre[u])) / 2;
      if (i != queriesCount - 1) {
        for (int v: es[u]) {
          cut(u, v);



Problem solution in C++.

#include <bits/stdc++.h>
using namespace std;

struct node {
	int size = 1;
	node *lch = nullptr;
	node *rch = nullptr;
	node *parent = nullptr;

unsigned xor32() {
	static unsigned z = time(NULL);
	z ^= z << 13; z ^= z >> 17; z ^= z << 5;
	return z;
int size(node *x) {
	return x == nullptr ? 0 : x->size;
node *push(node *x) {
	x->size = 1 + size(x->lch) + size(x->rch);
	x->parent = nullptr;
	if (x->lch != nullptr) x->lch->parent = x;
	if (x->rch != nullptr) x->rch->parent = x;
	return x;
node *merge(node *x, node *y) {
	if (x == nullptr) return y;
	if (y == nullptr) return x;
	if (xor32() % (size(x) + size(y)) < size(x)) {
		x = push(x);
		x->rch = merge(x->rch, y);
		return push(x);
	} else {
		y = push(y);
		y->lch = merge(x, y->lch);
		return push(y);
pair<node *, node *> split(node *x, int k) {
	if (x == nullptr) return{ nullptr, nullptr };
	x = push(x);
	if (size(x->lch) >= k) {
		auto p = split(x->lch, k);
		x->lch = p.second;
		return{ p.first, push(x) };
	} else {
		auto p = split(x->rch, k - size(x->lch) - 1);
		x->rch = p.first;
		return{ push(x), p.second };
node *root(node *x) {
	if (x->parent == nullptr) return x;
	return root(x->parent);
int index_of(node *x) {
	int result = -1;
	bool l = true;
	while (x != nullptr) {
		if (l) result += 1 + size(x->lch);
		if (x->parent == nullptr) break;
		l = x->parent->rch == x;
		x = x->parent;
	return result;

vector<int> g[200200];
int depth[200200];
node *L[200200];
node *R[200200];

node *tr = nullptr;

void dfs(int curr, int prev) {
	L[curr] = new node();
	tr = merge(tr, L[curr]);
	for (int next : g[curr]) if (next != prev) {
		depth[next] = depth[curr] + 1;
		dfs(next, curr);
	R[curr] = new node();
	tr = merge(tr, R[curr]);

void cut(int u, int v) {
	if (depth[u] < depth[v]) swap(u, v);

	int l = index_of(L[u]);
	int r = index_of(R[u]);

	node *rt = root(L[u]);
	auto x = split(rt, r + 1);
	auto y = split(x.first, l);
	merge(y.first, x.second);

int main() {
	int n;
	cin >> n;

	for (int i = 0; i < n - 1; i++) {
		int u, v;
		scanf("%d %d", &u, &v);
		u--; v--;

	int m;
	cin >> m;

	dfs(0, -1);

	int ans = 0;
	for (int i = 0; i < m; i++) {
		int x;
		scanf("%d", &x);
		int v = (ans ^ x) - 1;
		ans = size(root(L[v])) / 2;
		for (int u : g[v]) cut(u, v);
		printf("%d\n", ans);


Problem solution in C.

#include <stdlib.h>
#include <stdio.h>

struct Set {
    int count;

typedef struct Set Set;

struct node{
    int number;
    struct node * parent;
    struct node * next;
    struct node * prev;
    struct node * first_child;
    Set * set;

typedef struct node node;

void print_children(node * n){
    node * child = n->first_child;
        printf("%d\n", child->number);
        child = child->next;

void add_child(node * n, node * c){
    node * cur = n->first_child;
    n->first_child = c;
    if (cur){
        cur->prev = c;
        c->next = cur;

void fill_children(node * root, node ** nodes, node ** result_nodes){
    node * repr = nodes[root->number];
    if(repr == 0){
    node * child = repr->first_child;
        if (result_nodes[child->number] != 0){
            child = child->next;
        node * c = calloc(1, sizeof(node));
        c->number = child->number;
        c->parent = root;
        result_nodes[c->number] = c;
        add_child(root, c);
        fill_children(c, nodes, result_nodes);
        child = child->next;

void compute_below(node * root, Set * set) {
    if (set == 0) {
        set = calloc(1, sizeof(set));
    root->set = set;
    node * child = root -> first_child;
        compute_below(child, set);
        child = child->next;

void remove_node(node * item) {
//    subtract_below(item, item->below+1);
    int everyChild = item->parent != 0;
    node * child = item->first_child;
    int childCount = 0;
    int toRemove = 1;
    while (child) {
        if (everyChild || childCount > 1) {
            compute_below(child, 0);
            toRemove += child->set->count;
        child->parent = 0;
        child = child->next;
    item->set->count -= toRemove;
    node * parent = item->parent;
        if(parent->first_child == item){
            parent->first_child = item->next;
            item->next->prev = item->prev;
            item->prev->next = item->next;

int main(int argc, char **argv){
    int n;
    scanf("%d\n", &n);
    int i = 0;
    node ** nodes = calloc(n+1, sizeof(node *));
    for(i = 0; i < n-1; i++){
        int a,b;
        scanf("%d %d\n", &a, &b);
        node * node_a = nodes[a];
        if(node_a == 0) {
            node_a = calloc(1, sizeof(node));
            node_a->number = a;
            nodes[a] = node_a;
        node * x = calloc(1, sizeof(node));
        x->number = b;
        node * node_b = nodes[b];
        if(node_b == 0){
            node_b = calloc(1, sizeof(node));
            node_b->number = b;
            nodes[b] = node_b;
        x = calloc(1, sizeof(node));
        x->number = a;
        add_child(node_b, x);
    node * root = calloc(1, sizeof(node));
    root->number = 1;
    node ** result_nodes = calloc(n+1, sizeof(node *));
    result_nodes[1] = root;
    fill_children(root, nodes, result_nodes);
    compute_below(result_nodes[1], 0);
    int ans = 0;
    int num_queries;
    scanf("%d\n", &num_queries);
    for(i = 0; i < num_queries; i++){
        int m;
        scanf("%d\n", &m);
        int q = m^ans;
        node * n = result_nodes[q];
        ans = n->set->count;
        printf("%d\n", ans);
    return 0;


