From loop invariants to recursion invariants

Until very recently, the major commercial programming languages were based on the idea of update in place. C, C++, Pascal, Java etc presuppose that the way to solve a programming problem is to have procedures that read and write values to memory locations, be it directly visible to the programmer (by means of pointers) or not.

Other languages, like Lisp, are based on a different assumption: that a computation is best expressed as a set of functions over immutable values. So, in order to produce any kind of repetitive computation, recursion is essential. This assumption has strong roots in lambda calculus and related mathematical formalisms.

Probably due to the popularity of the procedural paradigm, most programmers find it easier to reason about loops than about recursion. But in order to reason effectively about a program, the functional model is far superior. This is what I intend to show in this post.

Programming models

The programming model behind procedural code is that of a machine that crunches data for some time and eventually stops. When it stops, we take a look at whatever came up from all this crunching and accept it as the result of the procedure. In some sense, this is the most natural model, since the underlying implementation of the algorithm runs on a real machine that does exactly that.

The problem is that it would be desirable to conduct some kind of formal reasoning about the code at hand. The same kind we use in mathematics, for example. But the machine model relies heavily on mutation of the contents of chunks of memory over time, which is not exactly ameneable to a mathematical analysis.

A neat solution to this problem is the use of loop invariants (cf. “Introduction to Algorithms”). A loop invariant is a logical predicate that we ascribe to an algorithm. That predicate remains true from the beginning to the end of the algorithm execution, including all the loops that are run, despite the fact that those loops might be changing things. Loop invariants have to take time into account (not clock time, but the discrete time defined by the sequence of operations made by the algorithm) and the values of the variables at each point in time.

Binary search – iterative version

Let’s take a simple and common example to illustrate the point: binary search. Given a sorted array and an element, the binary search algorithm returns the index of the element in the array, or -1 if the element is not found. Intuitively, what the algorithm does is: at each iteration, it cuts down the search space to a half, always looking at the element in the middle. The size of the search space tends to zero as the algorithm runs. So, either the element is eventually found or the search space becomes empty and the element can be declared not found. Here is an implementation of the algorithm in Scala:

def search[T <% Ordered[T]](a: Array[T], elem: T): Int = {
  var start = 0
  var end = a.length - 1
  while (start <= end) { 
    val mid = (start + end) / 2 
    if (elem == a(mid)) return mid 
    else if (elem > a(mid)) start = mid + 1
    else end = mid - 1
  }
  -1
}

We can prove the intuition behind binary search using loop invariants. The property that remains true throughout its execution is:

The subarray a(start..end) contains elem if, and only if, a contains elem. (P1)

Note that we are abusing the Scala notation a bit here, when we use a(i..j) to indicate a subarray of a from i to j, inclusive on both ends.

Proof

Loop invariants have to be proved in three main sections: initiation (the initial values assigned to the variables satisfy the property), maintenance (at the beginning of each loop, the property is satisfied) and termination (when the algorithm stops, what it returns matches up with the expected result). So, here are the three sections for the iterative version of the binary search algorithm:

Initiation: lines 2 and 3 define an interval that comprises the whole array, which is equivalent to say that a contains elem if, and only if, a contains elem. So, (P1) is trivially satisfied.

Maintenance: to prove the maintenance, we have to prove two things: a) that if a contains elem then a(start..end) also contains elem; and b) that if a does not contain elem, neither does a(start..end). It is easy to see that part b) is true, because no matter what values we assign to start or end, they will define a subarray of a, even if that subarray is empty. So if elem is not in a, it cannot be in any subarray of a.

Part a) is slightly more complex. For there to be a loop in the first place, the condition in line 6 has to be false (more on this in the next paragraph). In other words, elem is either strictly greater than, or less or equal to, the element in the middle (recall that the array is sorted). If it is greater, start will be updated to mid + 1, effectively dropping the first half of the array, so that, when line 4 is evaluated again, the interval will consist of the second half of the array. And it is at this point that the loop invariant is maintained. The first half of the array could not possibly contain elem, so if it is in the original array, it will also be in the second half, defined by the new values of start and end, satisfying (P1). For the case where elem is less or equal to a(mid), the proof is analogous.

Termination: there are two possible situations in which the procedure to stop: when condition in line 4 is false or condition in line 6 is true. Let’s start with the former. start &lt;= end being false implies an empty array. By definition, empty arrays do not contain any element and, in particular, do not contain elem. So it will return -1, which is the correct result in this case. In the latter case, elem == a(mid) means that we have actually found the element we were looking for, and the result in this case is mid.

Even though loop invariant proofs are mathematically precise techniques to show the correctness of an algorithm, they are somewhat cumbersome to work with, for two main reasons:

  1. Time is not readily visible in the code. At some point in time, a set of variables may have some value and then in the next point in time, the same set of variables may have a different value.
  2. Since variables are mutable, you have to keep track of all the updates these variables might go through, everywhere they are used. And if there is concurrency involved, you also have to keep track of how different threads or processes might interact when updating them, which may be unfeasible at scale.

