Header Ad

HackerRank Counting On a Tree problem solution

In this HackerRank Counting On a Tree problem, we have given t and q queries, process each query in order, printing the pair count for each query on a new line.

HackerRank Counting On a Tree problem solution


Problem solution in Java Programming.

import java.io.*;
import java.util.*;

public class Solution {

  static int[] nxt;
  static int[] succ;
  static int[] ptr;
  static int index = 1;

  static void addEdge(int u, int v) {
    nxt[index] = ptr[u];
    ptr[u] = index;
    succ[index++] = v;
  }

  static int timer = 0;
  static int[] tin;
  static int[] tout;
  static int[] pr;
  static int[][] up;
  static int[] w;

  static class NodeDfs {
    int v;
    int lvl;
    int p;
    boolean start = true;

    public NodeDfs(int v, int lvl, int p) {
      this.v = v;
      this.lvl = lvl;
      this.p = p;
    }
  }

  static void dfs() {
    Deque<NodeDfs> q = new LinkedList<>();
    q.add(new NodeDfs(1, 0, 1));
    while (!q.isEmpty()) {
      NodeDfs node = q.peekLast();
      if (node.start) {
        tin[node.v] = timer;
        w[node.v] = node.lvl;
        up[node.v][0] = node.p;
        for (int i = 1; i <= 17; i++) {
          up[node.v][i] = up[up[node.v][i - 1]][i - 1];
        }
        for (int i = ptr[node.v]; i > 0; i = nxt[i]) {
          int to = succ[i];
          if (to != node.p) {
            q.add(new NodeDfs(to, node.lvl + 1, node.v));
            pr[to] = node.v;
          }
        }
        node.start = false;
      } else {
        tout[node.v] = timer++;
        q.removeLast();
      }
    }

  }

  static boolean upper(int x, int y) {
    return tout[x] >= tout[y] && tin[x] <= tin[y];
  }

  static int lca(int a, int b) {
    if (upper(a, b))
      return a;
    if (upper(b, a))
      return b;
    for (int i = 17; i >= 0; --i)
      if (!upper(up[a][i], b))
        a = up[a][i];
    return up[a][0];
  }

  static void normalize(int[] a) {
    Map<Integer, Integer> trans = new HashMap<>();
    int j = 0;
    for (int i = 1; i < a.length; i++) {
      if (!trans.containsKey(a[i])) {
        trans.put(a[i], j++);
      }
    }
    for (int i = 1; i < a.length; i++) {
      a[i] = trans.get(a[i]);
    }
  }

  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int n = Integer.parseInt(st.nextToken());
    int m = Integer.parseInt(st.nextToken());

    int[] a = new int[n + 1];
    st = new StringTokenizer(br.readLine());
    for (int i = 1; i <= n; i++) {
      a[i] = Integer.parseInt(st.nextToken());
    }
    normalize(a);

    nxt = new int[n + 1];
    succ = new int[n + 1];
    ptr = new int[n + 1];

    for (int i = 0; i < n - 1; i++) {
      st = new StringTokenizer(br.readLine());
      int u = Integer.parseInt(st.nextToken());
      int v = Integer.parseInt(st.nextToken());
      if (u < v) {
        addEdge(u, v);
      } else {
        addEdge(v, u);
      }
    }

    tin = new int[n + 1];
    tout = new int[n + 1];
    pr = new int[n + 1];
    up = new int[n + 1][20];
    w = new int[n + 1];

    dfs();

    int[] d = new int[m];
    int[] X1 = new int[m];
    int[] Y1 = new int[m];
    int[] X2 = new int[m];
    int[] Y2 = new int[m];

