In this HackerRank BST maintenance problem solution Consider a binary search tree T which is initially empty. Also, consider the first N positive integers {1, 2, 3, 4, 5, ....., N} and its permutation P {a1, a2, ..., aN}. If we start adding these numbers to the binary search tree T, starting from a1, continuing with a2, ... (and so on) ..., ending with aN. After every addition, we ask you to output the sum of distances between every pair of T's nodes.

HackerRank BST maintenance problem solution


Problem solution in Java Programming.

import java.io.*;
import java.util.*;

public class Solution {

  static int tick;
  static int[] l;
  static int[] r;
  static int[] size;

  static class NodeDfs {
    int v;
    int p;
    boolean start = true;
    public NodeDfs(int v, int p) {
      this.v = v;
      this.p = p;
    }
  } 
  
  static int dfs(int v) {
    if (v < 0) {
      return 0;
    }
    Deque<NodeDfs> deque = new LinkedList<>();
    deque.add(new NodeDfs(v, -1));
    while (!deque.isEmpty()) {
      NodeDfs node = deque.peekLast();
      if (node.start) {
        size[node.v] = 1;
        if (l[node.v] >= 0) {
          deque.add(new NodeDfs(l[node.v], node.v));
        }
        if (r[node.v] >= 0) {
          deque.add(new NodeDfs(r[node.v], node.v));
        }
        node.start = false;
      } else {
        if (node.p >= 0) {
          size[node.p] += size[node.v];
        }
        deque.removeLast();
      }
    }
    
    return size[v];
  }

  static void add(long fenwick[], int i, long v) {
    for (; i < fenwick.length; i |= i+1) {
      fenwick[i] += v;
    }
  }

  static long getSum(long fenwick[], int i) {
    long s = 0;
    for (; i > 0; i &= i-1) {
      s += fenwick[i-1];
    }
    return s;
  }

  static int[] pre;
  static int[] post;
  static int[] heavy;
  static int[] dep;
  static int[] p;

  static class NodeH {
    int d;
    int v;
    int chain;
    NodeH next = null;
    boolean start = true;
    
    public NodeH(int d, int v, int chain) {
      this.d = d;
      this.v = v;
      this.chain = chain;
    }
  } 
  
  static void heavyLight(int d, int v, int chain) {
    Deque<NodeH> deque = new LinkedList<>();
    deque.add(new NodeH(d, v, chain));
    while (!deque.isEmpty()) {
      NodeH node = deque.peekLast();
      if (node.start) {
        int[] c = new int[2];
        int nc = 0;
        pre[node.v] = tick++;
        heavy[node.v] = node.chain;
        dep[node.v] = node.d;
        if (~ l[node.v] != 0) {
          p[l[node.v]] = node.v;
          c[nc++] = l[node.v];
        }
        if (~ r[node.v] != 0) {
          p[r[node.v]] = node.v;
          c[nc++] = r[node.v];
        }
        if (nc == 2 && size[c[0]] < size[c[1]]) {
          int tmp = c[0];
          c[0] = c[1];
          c[1] = tmp;
        }
        if (nc > 0) {
          deque.add(new NodeH(node.d+1, c[0], node.chain));
          if (nc == 2) {
            node.next = new NodeH(node.d+1, c[1], c[1]);
          }
        }
        node.start = false;
      } else {
        if (node.next != null) {
          deque.add(node.next);
          node.next = null;
        } else {
          post[node.v] = tick;
          deque.removeLast();          
        }
      }
    }
  }

  static long[] fw_size;
  
  static long getSize(int i) {
    return getSum(fw_size, post[i]) - getSum(fw_size, pre[i]);
  }

  static void iota(Integer v[], int val) {
    for (int i = 0; i < v.length; i++) {
      v[i] = val++;
    }
  }
  
