# HackerRank Tree Pruning problem solution

In this HackerRank Tree Pruning problem solution A tree, T, has N vertices numbered from 1 to N and is rooted at vertex 1. Each vertex I has an integer weight, Wi, associated with it, and T's total weight is the sum of the weights of its nodes. A single remove operation removes the subtree rooted at some arbitrary vertex U from tree T.

Given T, perform up to K remove operations so that the total weight of the remaining vertices in T is maximal. Then print T's maximal total weight on a new line.

## Problem solution in Python.

```#!/bin/python3

import os
import sys

#
# Complete the treePrunning function below.
#

from collections import defaultdict

INF = -(1e15)

def dfs(x, f, g, k, weights):
dpc = [INF]*(k+1)
dpc[0] = weights[x]

for n in g[x]:
if n == f:
continue
dpn = dfs(n, x, g, k, weights)
dptmp = [INF]*(k+1)
for i in range(k+1):
if dpc[i] == INF:
break
for j in range(0, k-i+1):
if dpn[j] == INF:
break
dptmp[i+j] = max(dptmp[i+j], dpc[i]+dpn[j])
if i+1 <= k:
dptmp[i+1] = max(dptmp[i+1], dpc[i])
dpc = dptmp
return dpc

def treePrunning(k,weights,edges):
g = defaultdict(list)
for u, v in edges:
g[u-1].append(v-1)
g[v-1].append(u-1)

dpn = dfs(0, -1, g, k, weights)
return max(max(dpn),0)

if __name__ == '__main__':
fptr = open(os.environ['OUTPUT_PATH'], 'w')

nk = input().split()

n = int(nk[0])

k = int(nk[1])

weights = list(map(int, input().rstrip().split()))

tree = []

for _ in range(n-1):
tree.append(list(map(int, input().rstrip().split())))

result = treePrunning(k, weights, tree)

fptr.write(str(result) + '\n')

fptr.close()
```

## Problem solution in Java.

```import java.util.*;
import java.io.*;

public class Main {
public static void main(String[] args) throws IOException {
FastScanner in = new FastScanner(System.in);
PrintWriter out = new PrintWriter(System.out);
new Main().run(in, out);
out.close();
}

int n;
int K;
int[] w;
void run(FastScanner in, PrintWriter out) {

n = in.nextInt();
K = in.nextInt();
w = new int[n];
for (int i = 0; i < n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n; i++) w[i] = in.nextInt();
for (int i = 1; i < n; i++) {
int u = in.nextInt()-1;
int v = in.nextInt()-1;
}

long[] dp = go(0, -1);
long max = Long.MIN_VALUE;
for (int k = 0; k <= K; k++) {
max = Math.max(max, dp[k]);
}
out.println(max);
}

long[] go(int u, int p) {

long[][] dp = new long[2][K+1];
for (long[] d : dp) Arrays.fill(d, Long.MIN_VALUE);
int flip = 0;
dp[0][0] = w[u];

for (int v : adj[u]) {
Arrays.fill(dp[flip^1], Long.MIN_VALUE);
if (v == p) continue;

long[] childDp = go(v, u);
for (int k = 0; k <= K && dp[flip][k] != Long.MIN_VALUE; k++) {
for (int pk = 0; pk+k <= K && childDp[pk] != Long.MIN_VALUE; pk++) {
dp[flip^1][pk+k] = Math.max(dp[flip^1][pk+k],
dp[flip][k] + childDp[pk]);
}
}
flip = flip^1;
}
dp[flip][1] = Math.max(dp[flip][1], 0);
return dp[flip];
}

static class FastScanner {
StringTokenizer st;

public FastScanner(InputStream in) {
st = null;
}

String next() {
while (st == null || !st.hasMoreElements()) {
try {
} catch (IOException e) {
e.printStackTrace();
}
}
return st.nextToken();
}

int nextInt() {
return Integer.parseInt(next());
}

long nextLong() {
return Long.parseLong(next());
}
}
}```

