In this HackerRank Travelling Salesman in a Grid problem solution, The traveling salesman has a map containing m*n squares. He starts from the top left corner and visits every cell exactly once and returns to his initial position (top left). The time taken for the salesman to move from a square to its neighbor might not be the same. Two squares are considered adjacent if they share a common edge and the time is taken to reach square b from square a and vice-versa is the same. Can you figure out the shortest time in which the salesman can visit every cell and get back to his initial position?

HackerRank Travelling Salesman in a Grid problem solution


Problem solution in Python.

#!/bin/python3

import os
import sys

#
# Complete the tspGrid function below.
#
INF = 10 ** 9

m = True, False, None
TT, TF, TN, FT, FF, FN, NT, NF, NN = ((i, j) for i in m for j in m)

m, n = map(int, input().split())
row = [list(map(int, input().split())) for i in range(m)]
column = [list(map(int, input().split())) for j in range(m - 1)]

def minimize(t, v):
    global current, INF
    current[t] = min(v, current.get(t, INF))

if m & n & 1:
    ans = 0
else:
    ans = INF
    previous, current = {}, {NN[:1] * (m + n): 0}
    for i in range(m):
        for j in range(n):
            previous, current, k = current, {}, m + j - 1 - i
            for state, value in previous.items():
                l, x, r = state[:k], state[k: k + 2], state[k + 2:]
                if x == NN:
                    if i + 1 < m and j + 1 < n:
                        minimize(l + TF + r, value)
                elif x == NT or x == NF:
                    value += column[i - 1][j]
                    if j + 1 < n:
                        minimize(state, value)
                    if i + 1 < m:
                        minimize(l + x[::-1] + r, value)
                elif x == FN or x == TN:
                    value += row[i][j - 1]
                    if j + 1 < n:
                        minimize(l + x[::-1] + r, value)
                    if i + 1 < m:
                        minimize(state, value)
                else:
                    value += row[i][j - 1] + column[i - 1][j]
                    if x == TF:
                        if i + 1 == m and j + 1 == n:
                            ans = min(ans, value)
                    elif x == FT:
                        minimize(l + NN + r, value)
                    elif x == TT:
                        count = 1
                        index = -1
                        while count:
                            index += 1
                            count += 1 if r[index] == TT[0] else -1 if r[index] == FF[0] else 0
                        minimize(l + NN + r[:index] + TT[:1] + r[index + 1:], value)
                    else:
                        count = -1
                        index = k
                        while count:
                            index -= 1
                            count += 1 if l[index] == TT[0] else -1 if l[index] == FF[0] else 0
                        minimize(l[:index] + FF[:1] + l[index + 1:] + NN + r, value)
print(ans)

{"mode":"full","isActive":false}


Problem solution in Java.

import java.io.*;
import java.util.*;

public class Solution {

  static int[] three;

  static int bit(int s, int i) {
    return s/three[i]%3;
  }

  static final int NS = 5798;
  static int ns = 0;
  static int[] mapping = new int[177147];
  static int[] states = new int[NS];
  
  static void dfs(int k, int x, int s) {
    if (k == 0) {
      if (x == 0) {
        mapping[s] = ns;
        states[ns++] = s;
      }
      return;
    }
    dfs(k-1, x, 3*s);
    if (x > 0) {
      dfs(k-1, x-1, 3*s+1);
    }
    dfs(k-1, x+1, 3*s+2);
  }

  static int n;
  static int[] cur;
  
