Header Ad

HackerEarth Repair this tree problem solution

In this HackerEarth Repair this tree problem solution You are given a rooted tree T of n nodes and a graph G of n nodes. Initially, graph G has no edges. There are two types of queries:
  1. 1 x y w: Add an edge of weight w between nodes x and y in G
  2. 2 v p: Consider that the edge between v and its parent in T is deleted which decomposes T into 2 connected components. You need to find the sum of weights of all edges like x in G such that:
  3. The weight of x is divisible by p
  4. If an edge is added in T between the same endpoints as x then again we will get a tree (it will connect the two components). Note that the edge is deleted only for that particular query. 

HackerEarth Repair this tree problem solution


HackerEarth Repair this tree problem solution.

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 2e5 + 14, lg = 18, maxv = 1e6 + 14;
int n, q, par[lg][maxn], h[maxn], st[maxn], en[maxn];
vector<int> prs[maxv], g[maxn];
struct Node{
Node *L, *R;
ll iman;
Node() : iman(0), L(0), R(0){}
void arpa(){
if(L) return ;
L = new Node();
R = new Node();
}
void add(int p, int x, int l = 0, int r = n){
if(l + 1 == r){
iman += x;
return ;
}
arpa();
int mid = l + r >> 1;
if(p < mid)
L -> add(p, x, l, mid);
else
R -> add(p, x, mid, r);
iman = L -> iman + R -> iman;
}
ll get(int s, int e, int l = 0, int r = n){
if(s <= l && r <= e)
return iman;
if(e <= l || r <= s) return 0;
int mid = l + r >> 1;
return (L ? L -> get(s, e, l, mid) : 0) + (R ? R -> get(s, e, mid, r) : 0);
}
} seg[maxv];
void init(){
bool isp[maxv];
memset(isp, 1, sizeof isp);
for(int i = 2; i < maxv; i++)
if(isp[i])
for(int j = i; j < maxv; j += i)
isp[j] = 0, prs[j].push_back(i);
}
void dfs(int v = 0){
static int t = 0;
st[v] = t++;
for(auto u : g[v])
dfs(u);
en[v] = t;
}
int lca(int v, int u){
if(h[v] > h[u])
swap(v, u);
for(int i = 0; i < lg; i++)
if(h[u] - h[v] >> i & 1)
u = par[i][u];
for(int i = lg - 1; i >= 0; i--)
if(par[i][v] != par[i][u])
v = par[i][v], u = par[i][u];
return v == u ? v : par[0][v];
}
int main(){
ios::sync_with_stdio(0), cin.tie(0);
init();
cin >> n >> q;
for(int i = 1; i < n; i++){
cin >> par[0][i];
par[0][i]--;
h[i] = 1 + h[ par[0][i] ];
g[ par[0][i] ].push_back(i);
}
dfs();
for(int k = 1; k < lg; k++)
for(int i = 0; i < n; i++)
par[k][i] = par[k - 1][ par[k - 1][i] ];
int lastAns = 0;
while(q--){
int t, v, u, w;
cin >> t;
if(t == 1){
cin >> v >> u >> w;
v--;
u--;
v = (v + lastAns) % n;
u = (u + lastAns) % n;
for(auto p : prs[w]){
seg[p].add(st[v], +w);
seg[p].add(st[u], +w);
seg[p].add(st[lca(v, u)], -2 * w);
}
}
else{
cin >> v >> w;
v--;
v = 1 + (v + lastAns) % (n - 1);
cout << (lastAns = seg[w].get(st[v], en[v])) << '\n';
}
}
}

Second solution

#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5 + 100;
const int maxp = 1e6 + 100;
void debug();
vector<int> get_primes(int x);

