HackerRank Tree Splitting problem solution

In this HackerRank Tree Splitting problem solution Given a tree with vertices numbered from 1 to n. You need to process m queries.

Problem solution in Java.

```import java.io.*;
import java.util.*;

public class Solution {

static long x = 1;
// Xorshift random number generators
static long marsagliaXor32() {
x ^= x << 13;
x ^= x >> 17;
return x ^= x << 5;
}

static class Node {
int size = 1;
long pri = marsagliaXor32();
Node l = null;
Node r = null;
Node p = null;

Node mconcat() {
this.size = size(l) + 1 + size(r);
if (l != null) {
l.p = this;
}
if (r != null) {
r.p = this;
}
return this;
}
}

static int size(Node x) {
return x != null ? x.size : 0;
}

static Node root(Node x) {
while (x.p != null) {
x = x.p;
}
return x;
}

static long orderOf(Node x) {
long r = size(x.l);
while (x.p != null) {
if (x.p.r == x) {
r += size(x.p.l) + 1;
}
x = x.p;
}
return r;
}

static Node join(Node x, Node y) {
if (x == null) return y;
if (y == null) return x;
if (x.pri < y.pri) {
x.r = join(x.r, y);
return x.mconcat();
} else {
y.l = join(x, y.l);
return y.mconcat();
}
}

static long[] dep;
static List<Integer>[] es;
static Node[] pre;
static Node[] post;
static Node tr = null;

static class NodeDfs {
int u;
int p;
boolean start = true;

public NodeDfs(int u, int p) {
this.u = u;
this.p = p;
}
}

static void dfs(int u, int p) {
Deque<NodeDfs> queue = new LinkedList<>();
while (!queue.isEmpty()) {
NodeDfs node = queue.peek();
if (node.start) {
pre[node.u] = new Node();
tr = join(tr, pre[node.u]);
for (int v: es[node.u]) {
if (v != node.p) {
dep[v] = dep[node.u] + 1;
queue.push(new NodeDfs(v, node.u));
}
}
node.start = false;
} else {
post[node.u] = new Node();
tr = join(tr, post[node.u]);
queue.remove();
}
}
}

static Node[] split(Node x, long k, Node l, Node r) {
if (x == null) {
l = r = null;
} else {
long c = size(x.l) + 1;
if (k < c) {
Node[] res = split(x.l, k, l, x.l);
l = res[0];
x.l = res[1];
r = x;
} else {
Node[] res = split(x.r, k - c, x.r, r);
x.r = res[0];
r =  res[1];
l = x;
}
x.mconcat();
x.p = null;
}
return new Node[] {l , r};
}

static void cut(int u, int v) {
if (dep[v] < dep[u]) {
int t = v;
v = u;
u = t;
}
long il = orderOf(pre[v]);
long ir = orderOf(post[v])+1;
Node y = root(pre[v]);
Node z = null;
Node[] res = split(y, ir, y, z);
y = res[0];
z = res[1];
Node x = null;
res = split(y, il, x, y);
x = res[0];
join(x, z);
}

public static void main(String[] args) throws IOException {
BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());

dep = new long[n];
es = new List[n];
pre = new Node[n];
post = new Node[n];

for (int i = 0; i < n; i++) {
es[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken())-1;
int v = Integer.parseInt(st.nextToken())-1;
}
dfs(0, -1);

st = new StringTokenizer(br.readLine());
int queriesCount = Integer.parseInt(st.nextToken());

int result = 0;
for (int i = 0; i < queriesCount; i++) {
st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
u = (result ^ u) - 1;
result = size(root(pre[u])) / 2;
bw.write(String.valueOf(result));
if (i != queriesCount - 1) {
bw.write("\n");
for (int v: es[u]) {
cut(u, v);
}
}
}

bw.newLine();
bw.close();
br.close();
}
}
```

Problem solution in C++.

