Skip to main content

Probability Monads from scratch in 100 lines of Haskell

The final code is on Github.

I recently spent some time trying to learn Haskell, when I stumbled across the concept of probabilistic functional programming, originally based on this PFP library. Due to Haskell’s syntax sugar for monad computations, expressing probability distributions as monads allows you to implement something like a mini probabilistic programming language in less than 100 lines of code. I thought that’s pretty neat, so I decided to write this post that walks through a simple implementation from scratch.

The result is not comparable to a real probabilistic programming language in terms of performance or expressiveness, but I think it’s a great learning tool to get an intuitive understanding for manipulating probability distributions. If you are not super familiar with Haskell or monads, you should still be able to get the general idea, though some of the terse syntax may look a bit confusing.

Defining Distributions #

Let’s start by defining a probability distribution. In this post, we’ll only work with discrete distributions. People have extended the same approach to work with continuous distributions, but that’s beyond the scope of this post. As our representation, we use a simple list of (value, probability) tuples, where value could be anything (a generic). For example, a fair coin flip could be represented as [(Heads, 0.5), (Tails, 0.5)] or [(1, 0.5), (0, 0.5)].

-- | We use a Double as the representation for probabilities
type Prob = Double

-- | A distribution is represented as a list of possible values
-- and their probabilities
newtype Dist a = Dist [(a, Prob)]

-- | Helper function to access the inner list wrapped by the distribution
unpackDist :: Dist a -> [(a, Prob)]
unpackDist (Dist xs) = xs

A valid probability distribution must adhere to two laws:

  • All probabilities must be greater or equal than 0
  • The sum of all probabilities must be 1

We are not checking these laws when we create a Distribution. As long as our operations conserve these laws, we only need to check them when initially creating distributions, which usualy happens through helper functions that we’ll define later.

Note that nothing prevents our Distribution from having duplicate values. For example, both of

  • [(Heads, 0.5), (Tails, 0.5)]
  • [(Heads, 0.3), (Heads, 0.2), (Tails, 0.5)]

are valid representations of the same distribution. There’s nothing inherently wrong with this in terms of correctness, but more elements mean extra computational cost. And when we print our distribution, we don’t want to see duplicates. So let’s define a helper function that combines equal values by summing up their probabilities. We do this by converting the values to a Map, combining them via (+) and then converting back to a list. Probably not the most efficient way, but we don’t care about performance for our little experiment.

import qualified Data.Map as M

-- | Combines outcomes that occur multiple time
squishD :: (Ord a) => Dist a -> Dist a
squishD (Dist xs) = Dist $ M.toList $ M.fromListWith (+) xs

We also need a helper function to normalize probabilities to 1. That will come in handy when creating distributions later. As a convention, we’ll suffix functions that act on distributions with D and functions that act directly on a list of values and probabilities with P. Because a Distribution just wraps a list these are the identical, but sometimes we need to manipulate the inner list directly.

-- | Sum all probilities in the given list
sumP :: [(a, Prob)] -> Prob
sumP = sum . map snd

-- | Normalize the probabilities to 1.0
normP :: [(a, Prob)] -> [(a, Prob)]
normP xs = [(x, p / q) | let q = sumP xs, (x, p) <- xs]

Before we start working with distributions, we need a way to print them. The following code implements a Show instance for Distribution. This allows us to convert it to a string. It looks a bit complicated, but most of the code calculates the padding so that we get nicely aligned output. Note that we call our squisD function here to make sure we don’t see duplicate values.

instance (Show a, Ord a) => Show (Dist a) where
  show d = concatMap showRow $ (unpackDist . squishD) d
    where
      showRow (elem, prob) = padded elem ++ " | " ++ printf "%.4f" prob ++ "\n"
      padded elem = replicate (maxElemLen - (length . show) elem) ' ' ++ show elem
      maxElemLen = maximum $ map (length . show . fst) (unpackDist d)

Finally, let’s define an Event as a function that maps an outcome to a boolean value. For example, the event that a roll of a fair 6-sided is even has a probability of 50%.

-- | An event maps an outcome to a truth value
type Event a = a -> Bool

To compute the probability of an event, we sum the probabilities for the values where the event function evaluates to true.

-- | Evaluate the probability for the given event
evalD :: Event a -> Dist a -> Prob
evalD p = sumP . filter (p . fst) . unpackDist