## Problem solution in C++.

```#include <cmath>
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdlib>
#include <climits>
using namespace std;
typedef long long ll;

const int MAXN = 100005;
const int MAXK = 205;
int wgt[MAXN];
ll dp[MAXN][MAXK];
vector< vector<int> > net;
int N, K;

void dfs(int cur, int par){
ll w = wgt[cur];
int first = -1;
for(int i = 0; i < net[cur].size(); i++){
int u = net[cur][i];
if(u == par) continue;
if(first == -1) first = i;
dfs(u, cur);
w += dp[u][0];
}
if(first != -1)
for(int i = 0; i <= K; i++) dp[cur][i] = w + dp[net[cur][first]][i] - dp[net[cur][first]][0];
for(int i = first +1; i < net[cur].size(); i++){
int u = net[cur][i];
if(u == par ) continue;
for(int j = K; j >=0; j--){
for(int l = j; l >= 0; l--){
if(dp[cur][j] < dp[cur][l] + dp[u][j-l] - dp[u][0]) dp[cur][j] = dp[cur][l] + dp[u][j-l] - dp[u][0];
}
}
}
dp[cur][0] = w;
for(int i = 1; i <= K; ++i)
if (dp[cur][i] < 0) dp[cur][i] = 0;
}

int main() {
/* Enter your code here. Read input from STDIN. Print output to STDOUT */
cin>>N>>K;
for(int i = 1; i <= N; i++) cin>>wgt[i];
net.resize(N+1);
for(int i = 0; i < N-1; i++){
int u,v;
cin>>u>>v;
net[u].push_back(v);
net[v].push_back(u);
}
for(int i = 1; i<= N; i++){
for(int j = 0; j<=K; j++){
dp[i][j] = -1*INT_MAX;
}
}
dfs(1,0);
ll ans = dp[1][0];
for(int i = 1; i <= K; i++) ans = max(ans, dp[1][i]);
cout<<ans<<endl;
return 0;
}
```

## Problem solution in C.

```#include<stdio.h>
#include<stdlib.h>
typedef struct _node
{
int x;
struct _node *next;
}node;
int a[100000], b[100000], size[100000], trace[100000] = {0}, NN = 0;
long long dp[100001][201];
node *table[100000] = {0};
void insert_edge(int x, int y)
{
node *t;
t = (node*)malloc(sizeof(node));
t -> x = y;
t -> next = table[x];
table[x] = t;
t = (node*)malloc(sizeof(node));
t -> x = x;
t -> next = table[y];
table[y] = t;
return;
}
void dfs(int x)
{
node *t;
int i = NN;
trace[x] = 1;
b[NN++] = a[x];
for( t = table[x] ; t ; t = t -> next )
{
if(!trace[t -> x])
{
dfs(t -> x);
}
}
size[i] = NN - i;
return;
}
long long max(long long x, long long y)
{
return ( x > y ) ? x : y;
}
int main()
{
int N, K, x, y, i, j;
long long sum;
scanf("%d%d", &N, &K);
for( i = 0 ; i < N ; i++ )
{
scanf("%d", a + i);
}
for( i = 0 ; i < N - 1 ; i++ )
{
scanf("%d%d", &x, &y);
insert_edge(x-1, y-1);
}
dfs(0);
for( i = 0 ; i <= K ; i++ )
dp[0][i] = 0;
for( i = 1, sum = 0 ; i <= N ; i++ )
{
sum += b[i-1];
for( j = 0 ; j <= K ; j++ )
{
dp[i][j] = sum;
}
}
for( i = 1, sum = 0 ; i <= N ; i++ )
{
for( j = 0 ; j <= K ; j++ )
{
if( j != K )
{
dp[i+size[i-1]-1][j+1] = max(dp[i+size[i-1]-1][j+1], dp[i-1][j]);
}
dp[i][j] = max(dp[i][j], dp[i-1][j] + b[i-1]);
}
}
printf("%lld", dp[N][K]);
return 0;
}```