Header Ad

HackerEarth Random subsets on a tree problem solution

In this HackerEarth Random Subset on a Tree problem solution You are given a tree that is rooted at 1 and contains N nodes. Each node u has a value assigned to it that is denoted as Val(u). A subset of nodes is selected in a random manner. Any node can be selected with equal probability.

The value of the subset is defined as follows:

Let the lowest common ancestor of these nodes be L. If Val(L) is greater than Val(u) for all u belonging to this subset, then the score of this subset is u. Otherwise, the score is 1. For an empty subset, the score is 0.

You must determine the expected score of the subset. The answer can be represented as P/Q.

Print the answer as P.Q(-1) modulo 10^9 + 7.


HackerEarth Random subsets on a tree problem solution


HackerEarth Random subsets on a tree problem solution.

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;

#define ll long long
#define db long double
#define ii pair<int,int>
#define vi vector<int>
#define fi first
#define se second
#define sz(a) (int)(a).size()
#define all(a) (a).begin(),(a).end()
#define pb push_back
#define mp make_pair
#define FN(i, n) for (int i = 0; i < (int)(n); ++i)
#define FEN(i,n) for (int i = 1;i <= (int)(n); ++i)
#define rep(i,a,b) for(int i=a;i<b;i++)
#define repv(i,a,b) for(int i=b-1;i>=a;i--)
#define SET(A, val) memset(A, val, sizeof(A))
typedef tree<int ,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update>ordered_set ;
// order_of_key (val): returns the no. of values less than val
// find_by_order (k): returns the kth largest element.(0-based)
#define TRACE
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
const char* comma = strchr(names + 1, ','); cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif
const int N=200005,mod=1e9+7;
int add(int x,int y)
{
x+=y;
if(x>=mod) x-=mod;
if(x<0) x+=mod;
return x;
}
int mult(int x,int y)
{
ll tmp=(ll)x*y;
if(tmp>=mod) tmp%=mod;
return tmp;
}
int pow1(int x,int y)
{
int ans=1;
while(y)
{
if(y&1) ans=mult(ans,x);
x=mult(x,x);
y>>=1;
}
return ans;
}
vi v[N];
int child[N],st[N],en[N],timer,rev[N],val[N],tot[N],bit[N],ans[N],curr[N],pw[N];
void dfs1(int u,int par=-1)
{
child[u]=1;
st[u]=timer; rev[timer]=u;
timer++;
int cnt=0;
for(int v1:v[u])
{
if(v1==par) continue;
cnt++;
dfs1(v1,u);
child[u]+=child[v1];
tot[u]=add(tot[u],-pw[child[v1]]);
}
en[u]=timer;
tot[u]=add(tot[u],pw[child[u]]);
tot[u]=add(tot[u],-1);
tot[u]=add(tot[u],cnt);
}
int query(int x)
{
int ans=0;
while(x)
{
ans+=bit[x];
x-=(x&(-x));
}
return ans;
}
void update(int x,int c)
{
while(x<N)
{
bit[x]+=c;
x+=(x&(-x));
}
}
void dfs(int u,int par,bool keep)
{
int bigchild=-1,mx=0,cnt=0;
for(int v1:v[u])
if(v1!=par && child[v1]>mx)
mx=child[v1],bigchild=v1;
for(int v1:v[u])
if(v1!=par && v1!=bigchild)
dfs(v1,u,0),cnt++;
int sum=0;
if(bigchild!=-1)
{
cnt++;
dfs(bigchild,u,1);
const int tmp=query(val[u]-1);
curr[u]=add(curr[u],-pw[tmp]);
sum+=tmp;
}
for(int v1:v[u])
{
if(v1==par || v1==bigchild) continue;
int tmp=0;
rep(v2,st[v1],en[v1])
{
update(val[rev[v2]],1);
if(val[rev[v2]]<val[u]) tmp++;
}
curr[u]=add(curr[u],-pw[tmp]); sum+=tmp;
}
curr[u]=add(curr[u],pw[sum]);
curr[u]=add(curr[u],-1);
curr[u]=add(curr[u],cnt);
ans[u]=add(ans[u],curr[u]);
ans[1]=add(ans[1],add(tot[u],-curr[u]));
update(val[u],1);
if(!keep)
{
rep(j,st[u],en[u])
update(val[rev[j]],-1);
}
}
int main()
{
std::ios::sync_with_stdio(false);
cin.tie(NULL) ; cout.tie(NULL) ;
pw[0]=1;
rep(i,1,N) pw[i]=mult(pw[i-1],2);
int n;
cin>>n;
rep(i,1,n)
{
int x,y;
cin>>x>>y;
v[x].pb(y); v[y].pb(x);
}
rep(i,1,n+1) cin>>val[i];
dfs1(1,-1);
dfs(1,-1,0);
int den=pw[n]; den=pow1(den,mod-2);
rep(i,1,n+1) ans[i]=mult(ans[i],den);
int ret=0;
rep(i,1,n+1) ret=add(ret,mult(i,ans[i]));
cout<<ret<<endl;
return 0 ;
}

Post a Comment

0 Comments