Header Ad

HackerEarth Valentina and the Gift Tree problem solution

In this HackerEarth Valentina and the Gift Tree problem solution The Valentina's birthday is coming soon! Her parents love her so much that they are preparing some gifts to give her. Valentina is cute and smart and her parents are planning to set an interesting task for her.

They prepared a tree (a connected graph without cycles) with one gift in each vertex. Vertices are numbered 1 through N and the i-th of them contains a gift with value Gi. A value Gi describes Valentina's happiness if she got a gift from the i-th vertex. All gifts are wrapped so Valentina doesn't know their values.

Note that Gi could be negative and it would mean that Valentina doesn't like the gift (do you remember when your parents gave you clothes or toys not adequate to your interests?).

Let's consider the following process:
  1. Valentina chooses two vertices, A and B (not necessarily distinct).
  2. She unpacks all gifts on the path connecting A and B and thus she sees their values.
  3. She chooses some part of the path connecting A and B. The chosen part must be connected and can't be empty (hence, it must be a path).
  4. Valentina takes all gifts in the chosen part of the path and her total happiness is equal to the sum of the values of taken gifts.

Valentina is smart and for chosen A and B she will choose a part resulting in the biggest total happiness.

In order to maximize her chance to get a good bunch of gifts, parents allow her to ask Q questions, each with two numbers, A and B. For each Valentina's question parents will tell her the answer - the maximum total happiness for chosen A and B.

They noticed that answering Valentina's questions is an interesting problem. They want to know if you are smart enough to correctly answer all questions.


HackerEarth Valentina and the Gift Tree problem solution


HackerEarth Valentina and the Gift Tree problem solution.

import java.io.*;
import java.util.*;
import java.math.*;

public class Offline implements Runnable {

static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
static StringTokenizer st = new StringTokenizer("");

public static String next() {
try {
while (!st.hasMoreTokens()) {
String s = br.readLine();
if (s == null)
return null;
st = new StringTokenizer(s);
}
return st.nextToken();
} catch(Exception e) {
return null;
}
}

public static void main(String[] asda) throws Exception {
new Thread(null, new Offline(), "Offline", 1<<28).start();
}

public void run() {
int N = Integer.parseInt( next() );

g = new ArrayList[N];
for (int k = 0; k < N; k++)
g[k] = new ArrayList<>();

for (int m = 1; m < N; m++) {
int a = Integer.parseInt( next() ) - 1;
int b = Integer.parseInt( next() ) - 1;
g[a].add(b);
g[b].add(a);
}

values = new long[N];
for (int k = 0; k < N; k++) {
values[k] = Integer.parseInt( next() );
}

dad = new int[N];
inTime = new int[N];
outTime = new int[N];
time = 0;
dfs(0, 0);

int Q = Integer.parseInt( next() );
Question[] questions = new Question[Q];

processQuery = new ArrayList[N];
for (int k = 0; k < N; k++)
processQuery[k] = new ArrayList<>();

LCA lca = new LCA(g, 0);
for (int k = 0; k < Q; k++) {
int a = Integer.parseInt( next() ) - 1;
int b = Integer.parseInt( next() ) - 1;
int ancestor = lca.getLCA(a, b);

Query first = null, second = null;

if (a == ancestor) {
first = new Query(ancestor, b);
} else if (b == ancestor) {
first = new Query(ancestor, a);
} else {
first = new Query(ancestor, a);
int[] sons = lca.getPossibleSons(a, b);
for (int to : sons) if (isAncestor(to, b)) {
second = new Query(to, b);
break;
}
if (second == null) {
throw new RuntimeException("son not found");
}
}

processQuery[ first.son ].add(first);
questions[k] = new Question(first, second);

if (second != null) {
processQuery[ second.son ].add(second);
}
}

stree = new STree(N);
node2id = new int[N];
solve(0);

for (int k = 0; k < Q; k++) {
long ans = questions[k].answer();
out.println(ans);
}
//
out.flush();
System.exit(0);
}

/***********************************************/
List<Integer>[] g;
long[] values;
int[] dad;
int time;
int[] inTime;
int[] outTime;

void dfs(int id, int parent) {
inTime[id] = time++;
dad[id] = parent;
for (int to : g[id]) if (to != parent) {
dfs(to, id);
}
outTime[id] = time;
}
boolean isAncestor(int ancestor, int son) {
return inTime[son] >= inTime[ancestor] && outTime[son] <= outTime[ancestor];
}
List<Query>[] processQuery;

STree stree;

// node2id[k] = the id of node k in the stree
int[] node2id;

void solve(int id) {
node2id[id] = stree.add( values[id] );

// answer all queries where id is a son
for (Query query : processQuery[id]) {
int a = node2id[ query.ancestor ];
int b = node2id[id];
query.answer = stree.max(a, b);
}

for (int to : g[id]) if (dad[id] != to) {
solve(to);
}

stree.remove();
}
/***********************************************/


final static long INF = 100000000L*50000*2;
}