Let’s look at some examples of what we can do with what we’ve implemented so far. A few helper functions for creating distributions:

-- | Create a uniform distribution
uniform :: [a] -> Dist a
uniform xs = Dist . normP $ map (,1.0) xs

-- | A fair n-sided die
die :: Int -> Dist Int
die n = uniform [1 .. n]

-- | A coin that lands on x with probability f and y with probability 1-f
coin :: Prob -> a -> a -> Dist a
coin f x y
  | f < 0.0 || f > 1.0 = error "f must be between 0 and 1"
  | otherwise = Dist [(x, f), (y, 1 - f)]
-- A 6-sided die
λ> die 6
1 | 0.1667
2 | 0.1667
3 | 0.1667
4 | 0.1667
5 | 0.1667
6 | 0.1667

-- A coin with probably 0.3 of coming up heads
λ> coin 0.3 True False
False | 0.7000
 True | 0.3000
 
-- Probability of rolling an even number with a 5-sided die
λ> evalD even $ die 5
0.4

Functors and Marginalizing #

With the basic definitions in place, let’s move on to something more interesting. Before we make Distribution a Monad, we need to make it a Functor and Applicative. You’ve probably seen the idea behind functors in other programming languages. A functor needs to implement the fmap function, which can also be used through the <$> operator. If you’ve ever used map in other languages, that’s the same thing. In Haskell, map is specific to lists, while fmap is the generalized version that can be applied to all functors. For example, List has an instance of functor that allows us to apply a function to each element. You can think of a functor as a kind of container, and fmap allows us to lift a function over the structure and apply it to values inside the container. For our distribution this means mapping existing values to new values without touching their probabilities.

-- We  apply the given function to each value in the distribution
instance Functor Dist where
  fmap f (Dist xs) = Dist $ [(f x, p) | (x, p) <- xs]

That doesn’t seem terribly useful at first, but by combining it with our squishD function, we’ve just implemented Marginalization, or summing over variables we don’t care about. For example, if we have a distribution of two variables (a, b), we can fmap the (a,b) tuples to a and we’ll end up with the marginal distribution:

sample =
  Dist
    [ ((0, 0), 0.1),
      ((0, 1), 0.2),
      ((1, 0), 0.3),
      ((1, 1), 0.4)
    ]


λ> sample
(0,0) | 0.1000
(0,1) | 0.2000
(1,0) | 0.3000
(1,1) | 0.4000

-- Map the tuple to the first value (sum over the second)
λ> fst <$> sample
0 | 0.3000
1 | 0.7000

-- Map the tuple to the second value (sum over the first)
λ>  snd <$> sample
0 | 0.4000
1 | 0.6000

Combining Distributions with Applicative Functors #

A common operation is to combine distributions. Let’s say we perform 2 independent coin flips and we want to know the joint distribution of this experiment. That is, we want to implement the Chain Rule

  • In the general case, P(x,y) = P(x|y)P(y)
  • If x and y are independent, P(x,y) = P(x)P(y)

fmap doesn’t really help us here, because it only lifts a function only over a single structure. What we want is a common use case for Applicative, which is also a requirement for being a Monad. Lying somewhere between functors and monads, appliative functors can be difficult to grasp at first. It took a while and looking at some examples before it clicked for me. I recommend checking out the Haskell Typeclassopedia entry and original paper for a deeper dive. Because it’s a bit out of scope I will only give a brief explanation here that doesn’t quite do it justice.

The main thing that Applicative gives us is the <*> operator (pronounced app). While functor allowed us to apply a simple function to a distribution of values, an applicative functors allows us to apply a distribution of functions to a distribution of values.

How is that related to combining distributions? It turns out that this is a special case of what we can do with Applicative. To understand this, it’s best to look at an example. But first we need to implement and instance of Applicative for Distribution just like we did with Functor:

instance Applicative Dist where
  -- pure :: a -> Dist a
  pure x = Dist [(x, 1.0)]

  -- (<*>) :: Dist (a -> b) -> Dist a -> Dist b
  (Dist fs) <*> (Dist xs) = Dist $ do
    (x, px) <- xs
    (f, pf) <- fs
    return (f x, px * pf)

