My son recently picked up a copy of Rubiks Race, a kids’ board game that’s essentially a slide puzzle turned into a two-player showdown. You first generate a random target pattern by rolling nine six-coloured dice in a little container, and then each player slides their tiles around to match the pattern. Sometimes the target pattern is invalid – when it has 5 or more of the same colour – but how often exactly?

The fact that you have to re-roll the dice occasionally is hardly a big inconvenience, but it’s a fun exercise to think about the chance of it happening. We tried to guess the rate and we came up with the broad range of somewhere between 1% and 20%.

Random sampling

One simple way to estimate the rate is to (virtually) roll the dice a lot of times and count how often it happens.

import Control.Monad
import Data.List
import System.Environment
import System.Random
import Text.Printf

data Colour = Red | Green | Blue | Yellow | White | Orange
  deriving (Eq, Ord, Enum, Bounded)

-- Roll a single die.
die :: IO Colour
die = toEnum <$> randomRIO (0,5)

-- Roll all the dice together.
roll :: IO [Colour]
roll = replicateM 9 die

-- "Categorise" a roll: how many occurrences are there of the most
-- common colour?
categorise :: [Colour] -> [Int]
categorise = sort . map length . group . sort

-- Roll the dice N times and count the results.
main :: IO ()
main = do
  n <- read . head <$> getArgs
  observations <- replicateM n (last . categorise <$> roll)
  let groups = (map (\xs -> (head xs, length xs)) . group . sort) observations
  forM_ groups $ \(x,y) -> printf "%d,%d\n" x y

Running this with 10,000 rolls gives these results:

2,1586
3,5518
4,2391
5,453
6,45
7,7

So we can see that 505 times out of 10,000 we would have to re-roll, about 5% of the time. More than half of the time, the most common colour will be showing on exactly 3 dice.

Exhaustive counting

Can we get a more precise answer? We could of course sample more, but in this case there are not that many combinations in total – only 6^9, or about 10 million – so we can simply try them all.

import Control.Monad
import Data.List
import Text.Printf

data Colour = Red | Green | Blue | Yellow | White | Orange
  deriving (Eq, Ord, Enum, Bounded)

-- "Categorise" a roll: how many occurrences are there of the most
-- common colour?
categorise :: [Colour] -> [Int]
categorise = sort . map length . group . sort

-- All possible 9-die rolls.
allRolls :: [[Colour]]
allRolls = replicateM 9 [minBound..maxBound]

main :: IO ()
main = do
  let observations = (last . categorise <$> allRolls)
  let groups = (map (\xs -> (head xs, length xs)) . group . sort) observations
  forM_ groups $ \(x,y) -> printf "%d,%d\n" x y

And we get the exact answers:

2,1587600
3,5628000
4,2320920
5,472500
6,63000
7,5400
8,270
9,6

So exactly 541176/10077696 times we’ll have to re-roll (about 5.73%).

The only problem remaining with this program is that it’s slow: on my laptop it took about 22 seconds to run. How can we make it faster?

Faster

One way to make it faster is to observe that the colours are interchangeable so we’re counting separately patterns that look superficially different but are actually the same. (For example, the pattern with 6 reds, 2 greens and 1 blue is spiritually the same as the one with 6 greens, 2 yellows and 1 orange.) We can avoid some of the repetition by forcing the first die to come up, say, Red.

-- All possible 9-die rolls that start with a Red.
allRolls :: [[Colour]]
allRolls = map (Red:) (replicateM 8 [minBound..maxBound])

main :: IO ()
main = do
  let observations = (last . categorise <$> allRolls)
  let groups = (map (\xs -> (head xs, length xs)) . group . sort) observations
  forM_ groups $ \(x,y) -> printf "%d,%d\n" x (y*6)

Then we multiply the resulting counts by 6 to get the full total. As expected, the running time is now just 1/6 of before, about 3.5 seconds.

Why is this correct? Think of it this way: we’ve counted all the different patterns that start with a red. But there’s nothing special about red; if we repeated the exercise using blue instead we’d get exactly the same counts. So the counts are correct for every starting colour, we just have to add them up to get the true counts.