class Query {
int ancestor, son;
Node answer = null;
Query(int ancestor, int son) {
this.ancestor = ancestor;
this.son = son;
}
}

// question(a, b) is separated in two queries, one from LCA to a and LCA to b
class Question {
Query first, second;
Question(Query first, Query second) {
this.first = first;
this.second = second;
}
long answer() {
Node ans = first.answer;
if (second != null) {
ans = STree.merge( ans.reverse(), second.answer );
}
return ans.max;
}
}


class STree {
int N;
Node [] data;
int nextIndex;
STree(int size) {
N = 1;
while ( (N <<= 1) < size );

data = new Node [N << 1];
nextIndex = 0;
}

// add a value to the tree in the first available position
int add(long value) {
int id = nextIndex + N;
data[id] = new Node(value);
id >>= 1;
while (id != 0) {
int son = id << 1;
data[id] = merge( data[son], data[son+1] );
id >>= 1;
}
return nextIndex++;
}

// remove the last item in the tree
void remove() {
int id = --nextIndex + N;
data[id] = null;
id >>= 1;
while (id != 0) {
int son = id << 1;
data[id] = merge( data[son], data[son+1] );
id >>= 1;
}
}

// the max contiguous subarray in the interval [L, R]
Node max(int L, int R) {
L += N;
R += N;

Node left = null, right = null;
while ( L <= R ) {
if ( L == R ) {
left = merge(left, data[L]);
break;
}

if ( L % 2 == 1 ) left = merge( left, data[L++] );
if ( R % 2 == 0 ) right = merge( data[R--], right );

L >>= 1; R >>= 1;
}
return merge(left, right);
}

static Node merge(Node left, Node right) {
if ( right == null )
return left;
if ( left == null )
return right;

Node node = new Node();
node.prefixSum = Math.max( left.prefixSum, left.totalSum + right.prefixSum );
node.suffixSum = Math.max( right.suffixSum, right.totalSum + left.suffixSum );
node.totalSum = left.totalSum + right.totalSum;

node.max = Math.max( left.suffixSum + right.prefixSum, left.max );
node.max = Math.max( node.max, right.max );

node.max = Math.max( node.max, node.suffixSum );
node.max = Math.max( node.max, node.prefixSum );

return node;
}
}
class Node {
long suffixSum, prefixSum, totalSum, max;
Node(long x) {
suffixSum = prefixSum = totalSum = max = x;
}
Node() {
this(0);
}
// reverse the node, if the current status is a query for [L,R], reversed is [R, L]
Node reverse() {
Node node = new Node();
node.max = max;
node.totalSum = totalSum;
node.prefixSum = suffixSum;
node.suffixSum = prefixSum;
return node;
}

public String toString() {
return String.format( "p %d, s %d. t %d. m %d", prefixSum, suffixSum, totalSum, max );
}
}