To implement <*> we create a nested loop over both distributions, the one with the functions and the one with the values. For each combination we multiply the probabilities and apply the function to the value. We also need to implement the pure function which is a way to convert any value into a Distribution. Here’s how we can use this:

--- Distribution of rolling a 5-sided die and then rolling a 4-sided die
λ> (,) <$> (die 5) <*> (die 4)
(1,1) | 0.0500
(1,2) | 0.0500
(1,3) | 0.0500
(1,4) | 0.0500
(2,1) | 0.0500
(2,2) | 0.0500
(2,3) | 0.0500
(2,4) | 0.0500
(3,1) | 0.0500
(3,2) | 0.0500
(3,3) | 0.0500
(3,4) | 0.0500
(4,1) | 0.0500
(4,2) | 0.0500
(4,3) | 0.0500
(4,4) | 0.0500
(5,1) | 0.0500
(5,2) | 0.0500
(5,3) | 0.0500
(5,4) | 0.0500

To understand (,) <$> (die 5) <*> (die 4), let’s look at it piece by piece:

  • (,) is the tuple constructor. It takes two arguments. For example, (,) 1 2 evaluates to (1, 2)
  • (,) <$> (die 5) fmaps the tuple constructor over the (die 5) Distribution. Because the constructor takes two arguments, but we only provide one (each outcome in the die 5 distribution), this creates a Distribution of partially applied functions, where each function is waiting for one more argument to create a tuple.
  • ... <*> (die 4) applies the distribution of partially applied functions from above to the (die 4) distribution and yields a distribution of tuples. This is where we use Applicative.

This pattern is quite common and a nicer way to write the above is liftA2 (,) (die 5) (die 4) where liftA2 is a function provided by the standard Control.Applicative library. Another example is the sum of rolling two dice:

λ> liftA2 (+) (die 6) (die 6)
 2 | 0.0278
 3 | 0.0556
 4 | 0.0833
 5 | 0.1111
 6 | 0.1389
 7 | 0.1667
 8 | 0.1389
 9 | 0.1111
10 | 0.0833
11 | 0.0556
12 | 0.0278

We can also implement a Binomial Distribution using this approach. This distribution describes the total number of successes in a series of n independent experiments, e.g. the total number of heads when flipping a coin n times. To do this, we create a list of n coin flip distributions where a success corresponds 1 and failure to 0. Then we fold (reduce) over this list using our Applicative instance and the (+) operator just like we did in the sum of rolling two dice above:

-- | Binomial distribution with n experiments and success probability p
binom :: Int -> Prob -> Dist Int
binom n p = foldl1 (\x y -> squishD (liftA2 (+) x y)) $ replicate n (coin p 1 0)

For efficiency reasons we needed to put a squishD in there. Otherwise the number of values in our final distribution becomes exponential since each step multiplies two distributions.

λ> binom 10 0.3
 0 | 0.0282
 1 | 0.1211
 2 | 0.2335
 3 | 0.2668
 4 | 0.2001
 5 | 0.1029
 6 | 0.0368
 7 | 0.0090
 8 | 0.0014
 9 | 0.0001
10 | 0.0000

Of course this is a rather inefficient way to implement a binomial distribution, but it’s simple and intuitive! The code mirrors exactly what a binomial distribution means: Summing up a number of coin flips successes.

Monadic Distributions, finally! #

With our Distribution being a functor and applicative functor, we can now finally make it a Monad. Monads give us some additional power of Applicative and allow us to use Haskell’s do-notation, which we’ll see later. More power you say? Yep. With Applicative there was no way to express dependencies. For example, if we roll two dice, there was no way to make the second die roll a function of the first. That’s the general story with applicative functors and monads for other types as well. Monads allow us to chain computations. In the context of probability distributions this means we can work with dependent distributions, while Applicative only let us work with independent distributions.

To make our distribution a monad, we need to implement the bind operator >>=. This operator applies a function that returns a Distribution to a Distribution and returns another distribution. Another (in my opinion more intuitive) way to think about this operation is as a combination of fmap and join. The canonicial example here is a list. If you map a function that returns a list over a list of values, you get a list of lists, and then you flatten (or join) this nested list get a single list. In some languages, this operation is called flatMap. The bind operator does something very similar, and one can be expressed in terms of the other. The key is that we’re defining how we combine nested distributions so that we get a single “flat” distribution in the end.

