Skip to content

Commit

Permalink
Make it possible to use None for nullable scalar types.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 25, 2023
1 parent 3ef1035 commit 856f432
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 186 deletions.
9 changes: 7 additions & 2 deletions gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ module Func = struct
| TensorOptions (* Tensor kind and device *)
| Scalar
| ScalarType
| ScalarTypeOption
| Device
| String
| Layout
Expand Down Expand Up @@ -169,7 +170,7 @@ module Func = struct
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device
| "const at::scalar &" | "at::scalar" -> Some Scalar
| "at::scalartype" -> Some ScalarType
| "at::scalartype" -> if is_nullable then Some ScalarTypeOption else Some ScalarType
| "c10::string_view" -> Some String
| "at::layout" -> Some (if is_nullable then LayoutOption else Layout)
| _ -> None
Expand All @@ -195,6 +196,7 @@ module Func = struct
| Tensor -> "tensor"
| TensorOption -> "tensor"
| ScalarType -> "int"
| ScalarTypeOption -> "int"
| Device -> "int"
| Scalar -> "scalar"
| Layout | LayoutOption -> "int8_t"
Expand Down Expand Up @@ -256,6 +258,7 @@ module Func = struct
arg_name
arg_name
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
| ScalarTypeOption -> Printf.sprintf "%s < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(%s))" arg_name arg_name
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
| _ -> arg_name)
|> String.concat ~sep:", "
Expand Down Expand Up @@ -310,7 +313,7 @@ module Func = struct
| Tensor -> single_param "*mut C_tensor"
| TensorOption -> single_param "*mut C_tensor"
| Scalar -> single_param "*mut C_scalar"
| ScalarType -> single_param "c_int"
| ScalarType | ScalarTypeOption-> single_param "c_int"
| Device -> single_param "c_int"
| String -> Printf.sprintf "%s_ptr: *const u8, %s_len: c_int" an an
| IntList | IntListOption ->
Expand Down Expand Up @@ -397,6 +400,7 @@ module Func = struct
| DoubleOption -> "impl Into<Option<f64>>"
| Scalar -> "S"
| ScalarType -> "Kind"
| ScalarTypeOption -> "impl Into<Option<Kind>>"
| Device -> "Device"
in
Printf.sprintf "%s: %s" (rust_name arg.arg_name) rust_arg_type)
Expand Down Expand Up @@ -450,6 +454,7 @@ module Func = struct
| Scalar -> Printf.sprintf "%s.into().c_scalar" name
| Bool -> Printf.sprintf "if %s { 1 } else { 0 }" name
| ScalarType -> Printf.sprintf "%s.c_int()" name
| ScalarTypeOption -> Printf.sprintf "%s.into().map_or(-1, |s| s.c_int())" name
| Device -> Printf.sprintf "%s.c_int()" name
| TensorOptions -> Printf.sprintf "%s.0.c_int(), %s.1.c_int()" name name
| Int64Option -> Printf.sprintf "%s.unwrap_or(0i64), %s.is_none() as i8" name name
Expand Down
Loading

0 comments on commit 856f432

Please sign in to comment.