In this HackerRank Kruskal (MST): Really Special Subtree problem solution we have given an undirected weighted connected graph, find the Really Special SubTree in it. The Really Special SubTree is defined as a subgraph consisting of all the nodes in the graph and:

  1. There is only one exclusive path from a node to every other node.
  2. The subgraph is of minimum overall weight (sum of all edges) among all such subgraphs.
  3. No cycles are formed

To create the Really Special SubTree, always pick the edge with the smallest weight. Determine if including it will create a cycle. If so, ignore the edge. If there are edges of equal weight available:

  1. Choose the edge that minimizes the sum u + v + wt where u and v are vertices and wt is the edge weight.
  2. If there is still a collision, choose any of them.

Print the overall weight of the tree formed using the rules.

HackerRank Kruskal (MST): Really Special Subtree problem solution


Problem solution in Python.

from collections import defaultdict
import heapq


def prim(G, S):
    visited = set()
    queue = [(0, S)]
    total_cost = 0
    while queue:
        cost, vertex = heapq.heappop(queue)
        if vertex not in visited:
            total_cost += cost
            visited.add(vertex)
            # We have reached an MST
            if len(G.keys()) == len(visited):
                return total_cost

            for next_cost, next_vertex in G[vertex]:
                heapq.heappush(queue, (next_cost, next_vertex))


def build_graph(M):
    G = defaultdict(set)
    for _ in range(M):
        e = [int(i) for i in input().split()]
        G[e[0]].add((e[2], e[1]))
        G[e[1]].add((e[2], e[0]))
    return G


def main():
    N, M = [int(i) for i in input().split()]
    G = build_graph(M)
    S = int(input())

    print(prim(G, S))


if __name__ == "__main__":
    main()

{"mode":"full","isActive":false}


Problem solution in Java.

import java.io.*;
import java.util.*;

public class Solution {
    static private class Node{
        int id;
        ArrayList<Edge> edges = new ArrayList<Edge>();
        int bestEdge = Integer.MAX_VALUE;
        
        public Node(int id) {
            this.id = id;
        }
    }
    
    static private class Edge{
        Node node;
        int weight;
        
        public Edge(Node node, int weight) {
            this.weight = weight;
            this.node = node;
        }
    }
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        
        int v = sc.nextInt();
        int e = sc.nextInt();
        
        ArrayList<Node> nodes = new ArrayList<Node>();
        nodes.add(null);
        
        for(int i=0; i<v; i++) {
            nodes.add(new Node(i+1));   
        }
        
        for(int i=0; i<e; i++) {
            Node a = nodes.get(sc.nextInt());
            Node b = nodes.get(sc.nextInt());
            int weight = sc.nextInt();
            
            a.edges.add(new Edge(b, weight));
            b.edges.add(new Edge(a, weight));
        }
        
        Node start = nodes.get(sc.nextInt());
        //System.out.format("start %d\n", start.id);
        start.bestEdge = 0;
        
        PriorityQueue<Node> q = new PriorityQueue<Node>(nodes.size(), new Comparator<Node>() {
            public int compare(Node a, Node b) {
                if(a.bestEdge == b.bestEdge) {
                    return 0;
                }
                
                return a.bestEdge > b.bestEdge ? 1 : -1;
            }
        });
        
        for(int i=1; i<=v; i++) {
            q.add(nodes.get(i));
        }
        
        int sum = 0;
        
        while(!q.isEmpty()) {
            Node node = q.poll();
            //System.out.format("entered %d\n", node.id);
            if(node.bestEdge != Integer.MAX_VALUE) {
                //System.out.format("bestedge %d\n", node.bestEdge);
                sum+=node.bestEdge;                
            }
            
            for(Edge edge:node.edges) {
                //System.out.format("neigbour %d\n", edge.node.id);
                if(!q.contains(edge.node)) {
                    //System.out.format("skipped\n");
                    continue;
                }
                
                Node neighbour = edge.node;
                if(edge.weight < neighbour.bestEdge) {
                    //System.out.format("used\n");
                    neighbour.bestEdge = edge.weight;
                    q.remove(neighbour);
                    q.add(neighbour);
                }
                
            }
        }
        
        System.out.println(sum);
    }
    
    
}

{"mode":"full","isActive":false}


Problem solution in C++.

#include <cmath>
#include <cstdio>
#include <vector>
#include <iostream>
#include <algorithm>
#include <unordered_set>
#include <climits>
using namespace std;

struct Node {
    int index;
    vector<Node*> neighbors;
    vector<int> nei_weights;
    Node(int i) {
        index = i;
    }
};

