In this HackerEarth Minimum distance in Tree problem solution You are given a tree and q queries. Each query consists of ki vertices: v(i,1), ..., v(i,k).

Let fi(u) be the minimum between distances from u to each v(i,j), for 1 <= j <= kj. For each query you have to find value of max(u belongs to V) fi(u).


HackerEarth Minimum distance problem solution


HackerEarth Minimum distance problem solution.

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <vector>
#include <algorithm>
#include <set>
#include <map>
#include <cmath>
#include <ctime>
#include <functional>
#include <sstream>
#include <fstream>
#include <valarray>
#include <complex>
#include <queue>
#include <cassert>
#include <bitset>
using namespace std;

#ifdef LOCAL
#define debug_flag 1
#else
#define debug_flag 0
#endif

template <class T1, class T2 >
std::ostream& operator << (std::ostream& os, const pair<T1, T2> &p)
{
os << "[" << p.first << ", " << p.second << "]";
return os;
}

template <class T >
std::ostream& operator << (std::ostream& os, const std::vector<T>& v)
{
os << "[";
bool first = true;
for (typename std::vector<T>::const_iterator it = v.begin(); it != v.end(); ++it)
{
if (!first)
os << ", ";
first = false;
os << *it;
}
os << "]";
return os;
}

template <class T >
std::ostream& operator << (std::ostream& os, const std::set<T>& v)
{
os << "[";
bool first = true;
for (typename std::set<T>::const_iterator it = v.begin(); it != v.end(); ++it)
{
if (!first)
os << ", ";
first = false;
os << *it;
}
os << "]";
return os;
}

template <class T >
std::ostream& operator << (std::ostream& os, const std::multiset<T>& v)
{
os << "[";
bool first = true;
for (typename std::multiset<T>::const_iterator it = v.begin(); it != v.end(); ++it)
{
if (!first)
os << ", ";
first = false;
os << *it;
}
os << "]";
return os;
}

#define dbg(args...) { if (debug_flag) { _print(_split(#args, ',').begin(), args); cerr << endl; } else { void(0);} }

vector<string> _split(const string& s, char c) {
vector<string> v;
stringstream ss(s);
string x;
while (getline(ss, x, c))
v.emplace_back(x);
return v;
}

void _print(vector<string>::iterator) {}
template<typename T, typename... Args>
void _print(vector<string>::iterator it, T a, Args... args) {
string name = it -> substr((*it)[0] == ' ', it -> length());
if (isalpha(name[0]))
cerr << name << " = " << a << " ";
else
cerr << name << " ";
_print(++it, args...);
}

typedef long long int int64;

const int N = (int)3e5;
const int LOGN = 20;
const int INF = (int)1e9;

int n;
vector<int> graph[N];

int par[N][LOGN];

int timer;
int t_in[N], t_out[N];
int h[N];

int depth[N];
multiset<int> depth_set[N];

int dist_to_nearest[N];

int up_dp[N][LOGN];
int down_dp[N][LOGN];

void init_par(int v, int p)
{
if (p != -1)
graph[v].erase(find(graph[v].begin(), graph[v].end(), p));

par[v][0] = p;
for (int i = 1; i < LOGN; i++)
{
if (par[v][i - 1] == -1)
break;
par[v][i] = par[par[v][i - 1]][i - 1];
}

for (int to : graph[v])
{
assert(to != p);
init_par(to, v);
}
}

void init_tree0(int v, int cur_h)
{
t_in[v] = timer++;
h[v] = cur_h;

for (int to : graph[v])
init_tree0(to, cur_h + 1);

t_out[v] = timer++;
}

bool is_par(int a, int b)
{
return t_in[a] <= t_in[b] && t_out[b] <= t_out[a];
}

int get_lca(int a, int b)
{
if (is_par(a, b))
return a;
if (is_par(b, a))
return b;

for (int i = LOGN - 1; i >= 0; i--)
{
int new_a = par[a][i];
if (new_a != -1 && !is_par(new_a, b))
a = new_a;
}

return par[a][0];
}

int get_dist(int a, int b)
{
int l = get_lca(a, b);
return h[a] + h[b] - 2 * h[l];
}

void init_depth(int v)
{
depth[v] = 0;
for (int to : graph[v])
{
init_depth(to);
depth_set[v].insert(depth[to] + 1);
depth[v] = max(depth[v], depth[to] + 1);
}
}

void init_up_dp(int v)
{
int p = par[v][0];

if (p != -1)
{
depth_set[p].erase(find(depth_set[p].begin(), depth_set[p].end(), depth[v] + 1));

if (depth_set[p].empty())
up_dp[v][0] = 0;
else
up_dp[v][0] = *depth_set[p].rbegin();

depth_set[p].insert(depth[v] + 1);
}

for (int i = 1; i < LOGN; i++)
{
if (par[v][i] == -1)
break;

int pp = par[v][i - 1];
int val1 = up_dp[v][i - 1] + (1 << (i - 1));
int val2 = up_dp[pp][i - 1];
up_dp[v][i] = max(val1, val2);
}

for (int to : graph[v])
init_up_dp(to);
}