    for (int i = 0; i < m; i++) {
      st = new StringTokenizer(br.readLine());
      int x1 = Integer.parseInt(st.nextToken());
      int y1 = Integer.parseInt(st.nextToken());
      int x2 = Integer.parseInt(st.nextToken());
      int y2 = Integer.parseInt(st.nextToken());
      X1[i] = x1;
      Y1[i] = y1;
      X2[i] = x2;
      Y2[i] = y2;

      int lc1 = lca(x1, y1);
      int lc2 = lca(x2, y2);
      int ans1 = 0;
      if (lc1 == lc2) {

        int lc3 = lca(x1, x2);
        int lc4 = lca(x1, y2);
        int lc5 = lca(y1, x2);
        int lc6 = lca(y1, y2);
        ans1++;
        ans1 += w[lc3] - w[lc1];
        ans1 += w[lc4] - w[lc1];
        ans1 += w[lc5] - w[lc1];
        ans1 += w[lc6] - w[lc1];
      } else if (w[lc1] < w[lc2]) {

        int lc3 = lca(x1, x2);
        int lc4 = lca(x1, y2);
        int lc5 = lca(y1, x2);
        int lc6 = lca(y1, y2);
        if (upper(lc2, x1) && upper(lc1, lc2)) {
          ans1 += Math.abs(w[lc3] - w[lc4]) + 1;
        }
        if (upper(lc2, y1) && upper(lc1, lc2)) {
          ans1 += Math.abs(w[lc5] - w[lc6]) + 1;
        }
      } else if (w[lc1] > w[lc2]) {

        int lc3 = lca(x1, x2);
        int lc4 = lca(x1, y2);
        int lc5 = lca(y1, x2);
        int lc6 = lca(y1, y2);
        if (upper(lc1, x2) && upper(lc2, lc1)) {
          ans1 += Math.abs(w[lc3] - w[lc5]) + 1;
        }
        if (upper(lc1, y2) && upper(lc2, lc1)) {
          ans1 += Math.abs(w[lc4] - w[lc6]) + 1;
        }
      }
      d[i] = ans1;
    }

    int[] b = new int[n + 1];

    for (int i = 0; i < m; i++) {
      int x1 = X1[i];
      int y1 = Y1[i];
      int x2 = X2[i];
      int y2 = Y2[i];
      int x3 = x1;
      int y3 = y1;
      while (!upper(x1, y1)) {
        ++b[a[x1]];
        x1 = pr[x1];
      }
      while (y1 != x1) {
        ++b[a[y1]];
        y1 = pr[y1];
      }
      b[a[x1]]++;

      int ans = 0;
      while (!upper(x2, y2)) {
        ans += b[a[x2]];
        x2 = pr[x2];
      }
      while (y2 != x2) {
        ans += b[a[y2]];
        y2 = pr[y2];
      }
      ans += b[a[x2]];

      int tmp = x1;
      x1 = x3;
      y1 = y3;
      while (x1 != tmp) {
        b[a[x1]] = 0;
        x1 = pr[x1];
      }
      while (y1 != tmp) {
        b[a[y1]] = 0;
        y1 = pr[y1];
      }
      b[a[x1]] = 0;
      bw.write((ans - d[i]) + "\n");
    }

    bw.newLine();

    bw.close();
    br.close();
  }
}


Problem solution in C++14 programming.

#include <cstdlib>
#include <cstdio>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <set>
#include <map>
#include <cstring>
#include <cassert>

using namespace std;

typedef long long LL;
typedef unsigned long long ULL;

#define SIZE(x) (int((x).size()))
#define rep(i,l,r) for (int i=(l); i<=(r); i++)
#define repd(i,r,l) for (int i=(r); i>=(l); i--)
#define rept(i,c) for (__typeof((c).begin()) i=(c).begin(); i!=(c).end(); i++)

#ifndef ONLINE_JUDGE
#define debug(x) { cerr<<#x<<" = "<<(x)<<endl; }
#else
#define debug(x) {}
#endif

#define maxn 100010
#define LIM 100

int ta[maxn];

void ta_modify(int x, int y)
{
    while (x<maxn) ta[x]+=y, x+=x&-x;
}

