What is a for comprehension?

Just like imperative programming languages, Scala provides a for-loop, but the similarities end there. Primarily it is used for iterating collections, and to extract items out of them. The benefit of for is that its clear and concise.

Taking the example of creating a list of all the indices of a matrix, we will do something like this in Scala:

1
2
3
4
5
6
def indices(row: Int, col: Int) =
    (0 until row).map(i =>
      (0 until col).map(j => (i, j)))

> indices(3,3)
res0: ... = Vector(Vector((0,0), (0,1), (0,2)), Vector((1,0), (1,1), (1,2)), Vector((2,0), (2,1), (2,2)))

With for comprehensions, we can do something like:

1
2
def indices(row: Int, col: Int) =
    for (i <- (0 until row); j <- (0 until col)) yield (i, j)

This is both clearer and abstracts away the details of map=s, which at this point seems at a lower level than =for.

Syntactic sugar

What we saw above is essentially what happens when for is used, but in the opposite direction, i.e. the for is changed to the map=s at compile time. Therefore, =for is not a new construct in the language, but just a transformation to map and other functions.

The conversion is done recursively, so that each for is converted to a simpler expression (which can be yet another for) until no more for expressions remain.

Syntactically, the for comprehensions are of 3 types (although both 1 and 2 are weak cases of 3):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
// Type 1: Iterate over a collection
for (i <- collection) yield i

// Type 2: Iterate over a collection and filtering out elements
for (i <- collection1 if somePredicate(i)) yield i

// Type 3: Iterate over some collections, and filter the values using
// an `if` guard.
for (i <- collection1; j <- collection2 if somePredicate(i, j)) yield i.op(j)

// The iterations can be spread over multiple lines by using braces instead
// of parenthesis. We also don't need the semicolons
for {
    i <- collection1
    j <- collection2
    if somePredicate(i, j)
    } yield i.op(j)

These are converted to a combinations of map, flatMap and withFilter during compile time. Let’s take each type and see how it is converted:

Type 1: A single collection:

Looking hard enough, we see that iterating over a collection is just another way of =map=ping over a collection. And the simplest conversion is essentially that:

1
2
3
4
5
for (i <- collection) yield i * 2

// converts to:

collection.map(i => i * 2)

Type 2: Filtering over a collection:

If a for comprehension contains an if guard, then that is converted to a simpler expression using a withFilter method.

1
2
3
4
5
for (i <- collection1 if somePredicate(i)) yield i

// converts to:

for (i <- collection1.withFilter(j => somePredicate(j))) yield i

The benefit of using a withFilter is that it’s:

  1. lazy and
  2. only creates the intermediate collection once.

In collection1.filter(pred).map(transform) both map and filter will create a copy of the collection for their calls. But collection.withFilter(pred).map(transform) will just work without creating a new collection for each call, and will be evaluated only when needed.

Type 3: More than one collection:

If there is more than one collection, then the outer collection is converted to a flatMap, and its elements are passed over to inner collection which yield the values from both the collections.

1
2
3
4
5
6
7
8
for {
    i <- collection1
    j <- collection2
    } yield i.op(j)

// converts to:

collection1.flatMap(i => for (j <- collection2) yield i.op(j))

We use flatMap because we do not want an unnecessary nesting of the collection. In the following code, if we use map, then we will get a nested collection instead of a linear/1-dimensional collection.

Conversion:

So if you noted above, each type gets converted to a simpler type.

Let’s take an example to convert a for into the more basic types:

1
2
3
4
5
for {
    i <- (0 until row)
    j <- (0 until col)
    if (i != j)
    } yield (i, j)

Since this holds 2 collection, the first step is to convert the 0 until row range into a flatMap by Type3 expansion:

1
(0 until row).flatMap(i => for (j <- 0 until col) if (i != j) yield (i, j))

The second step is to convert the if guard into a withFilter expression using Type2 expansion. Notice that withFilter is applied on the collection which had the guard, and not on any other collection:

1
2
3
(0 until row).flatMap(i =>
    for (j <- (0 until col).withFilter(x => i != x))
        yield (i, j))

Now we have only one for comprehension remaining, which is of the form Type1 – just a single collection, so we will change that to a map.

1
2
(0 until row).flatMap(i =>
    (0 until col).withFilter(x => i != x).map(j => (i, j)))

So it is this way that Scala allows for comprehension using just the primitive functions map, flatMap and withFilter.

One more thing which is clear from here is that for comprehension by itself is not lazy, because map, flatMap, are not lazy as well. But if the underlying collection is lazy, then for becomes lazy as well.

In another post, I’ll discuss how any class which defines these 3 methods can provide iteration based on the for syntax.