Ed’s discrimination package seemed very interesting to me. I was vaguely aware of sorting algorithms not based on comparison but I didn’t realize that you could achieve such impressive asymptotics with them. Radix sort seemed quite simple so I wanted to see how well it would perform.

There’s no point in explaining the algorithm here because I doubt I would do a better job than other people on the internet. Here’s a simple explanation with buckets being decimal digits.

I’ve went through 4 iterations of the algorithm and I’ll present them in order.

The common idea

All three approaches share the same mold. A sort function [Int] -> [Int] that’s the interface. Handles initialization and final result extraction. Function gather that takes the output of the previous step of the algorithm and produces a list-like structure that can be recursively passed to the algorithm again. Function toList that takes the final state of the algorithm and produces an actual list.

Finally, the radix function that does a single step of the sort.

Different approaches use different structures for buckets. The bucket structure consists of two parts, the “bucket holder” and the buckets themselves. All three approaches seemed to perform best with 256 buckets.

The list bucket problem

The algorithm requires that we’re able to insert numbers into buckets and then retrieve them from oldest to newest. This is a basic queue with one important difference. We first do all the inserting, then do the iteration. This means that a simple list is good enough for insertions. We just need to reverse at the end.

This reverse is what I wanted to optimize away.

Plain list solution

The first iteration just reversed the list. I’ve used mutable vectors in the ST monad to be my “bucket holders”.

Here’s the code.

module Main where

import qualified Data.List
import Control.Monad
import Control.Monad.ST
import Data.Vector.Mutable (STVector)
import qualified Data.Vector.Mutable as Vec
import Data.Foldable (foldlM)
import Data.Bits
import System.TimeIt
import System.Random

totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8

bucketList :: [Int]
bucketList = reverse [0..numBuckets - 1]

sort :: [Int] -> [Int]
sort original = runST $ do
    initial <- Vec.new numBuckets
    reset initial
    Vec.write initial 0 original --put the list in bucket 0 so we can
                                 --express the algorithm as repeated iteration
    mapM_ (radix initial) [0..(totalBits `div` bucketBits) - 1]
    toList initial


gather :: STVector s [a] -> ST s [a]
gather vec = fmap (reverse . concat) $! mapM (Vec.read vec) bucketList

toList :: STVector s [a] -> ST s [a]
toList = gather

reset :: STVector s [a] -> ST s ()
reset vec = Vec.set vec []


radix :: STVector s [Int] -> Int -> ST s ()
radix vec offset = do
    ll <- gather vec
    reset vec
    mapM_ ins ll
    where ins x = do
              let bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)
              l <- Vec.read vec bts
              Vec.write vec bts (x : l)

main :: IO ()
main = do
    gen <- newStdGen
    let list = take 1000000 $ randoms gen
    print $ sum list
    print "Radix"
    timeIt $ print $ sum $ sort list
    print "Standard"
    timeIt $ print $ sum $ Data.List.sort list

This first iteration already showed great results, beating the default sorting algorithm. I’ll discuss that later.

I used sum when testing to force the lists. Here are the times from one measurement (-O2 flags)

-6508836477411096561
"Radix"
-6508836477411096561
CPU time:   2.61s
"Standard"
-6508836477411096561
CPU time:   3.72s

Now to tackle the actual reversal problem. I thought that since I’m already in the ST monad, why not try implementing my own specialized linked lists. (I’ve tried just swapping the list for a sequence before this, it performed worse)

Custom linked lists

My implementation is nothing to write home about. It’s just a node with a STRef that points to the next element.

