Higher Order Functions (HOFs)

  • Functional languages treat functions as first-class values. i.e. like any other value, a function can be passed as a parameter and returned as a result.

  • This provides a flexible way to compose programs.

  • Functions that take other functions as parameters or return functions as results are called higher-order functions.

In [1]:
def sumInts(a: Int, b: Int): Int = 
    if (a > b) 0 else a + sumInts(a + 1, b)

def cube(x: Int): Int = 
    x * x * x

def fact(x: Int): Int = 
    if (x == 0) 1 else x * fact(x-1)

def sumCubes(a: Int, b: Int): Int =
    if (a > b) 0 else cube(a) + sumCubes(a + 1, b)

def sumFactorials(a: Int, b: Int): Int =
    if (a > b) 0 else fact(a) + sumFactorials(a + 1, b)

sumInts(3,6)
sumCubes(3,6)
sumFactorials(3,6)
Out[1]:
defined function sumInts
defined function cube
defined function fact
defined function sumCubes
defined function sumFactorials
res0_5: Int = 18
res0_6: Int = 432
res0_7: Int = 870

Can we factor out the function and reduce all of these to a single function?

In [2]:
def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a+1, b)

def sumInts(a: Int, b: Int): Int = 
    sum(id, a, b)

def sumCubes(a: Int, b: Int): Int = 
    sum(cube, a, b)

def sumFactorials(a: Int, b: Int): Int = 
    sum(fact, a, b)

// where 
// (lambda x x)
def id(x: Int): Int = 
    x

def cube(x: Int): Int = 
    x * x * x

def fact(x: Int): Int = 
    if (x == 0) 1 else x * fact(x - 1)

sumInts(3,6)
sumCubes(3,6)
sumFactorials(3,6)
Out[2]:
defined function sum
defined function sumInts
defined function sumCubes
defined function sumFactorials
defined function id
defined function cube
defined function fact
res1_7: Int = 18
res1_8: Int = 432
res1_9: Int = 870

Type A => B is the type of a function that takes an argument of type A and returns a result of type B

So, Int => Int is the type of functions that map integer to integers

Anonymous Functions

  • Passing functions as parameters leads to the creation of many small functions.

  • It is tedious to have to define (using def) and name these functions.

  • Can we not have function literals just like String literals?

e.g.

def str = "abc";
println(str);

// vs

println("abc")
  • Anonymous functions are basically function literals
In [3]:
// Examples of anonymous functions

(x: Int) => x * x * x

(x: Int, y: Int) => x + y

// parameters to the left of => and body of function to the right of =>

def sumInts(a: Int, b: Int) = 
    sum(x => x, a, b)

def sumCubes(a: Int, b: Int) = 
    sum(x => x * x * x, a, b)

sumInts(3,6)
sumCubes(3,6)
Out[3]:
res2_0: Int => Int = ammonite.$sess.cmd2$Helper$$Lambda$1855/0x00000008014a4840@16b78a8a
res2_1: (Int, Int) => Int = ammonite.$sess.cmd2$Helper$$Lambda$1856/0x00000008014a5040@36fc630b
defined function sumInts
defined function sumCubes
res2_4: Int = 18
res2_5: Int = 432

So, in general the following anonymous function

(x1: T1, x2: T2, , xn: Tn) => E

can be expressed using def as follows:

{ 
    def f(x1: T1, x2: T2, , xn: Tn) = 
      E; 
    f 
}

Products and Factorials

In [7]:
def sum(a: Int, b: Int): Int = 
    if (a > b) 0 else a + sum(a+1,b)

def product(a: Int, b: Int): Int = 
    if (a > b) 1 else a * product(a+1,b)

def factorial(n: Int): Int = 
    product(1,n)

object SumProduct {
    def operate(f: (Int, Int)=>Int, ident: Int, a: Int, b: Int): Int =
        if (a > b) ident else f(a, operate(f, ident, a+1,b))

    def sum(a: Int, b: Int): Int =
        operate((x, y)=>x+y, 0, a, b)

    def product(a: Int, b: Int): Int =
        operate((x, y)=>x*y, 1, a, b)

    def main(args: Array[String]) {
        println(sum(1, 6))
        println(product(1, 6))
    }
}

SumProduct.main(Array())
21
720
Out[7]:
defined function sum
defined function product
defined function factorial
defined object SumProduct

Currying

Motivation:

Recall the following definitions:

def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a+1, b)