int ta_query(int x)
{
    int ret=0;
    while (x) ret+=ta[x], x-=x&-x;
    return ret;
}

void ds_modify(int l, int r, int c)
{
    ta_modify(l,c);
    ta_modify(r+1,-c);
}

int ds_query(int v)
{
    return ta_query(v);
}

int dfsN;
int dfsLeft[maxn], dfsRight[maxn];
int lg2[maxn], p[maxn][17], depth[maxn];
vector<int> e[maxn];

void dfs(int cur, int pre, int dep)
{
    dfsN++; dfsLeft[cur]=dfsN;
    depth[cur]=dep;
    p[cur][0]=pre;
    rep(i,1,lg2[dep]) p[cur][i]=p[p[cur][i-1]][i-1];
    rept(it,e[cur]) if (*it!=pre) dfs(*it,cur,dep+1);
    dfsRight[cur]=dfsN;
}

int movedep(int x, int y)
{
    if (y<0) return 0;
    while (y) x=p[x][lg2[y&-y]], y-=y&-y;
    return x;
}

int lca(int x, int y)
{
    if (depth[x]<depth[y]) swap(x,y);
    x=movedep(x,depth[x]-depth[y]);
    repd(i,16,0)
        if (p[x][i]!=p[y][i])
        {
            x=p[x][i]; y=p[y][i];
        }
    if (x==y) return x;
    return p[x][0];
}

int get_dist(int x, int y)
{
    int z=lca(x,y);
    return depth[x]+depth[y]-2*depth[z]+1;
}

int all, ti[5][2];

void check_intersect(int p1, int p2, int q1, int q2)
{
    if (depth[p2]>depth[q2])
    {
        swap(p1,q1); swap(p2,q2);
    }
    if (depth[p1]<depth[q2]) return;
    if (lca(p1,q2)!=q2 || lca(q2,p2)!=p2) return;
    int z=lca(p1,q1);
    rep(i,1,all) if (ti[i][0]==z && ti[i][1]==q2) return;
    rep(i,1,all) if (ti[i][1]==z && ti[i][0]==q2) return;
    //if (z==q2) rep(i,1,all) if (ti[i][0]==z || ti[i][1]==z) return;
    all++; ti[all][0]=z; ti[all][1]=q2;
}

struct tasktype
{
    int x, y, c;
    tasktype() {}
    tasktype(int x, int y, int c): x(x), y(y), c(c) {}
};

vector<tasktype> eventAddList[maxn], eventQueryList[maxn];

void addQueryEvent(int i, int p1, int q1, int c)
{
    p1=dfsLeft[p1]; q1=dfsLeft[q1];
    eventQueryList[p1].push_back(tasktype(q1,i,c));
}

void addContributionEvent(int p1, int p2, int q1, int q2)
{
    eventAddList[p1].push_back(tasktype(q1,q2,1));
    eventAddList[p2+1].push_back(tasktype(q1,q2,-1));
}

void add_task(int i, int p1, int p2, int q1, int q2)
{
    if (!p1 || !p2 || !q1 || !q2) return;
    addQueryEvent(i,p1,q1,1);
    if (p[q2][0]) addQueryEvent(i,p1,p[q2][0],-1);
    if (p[p2][0]) addQueryEvent(i,p[p2][0],q1,-1);
    if (p[p2][0] && p[q2][0]) addQueryEvent(i,p[p2][0],p[q2][0],1);
}

map<int, vector<int> > clist;
int color[maxn];
int q[maxn][6];
int ans[maxn];

