Fork me on GitHub

Code Snippet:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
let bitlength (x:bigint) =
  let xs  = x.ToByteArray()
  let n   = xs |> Array.length
  let msb = xs.[n-1]
  let rec bitlength' a = function
    | 0uy  -> a
    | msb' -> (a+1,msb' >>> 1) ||> bitlength'
  ((n-1)*8,msb) ||> bitlength'

let split (x:bigint) m =
  let y = x >>> m
  y,(x - (y <<< m))

let karatsuba x y =
  let r = 1 <<< 10
  let leq x y = (x |> bitlength) <= y
  let rec karatsuba' = function
    | (x',y') when (x',r) ||> leq || (y',r) ||> leq -> (x' * y')
    | (x',y') ->
      let n = (x' |> bitlength, y' |> bitlength) ||> max
      let m = n >>> 1
      let h1,l1 = (x',m) ||> split
      let h2,l2 = (y',m) ||> split 
      let z0 = (l1,l2)       |> karatsuba'
      let z1 = (l1+h1,l2+h2) |> karatsuba'
      let z2 = (h1,h2)       |> karatsuba'   
      (z2 <<< (2 * m)) + ((z1 - z0 - z2) <<< m) + z0
  (x,y) |> karatsuba'

let fib n = // tail-recursive with two accs
  let rec fib' a1 a2 = function
    | 0 -> 0I
    | 1 -> a1 + a2
    | i -> fib' a2 (a1 + a2) (i - 1)
  fib' 1I 0I n

let fibfast n =
  let inline inner x y i =
    let a = x * (2I * y - x)
    let b = y * y + x * x
    match i % 2 = 0 with | true -> (a,b) | false -> (b, a+b)
  let rec fibfast' k = function
    | 0 -> k (0I,1I)
    | i -> fibfast' (fun (x,y) -> k((x,y,i) |||> inner)) (i >>> 1)
  (id,n) ||> fibfast' |> fst

let fibfastkarat n =
  let inline inner x y i =
    let a = (x,((2I,y) ||> karatsuba) - x) ||> karatsuba
    let b = ((y,y) ||> karatsuba) + ((x,x) ||> karatsuba)
    match i % 2 = 0 with | true -> (a,b) | false -> (b, a+b)
  let rec fibfastkarat' k = function
    | 0 -> k (0I,1I)
    | i -> fibfastkarat' (fun (x,y) -> k((x,y,i) |||> inner)) (i >>> 1)
  (id,n) ||> fibfastkarat' |> fst

Code output:

> val bitlength : x:bigint -> int
> val split : x:bigint -> m:int32 -> bigint * System.Numerics.BigInteger
> val karatsuba : x:bigint -> y:bigint -> System.Numerics.BigInteger
> val fib : n:int -> System.Numerics.BigInteger
> val fibfast : n:int -> System.Numerics.BigInteger
> val fibfastkarat : n:int -> System.Numerics.BigInteger

Correctness test:

1
2
3
4
5
let correctness =
  ((10. ** 6. |> int |> fib),
   (10. ** 6. |> int |> fibfast),
   (10. ** 6. |> int |> fibfastkarat))
  |> fun (x,y,z) -> x = y && x = z;;

Correctness output:

> val correctness : bool = true

Performance test:

1
2
3
4
5
6
7
8
9
let duration f =
  let t = System.Diagnostics.Stopwatch()
  t.Start()
  let x = f()
  x,t.ElapsedMilliseconds |> float

duration(fun _ -> 10. ** 6. |> int |> fib)          |> snd;;
duration(fun _ -> 10. ** 6. |> int |> fibfast)      |> snd;;
duration(fun _ -> 10. ** 6. |> int |> fibfastkarat) |> snd;;

Performance output:

> val duration : f:(unit -> 'a) -> 'a * float
> val it : float = 198480.0
> val it : float = 4623.0
> val it : float = 1082.0

References:

comments powered by Disqus