Skip to content

Commit

Permalink
[WASI] Add WASI-NN host functions.
Browse files Browse the repository at this point in the history
Signed-off-by: YiYing He <[email protected]>
  • Loading branch information
q82419 authored and hydai committed Jun 22, 2022
1 parent 0e0ef59 commit 7eb9ff3
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/common/enum.inc
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ P(FunctionReferences, "Typed Function References")
#define H Line
H(Wasi)
H(WasmEdge_Process)
H(WasiNN)
#undef H
#endif // UseHostRegistration

Expand Down
18 changes: 18 additions & 0 deletions include/host/wasi_nn/wasinncontext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2019-2022 Second State INC

#pragma once

namespace WasmEdge {
namespace Host {

class WasiNNContext {
public:
WasiNNContext() = default;
~WasiNNContext() = default;
// context for implementing WASI-NN
// Add and implement the context data here.
};

} // namespace Host
} // namespace WasmEdge
59 changes: 59 additions & 0 deletions include/host/wasi_nn/wasinnfunc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2019-2022 Second State INC

#pragma once

#include "common/errcode.h"
#include "host/wasi_nn/wasinncontext.h"
#include "runtime/hostfunc.h"
#include "runtime/instance/memory.h"

namespace WasmEdge {
namespace Host {

template <typename T> class WasiNN : public Runtime::HostFunction<T> {
public:
WasiNN(WasiNNContext &HostCtx) : Runtime::HostFunction<T>(0), Ctx(HostCtx) {}

protected:
WasiNNContext &Ctx;
};

class WasiNNLoad : public WasiNN<WasiNNLoad> {
public:
WasiNNLoad(WasiNNContext &HostCtx) : WasiNN(HostCtx) {}
Expect<uint32_t> body(Runtime::Instance::MemoryInstance *,
uint32_t BuilderPtr, uint32_t BuilderLen,
uint32_t Encoding, uint32_t Target, uint32_t GraphPtr);
};

class WasiNNInitExecCtx : public WasiNN<WasiNNInitExecCtx> {
public:
WasiNNInitExecCtx(WasiNNContext &HostCtx) : WasiNN(HostCtx) {}
Expect<uint32_t> body(Runtime::Instance::MemoryInstance *, uint32_t Graph,
uint32_t ContextPtr);
};

class WasiNNSetInput : public WasiNN<WasiNNSetInput> {
public:
WasiNNSetInput(WasiNNContext &HostCtx) : WasiNN(HostCtx) {}
Expect<uint32_t> body(Runtime::Instance::MemoryInstance *, uint32_t Context,
uint32_t Index, uint32_t TensorPtr);
};

class WasiNNGetOuput : public WasiNN<WasiNNGetOuput> {
public:
WasiNNGetOuput(WasiNNContext &HostCtx) : WasiNN(HostCtx) {}
Expect<uint32_t> body(Runtime::Instance::MemoryInstance *, uint32_t Context,
uint32_t Index, uint32_t OutBuffer,
uint32_t OutBufferMaxSize, uint32_t BytesWrittenPtr);
};

class WasiNNCompute : public WasiNN<WasiNNCompute> {
public:
WasiNNCompute(WasiNNContext &HostCtx) : WasiNN(HostCtx) {}
Expect<uint32_t> body(Runtime::Instance::MemoryInstance *, uint32_t Context);
};

} // namespace Host
} // namespace WasmEdge
21 changes: 21 additions & 0 deletions include/host/wasi_nn/wasinnmodule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2019-2022 Second State INC

#pragma once

#include "host/wasi_nn/wasinncontext.h"
#include "runtime/importobj.h"

namespace WasmEdge {
namespace Host {

class WasiNNModule : public Runtime::ImportObject {
public:
WasiNNModule();

private:
WasiNNContext Ctx;
};

} // namespace Host
} // namespace WasmEdge
1 change: 1 addition & 0 deletions lib/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ if(WASMEDGE_BUILD_STATIC_LIB)
wasmedge_add_static_lib_component_command(wasmedgeValidator)
wasmedge_add_static_lib_component_command(wasmedgeExecutor)
wasmedge_add_static_lib_component_command(wasmedgeHostModuleWasi)
wasmedge_add_static_lib_component_command(wasmedgeHostModuleWasiNN)
wasmedge_add_static_lib_component_command(wasmedgePlugin)
wasmedge_add_static_lib_component_command(wasmedgeVM)

Expand Down
1 change: 1 addition & 0 deletions lib/host/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# SPDX-FileCopyrightText: 2019-2022 Second State INC

add_subdirectory(wasi)
add_subdirectory(wasi_nn)
12 changes: 12 additions & 0 deletions lib/host/wasi_nn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

wasmedge_add_library(wasmedgeHostModuleWasiNN
wasinnfunc.cpp
wasinnmodule.cpp
)

