For Comprehensions in Scala
What is a for
comprehension?
#
Just like imperative programming languages, Scala provides a for-loop,
but the similarities end there. Primarily it is used for iterating
collections
, and to extract items out of them. The benefit of for
is
that its clear and concise.
Taking the example of creating a list of all the indices of a matrix, we will do something like this in Scala:
def indices(row: Int, col: Int) =
(0 until row).map(i =>
(0 until col).map(j => (i, j)))
> indices(3,3)
res0: ... = Vector(Vector((0,0), (0,1), (0,2)), Vector((1,0), (1,1), (1,2)), Vector((2,0), (2,1), (2,2)))
With for
comprehensions, we can do something like:
def indices(row: Int, col: Int) =
for (i <- (0 until row); j <- (0 until col)) yield (i, j)
This is both clearer and abstracts away the details of map=s, which at this point seems at a lower level than =for
.
Syntactic sugar #
What we saw above is essentially what happens when for
is used, but in
the opposite direction, i.e. the for
is changed to the map=s at compile time. Therefore, =for
is not a new construct in the language,
but just a transformation to map
and other functions.
The conversion is done recursively, so that each for
is converted to a
simpler expression (which can be yet another for
) until no more for
expressions remain.
Syntactically, the for comprehensions are of 3 types (although both 1 and 2 are weak cases of 3):
// Type 1: Iterate over a collection
for (i <- collection) yield i
// Type 2: Iterate over a collection and filtering out elements
for (i <- collection1 if somePredicate(i)) yield i
// Type 3: Iterate over some collections, and filter the values using
// an `if` guard.
for (i <- collection1; j <- collection2 if somePredicate(i, j)) yield i.op(j)
// The iterations can be spread over multiple lines by using braces instead
// of parenthesis. We also don't need the semicolons
for {
i <- collection1
j <- collection2
if somePredicate(i, j)
} yield i.op(j)
These are converted to a combinations of map
, flatMap
and
withFilter
during compile time. Let’s take each type and see how it is
converted:
Type 1: A single collection: #
Looking hard enough, we see that iterating over a collection is just another way of =map=ping over a collection. And the simplest conversion is essentially that:
for (i <- collection) yield i * 2
// converts to:
collection.map(i => i * 2)
Type 2: Filtering over a collection: #
If a for
comprehension contains an if
guard, then that is converted
to a simpler expression using a withFilter
method.
for (i <- collection1 if somePredicate(i)) yield i
// converts to:
for (i <- collection1.withFilter(j => somePredicate(j))) yield i
The benefit of using a withFilter
is that it’s:
- lazy and
- only creates the intermediate collection once.
In collection1.filter(pred).map(transform)
both map
and filter
will create a copy of the collection for their calls. But
collection.withFilter(pred).map(transform)
will just work without
creating a new collection for each call, and will be evaluated only when
needed.
Type 3: More than one collection: #
If there is more than one collection, then the outer collection is
converted to a flatMap
, and its elements are passed over to inner
collection which yield the values from both the collections.
for {
i <- collection1
j <- collection2
} yield i.op(j)
// converts to:
collection1.flatMap(i => for (j <- collection2) yield i.op(j))
We use flatMap
because we do not want an unnecessary nesting of the
collection. In the following code, if we use map
, then we will get a
nested collection instead of a linear/1-dimensional collection.
Conversion: #
So if you noted above, each type gets converted to a simpler type.
- Type1 removes the
for
comprehension. - Type2 removes the
if
guard. - Type3 removes 1 out of 2 collections.
Let’s take an example to convert a for
into the more basic types:
for {
i <- (0 until row)
j <- (0 until col)
if (i != j)
} yield (i, j)
Since this holds 2 collection, the first step is to convert the
0 until row
range into a flatMap by Type3 expansion:
(0 until row).flatMap(i => for (j <- 0 until col) if (i != j) yield (i, j))
The second step is to convert the if
guard into a withFilter
expression using Type2 expansion. Notice that withFilter
is applied on
the collection which had the guard, and not on any other collection:
(0 until row).flatMap(i =>
for (j <- (0 until col).withFilter(x => i != x))
yield (i, j))
Now we have only one for
comprehension remaining, which is of the form
Type1 – just a single collection, so we will change that to a map
.
(0 until row).flatMap(i =>
(0 until col).withFilter(x => i != x).map(j => (i, j)))
So it is this way that Scala allows for
comprehension using just the
primitive functions map
, flatMap
and withFilter
.
One more thing which is clear from here is that for
comprehension by
itself is not lazy, because map
, flatMap
, are not lazy as well. But
if the underlying collection is lazy, then for
becomes lazy as well.
In another post, I’ll discuss how any class which defines these 3
methods can provide iteration based on the for
syntax.