{-# LANGUAGE ViewPatterns #-}
module LinkedListSpecial where

import Prelude hiding (mapM_)
import Control.Monad.ST
import Data.STRef
import Data.Foldable (mapM_, foldlM, forM_, foldl')
import System.TimeIt
import qualified Data.DList as DList

data LLN s a = Stub (STRef s (Maybe (LLN s a)))
             | LLN a (STRef s (Maybe (LLN s a)))

getRef :: LLN s a -> STRef s (Maybe (LLN s a))
getRef (Stub ref)  = ref
getRef (LLN _ ref) = ref

emptyNode :: ST s (LLN s a)
emptyNode = fmap Stub (newSTRef Nothing)

makeNode :: a -> ST s (LLN s a)
makeNode x = fmap (LLN x) $! newSTRef Nothing

append :: LLN s a -> a -> ST s (LLN s a)
append (getRef -> ref) x = do
    new <- makeNode x
    writeSTRef ref (Just new)
    return new

iter :: (a -> ST s ()) -> LLN s a -> ST s ()
iter f (Stub ref) = do
    next <- readSTRef ref
    mapM_ (iter f) next
iter f (LLN x ref) = do
    f x
    next <- readSTRef ref
    mapM_ (iter f) next

iterAll :: (a -> ST s ()) -> [LLN s a] -> ST s ()
iterAll f = mapM_ (iter f)

fromList :: [a] -> ST s (LLN s a, LLN s a)
fromList xs = do
    f <- emptyNode
    l <- foldlM append f xs
    return (f, l)

collect :: LLN s a -> ST s [a]
collect (Stub ref) = do
    next <- readSTRef ref
    case next of
        Nothing -> return []
        Just n  -> collect n
collect (LLN x ref) = do
    next <- readSTRef ref
    case next of
        Nothing -> return [x]
        Just n  -> do xs <- collect n
                      return $! x : xs

collectAll :: [LLN s a] -> ST s [a]
collectAll = fmap concat . mapM collect

I have to say, the preliminary tests didn’t show great results. Generating a linked list and iterating through it consistently performed worse than making a normal list, reversing, then iterating.

Apparently it’s because GHC is optimized with the expectation that boxed references only get updated once and when you invalidate that expectation you pay the price. Or something. I think…

In any case, here’s the second iteration.

module Main where

import qualified Data.List
import Control.Monad
import Control.Monad.ST
import Data.Vector.Mutable (STVector)
import qualified Data.Vector.Mutable as Vec
import Data.Foldable (foldlM)
import Data.Bits
import System.TimeIt
import System.Random
import LinkedListSpecial (LLN)
import qualified LinkedListSpecial as LL

totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8

bucketList :: [Int]
bucketList = [0..numBuckets - 1]

sort :: [Int] -> [Int]
sort original = runST $ do
    initial <- Vec.new numBuckets
    reset initial
    LL.fromList original >>= Vec.write initial 0 --put the list in bucket 0 so we can
                                                 --express the algorithm as repeated iteration
    mapM_ (radix initial) [0..(totalBits `div` bucketBits) - 1]
    toList initial


gather :: STVector s (LLN s a, LLN s a) -> ST s [LLN s a]
gather vec = mapM (fmap fst . Vec.read vec) bucketList

toList :: STVector s (LLN s a, LLN s a) -> ST s [a]
toList vec = do
    lls <- gather vec
    LL.collectAll lls

reset :: STVector s (LLN s a, LLN s a) -> ST s ()
reset vec = forM_ [0..Vec.length vec - 1] $ \i -> do
    node <- LL.emptyNode
    Vec.write vec i (node, node)


radix :: STVector s (LLN s Int, LLN s Int) -> Int -> ST s ()
radix vec offset = do
    ll <- gather vec
    reset vec
    LL.iterAll ins ll
    where ins x = do
              let bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)
              (f, l) <- Vec.read vec bts
              newL <- LL.append l x
              Vec.write vec bts (f, newL)

And the results

-4051746686150878325
"Radix"
-4051746686150878325
CPU time:   5.78s
"Standard"
-4051746686150878325
CPU time:   3.88s

sadface

LUCKILY, I was informed that there exists a version of a list that’s optimized to appending. A difference list. The concept is simple. You actually suspend the modifications as functions that you compose together any way you want and then just do them all in order when you want to finally produce a list. I used a package that provided them.

DLists

Right to the implementation.

module Main where

import Prelude hiding (mapM_)
import qualified Data.List
import Control.Monad hiding (mapM_)
import Control.Applicative
import Control.Monad.ST
import Data.Vector.Mutable (STVector)
import qualified Data.Vector.Mutable as Vec
import Data.Foldable (foldlM, mapM_)
import Data.Bits
import System.TimeIt
import System.Random
import Data.DList (DList)
import qualified Data.DList as DList

totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8

bucketList :: [Int]
bucketList = [0..numBuckets - 1]

sort :: [Int] -> [Int]
sort original = runST $ do
    initial <- Vec.new numBuckets
    reset initial
    Vec.unsafeWrite initial 0 (DList.fromList original) --put the list in bucket 0 so we can
                                                        --express the algorithm as repeated iteration
    mapM_ (radix initial) [0..(totalBits `div` bucketBits) - 1]
    toList initial


gather :: STVector s (DList a) -> ST s (DList a)
gather vec = DList.concat <$> mapM (Vec.unsafeRead vec) bucketList

toList :: STVector s (DList a) -> ST s [a]
toList vec = DList.toList <$> gather vec

reset :: STVector s (DList a) -> ST s ()
reset vec = Vec.set vec DList.empty


radix :: STVector s (DList Int) -> Int -> ST s ()
radix vec offset = do
    ll <- gather vec
    reset vec
    mapM_ ins ll
    where ins x = do
              let bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)
              l <- Vec.unsafeRead vec bts
              Vec.unsafeWrite vec bts $ l `DList.snoc` x

