13 February 2009

An exercise with arrows in Scala

I'm still a bit fascinated by the so-called "computing models" that you get when programming with monads or arrows.

That fascination comes from that fact that the things I'm using everyday can be abstracted to useful entities. Of course, it is pretty well-known to programmers that functions can be described formally with parameters, types, return values and so on. Given the definition of functions you can even compose them and create a useful function out of 2 other ones.

But I was pretty amazed, when reading the paper from John Hughes on Arrows, to see that there are lots of properties and combinators that you can define for your basic computing blocks. And that these combinators, defined at an abstract level, can then be used to structure the concrete and dirty work of your everyday functions.

The Challenge

3 few weeks ago, I was intrigued by the small challenge brought up by Debasish Gosh, about translating one Haskell snippet to Scala. The function clusterBy, in Haskell, is defined by:

clusterBy :: Ord b => (a -> b) -> [a] -> [[a]]
clusterBy f = M.elems . reverse . M.fromListWith (++) . map (f &&& return)

Basically, it takes a list of elements and returns a list of lists, where the elements are grouped according to their result via a function. If 2 elements have the same result with f, then they end up in the same list. For example:clusterBy length ["ant", "part", "all"] => [["all", "ant"], ["part"]]That example looked interesting to me for 2 reasons:
  • that's not the kind of thing that you program in one line with Java
  • it uses Arrows and I remembered that some code for Arrows in Scala was available
So I set out to rewrite the clusterBy function in Scala, using Arrows.

The solution

Here's my solution (partial here, read more at the end):

def clusterBy[A, B](l: List[A])(f: A => B) =
(listArr(f &&& identity) >>> arr(groupByFirst) >>> arr(second) >>> arr(reverse))(l)

And the clusterBy function in action as a specs expectation:

clusterBy(List("ant", "part", "all"))(_.length) must beEqualTo(List(List("all", "ant"), List("part")))

Let's dissect the clusterBy function to understand what exactly Arrows are bringing to the table. 

"Lifting" a function to an Arrow (what the h...?)

First of all, for our understanding, we can remove some noise:

arr(groupByFirst) is just taking a function, groupByFirst, and "lifting" it to make it an Arrow object which can play nicely with other arrows using the >>> operator.

There are implicit definitions in the Arrows object transforming functions to Arrows but unfortunately type inference didn't seem to authorize me the removal of arr(...). If I could do it, I would get something like:

def clusterBy[A, B](l: List[A])(f: A => B) =
 (listArr(f &&& identity) >>> groupByFirst >>> second >>> reverse)(l)
>>>, the "sequence" operator

How should we read that expression now? 

The >>> operator reads very easily. It says:

f1 >>> f2 => take the result of f1 and give it to f2.

It is the equivalent of the composition operator, except that it reads the other way around, from left to right. With composition we would have "f2 . f1". The >>> operator (arguably) follows the natural reading flow (well, in English at least,...).

&&&, the "branching" operator

The next important operator is &&&:

f1 &&& f2 => A function with takes an input and returns a pair with the results of both f1 and f2.

So (f &&& identity) in our example simply takes an element and creates a pair containing the result of the application of f and the element itself:

(f(x), x)

But what we want to do is to apply this function (an Arrow more precisely) to the list of elements which is our input. That's exactly what listArr does! It creates a Arrow taking a list (read "enhanced function") from a function taking a single element.

Reading the whole expression in detail

With all that knowledge, let's read again the definition now with our example as an illustration. 

The input data is:

List("ant", "part", "all")

"take the list of elements and return (f(x), x) for each element":
listArr(f &&& identity)
=> List((3, "ant"), (4, "part"), (3, "all"))

"then group by the first element", i.e. create a list of pairs where the first element is f(x) and the second element is a list of all y where f(x) == f(y))
>>> groupByFirst
=> List((4, List("part"), (3, List("ant", "all")))

"then take the second element"
>>> second
=> List(List("part"), List("ant", "all"))
Note that the "second" function defined here just takes the second element of a pair. This means that the >>> operator is smart enough to know that we're operating on lists of pairs and not on a single pair!

"then reverse the list"
>>> reverse
=> List(List("ant", "all"), List("part"))

The hidden stuff under the carpet

To be able to create this nice one-liner in Scala, I had to create special-purpose functions: groupByFirst, reverse, second.

First of all, you can argue that the real meat of the clusterBy function is actually the groupByFirst function:
def groupByFirst[A, B] = new Function1[List[(A, B)], List[(A, List[B])]] {
  def apply(l: List[(A, B)]) =
    l.foldLeft(new HashMap[A, List[B]]: Map[A, List[B]]) { (res: Map[A, List[B]], cur: Pair[A, B]) =>
      val (key, value) = cur
      res.update(key, value :: res.getOrElse(key, Nil))}.toList

And actually if you look at Debasish's blog, there are full Scala solutions that fill the same space as the groupByFirst function. On the other hand, I find that the Arrows notation shows pretty well the process flow between "elementary" functions.

Then you can wonder why I couldn't basically "detach" the reverse operation from the List class, like this maybe?

def reverseList[T](l: List[T]) = l.reverse
def reverse[T] = reserveList _ 

However, this just doesn't work because the type inference seems to be not constrained enough when it comes to sequence the functions with the >>>  operator. The only way I found to have this working was to define a Function1 object:
def reverse[T] = new Function1[List[T], List[T]] {
def apply(l: List[T]) = l.reverse
This last issue is maybe the biggest drawback when using Arrows with Scala. Now I can propose my own challenge to the Scala community. Can you find a better way ;-) ?...