  static void cartesianTreeConstruction(int arr[], int n) {
    Integer[] b = new Integer[n];
    iota(b, 0);
    Arrays.sort(b, (i, j) -> { return arr[i] - arr[j]; });

    Deque<Integer> stack = new LinkedList<>();
    l = new int[n];
    r = new int[n];
    for (int i = 0; i < n+1; i++) {
      int j = i == n ? -1 : b[i], x = -1;
      for (; ! stack.isEmpty() && stack.peek() > j; stack.pop()) {
        int y = stack.peek();
        r[y] = x;
        x = y;
      }
      if (i < n) {
        stack.push(j);
        l[j] = x;
      }
    }

    size = new int[n];
    dfs(0);
    
    pre = new int[n];
    post = new int[n];
    heavy = new int[n];
    dep = new int[n];
    p = new int[n];
    
    heavyLight(0, 0, 0);
  }
  
  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());

    int[] arr = new int[n];
    st = new StringTokenizer(br.readLine());
    int cnt1 = 0;
    int cnt2 = 0;
    for (int i = 0; i < n; i++) {
      int item = Integer.parseInt(st.nextToken());
      arr[i] = item;
      if (item == i + 1) {
        cnt1++;
      }
      if (item == n - i) {
        cnt2++;
      }
    }

    if (cnt1 == n || cnt2 == n) {
      // special
        for (int i = 0; i < n; i++) {
          long result = (long)i * (i + 1) * (i + 2) / 6;
          bw.write(result + "\n");
        }
    } else {
      cartesianTreeConstruction(arr, n);
      
      long ans = 0;
      long sum = 0;
      long[] fw_s = new long[n];
      long[] fw_ds = new long[n];
      fw_size = new long[n];
      
      for (int i = 0; i < n; i++) {
        ans += sum;
        for (int v = i; v > 0; ) {
          int u = p[v];
          ans += (dep[i]-2*dep[u]) * (getSize(u)-getSize(v));
          long s = getSum(fw_s, pre[u]) - getSum(fw_s, pre[heavy[u]]);
          long ds = getSum(fw_ds, pre[u]) - getSum(fw_ds, pre[heavy[u]]);
          ans += dep[i]*s - 2*ds;
          if (heavy[v] != heavy[u]) {
            add(fw_s, pre[u], 1);
            add(fw_ds, pre[u], dep[u]);
          }
          v = heavy[u];
        }
        add(fw_size, pre[i], 1);
        add(fw_s, pre[i], 1);
        add(fw_ds, pre[i], dep[i]);
        sum += dep[i];
        bw.write(ans + "\n");
      }
    }
    
    bw.close();
    br.close();
  }
}


Problem solution in C++ programming.

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <memory.h>

using namespace std;

int p[250005];

// stupid solver

struct node {
  node(int _val) : val(_val), size(1), h(0), left(NULL), right(NULL), parent(NULL) {};
  int val, size, h;
  node* left;
  node* right;
  node* parent;
};

vector<vector<int> > g;

void addEdge(int a, int b) {
  g[a].push_back(b);
  g[b].push_back(a);
}

void buildGraph(node* root) {
  if (root == NULL) return;
  if (root->left != NULL) {
    addEdge(root->val, root->left->val);
  }
  if (root->right != NULL) {
    addEdge(root->val, root->right->val);
  }
  buildGraph(root->left);
  buildGraph(root->right);
}

int findAll(int v, int d = 0, int par = -1) {
  int ans = d;
  for (int i = 0; i < g[v].size(); ++i) {
    if (g[v][i] != par) {
      ans += findAll(g[v][i], d + 1, v);
    }
  }
  return ans;
}

int countForTree(int n, node* root) {
  g.clear(); g.resize(n);
  buildGraph(root);
  int ret = 0;
  for (int i = 0; i < n; ++i) {
    ret += findAll(i);
  }
  return ret / 2;
}

vector<int> solveStupid(int n) {
  node* root = new node(p[0]);
  vector<int> ret(1, 0);
  for (int i = 1; i < n; ++i) {
    node* cur = root;
    while (true) {
      if (cur->val > p[i]) {
        if (cur->left == NULL) {
          cur->left = new node(p[i]);
          break;
        } else {
          cur = cur->left;
        }
      } else {
        if (cur->right == NULL) {
          cur->right = new node(p[i]);
          break;
        } else {
          cur = cur->right;
        }
      }
    }

    ret.push_back(countForTree(n, root));

  }
  return ret;
}

// tree structure

long long tree[1000005], H[250005];
int add[1000005];

long long getH(int l, int r) {
  return r - l + 1;
  return H[r] - (l > 0 ? H[l - 1] : 0);
}

void push(int v, int l, int r) {
  if (add[v]) {
    int m = (l + r) / 2;
    add[2 * v] += add[v];
    tree[2 * v] += 1LL * getH(l, m) * add[v];
    add[2 * v + 1] += add[v];
    tree[2 * v + 1] += 1LL * getH(m + 1, r) * add[v];
    add[v] = 0;
  }
}

void update(int i, int l, int r, int L, int R, int val) {
  if (l == L && r == R) {
    tree[i] += 1LL * getH(l, r) * val;
    add[i] += val;
    return;
  }
  push(i, l, r);
  int m = (l + r) >> 1;
  if (R <= m) update(2 * i, l, m, L, R, val);
  else if (L > m) update(2 * i + 1, m + 1, r, L, R, val);
  else {
    update(2 * i, l, m, L, m, val);
    update(2 * i + 1, m + 1, r, m + 1, R, val);
  }
  tree[i] = tree[2 * i] + tree[2 * i + 1];
}

long long find(int i, int l, int r, int L, int R) {
  if (l == L && r == R)
    return tree[i];
  push(i, l, r);
  int m = (l + r) >> 1;
  if (R <= m) return find(2 * i, l, m, L, R);
  else if (L > m) return find(2 * i + 1, m + 1, r, L, R);
  return find(2 * i, l, m, L, m) + find(2 * i + 1, m + 1, r, m + 1, R);
}

// cool solution

node* whatNode[250005];
set<int> f;

node* createNode(int val) {
  node* ret = new node(val);
  f.insert(val);
  whatNode[val] = ret;
  return ret;
}

node* buildFullTree(int n) {
  node* root = createNode(p[0]);
  set<int>::iterator it;
  for (int i = 1; i < n; ++i) {
    it = f.upper_bound(p[i]);
    if (it != f.end() && whatNode[*it]->left == NULL) {
      whatNode[*it]->left = createNode(p[i]);
      whatNode[*it]->left->parent = whatNode[*it];
      whatNode[*it]->left->h = whatNode[*it]->h + 1;
    } else {
      --it;
      whatNode[*it]->right = createNode(p[i]);
      whatNode[*it]->right->parent = whatNode[*it];
      whatNode[*it]->right->h = whatNode[*it]->h + 1;
    }

    // // test-code
    // g.clear(); g.resize(n);
    // buildGraph(root);
    // int sum = 0;
    // for (int i = 0; i < n; ++i) {
    //   sum += findAll(i);
    // }
    // cout << sum / 2 << endl;
  }
  return root;
}

int fillSize(node* root) {
  if (root->left != NULL) {
    root->size += fillSize(root->left);
  }
  if (root->right != NULL) {
    root->size += fillSize(root->right);
  }
  return root->size;
}

vector<int> order;
int L[250005], where[250005];

void dfs(node* root, int left) {
  H[order.size()] = root->h;
  order.push_back(root->val);
  where[root->val] = order.size() - 1;
  L[root->val] = left;
  if (root->left != NULL && root->right != NULL) {
    if (root->left->size >= root->right->size) {
      dfs(root->left, left);
      dfs(root->right, order.size());
    } else {
      dfs(root->right, left);
      dfs(root->left, order.size());
    }
  } else if (root->left != NULL) {
    dfs(root->left, left);
  } else if (root->right != NULL) {
    dfs(root->right, left);
  }
}

long long findOnPath(int n, int v) {
  long long ret = 0, hSum = 0;
  int left = L[v];
  while (true) {
    // cout << "Fnd: " << left << "  " << where[v] << endl;
    ret += find(1, 0, n - 1, left, where[v]);
    if (left == 0) break;
    v = whatNode[order[left]]->parent->val;
    // cout << "V: " << v << endl;
    left = L[v];
  }
  return ret;
}

void updateOnPath(int n, int v) {
  int left = L[v];
  while (true) {
    // cout << "Upd: " << left << "  " << where[v] << endl;
    if (left == 0) {
      if (1 <= where[v])
        update(1, 0, n - 1, 1, where[v], 1);
      break;
    } else {
      update(1, 0, n - 1, left, where[v], 1);
    }
    v = whatNode[order[left]]->parent->val;
    left = L[v];
  }
}

