Skip to content

Commit

Permalink
chapter 2: line function
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-zhang committed Apr 11, 2023
1 parent 7d2030c commit 5279475
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
11 changes: 11 additions & 0 deletions lib/line.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
(* chapter 1 *)
let line x params =
match params with
| [a; b] ->
(a *. x) +. b
| _ ->
failwith "only accept #params = 2"

let line_xs = Tensor.floats [2.0; 1.0; 4.0; 3.0]

let line_ys = Tensor.floats [1.8; 1.2; 4.2; 3.3]
10 changes: 6 additions & 4 deletions lib/tensor.ml
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,27 @@ let%test "plus 3" =
let expected = tensor [floats [10.; 12.; 9.]; floats [13.; 13.; 8.]] in
plus t1 t4 = expected

let rec unary_op r op t =
let rec reduce r op t =
let r1 = rank t in
assert (r <= r1) ;
if r1 = r then op t
else
match t with
| Tensor t1 ->
Tensor (Array.map (unary_op r op) t1)
Tensor (Array.map (reduce r op) t1)
| _ ->
failwith "must be of tensor"

let sum =
unary_op 1 (function
reduce 1 (function
| Scalar _ ->
failwith "must be tensor"
| Tensor x ->
Scalar (Array.fold_left ( +. ) 0. (Array.map get_value x)) )

let%test "sum 1" =
let%test "sum 1" = sum (floats [1.0; 2.0; 3.0]) |> get_value = 6.0

let%test "sum 2" =
let t1 = floats [1.; 2.] in
let t2 = floats [3.; 4.] in
let t3 = floats [5.; 6.] in
Expand Down

0 comments on commit 5279475

Please sign in to comment.