int main() {
    int N, M;
    cin >> N >> M;
    Node** graph = new Node*[N + 1];
    for (int i = 1; i <= N; i++) {
        graph[i] = new Node(i);
    }
    for (int i = 0; i < M; i++) {
        int x, y, r;
        cin >> x >> y >> r;
        graph[x]->neighbors.push_back(graph[y]);
        graph[y]->neighbors.push_back(graph[x]);
        graph[x]->nei_weights.push_back(r);
        graph[y]->nei_weights.push_back(r);
    }
    int S;
    cin >> S;
    Node* start = graph[S];
    unordered_set<Node*> existing;
    existing.insert(start);
    int total_weights = 0;
    for (int i = 0; i < N - 1; i++) {
        int min_index = -1;
        int min_distance = INT_MAX;
        for (auto it = existing.begin(); it != existing.end(); it++) {
            for (int j = 0; j < (*it)->neighbors.size(); j++) {
                if (existing.find((*it)->neighbors[j]) == existing.end()) {
                    if ((*it)->nei_weights[j] < min_distance) {
                        min_distance = (*it)->nei_weights[j];
                        min_index = (*it)->neighbors[j]->index;
                    }
                }
            }
        }
        existing.insert(graph[min_index]);
        total_weights += min_distance;
    }
    cout << total_weights << endl;
    /* Enter your code here. Read input from STDIN. Print output to STDOUT */   
    return 0;
}

{"mode":"full","isActive":false}


Problem solution in C.

#include <stdio.h>
#include <string.h>
#include <math.h>
#include <stdlib.h>

typedef struct  {
    int u, v, m, w;
} Arco;

int menorArco( Arco * arcos, int M ) {
    int menor, i;
    menor = -1;
    for (i = 0; i < M; i++ ) {
        if ( arcos[i].m == 0 ) {
            if ( menor == -1 )
                menor = i;
            else
                if ( arcos[i].w < arcos[menor].w  )
                    menor = i;
                else
                    if ( arcos[i].w == arcos[menor].w ) 
                        if ( arcos[i].u + arcos[i].w + arcos[i].v < 
                             arcos[menor].u + arcos[menor].w + arcos[menor].v )
                            menor = i;
        }
    }
    return menor;
}

void unirComponentes( int * vertices, int N, int c1, int c2 ) {
    int i;
    for ( i = 0; i < N; i++ )
        if ( vertices[i] == c2 )
            vertices[i] = c1;
}

int main() {

    /* Enter your code here. Read input from STDIN. Print output to STDOUT */    
    int N, M, S, vs, vl, d;
    int **graph;
    Arco *arcos;
    int *vertices;
    
    int i, j, k;
    scanf("%d%d",&N, &M);
    graph = (int **)malloc(N*sizeof(int *));
    for ( i = 0; i < N; i++ )
        graph[i] = (int *)malloc( N*sizeof(int) );    

    for ( i = 0; i < N; i++ )
        for ( j = 0; j < N; j++ )
           graph[i][j] = -1;
    
    for ( i = 0; i < M; i++ ) {
        scanf("%d%d%d", &vs, &vl, &d );
        vs--;vl--;
        if ( vl > vs ) {
            int t = vs; vs = vl; vl = t;
        }
        if ( graph[vs][vl] == -1 || graph[vs][vl] > d )
            graph[vs][vl] = d;
        /*graph[vl][vs] = graph[vs][vl];*/
    }

    M = 0;
    for ( i = 0; i < N; i++ )
        for ( j = 0; j < N; j++ )
           if ( graph[i][j] >= 0 )
               M++;
    
    arcos = (Arco *)malloc( M*sizeof(Arco) );
    for ( k = i = 0; i < N; i++ )
        for ( j = 0; j < N; j++ )
            if ( graph[i][j] >= 0 ) {
                arcos[k].u = i;
                arcos[k].v = j;
                arcos[k].w = graph[i][j];
                arcos[k].m = 0;
                k++;
            }

    for ( i = 0; i < N; i++ )
        free( graph[i] );
    free( graph );
    
    vertices = (int *)malloc( N*sizeof(int) );
    for ( i = 0; i < N; i++ )
        vertices[i] = i;
    
    scanf("%d", &S); //ignored
    S--;

    int menor, u, v;
    while ( (menor = menorArco(arcos,M)) >= 0 ) {
        u = arcos[menor].u;
        v = arcos[menor].v;
        if ( vertices[u] != vertices[v] ) {
            unirComponentes( vertices, N, vertices[u], vertices[v]);
            arcos[menor].m = 1;
        }
        else
            arcos[menor].m = -1;
    }

    long R;
    R = 0;
    for ( i = 0; i < M; i++ )
        if ( arcos[i].m == 1 )
            R = R + arcos[i].w;
    
    printf("%ld",R);

   
    return 0;
}

{"mode":"full","isActive":false}