# HackerRank Self-Driving Bus problem solution

In this HackerRank Self-Driving Bus problem you need to find the connected segments in the tree.

## Problem solution in Java Programming.

```import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;
import java.util.Random;

public class G {
InputStream is;
PrintWriter out;

String INPUT = "";

void solve()
{
int n = ni();
int[] from = new int[n - 1];
int[] to = new int[n - 1];
for (int i = 0; i < n - 1; i++) {
from[i] = ni() - 1;
to[i] = ni() - 1;
}
int[][] g = packU(n, from, to);
int[][] pars = parents3(g, 0);
int[] par = pars[0], ord = pars[1], dep = pars[2];
int[] iord = new int[n];
for(int i = 0;i < n;i++)iord[ord[i]] = i;
Node[] nodes = new Node[n];
int[] map = new int[n];
Arrays.fill(map, -1);
int[] left = new int[n];
int[] right = new int[n];
for(int i = n-1;i >= 0;i--){
int cur = ord[i];
int des = count(nodes[cur]);
int curind = -search(nodes[cur], cur)-1; // #less
assert curind >= 0;

{
int low = -1, high = curind;
while(high - low > 1){
int h = high+low>>>1;
if(cur-get(nodes[cur], h).v == curind-h){
high = h;
}else{
low = h;
}
}
left[cur] = high + cur - curind;
}
{
int low = curind-1, high = count(nodes[cur]);
while(high - low > 1){
int h = high+low>>>1;
if(get(nodes[cur], h).v-cur == h-(curind-1)){
low = h;
}else{
high = h;
}
}
right[cur] = low + cur - curind + 1;
}

nodes[cur] = insertb(nodes[cur], new Node(cur));
if(par[cur] != -1){
if(count(nodes[cur]) > count(nodes[par[cur]])){
Node d = nodes[cur]; nodes[cur] = nodes[par[cur]]; nodes[par[cur]] = d;
}
// drain
while(nodes[cur] != null){
Node first = get(nodes[cur], 0);
nodes[cur] = erase(nodes[cur], 0);
nodes[par[cur]] = insertb(nodes[par[cur]], first);
}
}
}

int[][] rs = new int[n][];
int[][] rs2 = new int[n][];
int q = 0, q2 = 0;
for(int i = 0;i < n;i++){
if(right[i]-i >= i-left[i]){
rs[q++] = new int[]{left[i], i, right[i]};
}else{
rs2[q2++] = new int[]{right[i], i, left[i]};
}
}
long ret = 0;
ret += go(Arrays.copyOf(rs, q), n, par);
for(int i = 0;i < q2;i++){
rs2[i][0] = n-1-rs2[i][0];
rs2[i][1] = n-1-rs2[i][1];
rs2[i][2] = n-1-rs2[i][2];
}
for(int i = 0;i < n;i++)par[i] = n-1-par[i];
for(int i = 0, j = n-1;i < j;i++,j--){
int d = par[i]; par[i] = par[j]; par[j] = d;
}
for(int i = 0, j = q2-1;i < j;i++,j--){
int[] d = rs2[i]; rs2[i] = rs2[j]; rs2[j] = d;
}
ret += go(Arrays.copyOf(rs2, q2), n, par);
out.println(ret);
}

long go(int[][] rs, int n, int[] par){
int m = rs.length;
SegmentTreeRMQ stmin = new SegmentTreeRMQ(par);
int[] stack = new int[n]; // desc ind
long[] has = new long[n+1];
long[] lhas = new long[n+1];
int sp = 0;
int pre = n-1;
//        tr(par);
int[] lstack = new int[n]; // desc ind
int[] lvals = new int[n]; // desc ind
Arrays.fill(stack, -1);
Arrays.fill(lstack, -1);
long ret = 0;
for(int z = m-1;z >= 0;z--){
int i = rs[z][1];
int li = rs[z][0];
int ri = rs[z][2];
while(pre > i){
while(sp > 0 && par[pre] >= par[stack[sp-1]])sp--;
int ll = Math.max(pre, par[pre]);
int rr = sp >= 1 ? stack[sp-1] : n;
has[sp+1] = Math.max(0, rr-ll) + has[sp];
stack[sp++] = pre;
pre--;
}

int lsp = 0;

int lmin = i;
for(int j = i;j >= li;j--){
int pj = j == i ? j : par[j];
while(lsp > 0 && pj >= lvals[lsp-1])lsp--;
if(lsp == 0){
while(tsp > 0 && pj >= par[stack[tsp-1]])tsp--;

lvals[lsp] = pj;
int ll = Math.max(pre, pj);
int rr = tsp >= 1 ? stack[tsp-1] : n;
lhas[lsp+1] = Math.max(0, rr-ll) + lhas[lsp];
lstack[lsp++] = j;
}else{
int ll = Math.max(pre, pj);
int rr = lsp >= 1 ? lstack[lsp-1] : tsp >= 1 ? stack[tsp-1] : n;
lvals[lsp] = pj;
lhas[lsp+1] = Math.max(0, rr-ll) + lhas[lsp];
lstack[lsp++] = j;
}

lmin = Math.min(lmin, pj);

if(lmin >= j){
int fl = stmin.firstle(i+1, j-1);
if(fl == -1){
fl = ri+1;
}
int lright = Math.min(ri, fl-1);
if(tsp-1 >= 0 && lright >= stack[tsp-1]){

int ub = upperBoundR(stack, 0, tsp, lright);
int ll = Math.max(stack[ub], par[stack[ub]]);
int rr = lright+1;
long valid = lhas[lsp]+has[tsp]-has[ub+1] + Math.max(0, rr-ll);

ret += valid;
}else{

int ub = upperBoundR(lstack, 0, lsp, lright);

int ll = Math.max(lstack[ub], lvals[ub]);
int rr = lright+1;
long valid = lhas[lsp]-lhas[ub+1] + Math.max(0, rr-ll);

ret += valid;
}
}
}
}
return ret;

}

public static int upperBoundR(int[] a, int l, int r, int v)
{
int low = l-1, high = r;
while(high-low > 1){
int h = high+low>>>1;
if(a[h] <= v){
high = h;
}else{
low = h;
}
}
return high;
}

public static class SegmentTreeRMQ {
public int M, H, N;
public int[] st;

public SegmentTreeRMQ(int n)
{
N = n;
M = Integer.highestOneBit(Math.max(N-1, 1))<<2;
H = M>>>1;
st = new int[M];
Arrays.fill(st, 0, M, Integer.MAX_VALUE);
}

public SegmentTreeRMQ(int[] a)
{
N = a.length;
M = Integer.highestOneBit(Math.max(N-1, 1))<<2;
H = M>>>1;
st = new int[M];
for(int i = 0;i < N;i++){
st[H+i] = a[i];
}
Arrays.fill(st, H+N, M, Integer.MAX_VALUE);
for(int i = H-1;i >= 1;i--)propagate(i);
}

public void update(int pos, int x)
{
st[H+pos] = x;
for(int i = (H+pos)>>>1;i >= 1;i >>>= 1)propagate(i);
}

private void propagate(int i)
{
st[i] = Math.min(st[2*i], st[2*i+1]);
}

public int minx(int l, int r){
if(l >= r)return 0;
int min = Integer.MAX_VALUE;
while(l != 0){
int f = l&-l;
if(l+f > r)break;
int v = st[(H+l)/f];
if(v < min)min = v;
l += f;
}

while(l < r){
int f = r&-r;
int v = st[(H+r)/f-1];
if(v < min)min = v;
r -= f;
}
return min;
}

public int min(int l, int r){ return l >= r ? 0 : min(l, r, 0, H, 1);}

private int min(int l, int r, int cl, int cr, int cur)
{
if(l <= cl && cr <= r){
return st[cur];
}else{
int mid = cl+cr>>>1;
int ret = Integer.MAX_VALUE;
if(cl < r && l < mid){
ret = Math.min(ret, min(l, r, cl, mid, 2*cur));
}
if(mid < r && l < cr){
ret = Math.min(ret, min(l, r, mid, cr, 2*cur+1));
}
return ret;
}
}

public int firstle(int l, int v) {
if(l >= N)return -1;
int cur = H+l;
while(true){
if(st[cur] <= v){
if(cur < H){
cur = 2*cur;
}else{
return cur-H;
}
}else{
cur++;
if((cur&cur-1) == 0)return -1;
if((cur&1)==0)cur>>>=1;
}
}
}

public int lastle(int l, int v) {
int cur = H+l;
while(true){
if(st[cur] <= v){
if(cur < H){
cur = 2*cur+1;
}else{
return cur-H;
}
}else{
if((cur&cur-1) == 0)return -1;
cur--;
if((cur&1)==1)cur>>>=1;
}
}
}
}

public static Random gen = new Random(0);

static public class Node
{
public int v; // value
public long priority;
public Node left, right, parent;

public int count;

public Node(int v)
{
this.v = v;
priority = gen.nextLong();
update(this);
}

@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("Node [v=");
builder.append(v);
builder.append(", count=");
builder.append(count);
builder.append(", parent=");
builder.append(parent != null ? parent.v : "null");
builder.append("]");
return builder.toString();
}
}

public static Node update(Node a)
{
if(a == null)return null;
a.count = 1;
if(a.left != null)a.count += a.left.count;
if(a.right != null)a.count += a.right.count;

// TODO
return a;
}

public static void propagate(Node x)
{
for(;x != null;x = x.parent)update(x);
}

public static Node disconnect(Node a)
{
if(a == null)return null;
a.left = a.right = a.parent = null;
return update(a);
}

public static Node root(Node x)
{
if(x == null)return null;
while(x.parent != null)x = x.parent;
return x;
}

public static int count(Node a)
{
return a == null ? 0 : a.count;
}

public static void setParent(Node a, Node par)
{
if(a != null)a.parent = par;
}

public static Node merge(Node a, Node b, Node... c)
{
Node x = merge(a, b);
for(Node n : c)x = merge(x, n);
return x;
}

public static Node merge(Node a, Node b)
{
if(b == null)return a;
if(a == null)return b;
if(a.priority > b.priority){
setParent(a.right, null);
setParent(b, null);
a.right = merge(a.right, b);
setParent(a.right, a);
return update(a);
}else{
setParent(a, null);
setParent(b.left, null);
b.left = merge(a, b.left);
setParent(b.left, b);
return update(b);
}
}

public static Node[] split(Node x)
{
if(x == null)return new Node[]{null, null};
if(x.left != null)x.left.parent = null;
Node[] sp = new Node[]{x.left, x};
x.left = null;
update(x);
while(x.parent != null){
Node p = x.parent;
x.parent = null;
if(x == p.left){
p.left = sp[1];
if(sp[1] != null)sp[1].parent = p;
sp[1] = p;
}else{
p.right = sp[0];
if(sp[0] != null)sp[0].parent = p;
sp[0] = p;
}
update(p);
x = p;
}
return sp;
}

public static Node[] split(Node a, int... ks)
{
int n = ks.length;
if(n == 0)return new Node[]{a};
for(int i = 0;i < n-1;i++){
if(ks[i] > ks[i+1])throw new IllegalArgumentException(Arrays.toString(ks));
}

Node[] ns = new Node[n+1];
Node cur = a;
for(int i = n-1;i >= 0;i--){
Node[] sp = split(cur, ks[i]);
cur = sp[0];
ns[i] = sp[0];
ns[i+1] = sp[1];
}
return ns;
}

// [0,K),[K,N)
public static Node[] split(Node a, int K)
{
if(a == null)return new Node[]{null, null};
if(K <= count(a.left)){
setParent(a.left, null);
Node[] s = split(a.left, K);
a.left = s[1];
setParent(a.left, a);
s[1] = update(a);
return s;
}else{
setParent(a.right, null);
Node[] s = split(a.right, K-count(a.left)-1);
a.right = s[0];
setParent(a.right, a);
s[0] = update(a);
return s;
}
}

public static Node insertb(Node root, Node x)
{
int ind = search(root, x.v);
if(ind < 0)ind = -ind-1;
return insert(root, ind, x);
}

public static Node insert(Node a, int K, Node b)
{
if(a == null)return b;
if(b.priority < a.priority){
if(K <= count(a.left)){
a.left = insert(a.left, K, b);
setParent(a.left, a);
}else{
a.right = insert(a.right, K-count(a.left)-1, b);
setParent(a.right, a);
}
return update(a);
}else{
Node[] ch = split(a, K);
b.left = ch[0]; b.right = ch[1];
setParent(b.left, b);
setParent(b.right, b);
return update(b);
}
}

// delete K-th
public static Node erase(Node a, int K)
{
if(a == null)return null;
if(K < count(a.left)){
a.left = erase(a.left, K);
setParent(a.left, a);
return update(a);
}else if(K == count(a.left)){
setParent(a.left, null);
setParent(a.right, null);
Node aa = merge(a.left, a.right);
disconnect(a);
return aa;
}else{
a.right = erase(a.right, K-count(a.left)-1);
setParent(a.right, a);
return update(a);
}
}

public static Node get(Node a, int K)
{
while(a != null){
if(K < count(a.left)){
a = a.left;
}else if(K == count(a.left)){
break;
}else{
K = K - count(a.left)-1;
a = a.right;
}
}
return a;
}

public static int index(Node a)
{
if(a == null)return -1;
int ind = count(a.left);
while(a != null){
Node par = a.parent;
if(par != null && par.right == a){
ind += count(par.left) + 1;
}
a = par;
}
return ind;
}

public static int search(Node a, int q)
{
int lcount = 0;
while(a != null){
if(a.v == q){
lcount += count(a.left);
break;
}
if(q < a.v){
a = a.left;
}else{
lcount += count(a.left) + 1;
a = a.right;
}
}
return a == null ? -(lcount+1) : lcount;
}

public static Node next(Node x)
{
if(x == null)return null;
if(x.right != null){
x = x.right;
while(x.left != null)x = x.left;
return x;
}else{
while(true){
Node p = x.parent;
if(p == null)return null;
if(p.left == x)return p;
x = p;
}
}
}

public static Node prev(Node x)
{
if(x == null)return null;
if(x.left != null){
x = x.left;
while(x.right != null)x = x.right;
return x;
}else{
while(true){
Node p = x.parent;
if(p == null)return null;
if(p.right == x)return p;
x = p;
}
}
}

public static Node[] nodes(Node a) { return nodes(a, new Node[a.count], 0, a.count); }
public static Node[] nodes(Node a, Node[] ns, int L, int R)
{
if(a == null)return ns;
nodes(a.left, ns, L, L+count(a.left));
ns[L+count(a.left)] = a;
nodes(a.right, ns, R-count(a.right), R);
return ns;
}

public static String toString(Node a, String indent)
{
if(a == null)return "";
StringBuilder sb = new StringBuilder();
sb.append(toString(a.left, indent + "  "));
sb.append(indent).append(a).append("\n");
sb.append(toString(a.right, indent + "  "));
return sb.toString();
}

public static int[][] parents3(int[][] g, int root) {
int n = g.length;
int[] par = new int[n];
Arrays.fill(par, -1);

int[] depth = new int[n];
depth[0] = 0;

int[] q = new int[n];
q[0] = root;
for (int p = 0, r = 1; p < r; p++) {
int cur = q[p];
for (int nex : g[cur]) {
if (par[cur] != nex) {
q[r++] = nex;
par[nex] = cur;
depth[nex] = depth[cur] + 1;
}
}
}
return new int[][] { par, q, depth };
}

static int[][] packU(int n, int[] from, int[] to) {
int[][] g = new int[n][];
int[] p = new int[n];
for (int f : from)
p[f]++;
for (int t : to)
p[t]++;
for (int i = 0; i < n; i++)
g[i] = new int[p[i]];
for (int i = 0; i < from.length; i++) {
g[from[i]][--p[from[i]]] = to[i];
g[to[i]][--p[to[i]]] = from[i];
}
return g;
}

void run() throws Exception
{
is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);

long s = System.currentTimeMillis();
solve();
out.flush();
if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
}

public static void main(String[] args) throws Exception { new G().run(); }

private byte[] inbuf = new byte[1024];
private int lenbuf = 0, ptrbuf = 0;

{
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
}
return inbuf[ptrbuf++];
}

private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }

private double nd() { return Double.parseDouble(ns()); }
private char nc() { return (char)skip(); }

private String ns()
{
int b = skip();
StringBuilder sb = new StringBuilder();
while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
sb.appendCodePoint(b);
}
return sb.toString();
}

private char[] ns(int n)
{
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private char[][] nm(int n, int m)
{
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;
}

private int[] na(int n)
{
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;
}

private int ni()
{
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
}
}

private long nl()
{
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
}

while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
}
}

private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); }
}```

## Problem solution in C++ programming.

```#ifdef _MSC_VER
#define _CRT_SECURE_NO_WARNINGS
#endif
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<string>
#include<cstring>
#include<unordered_map>
#include<cassert>
#include<cmath>

#define dri(X) int (X); scanf("%d", &X)
#define drii(X, Y) int X, Y; scanf("%d%d", &X, &Y)
#define driii(X, Y, Z) int X, Y, Z; scanf("%d%d%d", &X, &Y, &Z)
#define pb push_back
#define mp make_pair
#define rep(i, s, t) for ( int i=(s) ; i <(t) ; i++)
#define fill(x, v) memset (x, v, sizeof(x))
#define all(x) (x).begin(), (x).end()
#define why(d) cerr << (d) << "!\n"
#define whisp(X, Y) cerr << (X) << " " << (Y) << "#\n"
#define exclam cerr << "!!\n"
#define left(p) (p << 1)
#define right(p) ((p << 1) + 1)
#define mid ((l + r) >> 1)
typedef long long ll;
using namespace std;
typedef pair<int, int> pii;
const ll inf = (ll)1e9 + 70;
const ll mod = 1e9 + 7;
const int maxn = 2e5 + 1000;

bool used[maxn];
int sz[maxn];//subtree sizes

int mark[maxn];
int tt = 0;

int maax[maxn];
int miin[maxn];
int goright[maxn];
int goleft[maxn];

int val[maxn];
int ST[4 * maxn];// an all-purpose ST: min, max, and sum!!
int query(int p, int l, int r, int i, int j){//sum query
if (l > j || r < i) return 0;
if (i <= l && r <= j){
return ST[p];
}
return query(left(p), l, mid, i, j) + query(right(p), mid + 1, r, i, j);
}
void update(int p, int l, int r, int i, int delta){
if (l > i || r < i) return;
if (l == r){
ST[p] += delta;
return;
}
update(left(p), l, mid, i, delta);
update(right(p), mid + 1, r, i, delta);
ST[p] = ST[left(p)] + ST[right(p)];
}
void buildtree(int p, int l, int r, bool m){
if (l == r){
ST[p] = val[l];
return;
}
buildtree(left(p), l, mid, m);
buildtree(right(p), mid + 1, r, m);
if (m) ST[p] = min(ST[left(p)], ST[right(p)]);
else ST[p] = max(ST[left(p)], ST[right(p)]);
}
vector<pair<int, pii>> blocks;
void decompose(int p, int l, int r, int i){
if (r < i) return;
if (l >= i){
blocks.push_back(mp(p, pii(l, r)));
return;
}
decompose(left(p), l, mid, i);
decompose(right(p), mid + 1, r, i);
}
void decompose2(int p, int l, int r, int i){
if (l > i) return;
if (r <= i){
blocks.push_back(mp(p, pii(l, r)));
return;
}
decompose2(left(p), l, mid, i);
decompose2(right(p), mid + 1, r, i);
}

int find(int p, int l, int r, int x){
assert(ST[p] < x);
if (l == r) return l;
if (ST[left(p)] >= x){
return find(right(p), mid + 1, r, x);
}
return find(left(p), l, mid, x);
}

int find2(int p, int l, int r, int x){
assert(ST[p] > x);
if (l == r) return l;
if (ST[right(p)] <= x){
return find2(left(p), l, mid, x);
}
return find2(right(p), mid + 1, r, x);
}

void dfs(int v, int p){
mark[v] = tt;
sz[v] = 1;
if (p == -1){
maax[v] = v; miin[v] = v;
}
else{
maax[v] = max(maax[p], v);
miin[v] = min(miin[p], v);
}

if (u == p || used[u]) continue;
dfs(u, v);
sz[v] += sz[u];
}
}

ll perform(int v, int n){
if (n == 1){
return 1;
}
//first, FIND the centroid.
dfs(v, -1);
int g = v; int p = -1;
while (true){
int w = -1;
if (h == p || used[h]) continue;
if (w == -1 || sz[h] > sz[w]) w = h;
}
assert(w != -1);//g should NOT be a leaf.
if (2 * sz[w] <= n){
break;
}
p = g; g = w;
}
//g is the centroid.
tt++;
dfs(g, -1);
//here comes the HEART OF THE ALGORITHM.
int m = -800;
for (int l = g; l > 0; l--){
if (mark[l] != tt) break;
m = l;
}
int M = -800;
for (int r = g; r < maxn; r++){
if (mark[r] != tt) break;
M = r;
}
//Our working interval is m <= i <= M.
rep(i, m, M + 1){
val[i] = miin[i];
//cout << miin[i] << " ";
}//cout << endl;
buildtree(1, m, M, true);
rep(i, m, M + 1){
if (miin[i] < i){
goright[i] = -inf;
continue;
}
blocks.clear();
decompose(1, m, M, i);
reverse(blocks.begin(), blocks.end());
while (!blocks.empty() && ST[blocks.back().first] >= i) blocks.pop_back();
if (blocks.empty()){
goright[i] = M;
}
else{
auto ee = blocks.back();
goright[i] = find(ee.first, ee.second.first, ee.second.second, i) - 1;
}
}
//rep(i, m, M + 1)cout << goright[i] << " "; cout << endl;
//now, goleft!
rep(i, m, M + 1) val[i] = maax[i];
//rep(i, m, M + 1) cout << maax[i] << " "; cout << endl;
buildtree(1, m, M, false);
rep(i, m, M + 1){
if (maax[i] > i){
goleft[i] = inf;
continue;
}
blocks.clear();
decompose2(1, m, M, i);
while (!blocks.empty() && ST[blocks.back().first] <= i) blocks.pop_back();
if (blocks.empty()){
goleft[i] = m;
}
else{
auto ee = blocks.back();
goleft[i] = find2(ee.first, ee.second.first, ee.second.second, i) + 1;
}
}
//rep(i, m, M + 1) cout << goleft[i] << " "; cout << endl;
vector<pii> rs;
rep(r, m, M + 1){
if (goleft[r] != inf) rs.push_back(pii(goleft[r], r));
}
sort(all(rs)); reverse(all(rs));
rep(i, m, M + 1) val[i] = 0;
buildtree(1, m, M, true);//basically: just clear it.
ll ans = 0;
for (int l = m; l <= M; l++){
//whisp(l, goright[l]);
while (!rs.empty() && rs.back().first == l){
update(1, m, M, rs.back().second, 1);
//cout << "update " << rs.back().second << "\n";
rs.pop_back();
}
//cout << query(1, m, M, l, goright[l]) << "\n";
ans += query(1, m, M, l, goright[l]);
}
used[g] = true;
vector<pii> ls;
if (used[u]) continue;
ls.push_back(pii(u, sz[u]));
}
for (pii x : ls){
ans += perform(x.first, x.second);
}
return ans;
}

int main(){
if (fopen("input.txt", "r")) freopen("input.txt", "r", stdin);
dri(n);
rep(i, 1, n){
drii(a, b);
}
cout << perform(1, n) << "\n";
return 0;
}```