  static void tr(int j, int s, int g, int opt) {
    s -= three[j]*bit(s, j)+three[j+1]*bit(s, j+1);
    s += three[j]*g;
    if (j == n-1) {
      if (bit(s, n) > 0) {
        return;
      }
      s *= 3;
    }
    cur[mapping[s]] = Math.min(cur[mapping[s]], opt);
  }

  
  public static void main(String[] args) throws IOException {
    BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    BufferedWriter bw = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

    StringTokenizer st = new StringTokenizer(br.readLine());
    int m = Integer.parseInt(st.nextToken());
    n = Integer.parseInt(st.nextToken());

    if (n%2 > 0 && m%2 > 0) {
      bw.write("0");
      bw.newLine();

      bw.close();
      br.close();
      return;
    }

    int[][] horizontal = new int[m > n ? m : n][n];

    for (int i = 0; i < m; i++) {
      st = new StringTokenizer(br.readLine());

      for (int j = 0; j < n - 1; j++) {
        int item = Integer.parseInt(st.nextToken());
        horizontal[i][j] = item;
      }
    }

    int[][] vertical = new int[m > n ? m : n][n];

    for (int i = 0; i < m - 1; i++) {
      st = new StringTokenizer(br.readLine());

      for (int j = 0; j < n; j++) {
        int item = Integer.parseInt(st.nextToken());
        vertical[i][j] = item;
      }
    }

    three = new int[n+1];
    three[0] = 1;
    for (int i = 0; i < n; i++) {
      three[i+1] = three[i]*3;
    }
    dfs(n+1, 0, 0);

    int[][] tr4 = new int[ns][n];
    int[][] tr8 = new int[ns][n];
    for (int si = 0; si < ns; si++) {
      int s = states[si];
      for (int i = 0; i < n; i++) {
        int g = bit(s, i)+3*bit(s, i+1);
        if (g == 4) {
          int c = 0;
          for (int j = i+1; ; j++) {
            int b = bit(s, j);
            if (b == 1) c++;
            if (b == 2) c--;
            if (c == 0) {
              tr4[si][i] = s-three[i]*g-three[j]; // 1122 -> 0012
              break;
            }
          }
        }
        if (g == 8) {
          int c = 0;
          for (int j = i; ; j--) {
            int b = bit(s, j);
            if (b == 1) c++;
            if (b == 2) c--;
            if (c == 0) {
              tr8[si][i] = s-three[i]*g+three[j]; // 1122 -> 1200
              break;
            }
          }
        }
      }
    }
    
    int[][] dp = new int[2][ns];
    int[] pre = dp[0];
    cur = dp[1];
    Arrays.fill(cur, 0, ns, Integer.MAX_VALUE/2);
    cur[mapping[0]] = 0;
    for (int i = 0; i < m; i++)
      for (int j = 0; j < n; j++) {
        int[] tmp = pre;
        pre = cur;
        cur = tmp;
        Arrays.fill(cur, 0, ns, Integer.MAX_VALUE/2);
        for (int si = 0; si < ns; si++) {
          int s = states[si];
          int g = bit(s, j)+bit(s, j+1)*3;
          int opt = pre[si];
          switch (g) {
          case 0:
            tr(j, s, 1+3*2, opt + vertical[i][j] + horizontal[i][j]);
            break;
          case 1:
          case 3:
            tr(j, s, 1, opt + vertical[i][j]);
            tr(j, s, 3, opt + horizontal[i][j]);
            break;
          case 2:
          case 6:
            tr(j, s, 2, opt + vertical[i][j]);
            tr(j, s, 6, opt + horizontal[i][j]);
            break;
          case 5:
            tr(j, s, 0, opt);
            break;
          case 4:
            tr(j, tr4[si][j], 0, opt);
            break;
          case 8:
            tr(j, tr8[si][j], 0, opt);
            break;
          case 7:
            if (i == m-1 && j == n-1)
              tr(j, s, 0, opt);
            break;
          }
        }
      }
    bw.write(String.valueOf(cur[mapping[0]]));
    bw.newLine();

    bw.close();
    br.close();
  }
}

{"mode":"full","isActive":false}


Problem solution in C++.

#include <iostream>
#include <vector>
#include <set>
#include <map>
#include <unordered_map>
#include <algorithm>
#include <ctime>

using namespace std;

typedef long long ll;

class DisjointSet{
public:
    DisjointSet( int size ){
        rank.resize( size, 0 );
        p.resize( size, -1 );
        elem.resize( size, 0 );
    }
    
    void makeSet( int x ){
        p[x] = x;
        rank[x] = 0;
        elem[x] = 1;
    }

    bool isSet( int x, int y ) {
        return !( p[x]<0 || p[y]<0 );
    }

    bool Union( int x, int y ){
        if ( isSameSet(x,y) ) return false;
        link( findSet(x), findSet(y) );
        return true;
    }
    
    int findSet( int x ){
        if ( x != p[x] ) p[x] = findSet( p[x] );
        return p[x];
    }
    
    bool isSameSet( int x, int y ){
        return ( findSet(x) == findSet(y) );
    }

    vector<int> rank, p, elem;
    
    void link ( int x, int y ){
        if ( rank[x] > rank[y] ){
            p[y] = x;
            elem[x] += elem[y];
        } else {
            p[x] = y;
            elem[y] += elem[x];
            if ( rank[x] == rank[y] ) rank[y]++;
        }
    }
};

inline bool bit(ll x, ll b)
{
    return ((x>>b)&1)==1;
}

static unordered_map<ll,vector<ll>> memo;
vector<ll>& Next(int r, ll cur, bool first)
{
    if ( memo.count(cur) ) return memo[cur];

    ll w = cur&((1LL<<32)-1), g = cur>>32;
    vector<ll> next;
    for ( int i = 0; i < (1<<r); i++ ) {
        ll a = w<<1, b = i<<1;
        bool valid = true;
        for ( int j = 0; j <= r; j++ ) {
            ll a0 = a&3, b0 = b&3;
            if ( a0==0&&b0==0 || a0==3&&b0==3 || a0==1&&b0==2 || a0==2&&b0==1 ) {
                valid = false;
                break;
            }
            a>>=1;
            b>>=1;
        }
        if ( !valid ) {
            continue;
        }

        DisjointSet d(2*r);
        for ( int j = 0; j < r; j++ ) {
            if ( bit(w,j) ) d.makeSet(j);
            if ( bit(i,j) ) d.makeSet(j+r);
            if ( bit(w,j) && bit(i,j) ) {
                d.Union(j,j+r);
            }
        }

        for ( int j = 0; j < r; j++ ) {
            if ( !bit(w,j) ) continue;
            for ( int k = j+1; k < r; k++ ) {
                if ( !bit(w,k) ) continue;
                if ( ((g>>(3*j))&7) == ((g>>(3*k))&7) ) {
                    d.Union(j,k);
                }
            }
        }

        for ( int j = 1; j < r; j++ ) {
            if ( bit(i,j-1) && bit(i,j) ) {
                if ( !d.Union(j+r-1, j+r) ) {
                    valid = false;
                    break;
                }
            }
        }
        if ( !valid ) {
            continue;
        }

        bool checked[32] = {0};
        for ( int j = 0; j < r; j++ ) {
            if ( !bit(w,j) || checked[d.findSet(j)] ) continue;
            bool connedted = false;
            for ( int k = 0; k < r; k++ ) {
                if ( !bit(i,k) ) continue;
                if ( d.findSet(j) == d.findSet(k+r) ) {
                    checked[d.findSet(j)] = true;
                    connedted = true;
                    break;
                }
            }
            if ( !connedted ) {
                valid = false;
                break;
            }
        }
        if ( !valid ) {
            continue;
        }

        ll gn[16] = {0};
        int ig = 0;
        map<ll,ll> mp;
        for ( int j = 0; j < r; j++ ) {
            if ( !bit(i,j) ) continue;
            if ( mp.count(d.findSet(j+r)) ) {
                gn[j] = mp[d.findSet(j+r)];
            } else {
                mp[d.findSet(j+r)] = ig;
                gn[j] = ig;
                ig++;
            }
        }

        ll t = 0;
        for ( int j = r-1; j >= 0; j-- ) {
            t <<= 3;
            t += gn[j];
        }
        next.push_back( (t<<32) + i );
    }

    return memo[cur] = next;
}

ll Count(int r, int c)
{
    memo.clear();
    if ( r > c ) swap(r,c);

    map<ll,ll> dp;
    dp[0] = 1;
    for ( int i = 0; i < c; i++ ) {
        map<ll,ll> dp2;
        for ( auto p : dp ) {
            auto& next = Next(r, p.first, i==0);
            for ( auto nn : next ) {
                dp2[nn] += p.second;
            }
        }
        dp.swap(dp2);
    }

    ll sum = 0;
    for ( auto it = dp.begin(); it != dp.end(); ++it ) {
        ll w = it->first&((1LL<<32)-1);
        w <<= 1;
        bool valid = true;
        for ( int j = 0; j <= r; j++ ) {
            if ( (w&3) == 0 ) {
                valid = false;
                break;
            }
            w >>= 1;
        }
        if ( !valid ) {
            continue;
        }

        ll g = it->first>>32;
        if ( g ) {
            continue;
        }

        sum += it->second;
    }

    return sum;
}

ll Solve(int r, int c, vector<vector<ll>>& ch, vector<vector<ll>>& cv)
{
    memo.clear();

    map<ll,ll> dp;
    dp[0] = 0;
    for ( int i = 0; i < c; i++ ) {
        map<ll,ll> dp2;
        for ( auto p : dp ) {
            auto cur = p.first;
            for ( auto next : Next(r, cur, i==0) ) {
                ll cost = p.second;
                for ( int j = 0; j < r; j++ ) {
                    if ( bit(cur,j) ^ bit(next,j) ) {
                        cost += cv[j][i];
                    }
                }
                ll t = next<<1;
                for ( int j = 0; j <= r; j++ ) {
                    if ( bit(t,j) ^ bit(t,j+1) ) {
                        cost += ch[j][i];
                    }
                }

                if ( dp2.count(next)==0 || dp2[next] > cost ) {
                    dp2[next] = cost;
                }
            }
        }
        dp.swap(dp2);
    }

    ll ans = -1;
    for ( auto it = dp.begin(); it != dp.end(); ++it ) {
        ll w = it->first&((1LL<<32)-1);
        w <<= 1;
        bool valid = true;
        for ( int j = 0; j <= r; j++ ) {
            if ( (w&3) == 0 ) {
                valid = false;
                break;
            }
            w >>= 1;
        }
        if ( !valid ) {
            continue;
        }

        ll g = it->first>>32;
        if ( g ) {
            continue;
        }

        ll cost = it->second;
        for ( int j = 0; j < r; j++ ) {
            if ( bit(it->first,j) ) {
                cost += cv[j].back();
            }
        }

        if ( ans < 0 || ans > cost ) {
            ans = cost;
        }
    }
    if ( ans < 0 ) return 0;
    return ans;
}

int main()
{
    int n = 0, m = 0;
    cin >> m >> n;

    vector<vector<ll>> ch, cv;
    for ( int i = 0; i < m; i++ ) {
        vector<ll> v(n-1,0);
        for ( int j = 0; j < n-1; j++ ) {
            cin >> v[j];
        }
        ch.push_back(v);
    }
    for ( int i = 0; i < m-1; i++ ) {
        vector<ll> v(n,0);
        for ( int j = 0; j < n; j++ ) {
            cin >> v[j];
        }
        cv.push_back(v);
    }

    cout << Solve(m-1,n-1,ch,cv) << endl;
    return 0;
}

{"mode":"full","isActive":false}