def sumInts(a: Int, b: Int): Int = 
    sum(x => x, a, b)

def sumCubes(a: Int, b: Int): Int = 
    sum(x => x*x*x, a, b)

def sumFactorials(a: Int, b: Int): Int = 
    sum(fact, a, b)
  • parameters a and b get passed on to sum() without any modifications. Can we get rid of these parameters?
def sum(f: Int => Int): (Int, Int) => Int = {
  def sumF(a: Int, b: Int): Int = 
      if (a > b) 0 else f(a) + sumF(a+1,b)
  sumF
}

sum is a function that returns another function, sumF.

sumF applies the given function parameter f and sums the results.

Currying Continued

With the new definition of sum, we can define

def sumInts = sum(x => x)

def sumCubes = sum(x => x*x*x)

def sumFactorials = sum(fact)

and use them as follows:

sumCubes(1,10) + sumFactorials(1,5)

We can even avoid the middlemen sumInt, sumCubes etc.

sum(cube)(1,10)

sum(cube) returns the sum of cubes function and this function is next applied to arguments (1,10).

Function applications associate to the left:

sum(cube)(1,10) = (sum(cube))(1,10)

Currying Continued

Special Syntax in Scala (the following is equivalent to the nested sumF):

def sum(f: Int => Int)(a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f)(a+1,b)

In general

def f(args1)(argsn) = 
    E

where n > 1 is equivalent to

def f(args1)(argsn-1) = { 
    def g(argsn) = E; 
    g
}

where g is a new function symbol or even shorter:

def f(args1)(argsn-1) = 
    (argsn => E)

Repeating this n times, we get

def f = 
    (args1 => (args2 =>  (argsn => E)))

This process is referred to as “Currying”

Determine the type

Given

def sum(f: Int => Int)(a: Int, b: Int): Int =
    ...

what is the type of sum?

Answer:

(Int => Int) => (Int, Int) => Int

Since function types associate to the right, this can be rewritten as

(Int => Int) => ((Int, Int) => Int)

Scala program to compute Fixpoints

A number x is called a fixed point of a function f if

f(x) = x

Very useful concept in computer science! For some functions, we can locate the fixed point by starting with an initial estimate, and then applying f in a repetitive manner:

x, f(x), f(f(x)), f(f(f(x))), …

until the value does not vary any more (or the change is sufficiently small)

In [5]:
val tolerance = 0.0001

def isCloseEnough(x: Double, y: Double) = 
    scala.math.abs((x-y)/x) < tolerance

def fixpoint(f: Double => Double)(firstGuess: Double): Double = {
    def iterate(guess: Double): Double = {
        val next = f(guess)
        //println(next)
        if (isCloseEnough(guess, next)) next else iterate(next)
    }
    iterate(firstGuess)
}
Out[5]:
tolerance: Double = 1.0E-4
defined function isCloseEnough
defined function fixpoint

sqrt function as a fixpoint

sqrt(x) can be expressed in terms of a fixpoint as follows:

sqrt(x)

= the number y such that y * y = x

= the number y such that y = x/y

= fixpoint of function (y => x/y)

def sqrt(x: Double) = 
    fixpoint(y => x/y)(1.0)

So, sqrt(2) (i.e. x = 2) would be computed as the following sequence of guesses, y:

1.0, 2.0, 1.0, ...

There is a problem! goes on infinitely.

To fix this, use the function (y => (y + x/y)/2) which takes the average of guess and next guess.

def sqrt(x: Double) = 
    fixpoint(y => (y + x/y)/2)(1.0)
In [8]:
def sqrt(x: Double) = 
    fixpoint(y => (y + x/y)/2)(1.0)

sqrt(5)
Out[8]:
defined function sqrt
res7_1: Double = 2.236067977499978

The iterative algorithm converges to a solution by averaging successive values. This technique of stabilizing by averaging can be generalized into a function!

def averageDamp(f: Double => Double)(x: Double) = 
    (x + f(x)) / 2

def sqrt(x: Double) = 
    fixpoint(averageDamp(y => x/y))(1.0)

This expresses the algorithm precisely!

In [7]:
def averageDamp(f: Double => Double)(x: Double) = 
    (x + f(x)) / 2

def sqrt(x: Double) = 
    fixpoint(averageDamp(y => x/y))(1.0)

sqrt(2)
Out[7]:
defined function averageDamp
defined function sqrt
res6_2: Double = 1.4142135623746899