Thursday, October 5, 2017

Kata - implementing a functional List data structure in Kotlin

I saw an exercise in chapter 3 of the excellent Functional Programming in Scala book which deals with defining functional data structures and uses the linked list as an example on how to go about developing such a datastructure. I wanted to try this sample using Kotlin to see to what extent I can replicate the sample.

A scala skeleton of the sample is available in the companion code to the book here and my attempt in Kotlin is heavily inspired (copied!) by the answerkey in the repository.


This is what a basic List representation in Kotlin looks like:

sealed class List<out A> {

    abstract val head: A

    abstract val tail: List<A>

data class Cons<out T>(override val head: T, override val tail: List<T>) : List<T>()

object Nil : List<Nothing>() {
    override val head: Nothing
        get() {
            throw NoSuchElementException("head of an empty list")

    override val tail: List<Nothing>
        get() {
            throw NoSuchElementException("tail of an empty list")

the List has been defined as a sealed class, this means that all subclasses of the sealed class will be defined in the same file. This is useful for pattern matching on the type of an instance and will come up repeatedly in most of the functions.

There are two implementations of this List -
1. Cons a non-empty list consisting of a head element and a tail List,
2. Nil an empty List

This is already very useful in its current form, consider the following which constructs a List and retrieves elements from it:

val l1:List<Int> = Cons(1, Cons(2, Cons(3, Cons(4, Nil))))
assertThat(l1.tail).isEqualTo(Cons(2, Cons(3, Cons(4, Nil))))

val l2:List<String> = Nil

Pattern Matching with "when" expression

Now to jump onto implementing some methods of List. Since List is a sealed class it allows for some good pattern matching, say to get the sum of elements in the List:

fun sum(l: List<Int>): Int {
    return when(l) {
        is Cons -> l.head + sum(l.tail)
        is Nil -> 0

The compiler understands that Cons and Nil are the only two paths to take for the match on a list instance.

A little more complex operation, "drop" some number of elements from the beginning of the list and "dropWhile" which takes in a predicate and drops elements from the beginning matching the predicate:

fun drop(n: Int): List<A> {
    return if (n <= 0)
    else when (this) {
        is Cons -> tail.drop(n - 1)
        is Nil -> Nil

val l = list(4, 3, 2, 1)
assertThat(l.drop(2)).isEqualTo(list(2, 1))

fun dropWhile(p: (A) -> Boolean): List<A> {
    return when(this) {
        is Cons -> if (p(this.head)) this.tail.dropWhile(p) else this
        is Nil -> Nil

val l = list(1, 2, 3, 5, 8, 13, 21, 34, 55, 89)
assertThat(l.dropWhile({e -> e < 20})).isEqualTo(list(21, 34, 55, 89))

These show off the power of pattern matching with the "when" expression in Kotlin.

Unsafe Variance!

To touch on a wrinkle, see how the List is defined with a type parameter that is declared as "out T", this is called the "declaration site variance" which in this instance makes List co-variant on type T. Declaration site variance is explained beautifully with the Kotlin documentation. With the way List is declared, it allows me to do something like this:

val l:List<Int> = Cons(1, Cons(2, Nil))
val lAny: List<Any> = l

Now, consider an "append" function which appends another list:

fun append(l: List<@UnsafeVariance A>): List<A> {
    return when (this) {
        is Cons -> Cons(head, tail.append(l))
        is Nil -> l

here a second list is taken as a parameter to the append function, however Kotlin would flag the parameter - this is because it is okay to return a co-variant type but not to take it as a parameter. However since we know the List in its current form is immutable, I can get past this by marking the type parameter with "@UnsafeVariance" annotation.


Folding operations allow the list to be "folded" into a result based on some aggregation on individual elemnents in it.

Consider foldLeft:

fun <B> foldLeft(z: B, f: (B, A) -> B): B {
    tailrec fun foldLeft(l: List<A>, z: B, f: (B, A) -> B): B {
        return when (l) {
            is Nil -> z
            is Cons -> foldLeft(l.tail, f(z, l.head), f)

    return foldLeft(this, z, f)

If a list were to consist of elements (2, 3, 5, 8) then foldLeft is equivalent to "f(f(f(f(z, 2), 3),5),8)"

With this higher order function in place, the sum function can expressed this way:

val l = Cons(1, Cons(2, Cons(3, Cons(4, Nil))))
assertThat(l.foldLeft(0, {r, e -> r + e})).isEqualTo(10)

foldRight looks like the following in Kotlin:

fun <B> foldRight(z: B, f: (A, B) -> B): B {
    return when(this) {
        is Cons -> f(this.head, tail.foldRight(z, f))
        is Nil -> z
If a list were to consist of elements (2, 3, 5, 8) then foldRight is equivalent to "f(2, f(3, f(5, f(8, z))))"

This version of the foldRight, though cooler looking is not tail recursive, a more stack friendly version can be implemented using the previously defined tail recursive foldLeft by simply reversing the List and calling foldLeft internally the following way:

fun reverse(): List<A> {
    return foldLeft(Nil as List<A>, { b, a -> Cons(a, b) })

fun <B> foldRightViaFoldLeft(z: B, f: (A, B) -> B): B {
    return reverse().foldLeft(z, { b, a -> f(a, b) })

map and flatMap

map is a function which transforms the element of this list:

fun <B> map(f: (A) -> B): List<B> {
    return when (this) {
        is Cons -> Cons(f(head),
        is Nil -> Nil

An example of using this function is the following:
val l = Cons(1, Cons(2, Cons(3, Nil)))
val l2 = { e -> e.toString() }
assertThat(l2).isEqualTo(Cons("1", Cons("2", Cons("3", Nil))))

A variation of map where the transforming function returns another list, and the final results flattens everything, best demoed using an example after the implementation:

fun <B> flatMap(f: (a: A) -> List<@UnsafeVariance B>): List<B> {
    return flatten(map { a -> f(a) })

companion object {
    fun <A> flatten(l: List<List<A>>): List<A> {
        return l.foldRight(Nil as List<A>, { a, b -> a.append(b) })

val l = Cons(1, Cons(2, Cons(3, Nil)))

val l2 = l.flatMap { e -> list(e.toString(), e.toString()) }

                Cons("1", Cons("1", Cons("2", Cons("2", Cons("3", Cons("3", Nil)))))))

This covers the basics involved in implementing a functional list datastructure using Kotlin, there were a few rough edges when compared to the scala version but I think it mostly works. Admittedly the sample can likely be improved drastically, if you have any observations on how to improve the code please do send me a PR at my github repo for this sample or as comment to this post.

No comments:

Post a Comment