class LCA {
int N;
int logN;
int[] dep;
int[][] go;
List<Integer>[] g;
// 1 based index
LCA(List<Integer> g[], int root) {
this.g = g;
N = g.length;
logN = 1;
while ((1 << logN) <= N)
logN++;

// System.out.println(N + " " + logN);
dep = new int[N];
go = new int[N][logN];

dfs(root, 0, 0);

// Prepare for LCA queries.
for (int k = 1; k < logN; ++k) {
for (int i = 1; i < N; ++i) {
go[i][k] = go[go[i][k - 1]][k - 1];
}
}
}

private void dfs(int id, int depth, int dad) {
go[id][0] = dad;
dep[id] = depth;

for (int to : g[id]) {
if ( to != dad )
dfs(to, depth + 1, id);
}
}

int getLCA(int u, int v) {
if (dep[u] < dep[v]) {
int aux = u;
u = v;
v = aux;
}

int diff = dep[u] - dep[v];
for (int i = 0; diff != 0; ++i, diff >>= 1) {
if ( (diff & 1) != 0 ) {
u = go[u][i];
}
}

if (u == v)
return u;

for (int i = logN - 1; i >= 0; --i) {
if (go[u][i] != go[v][i]) {
u = go[u][i];
v = go[v][i];
}
}
return go[u][0];
}
int[] getPossibleSons(int u, int v) {
if (dep[u] < dep[v]) {
int aux = u;
u = v;
v = aux;
}

int diff = dep[u] - dep[v];
for (int i = 0; diff != 0; ++i, diff >>= 1) {
if ( (diff & 1) != 0 ) {
u = go[u][i];
}
}

for (int i = logN - 1; i >= 0; --i) {
if (go[u][i] != go[v][i]) {
u = go[u][i];
v = go[v][i];
}
}
return new int[] {u,v};
}
}

Second solution

#include<bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(int i = (a); i <= (b); ++i)
#define RI(i,n) FOR(i,1,(n))
#define REP(i,n) FOR(i,0,(n)-1)
#define pb push_back
typedef long long ll;
const int inf = 1e9 + 5;
const int K = 19;
const int nax = 1 << K;

vector<int> w[nax];
int par[nax][K], val[nax], h[nax];

int find_lca(int a, int b) {
if(h[a] < h[b]) swap(a, b);
for(int i = K-1; i >= 0; --i)
if(h[par[a][i]] >= h[b])
a = par[a][i];
if(a == b) return a;
for(int i = K-1; i >= 0; --i)
if(par[a][i] != par[b][i]) {
a = par[a][i];
b = par[b][i];
}
assert(a != b);
assert(par[a][0] == par[b][0]);
return par[a][0];
}

void one(int a) {
for(int b : w[a]) if(b != par[a][0]) {
h[b] = h[a] + 1;
par[b][0] = a;
one(b);
}
}

struct interval {
ll left, right, inside, total;
interval(ll x = -inf) {
total = left = right = inside = x;
}
interval operator * (interval b) {
interval ans;
ans.left = max(left, total + b.left);
ans.right = max(b.right, right + b.total);
ans.inside = max(max(inside, b.inside), right + b.left);
ans.total = total + b.total;
return ans;
}
} tr[2 * nax];

interval merge(int from) {
interval ans = tr[nax + from];
for(int i = nax + from; i > 1; i /= 2)
if(i % 2 == 0)
ans = ans * tr[i+1];
return ans;
}

vector<int> query[nax];
vector<interval> ANSWER[nax];
int lca[nax];

void two(int a) {
static int sz = 1;
assert(sz == h[a]);
tr[nax + sz] = interval(val[a]);
for(int i = (nax + sz) / 2; i > 0; i /= 2)
tr[i] = tr[2*i] * tr[2*i+1];
++sz;

for(int id : query[a]) ANSWER[id].pb(merge(h[lca[id]]));

for(int b : w[a]) if(b != par[a][0]) two(b);

tr[nax + sz] = interval();
for(int i = (nax + sz) / 2; i > 0; i /= 2)
tr[i] = tr[2*i] * tr[2*i+1];
--sz;
}

int main() {
int n;
scanf("%d", &n);
REP(i, n - 1) {
int a, b;
scanf("%d%d", &a, &b);
w[a].pb(b);
w[b].pb(a);
}
h[1] = 1;
one(1);
RI(j, K-1) RI(a, n) par[a][j] = par[ par[a][j-1] ][j-1];
RI(i, n) scanf("%d", &val[i]);
int q;
scanf("%d", &q);
RI(i, q) {
int a, b;
scanf("%d%d", &a, &b);
lca[i] = find_lca(a, b);
query[a].pb(i);
query[b].pb(i);
}
two(1);
RI(a, q) {
ll r = max(ANSWER[a][0].inside, ANSWER[a][1].inside);
r = max(r, ANSWER[a][0].left + ANSWER[a][1].left - val[lca[a]]);
printf("%lld\n", r);
}
return 0;
}

Post a Comment

0 Comments