Binary search – recursive version

Now let’s look at a recursive version of the same algorithm:

def search[T <% Ordered[T]](a: Array[T], elem: T): Int = { 
  def doSearch(start: Int, end: Int): Int =
    if (start > end) -1
    else {
      val mid = (start + end) / 2
      if (elem == a(mid)) mid
      else if (elem > a(mid)) doSearch(mid + 1, end)
      else doSearch(start, mid - 1)
    }
  doSearch(0, a.length - 1)
}

It is the same algorithm, with the same runtime complexity and even the same actual running time. What is different is the programming model. Instead of a machine that changes data in place, we can think of this function as a kind of generator of an immutable sequence of immutable states. The states, in this case, are the tuples (not in the sense of a Scala syntax construct, but in the more general sense of an ordered list of elements) that make up the function’s list of parameters.

For example, suppose you call the recursive function with these parameters:

val a = Array(4, 9, 28, 37, 40, 50, 52, 57, 60, 61, 68, 71, 74, 76, 82, 87, 92, 98)
val elem = 61
search(a, elem)

The sequence of states that comprise the execution of the algorithm for the input above is:

(0, 17) → (9, 17) → (9, 12) → (9, 9)

The arrow represents a recursive call, that yields the next state. And every state in that sequence satisfies the property (P1). What we have here, then, is a recursion invariant, that is, an invariant property that is guaranteed to be preserved between recursive calls. The way to prove a recursion invariant is basically the same as we would a loop invariant: initiation, maintenance and termination. But instead of thinking in terms of how a loop changes certain variables, we think of states and the relationship between consecutive states.

In particular, we interested in the first state (the initiation), the last state (the termination, which can be straightforwardly transformed to the final result) and the inductive step of generating a new state from the current state, assuming that the current one satisfies the invariant. No mutation of variables and no notion of time to worry about; just the sequence of states.

And it gets even better: we don’t even need to think about the sequence itself. It suffices to establish the relationship between input and output. In terms of code, we have to establish the relationship between the first and second parameters in line 2 with the first and second parameters in lines 7 and 8.

To sum up, the next time you write code to solve some problem, try to think about what property your algorithm keeps throughout its execution. And, if possible, try to develop a (tail) recursive version of it, so that you can prove that it works with mucho more elegance and simplicity. The key to understand how an algorithm changes things is to observe what it does not change (more on this on a future post).

Nested types and function composition

In Scala and other typed functional languages – notably Haskell – monads are structures that allow the programmer to take a sequence of computations, each defined for a certain context, and chain them together in order to produce a single result at the end. For example, suppose you have to use an API that returns values inside the context of future executions. Something like this:

def getAddress(user: User): Future[Address]
def getGeolocation(address: Address): Future[LatLong]
def getCurrentWeather(coordinates: LatLong): Future[WeatherDetails]    

Given a certain user, we would like to use this API to find out the current weather conditions at the location where he lives. So we can write a method like:

def usersWeather(user: User): Future[WeatherDetails] = {
  for {
    address <- getAddress(user)
    coordinates <- getGeolocation(address)
    weather <- getCurrentWeather(coordinates)
  } yield weather
}

Quite straightforward. In each generator, the value “inside” the Future is supplied as input to the next method in the sequence. Now suppose you were given a more complex API, in which the computations are wrapped inside a context of future execution which may or may not contain a value. In other words, all methods return Future[Option[T]], for some type T. So, the methods above would become:

def getAddress(user: User): Future[Option[Address]]
def getGeolocation(address: Address): Future[Option[LatLong]]
def getCurrentWeather(coordinates: LatLong): Future[Option[WeatherDetails]]

In this case, it is not possible use our good old friend flatMap anymore. At least not directly, as in the first case. You cannot supply the value of each Future to the next function in the sequence because none of these methods accept an Option[T] as input. One possible solution to this problem is to convert these functions to other functions, lifting only their domain from A to Option[A]:

def convert[A, B](f: A => Future[Option[B]]): Option[A] => Future[Option[B]] = {
  maybeA => maybeA map f getOrElse Future(None)
}

which would allow us to put the functions back in shape to be used in a for comprehension:

for {
  address <- convert(getAddress)(user)
  coordinates <- convert(getGeolocation)(address)
  weather <- convert(getCurrentWeather)(coordinates)
} yield weather

But we are looking for a more elegant and more general solution, which could work for a whole family of types. If, instead of Option, for example, we had List or Either, we would need customized versions of convert for each of these types. With the aid of typeclasses, however, we can write a single function that allow us to compose these kinds of functions in a simple and concise way.

A word about typeclasses

In the beginning of this article, I talked about how monads allow the programmer to chain a sequence of computations inside a context. But what we actually did was using a for comprehension, which is a syntax sugar for classes that implement the methods map, flatMap, filter and foreach.

Scalaz, on the other hand, is an awesome library, written in Scala, that provides a broad array of bona fide typeclasses. In particular, Scalaz provides a Monad[F[_]] typeclass, which is going to be the most important piece in our solution to composition of functions with nested types.

So, let’s get to it. First of all, we need to define an instance of scalaz.Monad for the Future type:

val futureMonad = new Monad[Future] {
  override def point[A](a: => A): Future[A] = Future(a)

  override def bind[A, B](fa: Future[A])(f: A => Future[B]): Future[B] =
    fa.flatMap(f)
}

Strictly speaking, this implementation is not a monad, since it violates the law of left identity, which states that futureMonad.point(a) bind f should be equal to f(a). While in most cases, this equality relation holds, there is a special case for which it is false. For the sake of robustness, the designers of Scala chose to implement Future.flatMap so that it does not throw exceptions, even if the function that was passed as a parameter to it does. So, in the presence of exceptions, this law is broken. But for most practical purposes, this is not important and we can still think of it as a monad.

The Option type, on the other hand, is way less controversial. Scalaz defines monad instances for it, which can be readily made available to an application by an import. So we are not going to spend time talking about this particular monad. Option is also an instance of the Traverse typeclass. In the Haskell documentation (in which most of scalaz was inspired), Traverse is defined as a “class of data structures that can be traversed from left to right, performing an action on each element.” The importance of this will become clear in a bit.

Composing the functions

So, with this typeclass arsennal, we can start solving our composition problem. But first, let’s state clearly what exactly is the problem and what we aim to achieve. Given two functions, f: A => Future[Option[B]] and g: B => Future[Option[C]], we would like to define a higher-order function that takes f and g and produces a new function of type A => Future[Option[C]]. More formally – and more generically – we need a function composeN of the following type:

composeN[A, B, C, F[_], G[_]](f: A => F[G[B]], g: B => F[G[C]]): A => F[G[C]]

Taking advantage of the richness of Scala’s type system, let’s start by reasoning about the types and what transformations we would need. Observe that both f and composeN are functions from A to F (of something). This sounds like monad binding. So, let’s assume we have an implicit value of type Monad[F] in scope. The implementation of the method would have the form of:

a => fMonad.bind(f(a)) {
  gb => ...
}

The value gb above is of type G[B]. Assuming that we also have an implicit Monad[G] in scope, we could map over gb with the function g:

val gfgc: G[F[G[C]]] = gMonad.map(gb)(g)

which results in a value of type G[F[G[C]]]. Note how the type G appears twice in the type declaration. That’s one time too many and we need to get rid of this extra G. But before that, let’s make one further assumption: that there is an implicit Traverse[G] in scope. This would allow us to use the sequence function of monads to swap the outermost G and the F:

val fggc: F[G[G[C]]] = fMonad.sequence(gfgc)

The last transformation is the one that will take us where we want:

val fgc: F[G[C]] = fMonad.map(fggc)(ggc => gMonad.bind(ggc)(identity))

So, the complete definition of the function is as follows:

def composeN[A, B, C, F[_], G[_]](f: A => F[G[B]], g: B => F[G[C]]): A => F[G[C]] =
  a => fMonad.bind(f(a)) {
    gb => fMonad.map(fMonad.sequence(gMonad.map(gb)(g))) {
      ggc => gMonad.bind(ggc)(identity)
    }
  }
}

Finally, in order to provide a syntax sugar and make the code look cleaner, we can wrap the first function inside a class (I called it ComposeNested, but I’m sure there is a much better name for this) and expose a method to be used as infix operator:

class ComposeNested[A, B, F[_], G[_]](f: A => F[G[B]])
          (implicit fMonad: Monad[F], gMonad: Monad[G], gTraverse: Traverse[G]) {

  def ->>[C](g: B => F[G[C]]): A => F[G[C]] = composeN(f, g)

  private def composeN[C](f: A => F[G[B]], g: B => F[G[C]]): A => F[G[C]] = ...

Back to our original example: how can we compose those three functions, chaining the computation through them in a specified order, while preserving their semantics while reducing (or rather hiding) the code noise? Let’s use the operator ->> we just defined:

object Demo extends App {
  val usersWeather = getAddress ->> getGeolocation ->> getCurrentWeather

  val futureWeather = usersWeather(User("John Doe"))

  futureWeather onFailure {
    case e: Exception => 
      println(s"Computation failed with the message: ${e.getMessage}")
  }

  futureWeather onSuccess {
    case result => println(s"Computation result: $result")
  }

  Thread.sleep(500)
}

The complete implementation can be found in my github repository.