instance Monad Dist where
  -- (>>=) :: Dist a -> (a -> Dist b) -> Dist b
  (Dist xs) >>= f = Dist $ do
    (x, p) <- xs
    (y, p') <- unpackDist (f x)
    return (y, p * p')

Our instance of Monad looks very similar to what we did with Applicative in that we have a nested loop. But note that we feed the values into the function that returns another Distribution. This is what allows us to create a new distribution that depends on the values of another distribution. With this, we can now express experiments in a generative language. For example, here is an experiment where we roll a die and flip a different coin depending on whether we rolled a 6 or not. Note that we are able to use Haskell’s do-notation because we’ve made our distribution a monad. Under the hood, do-notation is just syntax sugar around the bind operator >>= that we’ve implemented.

conditionalCoin = do
  number <- die 6
  if number == 6
    then coin 0.5 1 0
    else coin 0.1 1 0
    
λ> conditionalCoin
0 | 0.8333
1 | 0.1667

Conditioning #

An operation that we haven’t discussed so far is conditioning on an event. Often, we want to create a new conditional probability distribution by introducing some evidence we’ve observed. This operation is also the idea behind Bayes’ rule. We can think of the conditioning operation as two steps:

  • Remove all outcomes (elements in our distribution) that conflict with the evidence. This gives us a new unnormalized probability space.
  • Renormalize all probabilities in this new space.

Wikipedia has a nice visualization of this idea.

-- Condition a distribution on an event
condD :: (a -> Bool) -> Dist a -> Dist a
condD f (Dist xs) = Dist . normP $ filter (f . fst) xs

We can express the example from Wikipedia as follows:

Suppose that somebody secretly rolls two fair six-sided dice, and we wish to compute the probability that the face-up value of the first one is 2, given the information that their sum is no greater than 5.

λ> evalD ((==2) . fst) $ condD ((<= 5) . uncurry (+)) $ liftA2 (,) (die 6) (die 6)
0.29999999999

Example - Bayes’ Rule #

No probability tutorial would be complete without the medical exam Bayes rule example! You’ve probably seen a variation of this question in a textbook before. It goes something like this:

Jo has took a test for a disease. The result of the test is either positive or negative and the test is 95% reliable: in 95% of cases of people who really have the disease, a positive result is returned, and in 95% of cases of people who do not have the disease, a negative result is obtained. 1% of people of Jo’s age and background have the disease. Jo took the test, and the result is positive. What is the probability that Jo has the disease?

With our Monad abstraction that’s easy to solve now. First we create the join distribution (hasDisease, testPositive) using do-notation, then we condition on a positive test result (second element in the tuple). Finally, we evaluate the probability of the hasDisease event, which the first element in the tuple.

bayesMedicalTest = evalD fst . condD snd $ do
  hasDisease <- coin 0.01 True False
  testPositive <-
    if hasDisease
      then coin 0.95 True False
      else coin 0.05 True False
  return (hasDisease, testPositive)
λ> bayesMedicalTest
0.16101694915254236

The probability that Joe has the disease is only ~16% despite the tet being 95% reliable.

You can imagine write much more complex generative processes using do-notation. It works like a mini-DSL that multiplies probabilities behind the scenes.

A side note on list monads #

In Haskell, the List monad is used to express nondeterministic computations without associated probabilities. What we have done here can be seen as nothing more than a small extension of the list Monad by adding probabilities to it. We could remove the probabilities and model our computations as a list of samples that are drawn from the distribution we want. For example, A coin flip with probability of 0.2 coming up heads could be expanded into [True, False, False, False, False] instead of [(True, 0.2), (False, 0.8)]. You could imagine taking this to the limit to model more complex distributions. We could then use the list of samples to perform computations, using exactly the same monadic code we wrote above, and we would get back a new list of samples that we could use to estimate the joint distribution.

Closing thoughts #

We’ve only scratched the surface of the kind of abstractions we can build by modeling probability distributions as monads. Because Haskell provides us with powerful tools and abstractions that work with any monad, we can use them to build significantly more complex models than we’ve done in this post. For example, we could extend our code to model Bayesian Networks or Markov Decision Processes, or use sampling instead of computing full joint distributions. If you want to learn more, here are a few resources: