Skip to content

Commit

Permalink
fix: defn compiler now works with axon (#10)
Browse files Browse the repository at this point in the history
* fix: defn compiler for axon

* fix: defn compiler

* feat: use new flag

* refactor: used_inputs as a mapset

* chore: update deps

* chore: use github deps
  • Loading branch information
polvalente authored Sep 5, 2024
1 parent dd01dec commit da1cf45
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ nx_iree-*.tar

/priv/iree-compile
/priv/iree-runtime
/priv/lbnx_iree.so
/priv/libnx_iree.so
34 changes: 34 additions & 0 deletions axon.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Mix.install([
{:axon, github: "elixir-nx/axon", branch: "main"},
{:nx_iree, path: "."},
{:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
{:exla, github: "elixir-nx/nx", sparse: "exla", override: true}
], system_env: %{"NX_IREE_PREFER_PRECOMPILED" => false})

NxIREE.list_drivers() |> IO.inspect(label: "drivers")

{:ok, [dev | _]} = NxIREE.list_devices("metal")

flags = ["--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal"]
Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, iree_runtime_options: [device: dev])

model =
Axon.input("x", shape: {nil, 3})
|> Axon.dense(8, activation: :relu)
|> Axon.dense(1, activation: :relu)

Nx.Defn.default_options(compiler: NxIREE.Compiler, iree_compiler_flags: flags, iree_runtime_options: [device: dev])
# Nx.Defn.default_options(compiler: EXLA, iree_compiler_flags: flags, iree_runtime_options: [device: dev])

template = %{"x" => Nx.template({10, 3}, :f32)}

{init_fn, predict_fn} = Axon.build(model, [])
init_params = Nx.Defn.jit_apply(init_fn, [template, Axon.ModelState.new(Axon.ModelState.empty())])

IO.puts("\n\n\n======= BEGIN predict_compiled_fn =======\n\n\n")
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template])
IO.puts("\n\n\n======= END predict_compiled_fn =======\n\n\n")

IO.puts("\n\n\n======= BEGIN predict_compiled_fn CALL =======\n\n\n")
predict_compiled_fn.(init_params, Nx.iota({10, 3}, type: :f32)) |> dbg()
IO.puts("\n\n\n======= END predict_compiled_fn CALL =======\n\n\n")
8 changes: 8 additions & 0 deletions c_src/nx_iree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ DECLARE_NIF(read_buffer_nif) {
return error(env, "invalid num_bytes");
}

std::cout << "num_bytes input: " << num_bytes << std::endl;

if (num_bytes == -1) {
num_bytes = (*input)->size;
}

std::cout << "num_bytes actual: " << num_bytes << std::endl;

ErlNifBinary binary;

if (!enif_alloc_binary(num_bytes, &binary)) {
Expand Down
11 changes: 8 additions & 3 deletions lib/nx_iree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ defmodule NxIREE do
{:ok, tmpfile} = create_temp_file(mlir_module)

compiler_path = Path.join(:code.priv_dir(:nx_iree), "iree-compile")
IO.puts(mlir_module)

try do
{output, 0} =
Expand Down Expand Up @@ -71,8 +70,14 @@ defmodule NxIREE do

input_refs =
Enum.map(inputs, fn
%Nx.Tensor{data: %NxIREE.Tensor{ref: ref}} -> ref
t -> NxIREE.VM.allocate_buffer(t, device_ref)
%Nx.Tensor{data: %NxIREE.Tensor{ref: ref}} ->
ref

fun when is_function(fun, 0) ->
NxIREE.VM.allocate_buffer(fun.(), device_ref)

t ->
NxIREE.VM.allocate_buffer(t, device_ref)
end)

instance_ref = NxIREE.VM.get_instance()
Expand Down
27 changes: 20 additions & 7 deletions lib/nx_iree/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ defmodule NxIREE.Compiler do
@behaviour Nx.Defn.Compiler

@impl true
def __compile__(key, vars, fun, opts) do
output_container = fun.(vars)

def __compile__(_key, vars, fun, opts) do
{iree_compiler_flags, opts} = Keyword.pop(opts, :iree_compiler_flags, nil)
{iree_runtime_options, opts} = Keyword.pop(opts, :iree_runtime_options, [])
{output_mode, opts} = Keyword.pop(opts, :output_mode, nil)
Expand All @@ -26,20 +24,22 @@ defmodule NxIREE.Compiler do
raise "missing :iree_compiler_flags option"
end

mlir_module = EXLA.to_mlir_module(key, vars, opts)
%{mlir_module: mlir_module, output_container: output_container, used_inputs: used_inputs} =
EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :within_defn_compiler, true))

bytecode = NxIREE.compile(mlir_module, iree_compiler_flags)

if output_mode == :bytecode do
throw({:bytecode, %{bytecode: bytecode, output_container: output_container}})
else
fn [inputs] ->
filtered_inputs =
filter_inputs_by_indices(inputs, used_inputs)

{:ok, results} =
NxIREE.call(
bytecode,
Enum.map(inputs, fn f ->
f.()
end),
filtered_inputs,
iree_runtime_options
)

Expand Down Expand Up @@ -68,4 +68,17 @@ defmodule NxIREE.Compiler do

@impl true
defdelegate __to_backend__(opts), to: EXLA.Defn

defp filter_inputs_by_indices(args, inputs) do
filter_by_indices_list(args, 0, Enum.sort(inputs), fn x, _ -> x end)
end

defp filter_by_indices_list([var | vars], i, [i | inputs], callback),
do: [callback.(var, i) | filter_by_indices_list(vars, i + 1, inputs, callback)]

defp filter_by_indices_list([_var | vars], i, inputs, callback),
do: filter_by_indices_list(vars, i + 1, inputs, callback)

defp filter_by_indices_list([], _i, [], _callback),
do: []
end
5 changes: 4 additions & 1 deletion lib/nx_iree/vm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ defmodule NxIREE.VM do
def allocate_buffer(binary, device_ref, shape, type) when is_binary(binary) do
element_type = to_iree_type(type)

NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)
{:ok, buffer_ref} =
NxIREE.Native.allocate_buffer(binary, device_ref, Tuple.to_list(shape), element_type)

buffer_ref
end

def read_buffer(%NxIREE.Tensor{} = t) do
Expand Down
4 changes: 2 additions & 2 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
%{
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "7a3d7cd87efc9811fb8c86ec0b0b245e99bf7c6d", [sparse: "exla"]},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "ad28ea754dc2780b0b0726a062c46a58c588dc31", [sparse: "exla"]},
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "7a3d7cd87efc9811fb8c86ec0b0b245e99bf7c6d", [sparse: "nx"]},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "ad28ea754dc2780b0b0726a062c46a58c588dc31", [sparse: "nx"]},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"},
}

0 comments on commit da1cf45

Please sign in to comment.