void init_down_dp(int v)
{
int p = par[v][0];

if (p != -1)
{
depth_set[p].erase(find(depth_set[p].begin(), depth_set[p].end(), depth[v] + 1));

if (depth_set[p].empty())
down_dp[v][0] = 1;
else
down_dp[v][0] = *depth_set[p].rbegin() + 1;

depth_set[p].insert(depth[v] + 1);
}

for (int i = 1; i < LOGN; i++)
{
if (par[v][i] == -1)
break;

int pp = par[v][i - 1];
int val1 = down_dp[v][i - 1];
int val2 = down_dp[pp][i - 1] + (1 << (i - 1));
down_dp[v][i] = max(val1, val2);
}

for (int to : graph[v])
init_down_dp(to);
}

void init_tree()
{
//init par[][] and erase edge to parent
for (int i = 0; i < N; i++)
for (int j = 0; j < LOGN; j++)
par[i][j] = -1;
init_par(0, -1);

//init h t_in t_out
init_tree0(0, 0);

//init depth and depth_set
init_depth(0);

//init up_dp and down_up
init_up_dp(0);
init_down_dp(0);
}

int get_down_max(int a, int b)
{
int old_a = a;
int ans = 0;

for (int i = LOGN - 1; i >= 0; i--)
{
int new_a = par[a][i];
if (new_a == -1 || !is_par(b, new_a))
continue;
ans = max(ans, down_dp[a][i] + get_dist(a, old_a));
a = new_a;
}

return ans;
}

int get_up_max(int a, int b)
{
//dbg(a, b);
int ans = 0;

for (int i = LOGN - 1; i >= 0; i--)
{
int new_a = par[a][i];
if (new_a == -1 || !is_par(b, new_a))
continue;
ans = max(ans, up_dp[a][i] + get_dist(new_a, b));
a = new_a;
}

return ans;
}

int go_up(int v, int delta)
{
assert(delta >= 0);
assert(delta <= h[v]);

for (int i = LOGN - 1; i >= 0; i--)
{
if (delta >= (1 << i))
{
v = par[v][i];
delta -= (1 << i);
}
}
return v;
}

int IT;

void solve()
{
int k;
scanf("%d", &k);

set<int> v_set;
vector<int> v_list(k);
for (int i = 0; i < k; i++)
{
int v;
scanf("%d", &v);
v--;
v_list[i] = v;
v_set.insert(v);
}

while (true)
{
set<int> new_v_set;
for (int v1 : v_set)
for (int v2 : v_set)
new_v_set.insert(get_lca(v1, v2));
if (new_v_set == v_set)
break;
v_set = new_v_set;
}

//dbg(k, v_list, v_set);

for (int v : v_set)
{
dist_to_nearest[v] = INF;
for (int u : v_list)
dist_to_nearest[v] = min(dist_to_nearest[v], get_dist(v, u));

//dbg(v, dist_to_nearest[v]);
}

vector<pair<int, int> > ab_pairs;

for (int a : v_set)
{
for (int b : v_set)
{
if (a == b || !is_par(b, a))
continue;

bool has_node_on_path = false;
for (int v : v_set)
if (v != a && v != b && is_par(b, v) && is_par(v, a))
has_node_on_path = true;

if (has_node_on_path)
continue;

ab_pairs.emplace_back(a, b);
}
}

//dbg(ab_pairs.size());

int ans = 0;

//go_up single
for (int v : v_set)
{
bool is_root = true;
for (int u : v_set)
if (u != v && is_par(u, v))
is_root = false;

if (!is_root)
continue;

ans = max(ans, get_down_max(v, 0) + dist_to_nearest[v]);
ans = max(ans, dist_to_nearest[v] + h[v]);
}

//go_down single

for (auto ab_p : ab_pairs)
{
int a = ab_p.first;
int b = ab_p.second;
int c = go_up(a, h[a] - h[b] - 1);
depth_set[b].erase(depth_set[b].find(depth[c] + 1));
}

for (int v : v_set)
{
if (!depth_set[v].empty())
ans = max(ans, *depth_set[v].rbegin() + dist_to_nearest[v]);
else
ans = max(ans, dist_to_nearest[v]);
}

for (auto ab_p : ab_pairs)
{
int a = ab_p.first;
int b = ab_p.second;
int c = go_up(a, h[a] - h[b] - 1);
depth_set[b].insert(depth[c] + 1);
}

//go path
for (auto ab_p : ab_pairs)
{
int a = ab_p.first;
int b = ab_p.second;

int c = a;
for (int i = LOGN - 1; i >= 0; i--)
{
int new_c = par[c][i];

if (new_c == -1 || is_par(new_c, b))
continue;

//O(1)
int dist1 = dist_to_nearest[a] + get_dist(a, new_c);
int dist2 = dist_to_nearest[b] + get_dist(b, new_c);

if (dist1 <= dist2)
c = new_c;
}

//[a, c] and (c, b]
int val1 = get_down_max(a, c) + dist_to_nearest[a];
int val2 = 0;

int b2 = go_up(c, h[c] - h[b] - 1);
if (c != b2)
val2 = get_up_max(c, b2) + 1 + dist_to_nearest[b];

//dbg(a, b, b2, c, val1, val2);

ans = max(ans, val1);
ans = max(ans, val2);
}

printf("%d\n", ans);
}

int main()
{
#ifdef LOCAL
freopen ("input.txt", "r", stdin);
#endif

int q;
scanf("%d%d", &n, &q);

for (int i = 0; i < n - 1; i++)
{
int a, b;
scanf("%d%d", &a, &b);
a--;
b--;
//dbg(a, b);
graph[a].push_back(b);
graph[b].push_back(a);
}

init_tree();

for (int it = 0; it < q; it++)
{
solve();
}

return 0;
}