void lemon()
{
    lg2[1]=0; rep(i,2,maxn-1) lg2[i]=lg2[i>>1]+1;
    int n,qa; scanf("%d%d",&n,&qa);
    rep(i,1,n) 
    {
        scanf("%d",&color[i]);
        clist[color[i]].push_back(i);
    }
    rep(i,1,n-1)
    {
        int x,y; scanf("%d%d",&x,&y);
        e[x].push_back(y);
        e[y].push_back(x);
    }
    rep(i,1,qa)
    {
        scanf("%d%d%d%d",&q[i][0],&q[i][1],&q[i][2],&q[i][3]);
    }
    dfsN=0;
    dfs(1,0,0);
    rep(i,1,qa) ans[i]=0;
    rep(i,1,qa)
    {
        all=0;
        int z1=lca(q[i][0],q[i][1]);
        int z2=lca(q[i][2],q[i][3]);
        
        q[i][4]=z1; q[i][5]=z2;
        
        check_intersect(q[i][0],z1,q[i][2],z2);
        check_intersect(q[i][0],z1,q[i][3],z2);
        check_intersect(q[i][1],z1,q[i][2],z2);
        check_intersect(q[i][1],z1,q[i][3],z2);
        
        int t1=movedep(q[i][1],depth[q[i][1]]-depth[z1]-1);
        int t2=movedep(q[i][3],depth[q[i][3]]-depth[z2]-1);
        
        add_task(i,q[i][0],z1,q[i][2],z2);
        add_task(i,q[i][0],z1,q[i][3],t2);
        add_task(i,q[i][1],t1,q[i][2],z2);
        add_task(i,q[i][1],t1,q[i][3],t2);
        
        if (all>0)
        {
            rep(k,1,all)
                ans[i]-=get_dist(ti[k][0],ti[k][1]);
                
            ans[i]+=all-1;
        }
    }
    
    rept(it,clist)
    {
        int cl=it->first;
        if (it->second.size()<=LIM)
        {
            int s=it->second.size();
            rep(i,0,s-1)
                rep(j,0,s-1)
                {
                    int i1=it->second[i], j1=it->second[j];
                    addContributionEvent(dfsLeft[i1], dfsRight[i1], dfsLeft[j1], dfsRight[j1]);
                }
        }
        else
        {
            rept(it2,it->second)
                ds_modify(dfsLeft[*it2],dfsRight[*it2],1);
                
            rep(i,1,qa)
            {
                int x1=ds_query(dfsLeft[q[i][0]])+ds_query(dfsLeft[q[i][1]])-2*ds_query(dfsLeft[q[i][4]]);
                if (color[q[i][4]]==cl) x1++;
                //printf("%d: %d %d %d\n",cl,q[i][0],q[i][1],x1);
                int x2=ds_query(dfsLeft[q[i][2]])+ds_query(dfsLeft[q[i][3]])-2*ds_query(dfsLeft[q[i][5]]);
                if (color[q[i][5]]==cl) x2++;
                //printf("%d: %d %d %d\n",cl,q[i][2],q[i][3],x2);
                ans[i]+=x1*x2;
            }
            
            rept(it2,it->second)
                ds_modify(dfsLeft[*it2],dfsRight[*it2],-1);
            
        }
    }
    
    rep(i,1,n)
    {
        rept(it,eventAddList[i]) ds_modify(it->x,it->y,it->c);
        rept(it,eventQueryList[i]) ans[it->y]+=it->c*ds_query(it->x);
    }
    
    rep(i,1,qa) printf("%d\n",ans[i]);
}

int main()
{
    ios::sync_with_stdio(true);
    #ifndef ONLINE_JUDGE
        //freopen("8.in","r",stdin);
    #endif
    lemon();
    return 0;
}


Problem solution in C programming.

#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>


#define floor_log2_X86(self) (__builtin_clz(self) ^ 31U)
#define floor_log2 floor_log2_X86

void heap_sort(unsigned *self, unsigned *weights, unsigned length) {
    unsigned
        at = length >> 1,
        member,
        node;

    for (self--; at; self[node >> 1] = member) {
        member = self[at];

        for (node = at-- << 1; node <= length; node <<= 1) {
            node |= (node < length) && (weights[self[node]] < weights[self[node | 1]]);
            if (weights[self[node]] < weights[member])
                break ;
            self[node >> 1] = self[node];
        }
    }
    for (; length > 1; self[at >> 1] = member) {
        member = self[length];
        self[length--] = self[1];

        for (at = 2; at <= length; at <<= 1) {
            at |= (at < length) && (weights[self[at]] < weights[self[at | 1]]);
            if (weights[self[at]] < weights[member])
                break ;
            self[at >> 1] = self[at];
        }
    }
}

void compress(unsigned length, unsigned values[length]) {
    unsigned
        at,
        order[length];

    unsigned long sum = 0x0000000100000000UL;
    for (at = 0; at < (length >> 1); sum += 0x0000000200000002UL)
        ((unsigned long *)order)[at++] = sum;
    order[length - 1] = length - 1;

    heap_sort(order, values, length);

    unsigned roots[length], seen = 1, max = 0, others;
    for (roots[at = 0] = -1U; at < length; roots[seen++] = at - 1) {
        for (others = at; (at < length) && values[order[at]] == values[order[others]]; at++);

        if (max < (at - others))
            max = (at - others);
    }

    unsigned
        indices[max + 1],
        ranks[seen];

    memset(indices, 0, sizeof(indices));
    for (at = 0; ++at < seen; indices[roots[at] - roots[at - 1]]++);
    for (at = max; at--; indices[at] += indices[at + 1]);
    for (at = seen; --at; ranks[--indices[roots[at] - roots[at - 1]]] = at);

    for (; at < (seen - 1); at++)
        for (others = roots[ranks[at] - 1]; ++others <= roots[ranks[at]]; values[order[others]] = at);
}


static inline unsigned nearest_common_ancestor(
    unsigned depth,
    unsigned base_cnt,
    unsigned vertex_cnt,
    unsigned base_ids[vertex_cnt],
    unsigned bases[base_cnt][depth],
    unsigned char depths[base_cnt],
    unsigned weights[vertex_cnt],
    unsigned lower,
    unsigned upper
) {
    if (upper < (lower + weights[lower]))
        return lower;

    if (depths[upper] > depths[lower])
        upper = bases[base_ids[upper]][depths[upper] - depths[lower] - 1];

    if (upper < lower)
        return upper;

    unsigned *others = bases[base_ids[upper]];
    for (; depth > 1; depth >>= 1)
        if (others[depth >> 1] > lower) {
            others += depth >> 1;
            depth += depth & 1U;
        }

    return others[others[0] > lower];
}

typedef union {
    unsigned long packd;
    struct {
        int low, high;
    };
} range_t;

typedef struct {
    unsigned
        *members,
        *colors,
        *indices,
        *locations;
} colored_tree_t;


unsigned long query_all(colored_tree_t *self, unsigned at, range_t other) {
    unsigned
        color = self->colors[at],
        length = self->indices[color + 1] - self->indices[color],
        *base = &self->members[self->indices[color]];

    if (other.high < base[0] || other.low > base[length - 1])
        return 0;

    if (self->colors[other.low] != color) {
        if (at < other.low) {
            base += self->locations[at] - self->indices[color];
            length = self->indices[color + 1] - self->locations[at];
        } else
            length = self->locations[at] - self->indices[color]; // at > other.low

        for (; length > 1; length >>= 1)
            if (base[length >> 1] < other.low) {
                base += length >> 1;
                length += length & 1;
            }

        base += (base[0] < other.low);
    } else
        base += (self->locations[other.low] - self->indices[color]);

    if (base[0] > other.high)
        return 0;

    unsigned *ceil;
    if (self->colors[other.high] != color) {
        ceil = (at > base[0] && at < other.high) ? &self->members[self->locations[at]] : base;

        for (length = self->indices[color + 1] - self->locations[ceil[0]]; length > 1; length >>= 1)
            if (ceil[length >> 1] <= other.high) {
                ceil += length >> 1;
                length += length & 1;
            }

        ceil -= (ceil[0] > other.high);
    } else
        ceil = &self->members[self->locations[other.high]];


    return ceil - base + 1 - (at >= other.low && at <= other.high);
}