A direct approach

So far we haven’t used any real mathematical knowledge other than counting. (A bit of programming knowledge and spare CPU time can make up for a lot of mathematical ignorance.) If we just want to know how often we’d have to re-roll – without getting the distribution of “number of the most commonly occurring colour” – we can apply some high-school probability theory.

~~~{.haskell}
import Data.Ratio

factorial :: Integer -> Integer
factorial n = product [1..n]

choose :: Integer -> Integer -> Integer
n `choose` r = factorial n `div` (factorial r * factorial (n-r))

prob :: Rational -> Integer -> Integer -> Rational
prob p n t = p^n * (1-p)^(t-n) * fromIntegral (t `choose` n)

rerollProb :: Rational
rerollProb = 6 * sum [ prob (1%6) n 9 | n <- [5..9] ]

main :: IO ()
main = print rerollProb

That is, we compute the probability of getting 5, 6, 7, 8 or 9 “red” faces from a roll of 9 dice, then add those up. We multiply the final result by 6 for the same reason as before.

The answer is 22549 % 419904, which is the same as before in reduced form.

Dynamic Programming

But what if you really, truly, want the distribution of the number of the most commonly occurring colour, and you’re not willing to wait a few seconds to get it? Another approach is dynamic programming:

import Control.Monad
import Control.Monad.Trans.State
import Data.Function
import Data.List
import Data.Map (Map)
import qualified Data.Map as M
import Data.Ord
import Text.Printf

-- A sequence of colour counts.
-- E.g. if we've rolled [Red,Red,Green,Red,Blue], the pattern is [1,1,3]
--   (1 of one colour, 1 of another colour, and 3 of another colour)
-- Always kept in ascending order.
type Pattern = [Integer]

-- From a given pattern, what other patterns can we reach with a
-- single die roll?
-- Always returns a list of length 6, representing the six possible
-- outcomes of the roll.  (Note that two rolls may lead to the same
-- result.)
extend :: Pattern -> [Pattern]
extend [] = replicate 6 [1]
extend xs =
    -- Either you roll a new colour...
    (replicate (6-length xs) (1:xs))
    ++
    -- or you get another occurrence of an existing colour.
    ([ sort (inc i xs) | i <- [0..length xs-1]])


-- Increment an element in a list.
-- Inefficient use of list as array.
inc :: Int -> [Integer] -> [Integer]
inc i xs = take i xs ++ [(xs!!i) + 1] ++ drop (i+1) xs

-- Execute the dynamic programming algorithm.
-- A breadth-first traversal of the graph defined by "extend", rooted
-- at the empty pattern [].  Stops when it gets to 9 rolls.
-- The result is the Map in the State, which maps each pattern
-- to the number of ways there are to reach it.
loop :: [Pattern] -> State (Map Pattern Integer) ()
loop [] = pure ()
loop (p:ps) = do
  -- How mayn ways are there to get to p?
  ways <- (M.! p) <$> get
  -- Where do we go from p?
  let extensions = filter ((<=9) . sum) (extend p)
  -- For each out-going edge...
  forM_ extensions $ \e -> do
    -- ... there are now "ways" more ways to get there.
    modify (\m -> M.insertWith (+) e ways m)
  -- Continue the traversal.
  loop (nub (ps ++ extensions))

main :: IO ()
main = do
  let -- Get the mapping.
      m = execState (loop [[]]) (M.fromList [([], 1)])
      -- Keep only the 9-die patterns.
      pairs1 = filter ((==9) . sum . fst) (M.toList m)
      -- Classify them by their most commonly occurring colour (the
      -- last number in the pattern).
      pairs2 = map (\(p,n) -> (last p,n)) pairs1
      -- Group them by their classification.
      groups = groupBy ((==) `on` fst) $ sortBy (comparing fst) $ pairs2
      -- Sum up the counts in each group.
      counts = map (\g -> (fst (head g), sum (map snd g))) groups
  forM_ counts $ \(x,y) -> printf "%d,%d\n" x y

This gives the same result as the exhaustive counting from earlier, but in the more acceptable running time of a couple of milliseconds.