struct Segment{
int root;
vector<int> left, right, val;

Segment(){
root = 0;
left.push_back(0);
right.push_back(0);
val.push_back(0);
}

int get(int l, int r, int st, int en, int id){
if(l <= st && en <= r)
return val[id];
if(r <= st || en <= l)
return 0;
int mid = (st + en) >> 1;
return get(l, r, st, mid, left[id]) +
get(l, r, mid, en, right[id]);
}

int add(int ind, int x, int st, int en, int id){
val.push_back(val[id]);
left.push_back(left[id]);
right.push_back(right[id]);
id = left.size() - 1;
val[id] += x;

if(en - st < 2)
return id;
int mid = (st + en) >> 1;
if(ind < mid) {
int hlp = add(ind, x, st, mid, left[id]);
left[id] = hlp;
}
else {
int hlp = add(ind, x, mid, en, right[id]);
right[id] = hlp;
}
return id;
}

void add(int ind, int x){
root = add(ind, x, 0, maxn, root);
}

int get(int l, int r){
return get(l, r, 0, maxn, root);
}

void print(){
cerr << "root = " << root << endl;
for(int i = 0; i < left.size(); i++){
cerr << i << ": " << left[i] << ' ' << right[i] << ' ' << val[i] << endl;
}
cerr << endl;
for(int i = 0; i < 10; i++)
cerr << get(i, i + 1) << ' ';
cerr << endl;
}
}segment[maxp];
int n;

namespace tree{
const int maxL = 20;
vector<int> vc[maxn];
int st[maxn], en[maxn], last_index;
int par[maxL][maxn];
int h[maxn];

int get_par(int v, int p){
for(int i = 0; i < maxL; i++)
if((p >> i) & 1)
v = par[i][v];
return v;
}

int LCA(int u, int v){
if(h[u] < h[v])
swap(u, v);
//h[u] >= h[v]
u = get_par(u, h[u] - h[v]);
if(u == v)
return v;
for(int i = maxL - 1; i >= 0; i--)
if(par[i][u] != par[i][v]){
u = par[i][u];
v = par[i][v];
}
return par[0][v];
}

void dfs(int v){
st[v] = last_index++;
for(int u: vc[v]){
h[u] = h[v] + 1;
dfs(u);
}
en[v] = last_index;
}

void initialize(){
dfs(0);
for(int i = 1; i < maxL; i++){
for(int v = 0; v < n; v++)
par[i][v] = par[i - 1][par[i - 1][v]];
}
}

}
void pre_process();

int32_t main(){
pre_process();
//debug();
int q;
cin >> n >> q;
for(int i = 1; i < n; i++){
int p;
cin >> p;
p--;
tree::par[0][i] = p;
tree::vc[p].push_back(i);
}

tree::initialize();
int lastAns = 0;
while(q--) {
int type;
cin >> type;
if(type == 1){
int x, y, w;
cin >> x >> y >> w;
x--, y--;
x = (x + lastAns) % n;
y = (y + lastAns) % n;
int par = tree::LCA(x, y);
//cerr << " par : " << par << endl;
vector<int> primes = get_primes(w);
for(int p: primes){
Segment &seg = segment[p];

//cerr << " add " << p << ' ' << tree::st[par] << ' ' << tree::st[x] << ' ' << tree::st[y] << endl;
seg.add(tree::st[par], -2*w);
seg.add(tree::st[x], +w);
seg.add(tree::st[y], +w);
}
} else if(type == 2){
int v, p;
cin >> v >> p;
v--;
v = 1 + (v + lastAns) % (n - 1);
//cerr << " get " << p << ' ' << tree::st[v] << ' ' << tree::en[v] << endl;
cout << (lastAns = segment[p].get(tree::st[v], tree::en[v])) << '\n';
}
}
}

vector<int> primes[maxp];

void pre_process(){
for(int i = 2; i < maxp; i++)
if(primes[i].empty())
for(int j = i; j < maxp; j += i)
primes[j].push_back(i);
}


vector<int> get_primes(int x){
return primes[x];
}


void debug(){
int n;
cin >> n;
Segment seg;
while(true){
int type;
cin >> type;
if(type == 1){
int ind, x;
cin >> ind >> x;
ind--;
seg.add(ind, x);
} else if(type == 2){
int l, r;
cin >> l >> r;
l--;
cout << seg.get(l, r) << endl;;
} else if(type == 3){
seg.print();
} else {
for(int i = 0; i < n; i++)
cout << seg.get(i, i + 1) << ' ';
cout << endl;
}

}
}

Post a Comment

0 Comments