unsigned long count_pairs(
    unsigned cnt,
    unsigned length,
    unsigned long pairs[cnt][cnt],
    unsigned *overlapping,
    colored_tree_t *tree,
    range_t self,
    range_t other
) {
    unsigned long count = 0;
    for (; (self.low % length) && (self.low <= self.high); count += query_all(tree, self.low++, other));
    for (; ((self.high + 1) % length) && (self.low <= self.high); count += query_all(tree, self.high--, other));

    if (self.low <= self.high) {
        for (; (other.low % length) && (other.low <= other.high); count += query_all(tree, other.low++, self));
        for (; ((other.high + 1) % length) && (other.low <= other.high); count += query_all(tree, other.high--, self));

        if (other.low <= other.high) {
            self.low   /= length;
            self.high  /= length;
            other.low  /= length;
            other.high /= length;

            if (self.low > other.low) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;
            }

            unsigned high = (self.high < other.low) ? self.high : (other.low - 1);

            count +=
                pairs[high][other.high]
                    - pairs[high][other.low - 1UL]
                    - pairs[self.low - 1UL][other.high]
                    + pairs[self.low - 1UL][other.low - 1UL];

            self.low = high + 1;

            if (self.high > other.high) {
                self.packd  ^= other.packd;
                other.packd ^= self.packd;
                self.packd  ^= other.packd;
            }

            if (self.low <= self.high)
                count +=
                    (overlapping[self.high] - overlapping[self.low - 1UL])
                        + ((
                        pairs[self.high][self.high]
                            - pairs[self.high][self.low - 1UL]
                            - pairs[self.low - 1UL][self.high]
                            + pairs[self.low - 1UL][self.low - 1UL]
                    ) << 1) + (
                        pairs[self.high][other.high]
                            - pairs[self.high][self.high]
                            - pairs[self.low - 1UL][other.high]
                            + pairs[self.low - 1UL][self.high]
                    );
        }
    }

    return count;
}