And the, pretty amazing, results

6741553814578555192
"Radix"
6741553814578555192
CPU time:   1.91s
"Standard"
6741553814578555192
CPU time:   3.72s

Twice as fast! Nice!

Finally, I’ve tried ditching the mutable vectors and going fully immutable with IntMaps. Spoiler alert: It’s another sadface unfortunately.

Maximum immutability

module Main where

import qualified Data.List
import Prelude hiding (mapM_)
import Data.Foldable (foldl')
import Data.Bits
import System.TimeIt
import System.Random
import Data.DList (DList)
import qualified Data.DList as DList
import Data.IntMap (IntMap)
import qualified Data.IntMap as Map

totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8

bucketList :: [Int]
bucketList = [0..numBuckets - 1]

sort :: [Int] -> [Int]
sort original = toList $ foldl' radix start [0..(totalBits `div` bucketBits) - 1]
    where start = Map.insert 0 (DList.fromList original) initial --put the list in bucket 0 so we can
                                                                 --express the algorithm as repeated iteration

gather :: IntMap (DList a) -> DList a
gather m = DList.concat $! map (m Map.!) bucketList

toList :: IntMap (DList a) -> [a]
toList m = DList.toList $! gather m

initial :: IntMap (DList a)
initial = foldl' (\m i -> Map.insert i DList.empty m) Map.empty bucketList

radix :: IntMap (DList Int) -> Int -> IntMap (DList Int)
radix m offset = foldl' ins initial list
    where list = gather m
          ins m' x = Map.adjust (`DList.snoc` x) bts m'
              where bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)

It’s a definite winner in terms of conciseness and it really did feel like a breath of fresh air when my functions finally started returning things instead of units. It’s not very performant though. Still faster than my linked lists though.

1830962452316129604
"Radix"
1830962452316129604
CPU time:   3.97s
"Standard"
1830962452316129604
CPU time:   3.80s

Comparing to Data.List.sort

At first glance it might seem like the default sort function is pretty bad but there are a couple of very important tradeoffs being made here. Firstly, my sorting only works for Ints. Though it could be extended to work for anything with a Bits instance, it’s still much less general than being able to sort anything thats in Ord.

Secondly, it has a much larger memory footprint. There’s a lot of allocation happening. Sorting a million numbers allocated around 50 megabytes (don’t quote me on the number).

Thirdly, it’s not lazy. No matter if you want only the first 10 numbers from the list or all of them it takes the same ammount of time. This isn’t true for Data.List.sort which is nice and lazy.

All in all I’m still pretty happy that I managed to write something that outperforms the default implementation. I’ll probably come back to this subject later and see if we can implement some other algorithms or improve the above implementations.