vector<long long> solve(int n) {
  vector<long long> ret(1, 0);
  node* root = buildFullTree(n);
  fillSize(root);
  dfs(root, 0);
  for (int i = 1; i < n; ++i) {
    H[i] += H[i - 1];
    // cout << H[i] << " ";
  }
  // cout << endl;
  // for (int i = 0; i < n; ++i)
  //   cout << order[i] + 1 << " ";
  // cout << endl;
  // for (int i = 0; i < n; ++i)
  //   cout << L[i] << " ";
  // cout << endl;
  long long hSum = 0, ans = 0;
  for (int i = 1; i < n; ++i) {
    ans += hSum;
    hSum += whatNode[p[i]]->h;
    ans += 1LL * i * whatNode[p[i]]->h;
    ans -= 2 * findOnPath(n, whatNode[p[i]]->parent->val);
    // long long K = findOnPath(n, whatNode[p[i]]->parent->val);
    // cout << whatNode[p[i]]->parent->val + 1 << "  " << K << endl;
    updateOnPath(n, p[i]);
    ret.push_back(ans);
  }
  return ret;
}

void checker() {
  for (int it = 0; it < 100; ++it) {
    int n = rand() % 100 + 1;
    for (int i = 0; i < n; ++i) {
      p[i] = i;
      int ind = rand() % i + 1;
      swap(p[i], p[ind]);
    }
    
    vector<int> real = solveStupid(n);
    vector<long long> my = solve(n);
    for (int i = 0; i < my.size(); ++i) {
      if (real[i] != my[i])
        puts("fuck");
    }
  }
}

int main() {
  int n;
  scanf("%d", &n);
  for (int i = 0; i < n; ++i) {
    scanf("%d", &p[i]); --p[i];
  }

  vector<long long> my = solve(n);
  for (int i = 0; i < my.size(); ++i) {
    printf("%lld\n", my[i]);
  }

  // checker();
  return 0;
}


Problem solution in C programming.

#include <stdio.h>
#include <stdlib.h>
typedef struct _node{
  int x;
  struct _node *next;
} lnode;
void init( int n ,int *tree);
void range_increment( int i, int j, int val ,int *tree);
int query( int i ,int *tree);
void insert_edge(int x,int y);
void dfs0(int u);
void preprocess();
int lca(int a,int b);
int dist(int u,int v);
void dfs1(int u,int p);
int dfs2(int u,int p);
void decompose(int root,int p);
int a[250000],cut[250000]={0},parent[250000],DP[18][250000],mid[750000],left[750000],right[750000],level[250000],sub[250000],N,NN,nn;
long long count[250000]={0},sum[250000]={0},con[250000]={0};
lnode *table[250000]={0};

int main(){
  int x,y,z,leftd,rightd,i;
  long long ans,aa=0;
  scanf("%d",&NN);
  for(i=0;i<NN;i++)
    scanf("%d",a+i);
  init(NN,mid);
  init(NN,left);
  init(NN,right);
  for(i=0;i<NN;i++){
    leftd=x=query(a[i]-1,left);
    if(!x)
      leftd=1;
    rightd=y=query(a[i]-1,right);
    if(!y)
      rightd=NN;
    z=query(a[i]-1,mid);
    if(z)
      insert_edge(z-1,a[i]-1);
    range_increment(leftd-1,rightd-1,a[i]-z,mid);
    range_increment(a[i]-1,rightd-1,a[i]-x,left);
    range_increment(leftd-1,a[i]-1,a[i]-y,right);
  }
  preprocess();
  decompose(a[NN/2]-1,-1);
  for(i=0;i<NN;i++){
    for(ans=sum[a[i]-1],x=a[i]-1;1;x=parent[x]){
      if(parent[x]==-1)
        break;
      ans+=sum[parent[x]]-con[x]+dist(a[i]-1,parent[x])*(count[parent[x]]-count[x]);
    }
    for(x=a[i]-1;x!=-1;x=parent[x]){
      sum[x]+=dist(a[i]-1,x);
      count[x]++;
      if(parent[x]!=-1)
        con[x]+=dist(a[i]-1,parent[x]);
    }
    printf("%lld\n",aa+=ans);
  }
  return 0;
}
void init( int n ,int *tree){
  N = 1;
  while( N < n ) N *= 2;
  int i;
  for( i = 1; i < N + n; i++ ) tree[i] = 0;
}
void range_increment( int i, int j, int val ,int *tree){
  for( i += N, j += N; i <= j; i = ( i + 1 ) / 2, j = ( j - 1 ) / 2 )
  {
    if( i % 2 == 1 ) tree[i] += val;
    if( j % 2 == 0 ) tree[j] += val;
  }
}
int query( int i ,int *tree){
  int ans = 0,j;
  for( j = i + N; j; j /= 2 ) ans += tree[j];
  return ans;
}
void insert_edge(int x,int y){
  lnode *t=malloc(sizeof(lnode));
  t->x=y;
  t->next=table[x];
  table[x]=t;
  t=malloc(sizeof(lnode));
  t->x=x;
  t->next=table[y];
  table[y]=t;
  return;
}
void dfs0(int u){
  lnode *x;
  for(x=table[u];x;x=x->next)
    if(x->x!=DP[0][u]){
      DP[0][x->x]=u;
      level[x->x]=level[u]+1;
      dfs0(x->x);
    }
  return;
}
void preprocess(){
  int i,j;
  level[a[0]-1]=0;
  DP[0][a[0]-1]=a[0]-1;
  dfs0(a[0]-1);
  for(i=1;i<18;i++)
    for(j=0;j<NN;j++)
      DP[i][j] = DP[i-1][DP[i-1][j]];
  return;
}
int lca(int a,int b){
  int i;
  if(level[a]>level[b]){
    i=a;
    a=b;
    b=i;
  }
  int d = level[b]-level[a];
  for(i=0;i<18;i++)
    if(d&(1<<i))
      b=DP[i][b];
  if(a==b)return a;
  for(i=17;i>=0;i--)
    if(DP[i][a]!=DP[i][b])
      a=DP[i][a],b=DP[i][b];
  return DP[0][a];
}
int dist(int u,int v){
  return level[u] + level[v] - 2*level[lca(u,v)];
}
void dfs1(int u,int p){
  sub[u]=1;
  nn++;
  lnode *x;
  for(x=table[u];x;x=x->next)
    if(x->x!=p && !cut[x->x]){
      dfs1(x->x,u);
      sub[u]+=sub[x->x];
    }
  return;
}
int dfs2(int u,int p){
  lnode *x;
  for(x=table[u];x;x=x->next)
    if(x->x!=p && sub[x->x]>nn/2 && !cut[x->x])
      return dfs2(x->x,u);
  return u;
}
void decompose(int root,int p){
  nn=0;
  dfs1(root,root);
  int centroid = dfs2(root,root);
  parent[centroid]=p;
  cut[centroid]=1;
  lnode *x;
  for(x=table[centroid];x;x=x->next)
    if(!cut[x->x])
      decompose(x->x,centroid);
  return;
}