target_link_libraries(wasmedgeHostModuleWasiNN
PUBLIC
wasmedgeCommon
wasmedgeSystem
)
52 changes: 52 additions & 0 deletions lib/host/wasi_nn/wasinnfunc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2019-2022 Second State INC

#include "host/wasi_nn/wasinnfunc.h"
#include "common/errcode.h"
#include "runtime/hostfunc.h"
#include "runtime/instance/memory.h"

namespace WasmEdge {
namespace Host {

Expect<uint32_t> WasiNNLoad::body(Runtime::Instance::MemoryInstance *MemInst
[[maybe_unused]],
uint32_t BuilderPtr [[maybe_unused]],
uint32_t BuilderLen [[maybe_unused]],
uint32_t Encoding [[maybe_unused]],
uint32_t Target [[maybe_unused]],
uint32_t GraphPtr [[maybe_unused]]) {
return 0;
}

Expect<uint32_t> WasiNNInitExecCtx::body(
Runtime::Instance::MemoryInstance *MemInst [[maybe_unused]],
uint32_t Graph [[maybe_unused]], uint32_t ContextPtr [[maybe_unused]]) {
return 0;
}

Expect<uint32_t> WasiNNSetInput::body(Runtime::Instance::MemoryInstance *MemInst
[[maybe_unused]],
uint32_t Context [[maybe_unused]],
uint32_t Index [[maybe_unused]],
uint32_t TensorPtr [[maybe_unused]]) {
return 0;
}

Expect<uint32_t> WasiNNGetOuput::body(
Runtime::Instance::MemoryInstance *MemInst [[maybe_unused]],
uint32_t Context [[maybe_unused]], uint32_t Index [[maybe_unused]],
uint32_t OutBuffer [[maybe_unused]],
uint32_t OutBufferMaxSize [[maybe_unused]],
uint32_t BytesWrittenPtr [[maybe_unused]]) {
return 0;
}

Expect<uint32_t> WasiNNCompute::body(Runtime::Instance::MemoryInstance *MemInst
[[maybe_unused]],
uint32_t Context [[maybe_unused]]) {
return 0;
}

} // namespace Host
} // namespace WasmEdge
20 changes: 20 additions & 0 deletions lib/host/wasi_nn/wasinnmodule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2019-2022 Second State INC

#include "host/wasi_nn/wasinnmodule.h"
#include "host/wasi_nn/wasinnfunc.h"

namespace WasmEdge {
namespace Host {

WasiNNModule::WasiNNModule() : ImportObject("wasi_ephemeral_nn") {
addHostFunc("load", std::make_unique<WasiNNLoad>(Ctx));
addHostFunc("init_execution_context",
std::make_unique<WasiNNInitExecCtx>(Ctx));
addHostFunc("set_input", std::make_unique<WasiNNSetInput>(Ctx));
addHostFunc("get_output", std::make_unique<WasiNNGetOuput>(Ctx));
addHostFunc("compute", std::make_unique<WasiNNCompute>(Ctx));
}

} // namespace Host
} // namespace WasmEdge
1 change: 1 addition & 0 deletions lib/vm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ target_link_libraries(wasmedgeVM
wasmedgeValidator
wasmedgeExecutor
wasmedgeHostModuleWasi
wasmedgeHostModuleWasiNN
)
7 changes: 7 additions & 0 deletions lib/vm/vm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "vm/async.h"

#include "host/wasi/wasimodule.h"
#include "host/wasi_nn/wasinnmodule.h"
#include "plugin/plugin.h"

namespace WasmEdge {
Expand Down Expand Up @@ -64,6 +65,12 @@ void VM::unsafeInitVM() {
std::move(ModObj));
}
}
if (Conf.hasHostRegistration(HostRegistration::WasiNN)) {
std::unique_ptr<Runtime::ImportObject> WasiNNMod =
std::make_unique<Host::WasiNNModule>();
ExecutorEngine.registerModule(StoreRef, *WasiNNMod.get());
ImpObjs.insert({HostRegistration::WasiNN, std::move(WasiNNMod)});
}
}

Expect<void> VM::unsafeRegisterModule(std::string_view Name,
Expand Down
1 change: 1 addition & 0 deletions tools/wasmedge/wasmedger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ int main(int Argc, const char *Argv[]) {

Conf.addHostRegistration(WasmEdge::HostRegistration::Wasi);
Conf.addHostRegistration(WasmEdge::HostRegistration::WasmEdge_Process);
Conf.addHostRegistration(WasmEdge::HostRegistration::WasiNN);
const auto InputPath = std::filesystem::absolute(SoName.value());
WasmEdge::VM::VM VM(Conf);

Expand Down

0 comments on commit 7eb9ff3

Please sign in to comment.