Higher Order Functions (HOFs)

  • Functional languages treat functions as first-class values. i.e. like any other value, a function can be passed as a parameter and returned as a result.

  • This provides a flexible way to compose programs.

  • Functions that take other functions as parameters or return functions as results are called higher-order functions.

In [1]:
def sumInts(a: Int, b: Int): Int = 
    if (a > b) 0 else a + sumInts(a + 1, b)

def cube(x: Int): Int = 
    x * x * x

def fact(x: Int): Int = 
    if (x == 0) 1 else x * fact(x-1)

def sumCubes(a: Int, b: Int): Int =
    if (a > b) 0 else cube(a) + sumCubes(a + 1, b)

def sumFactorials(a: Int, b: Int): Int =
    if (a > b) 0 else fact(a) + sumFactorials(a + 1, b)

defined function sumInts
defined function cube
defined function fact
defined function sumCubes
defined function sumFactorials
res0_5: Int = 18
res0_6: Int = 432
res0_7: Int = 870

Can we factor out the function and reduce all of these to a single function?

In [2]:
def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a+1, b)

def sumInts(a: Int, b: Int): Int = 
    sum(id, a, b)

def sumCubes(a: Int, b: Int): Int = 
    sum(cube, a, b)

def sumFactorials(a: Int, b: Int): Int = 
    sum(fact, a, b)

// where 
// (lambda x x)
def id(x: Int): Int = 

def cube(x: Int): Int = 
    x * x * x

def fact(x: Int): Int = 
    if (x == 0) 1 else x * fact(x - 1)

defined function sum
defined function sumInts
defined function sumCubes
defined function sumFactorials
defined function id
defined function cube
defined function fact
res1_7: Int = 18
res1_8: Int = 432
res1_9: Int = 870

Type A => B is the type of a function that takes an argument of type A and returns a result of type B

So, Int => Int is the type of functions that map integer to integers

Anonymous Functions

  • Passing functions as parameters leads to the creation of many small functions.

  • It is tedious to have to define (using def) and name these functions.

  • Can we not have function literals just like String literals?


def str = "abc";

// vs

  • Anonymous functions are basically function literals
In [3]:
// Examples of anonymous functions

(x: Int) => x * x * x

(x: Int, y: Int) => x + y

// parameters to the left of => and body of function to the right of =>

def sumInts(a: Int, b: Int) = 
    sum(x => x, a, b)

def sumCubes(a: Int, b: Int) = 
    sum(x => x * x * x, a, b)

res2_0: Int => Int = ammonite.$sess.cmd2$Helper$$Lambda$1855/0x00000008014a4840@16b78a8a
res2_1: (Int, Int) => Int = ammonite.$sess.cmd2$Helper$$Lambda$1856/0x00000008014a5040@36fc630b
defined function sumInts
defined function sumCubes
res2_4: Int = 18
res2_5: Int = 432

So, in general the following anonymous function

(x1: T1, x2: T2, , xn: Tn) => E

can be expressed using def as follows:

    def f(x1: T1, x2: T2, , xn: Tn) = 

Products and Factorials

In [7]:
def sum(a: Int, b: Int): Int = 
    if (a > b) 0 else a + sum(a+1,b)

def product(a: Int, b: Int): Int = 
    if (a > b) 1 else a * product(a+1,b)

def factorial(n: Int): Int = 

object SumProduct {
    def operate(f: (Int, Int)=>Int, ident: Int, a: Int, b: Int): Int =
        if (a > b) ident else f(a, operate(f, ident, a+1,b))

    def sum(a: Int, b: Int): Int =
        operate((x, y)=>x+y, 0, a, b)

    def product(a: Int, b: Int): Int =
        operate((x, y)=>x*y, 1, a, b)

