Skip to content

Commit

Permalink
[WEB] WebGPU support (apache#5545)
Browse files Browse the repository at this point in the history
This PR introduces WebGPU support to tvm.
The WebGPU runtime is directly built in javascript(as WebGPU uses JS as the first class citizen API)
and exposes back to the tvm's runtime via PackedFuncs.

One important note is that `ctx.sync` is not async.
This is due to the fact that WebGPU is a purely async API and we cannot block in the web environment.

So the current best way to use the js api is to wrap things in an async function.
When copy a GPU array to CPU, `await ctx.sync()` need to be called to wait for copy completion.

We use a AsyncIO rpc server to serve the async functions to the clients.
  • Loading branch information
tqchen authored May 9, 2020
1 parent 0c43fa0 commit cdc7ae4
Show file tree
Hide file tree
Showing 35 changed files with 1,459 additions and 179 deletions.
1 change: 1 addition & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ typedef enum {
kOpenGL = 11,
kDLMicroDev = 13,
kDLHexagon = 14,
kDLWebGPU = 15
// AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;

Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ inline const char* DeviceName(int type) {
case kDLROCM: return "rocm";
case kOpenGL: return "opengl";
case kDLExtDev: return "ext_dev";
case kDLWebGPU: return "webgpu";
case kDLMicroDev: return "micro_dev";
case kDLHexagon: return "hexagon";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/libinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def find_lib_path(name=None, search_path=None, optional=False):

if os.path.isdir(source_dir):
dll_path.append(os.path.join(source_dir, "web", "dist", "wasm"))
dll_path.append(os.path.join(source_dir, "web", "dist"))

dll_path = [os.path.realpath(x) for x in dll_path]
if search_path is not None:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class TVMContext(ctypes.Structure):
12: 'ext_dev',
13: 'micro_dev',
14: 'hexagon',
15: 'webgpu'
}
STR2MASK = {
'llvm': 1,
Expand All @@ -169,6 +170,7 @@ class TVMContext(ctypes.Structure):
'ext_dev': 12,
'micro_dev': 13,
'hexagon': 14,
'webgpu': 15,
}
def __init__(self, device_type, device_id):
super(TVMContext, self).__init__()
Expand Down
1 change: 1 addition & 0 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _alias(name):
'vtacpu': 'vta',

'metal': 'opencl',
'webgpu': 'opencl',
'vulkan': 'opencl',
'nvptx': 'cuda',
}
Expand Down
1 change: 1 addition & 0 deletions python/tvm/contrib/emcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def create_tvmjs_wasm(output,
objects += [find_lib_path("wasm_runtime.bc")[0]]

objects += [find_lib_path("tvmjs_support.bc")[0]]
objects += [find_lib_path("webgpu_runtime.bc")[0]]

cmd += ["-o", output]
cmd += objects
Expand Down
16 changes: 11 additions & 5 deletions python/tvm/exec/rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,20 @@ def find_example_resource():
curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
base_path = os.path.abspath(os.path.join(curr_path, "..", "..", ".."))
index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html")
js_files = [
os.path.join(base_path, "web/dist/tvmjs.bundle.js"),
os.path.join(base_path, "web/dist/wasm/tvmjs_runtime.wasi.js")
resource_files = [
os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"),
os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js")
]
for fname in [index_page] + js_files:
resource_base = os.path.join(base_path, "web", "dist", "www")
if os.path.isdir(resource_base):
for fname in os.listdir(resource_base):
full_name = os.path.join(resource_base, fname)
if os.path.isfile(full_name):
resource_files.append(full_name)
for fname in [index_page] + resource_files:
if not os.path.exists(fname):
raise RuntimeError("Cannot find %s" % fname)
return index_page, js_files
return index_page, resource_files


def main(args):
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def ext_dev(self, dev_id=0):
"""Construct extension device."""
return self.context(12, dev_id)

def webgpu(self, dev_id=0):
"""Construct WebGPU device."""
return self.context(15, dev_id)


class LocalSession(RPCSession):
"""RPCSession interface backed by local environment.
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/rpc/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def close_pair(self):
def on_close_event(self):
"""on close event"""
assert not self._done
logging.info("RPCProxy:on_close %s ...", self.name())
logging.info("RPCProxy:on_close_event %s ...", self.name())
if self.match_key:
key = self.match_key
if self._proxy._client_pool.get(key, None) == self:
Expand Down Expand Up @@ -158,10 +158,12 @@ def on_message(self, message):
self.on_data(message)

def on_close(self):
logging.info("RPCProxy: on_close %s ...", self.name())
self._close_process = True

if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
logging.info("%s Close socket..", self.name())
self.on_close_event()


Expand All @@ -187,6 +189,7 @@ def send_data(self, message):
self.on_error(err)

def on_close(self):
logging.info("RPCProxy: on_close %s ...", self.name())
if self.forward_proxy:
self.forward_proxy.signal_close()
self.forward_proxy = None
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def export_library(self,

if self.imported_modules:
if enabled("llvm") and llvm_target_triple:
path_obj = temp.relpath("devc.o")
path_obj = temp.relpath("devc." + object_format)
m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_triple)
m.save(path_obj)
files.append(path_obj)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,22 @@ def hexagon(dev_id=0):
return TVMContext(14, dev_id)


def webgpu(dev_id=0):
"""Construct a webgpu device.
Parameters
----------
dev_id : int, optional
The integer device id
Returns
-------
ctx : TVMContext
The created context
"""
return TVMContext(15, dev_id)


cl = opencl
mtl = metal

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down
9 changes: 5 additions & 4 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,10 @@ class VulkanModuleNode final : public runtime::ModuleNode {
}
}

std::shared_ptr<VulkanPipeline> GetPipeline(size_t device_id, const std::string& func_name,
size_t num_pack_args) {
std::shared_ptr<VulkanPipeline> GetPipeline(
size_t device_id,
const std::string& func_name,
size_t num_pack_args) {
const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
std::lock_guard<std::mutex> lock(mutex_);
const auto& cp = ecache_[device_id][func_name];
Expand All @@ -776,6 +778,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
uint32_t num_pod = 0, num_buffer = 0;

{
auto fit = fmap_.find(func_name);
CHECK(fit != fmap_.end());
Expand Down Expand Up @@ -931,8 +934,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
}

private:
// the binary data
std::vector<uint32_t> data_;
// function information table.
std::unordered_map<std::string, VulkanShader> smap_;
// function information table.
Expand Down
21 changes: 18 additions & 3 deletions src/target/spirv/build_vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SPIRVTools {
spv_context ctx_;
};

runtime::Module BuildSPIRV(IRModule mod, std::string target) {
runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restriction) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;

Expand Down Expand Up @@ -98,7 +98,15 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target) {
std::string f_name = global_symbol.value();

VulkanShader shader;
shader.data = cg.BuildFunction(f);
std::string entry = webgpu_restriction ? "main" : f_name;
shader.data = cg.BuildFunction(f, entry);

if (webgpu_restriction) {
for (auto param : f->params) {
CHECK(param.dtype().is_handle())
<< "WebGPU does not yet support non-buffer arguments";
}
}

if (postproc != nullptr) {
TVMByteArray arr;
Expand All @@ -119,7 +127,14 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target) {
}

TVM_REGISTER_GLOBAL("target.build.vulkan")
.set_body_typed(BuildSPIRV);
.set_body_typed([](IRModule mod, std::string target) {
return BuildSPIRV(mod, target, false);
});

TVM_REGISTER_GLOBAL("target.build.webgpu")
.set_body_typed([](IRModule mod, std::string target) {
return BuildSPIRV(mod, target, true);
});

} // namespace codegen
} // namespace tvm
11 changes: 4 additions & 7 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
namespace tvm {
namespace codegen {

std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
std::vector<uint32_t> CodeGenSPIRV::BuildFunction(
const PrimFunc& f,
const std::string& name) {
this->InitFuncState();
CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias))
<< "SPIRV only takes restricted memory model";
Expand Down Expand Up @@ -77,12 +79,7 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);

auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";

builder_->CommitKernelFunction(
func_ptr, static_cast<std::string>(global_symbol.value()));
builder_->CommitKernelFunction(func_ptr, name);

return builder_->Finalize();
}
Expand Down
5 changes: 4 additions & 1 deletion src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <vector>
#include <memory>
#include <unordered_map>
#include <string>

#include "ir_builder.h"
#include "../../runtime/thread_storage_scope.h"
Expand All @@ -51,9 +52,11 @@ class CodeGenSPIRV:
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
* \param name The name of the target function.
* \return The final spirv module.
*/
virtual std::vector<uint32_t> BuildFunction(const PrimFunc& f);
virtual std::vector<uint32_t> BuildFunction(const PrimFunc& f,
const std::string& name);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
Expand Down
32 changes: 32 additions & 0 deletions src/target/spirv/intrin_rule_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);


TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);

Expand All @@ -77,6 +78,37 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);

// WebGPU rules.
TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Ceil>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.round")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Round>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.trunc")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Trunc>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs")
.set_body(DispatchGLSLPureIntrin<GLSLstd450FAbs>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh")
.set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);

} // namespace spirv
} // namespace codegen
} // namespace tvm
Loading

0 comments on commit cdc7ae4

Please sign in to comment.