In this HackerRank Simple Game problem solution Big Cat and Little Cat love playing games. Today, they decide to play a Game of Stones, the Kitties are Coming edition. The game's rules are as follows:

  1. The game starts with N stones that are randomly divided into M piles.
  2. The cats move in alternating turns, and Little Cat always moves first.
  3. During a move, a cat picks a pile having a number of stones >= 2 and splits it into any number of non-empty piles in the inclusive range from 2 to K.
  4. The first cat to be unable to make a move (e.g., because all piles contain exactly 1 stone) loses the game.
  5. Little Cat is curious about the number of ways in which the stones can be initially arranged so that she is guaranteed to win. Two arrangements of stone piles are considered to be different if they contain different sequences of values. For example, arrangements (2,1,2) and (2,2,1) are different.

Given the values for N, M, and K, find the number of winning configurations for Little Cat and print it modulo 10 to power 9 plus 7.

HackerRank Simple Game problem solution


Problem solution in Java.

import java.util.*;

import java.util.*;

public class Solution {
    private static class NimCounter {
        final long[] nimCounter;

        NimCounter(int splitVal) {
            nimCounter  =  new long[splitVal];
        }
    }

    private static long mod = 1000000007;
    private static Map<Integer, Integer> k3nimberCache = new HashMap<>();

    private static void split(int pile, int lastUsedVal, int k, List<Integer> currentSplit, List<List<Integer>> splittings) {
        if (pile == 0 && currentSplit.size() > 1) {
            splittings.add(new ArrayList<>(currentSplit));
            return;
        }

        if (pile < lastUsedVal) return;

        if (currentSplit.size() == k - 1) {
            currentSplit.add(pile);
            splittings.add(new ArrayList<>(currentSplit));
            currentSplit.remove(currentSplit.size() - 1);
        } else {
            for (int i = lastUsedVal; i <= pile; ++i) {
                currentSplit.add(i);
                split(pile - i, i, k, currentSplit, splittings);
                currentSplit.remove(currentSplit.size() - 1);
            }
        }
    }

    private static int mex(List<Integer> nimbers) {
        Collections.sort(nimbers);
        for (int i = 0; i < nimbers.size(); ++i) {
            if (nimbers.get(i) != i) return i;
        }
        return nimbers.size();
    }

    private static int nimberk3(List<Integer> split, int k) {
        int result = 0;
        for (Integer i : split) {
            result = result ^ nimValue(i, k);
        }
        return result;
    }

    private static int nimValue(int splitValue, int maxPilesNum) {
        if (splitValue < 2) return 0;
        if (maxPilesNum == 2) return 1 - (splitValue % 2);
        if (maxPilesNum > 3) return splitValue - 1;

        Integer cached = k3nimberCache.get(splitValue);
        if (cached != null) return cached;

        List<List<Integer>> splits = new ArrayList<>();
        split(splitValue,1, maxPilesNum, new ArrayList<Integer>(), splits);

        Set<Integer> nimbers = new HashSet<>();
        for (List<Integer> aSplit : splits) {
            nimbers.add(nimberk3(aSplit, maxPilesNum));
        }

        int result = mex(new ArrayList<>(nimbers));
        k3nimberCache.put(splitValue, result);

        return result;
    }

    private static long solve(int stones, int initialPilesNum, int splitPilesNum) {
        NimCounter[][] nimCounters = new NimCounter[initialPilesNum + 1][stones + 1];
        for (int i = 0; i <= initialPilesNum; ++i) {
            for (int j = 0; j <= stones; ++j) nimCounters[i][j] = new NimCounter(stones);
        }

        for (int i = 1; i <= stones; ++i) {
            ++nimCounters[1][i].nimCounter[nimValue(i, splitPilesNum)];
        }

        for (int splitInto = 2; splitInto <= initialPilesNum; ++splitInto) {
            for (int splitValue = splitInto; splitValue <= stones; ++splitValue) {
                NimCounter splitCounter = nimCounters[splitInto][splitValue];
                for (int leaveAtPile = 1; leaveAtPile <= splitValue - splitInto + 1; ++leaveAtPile) {
                    int nimAtPile = nimValue(leaveAtPile, splitPilesNum);
                    NimCounter counter = nimCounters[splitInto - 1][splitValue - leaveAtPile];
                    for (int cnt = 0; cnt < counter.nimCounter.length; ++cnt) {
                        if (counter.nimCounter[cnt] > 0) {
                            splitCounter.nimCounter[nimAtPile ^ cnt] += counter.nimCounter[cnt];
                        }
                    }
                }
            }
        }

        NimCounter counter = nimCounters[initialPilesNum][stones];
        long result = 0;
        for (int i = 1; i < counter.nimCounter.length; ++i) {
            result = (result + counter.nimCounter[i]) % mod;
        }

        return result;
    }

    public static void main(String[] params) {
        Scanner scanner = new Scanner(System.in);
        String[] input = scanner.nextLine().split(" ");
        System.out.println(solve(Integer.parseInt(input[0]), Integer.parseInt(input[1]), Integer.parseInt(input[2])));
    }
}


Problem solution in C++.

#include <iostream>
#include <numeric>
#include <type_traits>
using namespace std;

#define FOR(i, a, b) for (remove_cv<remove_reference<decltype(b)>::type>::type i = (a); i < (b); i++)
#define REP(i, n) FOR(i, 0, n)
#define REP1(i, n) for (remove_cv<remove_reference<decltype(n)>::type>::type i = 1; i <= (n); i++)

const long N = 600, MOD = 1000000007;
int nim[N+1], mx[N+1], mex[N], s[2][N+1][N];

