Skip to content

Commit

Permalink
[FIX] Fix RPC for the VM (apache#7810)
Browse files Browse the repository at this point in the history
* [FIX] Fix RPC for the VM
  • Loading branch information
tkonolige authored Apr 15, 2021
1 parent 1ebfafd commit f57830b
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 21 deletions.
2 changes: 1 addition & 1 deletion python/tvm/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def connect(url, port, key="", session_timeout=0, session_constructor_args=None)
Additional key to match server
session_timeout : float, optional
The duration of the session, allows server to kill
The duration of the session in seconds, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
Expand Down
39 changes: 37 additions & 2 deletions python/tvm/runtime/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tvm._ffi import base as _base
from .object import Object
from . import _ffi_api, container
from ..rpc.base import RPC_SESS_MASK


def _convert(arg, cargs):
Expand Down Expand Up @@ -341,6 +342,9 @@ def __init__(self, exe, device, memory_cfg=None):
self._exec = exe
self._init = self.module["init"]
self._invoke = self.module["invoke"]
self._invoke_stateful = self.module["invoke_stateful"]
self._get_output = self.module["get_output"]
self._get_num_outputs = self.module["get_num_outputs"]
self._set_input = self.module["set_input"]
self._setup_device(device, memory_cfg)

Expand All @@ -356,7 +360,7 @@ def _setup_device(self, dev, memory_cfg):
devs = [dev]

# CPU is required for executing shape functions
if not any(c.device_type == tvm.cpu().device_type for c in devs):
if not any(c.device_type % RPC_SESS_MASK == tvm.cpu().device_type for c in devs):
devs.append(tvm.cpu())

default_alloc_type = VirtualMachine.POOLED_ALLOCATOR
Expand All @@ -374,7 +378,7 @@ def _setup_device(self, dev, memory_cfg):
)
init_args = []
for device in devs:
init_args.append(device.device_type)
init_args.append(device.device_type % RPC_SESS_MASK)
init_args.append(device.device_id)
alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type
init_args.append(alloc_type)
Expand Down Expand Up @@ -455,3 +459,34 @@ def run(self, *args, **kwargs):
The output.
"""
return self.invoke("main", *args, **kwargs)

def invoke_stateful(self, func_name, *args, **kwargs):
"""Invoke a function and ignore the returned result.
Use this function when running over rpc because it is currently
impossible to return a ADT object over rpc. To get the outputs, use
:py:func`get_outputs`.
Parameters
----------
func_name : str
The name of the function.
args : list[tvm.runtime.NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to tvm.runtime.NDArray or np.ndarray
Named arguments to the function.
"""
if args or kwargs:
self.set_input(func_name, *args, **kwargs)
self._invoke_stateful(func_name)

def get_outputs(self):
"""Get the outputs from a call to :py:func`invoke_stateful`.
Returns
-------
outputs : List[NDArray]
"""
return [self._get_output(i) for i in range(self._get_num_outputs())]
25 changes: 13 additions & 12 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,17 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
TVMArgs args = RecvPackedSeq();

this->SwitchToState(kWaitForAsyncCallback);
GetServingSession()->AsyncCallFunc(reinterpret_cast<void*>(call_handle), args.values,
args.type_codes, args.size(),
[this](RPCCode status, TVMArgs args) {
if (status == RPCCode::kException) {
this->ReturnException(args.values[0].v_str);
} else {
this->ReturnPackedSeq(args);
}
this->SwitchToState(kRecvPacketNumBytes);
});
GetServingSession()->AsyncCallFunc(
reinterpret_cast<void*>(call_handle), args.values, args.type_codes, args.size(),
[this](RPCCode status, TVMArgs args) {
if (status == RPCCode::kException) {
this->ReturnException(args.values[0].v_str);
} else {
ValidateArguments(args.values, args.type_codes, args.size());
this->ReturnPackedSeq(args);
}
this->SwitchToState(kRecvPacketNumBytes);
});
}

void HandleInitServer() {
Expand Down Expand Up @@ -637,7 +638,7 @@ RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncod
if (handler_->CanCleanShutdown()) {
return RPCCode::kShutdown;
} else {
LOG(FATAL) << "Channel closes before we get neded bytes";
LOG(FATAL) << "Channel closes before we get needed bytes";
}
}
}
Expand Down Expand Up @@ -794,7 +795,7 @@ void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_v
handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true);

code = HandleUntilReturnEvent(true, encode_return);
ICHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code);
}

void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) {
Expand Down
32 changes: 30 additions & 2 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
*rv = Invoke(func, func_args);
}
});
} else if (name == "invoke_stateful") {
// TODO(tkonolige, jroesch, tqchen): invoke_stateful and get_output are
// stop-gap measure to allow using vm over a remote connection.
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
PackedFunc invoke = GetFunction("invoke", sptr_to_self);
TVMRetValue rv_;
invoke.CallPacked(args, &rv_);
});
} else if (name == "get_output") {
return TypedPackedFunc<NDArray(int64_t)>([this](int64_t index) {
return Downcast<NDArray>(Downcast<ADT>(this->return_register_)[index]);
});
} else if (name == "get_num_outputs") {
return TypedPackedFunc<int64_t(void)>(
[this]() -> int64_t { return Downcast<ADT>(this->return_register_).size(); });
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size() % 3, 0);
Expand Down Expand Up @@ -165,8 +180,21 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
for (int i = 1; i < args.size(); ++i) {
Index device_type = vm_func.params_device_type[i - 1];
Device dev = GetDevice(device_type);
ObjectRef obj = CopyTo(args[i], dev);
func_args[i - 1] = obj;

if (args[i].type_code() == kTVMDLTensorHandle) {
// Automatically convert input DLTensors to NDArray
DLTensor* tensor = args[i];
std::vector<int64_t> shape;
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
NDArray ary = NDArray::Empty(shape, tensor->dtype, dev);
ary.CopyFrom(tensor);
func_args[i - 1] = ary;
} else {
ObjectRef obj = CopyTo(args[i], dev);
func_args[i - 1] = obj;
}
}
inputs_.erase(func_name);
inputs_.emplace(func_name, func_args);
Expand Down
17 changes: 14 additions & 3 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import numpy as np
import pytest
import time

import tvm
from tvm import runtime
Expand Down Expand Up @@ -823,8 +824,12 @@ def test_vm_rpc():
path = temp.relpath("vm_library.so")
vm_exec.mod.export_library(path)

# Use LocalRPC for testing.
remote = rpc.LocalSession()
# Use local rpc server for testing.
# Server must use popen so it doesn't inherit the current process state. It
# will crash otherwise.
server = rpc.Server("localhost", port=9120, use_popen=True)
time.sleep(2)
remote = rpc.connect(server.host, server.port, session_timeout=10)

# Upload the serialized Executable.
remote.upload(path)
Expand All @@ -837,10 +842,16 @@ def test_vm_rpc():
np_input = np.random.uniform(size=(10, 1)).astype("float32")
input_tensor = tvm.nd.array(np_input, ctx)
# Invoke its "main" function.
out = vm_factory.invoke("main", [input_tensor])
out = vm_factory.invoke("main", input_tensor)
# Check the result.
np.testing.assert_allclose(out.asnumpy(), np_input + np_input)

# delete tensors before the server shuts down so we don't throw errors.
del input_tensor
del out

server.terminate()


if __name__ == "__main__":
pytest.main([__file__])
2 changes: 1 addition & 1 deletion tests/python/unittest/test_runtime_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_vm(target, dev):
vm = profiler_vm.VirtualMachineProfiler(exe, dev)

data = np.random.rand(1, 1, 28, 28).astype("float32")
report = vm.profile([data], func_name="main")
report = vm.profile(data, func_name="main")
assert "fused_nn_softmax" in report
assert "Total time" in report

Expand Down
1 change: 1 addition & 0 deletions tests/scripts/task_python_integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ if python -c "import tvm; from tvm.relay.op.contrib.ethosn import ethosn_availab
fi
run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib tests/python/contrib

# forked is needed because the global registry gets contaminated
TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" \
run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-relay tests/python/relay

Expand Down

0 comments on commit f57830b

Please sign in to comment.