Radix sort
Ed’s discrimination
package seemed very interesting to me.
I was vaguely aware of sorting algorithms not based on comparison but I didn’t realize that you could
achieve such impressive asymptotics with them.
Radix sort seemed quite simple so I wanted to see how well it would perform.
There’s no point in explaining the algorithm here because I doubt I would do a better job than other people on the internet. Here’s a simple explanation with buckets being decimal digits.
I’ve went through 4 iterations of the algorithm and I’ll present them in order.
The common idea
All three approaches share the same mold.
A sort
function [Int] -> [Int]
that’s the interface. Handles initialization and final result extraction.
Function gather
that takes the output of the previous step of the algorithm and produces a list-like
structure that can be recursively passed to the algorithm again.
Function toList
that takes the final state of the algorithm and produces an actual list.
Finally, the radix
function that does a single step of the sort.
Different approaches use different structures for buckets. The bucket structure consists of two parts, the “bucket holder” and the buckets themselves. All three approaches seemed to perform best with 256 buckets.
The list bucket problem
The algorithm requires that we’re able to insert numbers into buckets and then retrieve them from oldest to newest. This is a basic queue with one important difference. We first do all the inserting, then do the iteration. This means that a simple list is good enough for insertions. We just need to reverse at the end.
This reverse is what I wanted to optimize away.
Plain list solution
The first iteration just reversed the list. I’ve used mutable vectors in the ST
monad to be my
“bucket holders”.
Here’s the code.
module Main where
import qualified Data.List
import Control.Monad
import Control.Monad.ST
import Data.Vector.Mutable (STVector)
import qualified Data.Vector.Mutable as Vec
import Data.Foldable (foldlM)
import Data.Bits
import System.TimeIt
import System.Random
totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8
bucketList :: [Int]
bucketList = reverse [0..numBuckets - 1]
sort :: [Int] -> [Int]
sort original = runST $ do
initial <- Vec.new numBuckets
reset initial
Vec.write initial 0 original --put the list in bucket 0 so we can
--express the algorithm as repeated iteration
mapM_ (radix initial) [0..(totalBits `div` bucketBits) - 1]
toList initial
gather :: STVector s [a] -> ST s [a]
gather vec = fmap (reverse . concat) $! mapM (Vec.read vec) bucketList
toList :: STVector s [a] -> ST s [a]
toList = gather
reset :: STVector s [a] -> ST s ()
reset vec = Vec.set vec []
radix :: STVector s [Int] -> Int -> ST s ()
radix vec offset = do
ll <- gather vec
reset vec
mapM_ ins ll
where ins x = do
let bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)
l <- Vec.read vec bts
Vec.write vec bts (x : l)
main :: IO ()
main = do
gen <- newStdGen
let list = take 1000000 $ randoms gen
print $ sum list
print "Radix"
timeIt $ print $ sum $ sort list
print "Standard"
timeIt $ print $ sum $ Data.List.sort list
This first iteration already showed great results, beating the default sorting algorithm. I’ll discuss that later.
I used sum
when testing to force the lists. Here are the times from one measurement (-O2
flags)
-6508836477411096561
"Radix"
-6508836477411096561
CPU time: 2.61s
"Standard"
-6508836477411096561
CPU time: 3.72s
Now to tackle the actual reversal problem. I thought that since I’m already in the ST
monad, why
not try implementing my own specialized linked lists. (I’ve tried just swapping the list for a sequence
before this, it performed worse)
Custom linked lists
My implementation is nothing to write home about. It’s just a node with a STRef
that points to the
next element.
{-# LANGUAGE ViewPatterns #-}
module LinkedListSpecial where
import Prelude hiding (mapM_)
import Control.Monad.ST
import Data.STRef
import Data.Foldable (mapM_, foldlM, forM_, foldl')
import System.TimeIt
import qualified Data.DList as DList
data LLN s a = Stub (STRef s (Maybe (LLN s a)))
| LLN a (STRef s (Maybe (LLN s a)))
getRef :: LLN s a -> STRef s (Maybe (LLN s a))
getRef (Stub ref) = ref
getRef (LLN _ ref) = ref
emptyNode :: ST s (LLN s a)
emptyNode = fmap Stub (newSTRef Nothing)
makeNode :: a -> ST s (LLN s a)
makeNode x = fmap (LLN x) $! newSTRef Nothing
append :: LLN s a -> a -> ST s (LLN s a)
append (getRef -> ref) x = do
new <- makeNode x
writeSTRef ref (Just new)
return new
iter :: (a -> ST s ()) -> LLN s a -> ST s ()
iter f (Stub ref) = do
next <- readSTRef ref
mapM_ (iter f) next
iter f (LLN x ref) = do
f x
next <- readSTRef ref
mapM_ (iter f) next
iterAll :: (a -> ST s ()) -> [LLN s a] -> ST s ()
iterAll f = mapM_ (iter f)
fromList :: [a] -> ST s (LLN s a, LLN s a)
fromList xs = do
f <- emptyNode
l <- foldlM append f xs
return (f, l)
collect :: LLN s a -> ST s [a]
collect (Stub ref) = do
next <- readSTRef ref
case next of
Nothing -> return []
Just n -> collect n
collect (LLN x ref) = do
next <- readSTRef ref
case next of
Nothing -> return [x]
Just n -> do xs <- collect n
return $! x : xs
collectAll :: [LLN s a] -> ST s [a]
collectAll = fmap concat . mapM collect
I have to say, the preliminary tests didn’t show great results. Generating a linked list and iterating through it consistently performed worse than making a normal list, reversing, then iterating.
Apparently it’s because GHC is optimized with the expectation that boxed references only get updated once and when you invalidate that expectation you pay the price. Or something. I think…
In any case, here’s the second iteration.
module Main where
import qualified Data.List
import Control.Monad
import Control.Monad.ST
import Data.Vector.Mutable (STVector)
import qualified Data.Vector.Mutable as Vec
import Data.Foldable (foldlM)
import Data.Bits
import System.TimeIt
import System.Random
import LinkedListSpecial (LLN)
import qualified LinkedListSpecial as LL
totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8
bucketList :: [Int]
bucketList = [0..numBuckets - 1]
sort :: [Int] -> [Int]
sort original = runST $ do
initial <- Vec.new numBuckets
reset initial
LL.fromList original >>= Vec.write initial 0 --put the list in bucket 0 so we can
--express the algorithm as repeated iteration
mapM_ (radix initial) [0..(totalBits `div` bucketBits) - 1]
toList initial
gather :: STVector s (LLN s a, LLN s a) -> ST s [LLN s a]
gather vec = mapM (fmap fst . Vec.read vec) bucketList
toList :: STVector s (LLN s a, LLN s a) -> ST s [a]
toList vec = do
lls <- gather vec
LL.collectAll lls
reset :: STVector s (LLN s a, LLN s a) -> ST s ()
reset vec = forM_ [0..Vec.length vec - 1] $ \i -> do
node <- LL.emptyNode
Vec.write vec i (node, node)
radix :: STVector s (LLN s Int, LLN s Int) -> Int -> ST s ()
radix vec offset = do
ll <- gather vec
reset vec
LL.iterAll ins ll
where ins x = do
let bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)
(f, l) <- Vec.read vec bts
newL <- LL.append l x
Vec.write vec bts (f, newL)
And the results
-4051746686150878325
"Radix"
-4051746686150878325
CPU time: 5.78s
"Standard"
-4051746686150878325
CPU time: 3.88s
sadface
LUCKILY, I was informed that there exists a version of a list that’s optimized to appending. A difference list. The concept is simple. You actually suspend the modifications as functions that you compose together any way you want and then just do them all in order when you want to finally produce a list. I used a package that provided them.
DLists
Right to the implementation.
module Main where
import Prelude hiding (mapM_)
import qualified Data.List
import Control.Monad hiding (mapM_)
import Control.Applicative
import Control.Monad.ST
import Data.Vector.Mutable (STVector)
import qualified Data.Vector.Mutable as Vec
import Data.Foldable (foldlM, mapM_)
import Data.Bits
import System.TimeIt
import System.Random
import Data.DList (DList)
import qualified Data.DList as DList
totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8
bucketList :: [Int]
bucketList = [0..numBuckets - 1]
sort :: [Int] -> [Int]
sort original = runST $ do
initial <- Vec.new numBuckets
reset initial
Vec.unsafeWrite initial 0 (DList.fromList original) --put the list in bucket 0 so we can
--express the algorithm as repeated iteration
mapM_ (radix initial) [0..(totalBits `div` bucketBits) - 1]
toList initial
gather :: STVector s (DList a) -> ST s (DList a)
gather vec = DList.concat <$> mapM (Vec.unsafeRead vec) bucketList
toList :: STVector s (DList a) -> ST s [a]
toList vec = DList.toList <$> gather vec
reset :: STVector s (DList a) -> ST s ()
reset vec = Vec.set vec DList.empty
radix :: STVector s (DList Int) -> Int -> ST s ()
radix vec offset = do
ll <- gather vec
reset vec
mapM_ ins ll
where ins x = do
let bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)
l <- Vec.unsafeRead vec bts
Vec.unsafeWrite vec bts $ l `DList.snoc` x
And the, pretty amazing, results
6741553814578555192
"Radix"
6741553814578555192
CPU time: 1.91s
"Standard"
6741553814578555192
CPU time: 3.72s
Twice as fast! Nice!
Finally, I’ve tried ditching the mutable vectors and going fully immutable with IntMap
s.
Spoiler alert: It’s another sadface unfortunately.
Maximum immutability
module Main where
import qualified Data.List
import Prelude hiding (mapM_)
import Data.Foldable (foldl')
import Data.Bits
import System.TimeIt
import System.Random
import Data.DList (DList)
import qualified Data.DList as DList
import Data.IntMap (IntMap)
import qualified Data.IntMap as Map
totalBits, numBuckets, bucketBits :: Int
totalBits = 64
numBuckets = 2^bucketBits
bucketBits = 8
bucketList :: [Int]
bucketList = [0..numBuckets - 1]
sort :: [Int] -> [Int]
sort original = toList $ foldl' radix start [0..(totalBits `div` bucketBits) - 1]
where start = Map.insert 0 (DList.fromList original) initial --put the list in bucket 0 so we can
--express the algorithm as repeated iteration
gather :: IntMap (DList a) -> DList a
gather m = DList.concat $! map (m Map.!) bucketList
toList :: IntMap (DList a) -> [a]
toList m = DList.toList $! gather m
initial :: IntMap (DList a)
initial = foldl' (\m i -> Map.insert i DList.empty m) Map.empty bucketList
radix :: IntMap (DList Int) -> Int -> IntMap (DList Int)
radix m offset = foldl' ins initial list
where list = gather m
ins m' x = Map.adjust (`DList.snoc` x) bts m'
where bts = shiftR x (offset * bucketBits) .&. (numBuckets - 1)
It’s a definite winner in terms of conciseness and it really did feel like a breath of fresh air when my functions finally started returning things instead of units. It’s not very performant though. Still faster than my linked lists though.
1830962452316129604
"Radix"
1830962452316129604
CPU time: 3.97s
"Standard"
1830962452316129604
CPU time: 3.80s
Comparing to Data.List.sort
At first glance it might seem like the default sort
function is pretty bad but there are a couple
of very important tradeoffs being made here. Firstly, my sorting only works for Int
s. Though it
could be extended to work for anything with a Bits
instance, it’s still much less general than being
able to sort anything thats in Ord
.
Secondly, it has a much larger memory footprint. There’s a lot of allocation happening. Sorting a million numbers allocated around 50 megabytes (don’t quote me on the number).
Thirdly, it’s not lazy. No matter if you want only the first 10 numbers from the list or all of them
it takes the same ammount of time. This isn’t true for Data.List.sort
which is nice and lazy.
All in all I’m still pretty happy that I managed to write something that outperforms the default implementation. I’ll probably come back to this subject later and see if we can implement some other algorithms or improve the above implementations.