int main()
{
  long m, n, x;
  cin >> n >> m >> x;
  if (x == 2)
    FOR(i, 2, n+1)
      nim[i] = (i & 1) ^ 1;
  else if (x == 3) {
    long tick = 0, t;
    FOR(i, 2, n+1) {
      tick++;
      FOR(j, 1, i)
        if ((t = nim[j] ^ nim[i-j]) < i)
          mex[t] = tick;
      FOR(j, 1, i)
        FOR(k, 1, i-j)
          if ((t = nim[j] ^ nim[k] ^ nim[i-j-k]) < i)
            mex[t] = tick;
      while (mex[nim[i]] == tick)
        nim[i]++;
    }
  } else
    FOR(i, 2, n+1)
      nim[i] = i-1;
  FOR(i, 2, n+1) {
    mx[i] = 0;
    REP1(j, i)
      mx[i] = max(mx[i], mx[i-j] ^ nim[j]);
  }

  s[0][0][0] = 1;
  REP(i, m) {
    REP(j, n+1)
      fill_n(s[i+1&1][j], N, 0);
    REP(j, n)
      REP(k, mx[j]+1)
        if (s[i&1][j][k])
          REP1(jj, n-j) {
            long kk = k ^ nim[jj];
            if (kk < N) {
              int& t = s[i+1&1][j+jj][kk];
              t += s[i&1][j][k];
              if (t >= MOD)
                t -= MOD;
            }
          }
  }
  cout << accumulate(s[m&1][n]+1, s[m&1][n]+N, 0L) % MOD << endl;
}


Problem solution in C.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MOD 1000000007
#define MAX_NIM 600
#define MAX_NIM_SIZE (MAX_NIM/63+1)
void solve(int x);
void solve_aux(int x,int y,long long *a);
int min(int x,int y);
long long solve1(int n,int m);
long long solve2(int n,int m,int o);
int K,nim[600],tt[MAX_NIM];
long long dp[600][600][MAX_NIM_SIZE],dp1[600][10],dp2[600][10][600];

int main()
{
  int N,M,i;
  long long total,total1;
  memset(dp,-1,sizeof(dp));
  memset(dp1,-1,sizeof(dp1));
  memset(dp2,-1,sizeof(dp2));
  scanf("%d%d%d",&N,&M,&K);
  if(K>=4)
    for(i=0;i<N;i++)
      nim[i]=i;
  else
  {
    nim[0]=0;
    for(i=2;i<=N;i++)
      solve(i);
  }
  total=solve1(N,M);
  total1=solve2(N,M,0);
  printf("%lld",(total-total1+MOD)%MOD);
  return 0;
}
void solve(int x)
{
  int i,j,sum;
  long long a[MAX_NIM_SIZE];
  memset(tt,0,sizeof(tt));
  for(i=1;i<x;i++)
  {
    solve_aux(x-i,min(x-i,K-1),a);
    for(j=0;j<MAX_NIM;j++)
      if(a[j/63]&(1LL<<(j%63)))
      {
        sum=(nim[i-1]^j);
        tt[sum]=1;
      }
  }
  for(i=0;i<MAX_NIM;i++)
    if(!tt[i])
    {
      nim[x-1]=i;
      break;
    }
  return;
}
void solve_aux(int x,int y,long long *a)
{
  int i,j,sum;
  long long b[MAX_NIM_SIZE];
  if(!x)
  {
    for(i=0;i<MAX_NIM_SIZE;i++)
      a[i]=0;
    return;
  }
  if(dp[x-1][y-1][0]!=-1)
  {
    for(i=0;i<MAX_NIM_SIZE;i++)
      a[i]=dp[x-1][y-1][i];
    a[nim[x-1]/63]|=(1LL<<(nim[x-1]%63));
    return;
  }
  for(i=0;i<MAX_NIM_SIZE;i++)
    a[i]=dp[x-1][y-1][i]=0;
  if(y==1)
  {
    a[nim[x-1]/63]|=(1LL<<(nim[x-1]%63));
    dp[x-1][y-1][nim[x-1]/63]|=(1LL<<(nim[x-1]%63));
    return;
  }
  for(i=1;i<=x;i++)
  {
    solve_aux(x-i,min(x-i,y-1),b);
    for(j=0;j<MAX_NIM;j++)
      if(b[j/63]&(1LL<<(j%63)))
      {
        sum=(nim[i-1]^j);
        a[sum/63]|=(1LL<<(sum%63));
        dp[x-1][y-1][sum/63]|=(1LL<<(sum%63));
      }
  }
  return;
}
int min(int x,int y)
{
  return (x<y)?x:y;
}
long long solve1(int n,int m)
{
  int i;
  if(!n || !m)
    return 0;
  if(dp1[n-1][m-1]!=-1)
    return dp1[n-1][m-1];
  if(m==1)
    dp1[n-1][m-1]=1;
  else
  {
    dp1[n-1][m-1]=0;
    for(i=1;i<n;i++)
      dp1[n-1][m-1]=(dp1[n-1][m-1]+solve1(n-i,m-1))%MOD;
  }
  return dp1[n-1][m-1];
}
long long solve2(int n,int m,int o)
{
  int i;
  if(!n || !m)
    return 0;
  if(dp2[n-1][m-1][o]!=-1)
    return dp2[n-1][m-1][o];
  if(m==1)
    if(nim[n-1]==o)
      dp2[n-1][m-1][o]=1;
    else
      dp2[n-1][m-1][o]=0;
  else
  {
    dp2[n-1][m-1][o]=0;
    for(i=1;i<n;i++)
      dp2[n-1][m-1][o]=(dp2[n-1][m-1][o]+solve2(n-i,m-1,o^nim[i-1]))%MOD;
  }
  return dp2[n-1][m-1][o];
}