```#include <bits/stdc++.h>
using namespace std;

struct node {
int size = 1;
node *lch = nullptr;
node *rch = nullptr;
node *parent = nullptr;
};

unsigned xor32() {
static unsigned z = time(NULL);
z ^= z << 13; z ^= z >> 17; z ^= z << 5;
return z;
}
int size(node *x) {
return x == nullptr ? 0 : x->size;
}
node *push(node *x) {
x->size = 1 + size(x->lch) + size(x->rch);
x->parent = nullptr;
if (x->lch != nullptr) x->lch->parent = x;
if (x->rch != nullptr) x->rch->parent = x;
return x;
}
node *merge(node *x, node *y) {
if (x == nullptr) return y;
if (y == nullptr) return x;
if (xor32() % (size(x) + size(y)) < size(x)) {
x = push(x);
x->rch = merge(x->rch, y);
return push(x);
} else {
y = push(y);
y->lch = merge(x, y->lch);
return push(y);
}
}
pair<node *, node *> split(node *x, int k) {
if (x == nullptr) return{ nullptr, nullptr };
x = push(x);
if (size(x->lch) >= k) {
auto p = split(x->lch, k);
x->lch = p.second;
return{ p.first, push(x) };
} else {
auto p = split(x->rch, k - size(x->lch) - 1);
x->rch = p.first;
return{ push(x), p.second };
}
}
node *root(node *x) {
if (x->parent == nullptr) return x;
return root(x->parent);
}
int index_of(node *x) {
int result = -1;
bool l = true;
while (x != nullptr) {
if (l) result += 1 + size(x->lch);
if (x->parent == nullptr) break;
l = x->parent->rch == x;
x = x->parent;
}
return result;
}

vector<int> g[200200];
int depth[200200];
node *L[200200];
node *R[200200];

node *tr = nullptr;

void dfs(int curr, int prev) {
L[curr] = new node();
tr = merge(tr, L[curr]);
for (int next : g[curr]) if (next != prev) {
depth[next] = depth[curr] + 1;
dfs(next, curr);
}
R[curr] = new node();
tr = merge(tr, R[curr]);
}

void cut(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);

int l = index_of(L[u]);
int r = index_of(R[u]);

node *rt = root(L[u]);
auto x = split(rt, r + 1);
auto y = split(x.first, l);
merge(y.first, x.second);
}

int main() {
int n;
cin >> n;

for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d %d", &u, &v);
u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}

int m;
cin >> m;

dfs(0, -1);

int ans = 0;
for (int i = 0; i < m; i++) {
int x;
scanf("%d", &x);
int v = (ans ^ x) - 1;
ans = size(root(L[v])) / 2;
for (int u : g[v]) cut(u, v);
printf("%d\n", ans);
}
}
```

Problem solution in C.

```#include <stdlib.h>
#include <stdio.h>

struct Set {
int count;
};

typedef struct Set Set;

struct node{
int number;
struct node * parent;
struct node * next;
struct node * prev;
struct node * first_child;
Set * set;
};

typedef struct node node;

void print_children(node * n){
node * child = n->first_child;
while(child){
printf("%d\n", child->number);
child = child->next;
}
}

void add_child(node * n, node * c){
node * cur = n->first_child;
n->first_child = c;
if (cur){
cur->prev = c;
c->next = cur;
}
}

void fill_children(node * root, node ** nodes, node ** result_nodes){
node * repr = nodes[root->number];
if(repr == 0){
return;
}
node * child = repr->first_child;
while(child){
if (result_nodes[child->number] != 0){
child = child->next;
continue;
}
node * c = calloc(1, sizeof(node));
c->number = child->number;
c->parent = root;
result_nodes[c->number] = c;
fill_children(c, nodes, result_nodes);
child = child->next;
}
}

void compute_below(node * root, Set * set) {
if (set == 0) {
set = calloc(1, sizeof(set));
}
root->set = set;
set->count++;
node * child = root -> first_child;
while(child){
compute_below(child, set);
child = child->next;
}
}

void remove_node(node * item) {
//    subtract_below(item, item->below+1);
int everyChild = item->parent != 0;
node * child = item->first_child;
int childCount = 0;
int toRemove = 1;
while (child) {
childCount++;
if (everyChild || childCount > 1) {
compute_below(child, 0);
toRemove += child->set->count;
}
child->parent = 0;
child = child->next;
}
item->set->count -= toRemove;
node * parent = item->parent;
if(parent){
if(parent->first_child == item){
parent->first_child = item->next;
}
if(item->next){
item->next->prev = item->prev;
}
if(item->prev){
item->prev->next = item->next;
}
}
}

int main(int argc, char **argv){
int n;
scanf("%d\n", &n);
int i = 0;
node ** nodes = calloc(n+1, sizeof(node *));
for(i = 0; i < n-1; i++){
int a,b;
scanf("%d %d\n", &a, &b);
node * node_a = nodes[a];
if(node_a == 0) {
node_a = calloc(1, sizeof(node));
node_a->number = a;
nodes[a] = node_a;
}
node * x = calloc(1, sizeof(node));
x->number = b;

node * node_b = nodes[b];
if(node_b == 0){
node_b = calloc(1, sizeof(node));
node_b->number = b;
nodes[b] = node_b;
}
x = calloc(1, sizeof(node));
x->number = a;
}

node * root = calloc(1, sizeof(node));
root->number = 1;
node ** result_nodes = calloc(n+1, sizeof(node *));
result_nodes[1] = root;
fill_children(root, nodes, result_nodes);
compute_below(result_nodes[1], 0);
int ans = 0;
int num_queries;
scanf("%d\n", &num_queries);
for(i = 0; i < num_queries; i++){
int m;
scanf("%d\n", &m);
int q = m^ans;
node * n = result_nodes[q];
ans = n->set->count;
printf("%d\n", ans);
remove_node(n);
}
return 0;
}

```

