(* Higher-order functions -- Part 2 Returning functions as results *) (* curry : (('a * 'b) -> 'c) -> ('a -> 'b -> 'c) This is equivalent to (('a * 'b) -> 'c) -> 'a -> 'b -> 'c Note : Arrows are right-associative. *) fun curry f = (fn x => fn y => f (x,y)) (* uncurry: ('a -> 'b -> 'c) -> (('a * 'b) -> 'c) *) fun uncurry f = (fn (y,x) => f y x) (* swap : ('a * 'b -> 'c) -> ('b * 'a -> 'c) *) fun swap f = (fn (x,y) => f (y,x)) (* ---------------------------------------------------------- *) (* repeat (f, n) = f' s.t. f^n = f' val repeated = fn : ('a -> 'a) * int -> 'a -> 'a *) fun repeated(f,n:int) = if (n = 0) then fn x => x else fn x => (repeated(f,n-1)(f x)) (* Think about how you can write a function repeat which when given n and f *generates* a function (fn x => f (f(...(f x))). *) (* ---------------------------------------------------------- *) (* derivative : df/dx = lim(e -> 0) (f(x+e) - f(x))/e approximately computing derivatives: df/dx = (f(x + e) - f(x))/e where e is small deriv: (real -> real) * real -> (real -> real) *) fun deriv (f, dx) = (fn x => (f (x + dx) - f(x)) / dx) (* -------------------------------------------------------------------- *) (* Partial evaluation, simple staged computation example, demonstrating efficiency gain: *) (* plus : int -> int -> int plusSq = fn x => fn y => x + y*y *) fun plusSq x y = x * x + y * y (* plus3 : int -> int *) val plus3 = (plusSq 3) (* Note the resulting function plus3 (y) = 3 * 3 + y * y SML will not evaluate 3*3 to 9! and produce plus3(y) = 9 + y * y The function closure shields the function body, and the body can only be evaluated when we pass an argument for y. *) (* val horriblecomputation : int -> int *) fun horriblecomputation(x:int):int = let fun ackermann(0:int, n:int):int = n+1 | ackermann(m, 0) = ackermann(m-1, 1) | ackermann(m, n) = ackermann(m-1, ackermann(m, n-1)) val y = Int.abs(x) mod 3 + 2 fun count(0) = ackermann(y, 4) | count(n) = count(n-1)+0*ackermann(y,4) val large = 1000 in ackermann(y, 1)*ackermann(y, 2)*ackermann(y, 3)*count(large) end; (* Unstaged uncurried version: *) (* val f1 : int * int -> int *) fun f1 (x:int, y:int) : int = let val z = horriblecomputation(x) in z + y end; (* The horrible computation is performed each time: *) val r1 = map (fn x => f1 (10, x)) [5, 2, 18]; (* Unstaged curried version: *) (* val f2 : int -> int -> int *) fun f2 (x:int) (y:int) : int = let val z = horriblecomputation(x) in z + y end; (* val f2' : int -> int *) val f2' = f2 10; (* NOTE : 10 is only substituted into the body for x yielding fn y => let val z = horriblecomputation(10) in z + y end; The closue / function shields its evaluation! *) (* Hence, the horrible computation is again performed each time: *) val r2 = map f2' [5, 2, 18, 22, 32]; (* Staged curried version: Note how we now create the function (fn y => ...) *after* the horrible computation. *) (* val f3 : int -> int -> int *) fun f3 (x:int) : int -> int = let val z = horriblecomputation(x) (* HERE COMPUTATION HAPPENS ITS RESULT WILL BE SUBSTITUTED IN THE BODY OF THE LET-EXPRESSION ! *) in (fn y => z + y) end; (* The horrible computation is performed once, during the declaration of f3': *) (* val f3' : int -> int *) val f3' = f3 10; val r3 = map f3' [5, 2, 18, 22, 32]; (* NOTE: f2 and f3 have exactly the same type and describe the same functional relationship between domain and range values, but their efficiencies are vastly different. *) fun genInputs a b = let fun gen b acc = if a = b then acc else gen (b-1) (b::acc) in gen b [] end; (* Observe the difference in runtime .... *) (* val r100_slow = map f2' (genInputs 1 100); val r100_fast = map f3' (genInputs 1 100); val r500_fast = map f3' (genInputs 1 500); val r1000_fast = map f3' (genInputs 1 1000); (* If you have time, you can also try to run *) val r500_slow = map f2' (genInputs 1 500); val r1000_slow = map f2' (genInputs 1 1000); *)