int main() {
    unsigned at, vertex_cnt;
    unsigned short query_cnt;
    scanf("%u %hu", &vertex_cnt, &query_cnt);

    unsigned colors[vertex_cnt + 1];
    for (at = 0; at < vertex_cnt; scanf("%u", &colors[at++]));
    colors[at] = 0xFFFFFFFFU;
    compress(at + 1, colors);

    unsigned ancestors[at + 1];
    {
        unsigned ancestor, descendant;
        for (memset(ancestors, 0xFFU, sizeof(ancestors)); --at; ancestors[descendant] = ancestor) {
            scanf("%u %u", &ancestor, &descendant);
            --ancestor;
            if (ancestors[--descendant] != 0xFFFFFFFFU) {
                unsigned root = descendant, next;
                for (; ancestor != 0xFFFFFFFFU; ancestor = next) {
                    next = ancestors[ancestor];
                    ancestors[ancestor] = root;
                    root = ancestor;
                }
                for (; ancestors[descendant] != 0xFFFFFFFFU; descendant = next) {
                    next = ancestors[descendant];
                    ancestors[descendant] = ancestor;
                    ancestor = descendant;
                }
            }
        }

        for (ancestor = 0xFFFFFFFFU; at != 0xFFFFFFFFU; at = descendant) {
            descendant = ancestors[at];
            ancestors[at] = ancestor;
            ancestor = at;
        }
    }

    unsigned
        others,
        ids[vertex_cnt + 1],
        weights[vertex_cnt],
        bases[vertex_cnt + 1],
        history[vertex_cnt];

    unsigned char
        base_depths[vertex_cnt],
        dist = 0;

    {
        unsigned
            history[vertex_cnt],
            indices[vertex_cnt + 1],
            descendants[vertex_cnt];

        memset(indices, 0, sizeof(indices));
        for (ancestors[vertex_cnt] = (at = vertex_cnt); at; indices[ancestors[at--]]++);
        for (; ++at <= vertex_cnt; indices[at] += indices[at - 1]);
        for (; --at; descendants[--indices[ancestors[at]]] = at);

        history[0] = 0;
        memset(weights, 0, sizeof(weights));
        for (at = 1; at--; )
            if (weights[history[at]])
                for (others = indices[history[at]];
                     others < indices[history[at] + 1];
                     weights[history[at]] += weights[descendants[others++]]);
            else {
                weights[history[at]] = 1;
                memcpy(
                    &history[at + 1],
                    &descendants[indices[history[at]]],
                    (indices[history[at] + 1] - indices[history[at]]) * sizeof(descendants[0])
                );
                at += indices[history[at] + 1] - indices[history[at]] + 1;
            }

        unsigned
            orig_ancestors[vertex_cnt + 1],
            orig_colors[vertex_cnt + 1],
            orig_weights[vertex_cnt];

        memcpy(orig_ancestors, ancestors, sizeof(ancestors));
        memcpy(orig_weights, weights, sizeof(weights));
        memcpy(orig_colors, colors, sizeof(colors));

        base_depths[0] = (bases[0] = (ids[0] = 0));
        bases[vertex_cnt] = (ids[vertex_cnt] = vertex_cnt);
        for (at = 1; at--;) {
            unsigned
                id = ids[history[at]],
                base = bases[id++],
                branches = indices[history[at] + 1] - indices[history[at]];

            heap_sort(&descendants[indices[history[at]]], orig_weights, branches);
            memcpy(&history[at], &descendants[indices[history[at]]], branches * sizeof(descendants[0]));

            for (others = (at += branches); branches--; base = id) {
                ids[history[--others]] = id;

                ancestors[id] = ids[orig_ancestors[history[others]]];
                weights[id] = orig_weights[history[others]];
                colors[id] = orig_colors[history[others]];

                bases[id] = base;
                base_depths[id] = base_depths[ancestors[id]] + (base == id);

                if (dist < base_depths[id])
                    dist = base_depths[id];

                id += weights[id];
            }
        }
    }

    unsigned base_ids[vertex_cnt + 1];
    for (base_ids[0] = (others = (at = 0)); others < vertex_cnt; base_ids[others] = base_ids[at] + 1)
        for (at = others; bases[at] == bases[others]; base_ids[others++] = base_ids[at]);

    unsigned ancestral_bases[base_ids[vertex_cnt]][dist];
    for (ancestors[0] = 0; others--; ancestral_bases[base_ids[others]][0] = ancestors[others]);
    while ((++others + 1) < dist)
        for (at = 0; ++at < base_ids[vertex_cnt];
             ancestral_bases[at][others + 1] = ancestors[bases[ancestral_bases[at][others]]]);

    unsigned
        indexed_colors[colors[vertex_cnt] + 2],
        members[vertex_cnt + 1];

    memset(indexed_colors, 0, sizeof(indexed_colors));
    for (at = vertex_cnt + 1; at--; indexed_colors[colors[at]]++);
    for (; ++at < colors[vertex_cnt]; indexed_colors[at + 1] += indexed_colors[at]);
    for (at = vertex_cnt + 1; at--; members[--indexed_colors[colors[at]]] = at);
    indexed_colors[colors[vertex_cnt] + 1] = indexed_colors[colors[vertex_cnt]];

    unsigned
        levels = floor_log2(vertex_cnt) + 1,
        block_cnt = (vertex_cnt / levels) + 1,
        locations[vertex_cnt + 1],
        overlapping[block_cnt];

    unsigned long (*pairs)[block_cnt][block_cnt] = calloc(
        (1 + block_cnt) * (1 + block_cnt),
        sizeof(pairs[0][0][0])
    );
    pairs = (void *)&pairs[0][1][1];

    for (at = vertex_cnt + 1; at--; locations[members[at]] = at);

    memset(overlapping, 0, sizeof(overlapping));
    for (at = 0; (indexed_colors[at + 1] - indexed_colors[at]) > 1; at++) {
        others = indexed_colors[at];

        unsigned
            block_bases[indexed_colors[at + 1] - others + 1],
            cnt = 1;

        for (block_bases[0] = members[others]; at == colors[members[++others]]; )
            if ((members[others] / levels) != (block_bases[cnt - 1] / levels))
                block_bases[cnt++] = members[others];

        block_bases[cnt] = members[others];
        for (others = 0; others < cnt; others++) {
            unsigned long density = locations[block_bases[others + 1]] - locations[block_bases[others]];
            overlapping[block_bases[others] / levels] += density * (density - 1);

            unsigned block = others;
            for (; ++block < cnt; pairs[0][block_bases[others] / levels][block_bases[block] / levels]
                += density * (locations[block_bases[block + 1]] - locations[block_bases[block]]));
        }
    }

    for (at = 0; ++at < block_cnt; overlapping[at] += overlapping[at - 1])
        pairs[0][0][at] += pairs[0][0][at - 1];

    for (at = 0; ++at < block_cnt; )
        for (others = 0; ++others < block_cnt; pairs[0][at][others] += pairs[0][at][others - 1]);

    for (at = 0; ++at < block_cnt; )
        for (others = 0; others < block_cnt; others++)
            pairs[0][at][others] += pairs[0][at - 1][others];

    colored_tree_t *tree = &(colored_tree_t) {
        .members = members,
        .colors = colors,
        .indices = indexed_colors,
        .locations = locations
    };


    while (query_cnt--) {
        range_t left, right;
        scanf("%u %u %u %u", &left.low, &left.high, &right.low, &right.high);
        left.packd -= 0x0000000100000001UL;
        right.packd -= 0x0000000100000001UL;

        left.low = ids[left.low];
        left.high = ids[left.high];

        right.low = ids[right.low];
        right.high = ids[right.high];

        if (left.high < left.low)
            left.packd = (left.packd << 32) | (left.packd >> 32);

        if (right.high < right.low)
            right.packd = (right.packd << 32) | (right.packd >> 32);

        if (right.high < left.low) {
            left.packd  ^= right.packd;
            right.packd ^= left.packd;
            left.packd  ^= right.packd;
        }

        struct {
            range_t members[32];
            unsigned cnt;
        }
            a = {.cnt = 0},
            b = {.cnt = 0};

        unsigned common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            left.low, left.high
        );

        for (at = left.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            a.members[a.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = left.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            a.members[a.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        a.members[a.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        common = nearest_common_ancestor(
            dist, base_ids[vertex_cnt], vertex_cnt,
            base_ids, ancestral_bases,
            base_depths, weights,
            right.low, right.high
        );

        for (at = right.low; bases[at] != bases[common]; at = ancestral_bases[base_ids[at]][0])
            b.members[b.cnt++].packd = bases[at] | ((unsigned long)at << 32);

        for (others = right.high; bases[others] != bases[common]; others = ancestral_bases[base_ids[others]][0])
            b.members[b.cnt++].packd = bases[others] | ((unsigned long)others << 32);

        b.members[b.cnt++].packd = common | ((unsigned long)((at != common) ? at : others) << 32);

        unsigned long total = 0;
        for (at = 0; at < a.cnt; at++)
            for (others = 0; others < b.cnt;
                 total += count_pairs(
                     block_cnt, levels, pairs[0], overlapping, tree,
                     a.members[at], b.members[others++]
                 )
            );

        printf("%lu\n", total);
    }

    return 0;
}


Post a Comment

0 Comments