In this HackerEarth Mojtabas Trees and Arpas Queries March HourStorm problem solution Mojtaba has two trees, each of them has n vertices. Arpa has q queries, each in type v, u, x, y. Let s be the set of vertices in the path from v to u in the first tree and p be the set of vertices in the path from x to y in the second tree. Mojtaba has to calculate the size of s intersects p for each query. Help him!


HackerEarth Mojtabas Trees and Arpas Queries <March HourStorm> problem solution


HackerEarth Mojtabas Trees and Arpas Queries March HourStorm problem solution.

#include <bits/stdc++.h>
using namespace std;

const int maxN = 300 * 1000 + 100;
const int maxL = 20;

typedef pair<int,int> pii;

vector<int> c[2][maxN];
vector<pii> que[maxN];

int st[maxN], en[maxN];
int seg[4*maxN];

int h[2][maxN], par[2][maxN][maxL];

void dfs_lca(int t, int u) {
for(int k = 1; k < maxL; k++)
par[t][u][k] = par[t][par[t][u][k-1]][k-1];

for( auto x : c[t][u] )
if( x != par[t][u][0] ) {
par[t][x][0] = u;
h[t][x] = h[t][u] + 1;
dfs_lca(t, x);
}
}

void dfs_time(int u, int p) {
static int ind = 0;
st[u] = ind++;
for( auto x : c[0][u] )
if( x != p )
dfs_time(x, u);
en[u] = ind;
}

int get_lca(int t, int u, int v) {
if( h[t][u] < h[t][v] ) swap(u, v);

int diff = h[t][u] - h[t][v];
for(int k = 0; k < maxL; k++)
if( (diff>>k) & 1 )
u = par[t][u][k];

if( u == v ) return u;

for(int k = maxL - 1; k >= 0; k--)
if( par[t][u][k] != par[t][v][k] ) {
u = par[t][u][k];
v = par[t][v][k];
}

return par[t][u][0];
}

void query(int i, int u, int v, int x, int y) {
que[x].push_back( pii(i, u) );
que[x].push_back( pii(-i, v) );
que[y].push_back( pii(-i, u) );
que[y].push_back( pii(i, v) );
}

int n;
int ans[maxN];

void seg_add(int ql, int qr, int qv, int xl=0, int xr=n, int ind=1) {
if( xr <= ql || qr <= xl ) return;
if( ql <= xl && xr <= qr ) {
seg[ind] += qv;
return;
}
int xm = (xl+xr)/2;
seg_add(ql, qr, qv, xl, xm, ind * 2);
seg_add(ql, qr, qv, xm, xr, ind * 2 + 1);
}

int seg_get(int qp, int xl=0, int xr=n, int ind=1) {
if( xr - xl == 1 )
return seg[ind];

int xm = (xl+xr)/2;
if( qp < xm )
return seg[ind] + seg_get(qp, xl, xm, ind*2);
return seg[ind] + seg_get(qp, xm, xr, ind*2+1);
}

void dfs_solve(int u, int p) {
seg_add(st[u], en[u], 1);


for(auto q: que[u]) {
int id = abs(q.first);
int v = seg_get(st[q.second]);
if( q.first < 0 )
ans[id] -= v;
else
ans[id] += v;
}

for( auto x : c[1][u] )
if( x != p )
dfs_solve(x, u);

seg_add(st[u], en[u], -1);
}

int main() {
ios::sync_with_stdio(false);
cin.tie(0);

int q;
cin >> n >> q;

for(int t = 0; t < 2; t++) {
c[t][0].push_back(1);
for(int i = 0; i + 1 < n; i++) {
int u, v;
cin >> u >> v;
c[t][u].push_back(v);
c[t][v].push_back(u);
}
}

n++;

dfs_lca(0, 0);
dfs_lca(1, 0);

dfs_time(0, -1);

for(int i = 1; i <= q; i++) {
int u, v, x, y;
cin >> u >> v >> x >> y;

int w = get_lca(0, u, v);
int z = get_lca(1, x, y);


query(i, u, par[0][w][0], x, par[1][z][0]);
query(i, v, w, x, par[1][z][0]);
query(i, v, w, y, z);
query(i, u, par[0][w][0], y, z);
}

dfs_solve(0, -1);

for(int i = 1; i <= q; i++)
cout << ans[i] << '\n';

return 0;
}

Second solution

#include<bits/stdc++.h>
using namespace std;

const int maxn = 3e5 + 17, lg = 19;

struct Q{
int x, y, i, z;
};
int n, q, par[2][lg][maxn], st[maxn], ft[maxn], h[2][maxn], ans[maxn], iman[maxn];
vector<int> g[2][maxn];
vector<Q> assign[maxn];
void make_par(int id, int v = 0){
for(auto u : g[id][v])
if(u != par[id][0][v]){
par[id][0][u] = v;
h[id][u] = h[id][v] + 1;
make_par(id, u);
}
}
void get_st(int v = 0){
static int time = 0;
st[v] = time++;
for(auto u : g[1][v])
if(u != par[1][0][v])
get_st(u);
ft[v] = time;
}
int lca(int id, int v, int u){
if(h[id][v] > h[id][u])
swap(v, u);
for(int i = 0; i < lg; i++)
if(h[id][u] - h[id][v] >> i & 1)
u = par[id][i][u];
for(int i = lg - 1; i >= 0; i--)
if(par[id][i][v] != par[id][i][u])
v = par[id][i][v], u = par[id][i][u];
return v == u ? v : par[id][0][v];
}
int hamid(int p){
int ans = 0;
for(p++; p; p ^= p & -p) ans += iman[p];
return ans;
}
void majid(int p, int v){
for(p++; p < maxn; p += p & -p) iman[p] += v;
}
void majid(int l, int r, int v){
majid(l, v), majid(r, -v);
}
void dfs(int v = 0){
majid(st[v], ft[v], +1);
for(auto q : assign[v]){
int p = lca(1, q.x, q.y);
ans[q.i] += q.z * (hamid(st[q.x]) + hamid(st[q.y]) - hamid(st[p]) - (p ? hamid(st[ par[1][0][p] ]) : 0));
}
for(auto u : g[0][v])
if(u != par[0][0][v])
dfs(u);
majid(st[v], ft[v], -1);
}
int main(){
ios::sync_with_stdio(0), cin.tie(0);
cin >> n >> q;
for(int k = 0; k < 2; k++)
for(int i = 1; i < n; i++){
int v, u;
cin >> v >> u;
v--, u--;
g[k][v].push_back(u);
g[k][u].push_back(v);
}
for(int j = 0; j < 2; j++){
make_par(j);
for(int k = 1; k < lg; k++)
for(int v = 0; v < n; v++)
par[j][k][v] = par[j][k - 1][ par[j][k - 1][v] ];
}
for(int i = 0; i < q; i++){
int v, u, x, y, p;
cin >> v >> u >> x >> y;
v--, u--, x--, y--;
p = lca(0, v, u);
assign[v].push_back({x, y, i, +1});
assign[u].push_back({x, y, i, +1});
assign[p].push_back({x, y, i, -1});
if(p)
assign[ par[0][0][p] ].push_back({x, y, i, -1});
}
get_st();
dfs();
for(int i = 0; i < q; i++)
cout << ans[i] << '\n';
}