## Problem solution in C programming.

```#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>

#define fprintf(...)

struct vertex {
struct vertex* parent;
int rank;
int count;
};

struct vertex* vfind(struct vertex *v) {
if (v->parent == NULL) return v;  // this is a disconnected one.
if (v->parent != v) {
v->parent = vfind(v->parent);
}
return v->parent;
}

struct vertex* vunion(struct vertex *x, struct vertex* y) {
struct vertex *xroot = vfind(x);
struct vertex *yroot = vfind(y);
if (xroot == yroot) return yroot;
// fix any uninitialized counts.
if (xroot->count == 0) xroot->count++;
if (yroot->count == 0) yroot->count++;

if (xroot->rank > yroot->rank) {
struct vertex* tmp = xroot;
xroot = yroot;
yroot = tmp;
}
// xroot is now the smaller tree if they're not the same.
if (xroot->rank == yroot->rank) {
yroot->rank++;
}
xroot->parent = yroot;
yroot->count += xroot->count;
return yroot;
}

struct edge {
int a, b;
};

int ecmp(const void*a_in, const void*b_in) {
const struct edge* a = a_in;
const struct edge* b = b_in;

if (a->b < b->b) return -1;
if (a->b > b->b) return 1;
if (a->a < b->a) return -1;
if (a->a > b->a) return 1;
return 0;
}

int ecmp_a(const void*a_in, const void*b_in) {
const struct edge* a = a_in;
const struct edge* b = b_in;

if (a->a < b->a) return -1;
if (a->a > b->a) return 1;
if (a->b < b->b) return -1;
if (a->b > b->b) return 1;
return 0;
}

// n * ack-1(n) algorithm; needs to be run n times for n^2 ack-1(n).  Not the best, but gets 50%.
int count_components1(int start, struct edge* edges, int ne, int n) {
if (ne == 0) return 1;
fprintf(stderr, "start: %d, ne %d, n %d\n", start, ne, n);
int max_components = n - start + 1;
struct vertex v[max_components];
memset(v, 0, sizeof(v));
int components = 1;
struct edge* le = edges + ne;
for (int maxv = start + 1; maxv <= n; maxv++) {
struct vertex* join = NULL;
while (edges < le && edges->b <= maxv) {
if (edges->a >= start) {
join = vunion(&v[edges->a - start], &v[edges->b - start]);
fprintf(stderr, "Join: %d to %d, new count %d\n", edges->a, edges->b, join->count);
}
edges++;
}
if (join && join->count == maxv - start + 1) components++;
}
return components;
}

int count_components(int start, struct edge* edges, int ne, int n) {
if (ne == 0) return 1;
fprintf(stderr, "start: %d, ne %d, n %d\n", start, ne, n);
int max_components = n - start + 1;
struct vertex v[max_components];
memset(v, 0, sizeof(v));
int components = 1;
struct edge* le = edges + ne;
for (int maxv = start + 1; maxv <= n; maxv++) {
struct vertex* join = NULL;
while (edges < le && edges->b <= maxv) {
if (edges->a >= start) {
join = vunion(&v[edges->a - start], &v[edges->b - start]);
fprintf(stderr, "Join: %d to %d, new count %d\n", edges->a, edges->b, join->count);
}
edges++;
}
if (join && join->count == maxv - start + 1) components++;
}
return components;
}

int old_main() {
int n;
scanf("%d\n", &n);
struct edge edges[n-1];
memset(edges, 0, sizeof(edges));
for (int i = 0; i < n-1; i++) {
int e1, e2;
scanf("%d %d\n", &e1, &e2);
if (e1 < e2) {
edges[i].a = e1;
edges[i].b = e2;
} else {
edges[i].a = e2;
edges[i].b = e1;
}
}
qsort(edges, n-1, sizeof(struct edge), ecmp);
for (int i = 0; i < n-1; i++) {
fprintf(stderr, "Edge: %d %d\n", edges[i].a, edges[i].b);
}
int result = 0;
struct edge *ep = edges;
struct edge *lp = edges + n - 1;
for (int i = 1; i <= n; i++) {
while(ep < lp && ep->a < i) ep++;
int cc = count_components(i, ep, lp - ep, n);
fprintf(stderr, "i: %d  cc: %d\n", i, cc);
result += cc;
}
printf("%d\n", result);
return 0;
}

struct node {
int nn;
// indexes of forward edges in the node.
// Edges always belong to the low node.
int first_edge;
int n_edges;
};

struct line {
int start_node;
int end_node;
};

struct segment_node {
int lazy;
int max_v;  // maximum value of any node below
int num_v;  // number of nodes with that maximum value
};

#define C1(i) ((i)*2+1)
#define C2(i) ((i)*2+2)

void propagate(struct segment_node* tree, int index, int start, int end) {
if (!tree[index].lazy) return;
if (start == end) {
// leaf, nothing to do;
tree[index].lazy = 0;
return;
}
fprintf(stderr, "Prop: %d v: %d\n", index, tree[index].lazy);
tree[C1(index)].lazy += tree[index].lazy;
tree[C2(index)].lazy += tree[index].lazy;
tree[C1(index)].max_v += tree[index].lazy;
tree[C2(index)].max_v += tree[index].lazy;
tree[index].lazy = 0;
}

// ns, ne = node start/end = recursion counter
// rs, re = input range start/end
// adds "v" to all nodes between rs and re.
// the segment tree is implicitly "complete", i.e. contains all integers in [ns, ne]
int treelim;
void update(struct segment_node* tree, int index, int ns, int ne, int rs, int re, int v) {
if (index >= treelim) exit(-1);
fprintf(stderr, "upd: i: %d ns,ne (%d %d) rs, re (%d %d), v %d\n", index, ns, ne, rs, re, v);
fprintf(stderr, "   prev max_v %d num_v %d\n", tree[index].max_v, tree[index].num_v);
propagate(tree, index, ns, ne);
if (ns == rs && ne == re) {
tree[index].max_v += v;
if (ns == ne) tree[index].num_v = 1;
tree[index].lazy += v;
return;
}
int mid = (ns + ne) / 2;
if (re <= mid) update(tree, C1(index), ns, mid, rs, re, v);
else if (rs > mid) update(tree, C2(index), mid + 1, ne, rs, re, v);
else {
update(tree, C1(index), ns, mid, rs, mid, v);
update(tree, C2(index), mid + 1, ne, mid + 1, re, v);
}
// now up-propagate.
if (tree[C1(index)].max_v > tree[C2(index)].max_v) {
fprintf(stderr, "C1\n");
tree[index].max_v = tree[C1(index)].max_v;
tree[index].num_v = tree[C1(index)].num_v;
} else if (tree[C1(index)].max_v < tree[C2(index)].max_v) {
fprintf(stderr, "C2\n");
tree[index].max_v = tree[C2(index)].max_v;
tree[index].num_v = tree[C2(index)].num_v;
} else {
fprintf(stderr, "BB\n");
tree[index].max_v = tree[C1(index)].max_v;
tree[index].num_v = tree[C1(index)].num_v + tree[C2(index)].num_v;
}
fprintf(stderr, "upd done: %d max_v %d num_v %d\n", index, tree[index].max_v, tree[index].num_v);
}

int main() {
int n;
scanf("%d\n", &n);

struct node nodes[n];
struct edge edges[n-1];
int nl = 0;
memset(nodes, 0, sizeof(nodes));
memset(edges, 0, sizeof(edges));
for (int i = 0; i < n-1; i++) {
int e1, e2;
scanf("%d %d\n", &e1, &e2);
if (e1 < e2) {
edges[i].a = e1;
edges[i].b = e2;
} else {
edges[i].a = e2;
edges[i].b = e1;
}
}
qsort(edges, n-1, sizeof(struct edge), ecmp);
for (int i = 0; i < n; i++) {
nodes[i].nn = i+1;
}
int cur_node = -1;
for (int i = 0; i < n-1; i++) {
fprintf(stderr, "Edge %d: %d %d\n", i, edges[i].a, edges[i].b);
if (edges[i].b - 1 > cur_node) {
for (int j = cur_node + 1; j < edges[i].b - 1; j++) {
// Make the zero-edge nodes have a "first edge" that makes sense
nodes[j].first_edge = i;
}
if (cur_node >= 0) {
nodes[cur_node].n_edges = i - nodes[cur_node].first_edge;
}
cur_node = edges[i].b - 1;
nodes[cur_node].first_edge = i;
}
}
fprintf(stderr, "n:%d, cur_node %d %d\n", n, cur_node, nodes[cur_node].nn);
nodes[cur_node].n_edges = n - 1 - nodes[cur_node].first_edge;
while (++cur_node < n) {
nodes[cur_node].first_edge = n - 1;
}
for (int i = 0; i < n; i++) {
fprintf(stderr, "Node: %d edges start at %d nedges %d\n", nodes[i].nn, nodes[i].first_edge, nodes[i].n_edges);
}
long result = 0;
treelim = 1<<((int)ceil(log2(n)) + 1);
struct segment_node stree[treelim];
memset(stree, 0, sizeof(stree));
for (int i = 0; i < n; i++) {
for (int j = nodes[i].first_edge; j < nodes[i].first_edge + nodes[i].n_edges; j++) {
update(stree, 0, 1, n, 1, edges[j].a, 1);
}