    def main(args: Array[String]) {
        println(sum(1, 6))
        println(product(1, 6))

defined function sum
defined function product
defined function factorial
defined object SumProduct



Recall the following definitions:

def sum(f: Int => Int, a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f, a+1, b)

def sumInts(a: Int, b: Int): Int = 
    sum(x => x, a, b)

def sumCubes(a: Int, b: Int): Int = 
    sum(x => x*x*x, a, b)

def sumFactorials(a: Int, b: Int): Int = 
    sum(fact, a, b)
  • parameters a and b get passed on to sum() without any modifications. Can we get rid of these parameters?
def sum(f: Int => Int): (Int, Int) => Int = {
  def sumF(a: Int, b: Int): Int = 
      if (a > b) 0 else f(a) + sumF(a+1,b)

sum is a function that returns another function, sumF.

sumF applies the given function parameter f and sums the results.

Currying Continued

With the new definition of sum, we can define

def sumInts = sum(x => x)

def sumCubes = sum(x => x*x*x)

def sumFactorials = sum(fact)

and use them as follows:

sumCubes(1,10) + sumFactorials(1,5)

We can even avoid the middlemen sumInt, sumCubes etc.


sum(cube) returns the sum of cubes function and this function is next applied to arguments (1,10).

Function applications associate to the left:

sum(cube)(1,10) = (sum(cube))(1,10)

Currying Continued

Special Syntax in Scala (the following is equivalent to the nested sumF):

def sum(f: Int => Int)(a: Int, b: Int): Int =
    if (a > b) 0 else f(a) + sum(f)(a+1,b)

In general

def f(args1)(argsn) = 

where n > 1 is equivalent to

def f(args1)(argsn-1) = { 
    def g(argsn) = E; 

where g is a new function symbol or even shorter:

def f(args1)(argsn-1) = 
    (argsn => E)

Repeating this n times, we get

def f = 
    (args1 => (args2 =>  (argsn => E)))

This process is referred to as “Currying”

Determine the type


def sum(f: Int => Int)(a: Int, b: Int): Int =

what is the type of sum?


(Int => Int) => (Int, Int) => Int

Since function types associate to the right, this can be rewritten as

(Int => Int) => ((Int, Int) => Int)

Scala program to compute Fixpoints

A number x is called a fixed point of a function f if

f(x) = x

Very useful concept in computer science! For some functions, we can locate the fixed point by starting with an initial estimate, and then applying f in a repetitive manner:

x, f(x), f(f(x)), f(f(f(x))), …

until the value does not vary any more (or the change is sufficiently small)

In [5]:
val tolerance = 0.0001

def isCloseEnough(x: Double, y: Double) = 
    scala.math.abs((x-y)/x) < tolerance

def fixpoint(f: Double => Double)(firstGuess: Double): Double = {
    def iterate(guess: Double): Double = {
        val next = f(guess)
        if (isCloseEnough(guess, next)) next else iterate(next)
tolerance: Double = 1.0E-4
defined function isCloseEnough
defined function fixpoint

sqrt function as a fixpoint

sqrt(x) can be expressed in terms of a fixpoint as follows:


= the number y such that y * y = x

= the number y such that y = x/y

= fixpoint of function (y => x/y)

def sqrt(x: Double) = 
    fixpoint(y => x/y)(1.0)

So, sqrt(2) (i.e. x = 2) would be computed as the following sequence of guesses, y:

1.0, 2.0, 1.0, ...

There is a problem! goes on infinitely.

To fix this, use the function (y => (y + x/y)/2) which takes the average of guess and next guess.

def sqrt(x: Double) = 
    fixpoint(y => (y + x/y)/2)(1.0)
In [8]:
def sqrt(x: Double) = 
    fixpoint(y => (y + x/y)/2)(1.0)

defined function sqrt
res7_1: Double = 2.236067977499978

The iterative algorithm converges to a solution by averaging successive values. This technique of stabilizing by averaging can be generalized into a function!

def averageDamp(f: Double => Double)(x: Double) = 
    (x + f(x)) / 2

def sqrt(x: Double) = 
    fixpoint(averageDamp(y => x/y))(1.0)

This expresses the algorithm precisely!

In [7]:
def averageDamp(f: Double => Double)(x: Double) = 
    (x + f(x)) / 2

def sqrt(x: Double) = 
    fixpoint(averageDamp(y => x/y))(1.0)

defined function averageDamp
defined function sqrt
res6_2: Double = 1.4142135623746899