Problem solution in JavaScript programming.

'use strict';


function update(n) {
    var w = 0;
    var c = 1;

    while (n != null) {
        if (n.dir == 1) {
            w += n.w2 + n.c2 * c;
            n.w1 += c;
            n.c1++;
        } else if (n.dir == 2) {
            w += n.w1 + n.c1 * c;
            n.w2 += c;
            n.c2++;
        }

        c++;
        n.dir = 0;
        n = n.p;
    }

    return w;
}


function processData(input) {
    var parse_fun = function (s) { return parseInt(s, 10); };

    var lines = input.split('\n');
    var N = parse_fun(lines.shift());
    var A = lines.shift().split(' ').splice(0, N).map(parse_fun);

    var res = new Array(N);
    for (var i = 0; i < N; i++) {
        res[i] = 0;
    }

    var w = 0;
    var root = { p: null, v: A[0], n1: null, w1: 0, c1: 1, n2: null, w2: 0, c2: 1, dir: 0 };
    res[0] = 0;

    for (var i = 1; i < N; i++) {
        var v = A[i];
        var n = root;
        var last = null;

        while (n != null) {
            if (v == n.v) {
                last = null;
                n = null;
                break;
            }

            last = n;
            if (v < n.v) {
                n.dir = 1;
                if (n.n1 == null) {
                    n.n1 = { p: n, v: v, n1: null, w1: 0, c1: 1, n2: null, w2: 0, c2: 1, dir: 0 };
                    n = null;
                } else {
                    n = n.n1;
                }
            } else { // v > n.v
                n.dir = 2;
                if (n.n2 == null) {
                    n.n2 = { p: n, v: v, n1: null, w1: 0, c1: 1, n2: null, w2: 0, c2: 1, dir: 0 };
                    n = null;
                } else {
                    n = n.n2;
                }
            }
        }

        w += update(last);
        res[i] = w;
    }

    var out = '';
    for (var i = 0; i < N; i++) {
        out += res[i].toString(10) + '\n';
    }
    process.stdout.write(out);
}


process.stdin.resume();
process.stdin.setEncoding("ascii");
var _input = "";
process.stdin.on("data", function (input) { _input += input; });
process.stdin.on("end", function () { processData(_input); });