Skip to content

Commit

Permalink
Enable switching Glow backend (pytorch#3698)
Browse files Browse the repository at this point in the history
Summary:
Creates a new HostManager for the given Glow backend and graphs fused afterwards will be fused for the HostManager containing devices of the given type.
Previously fused nodes may contain shared references to previously created HostManager instances so these will only be destroyed if the PT graphs containing those nodes are destroyed to prevent invalidation of previously fused graph.

Documentation:
doxygen
Pull Request resolved: pytorch#3698

Test Plan: added unit test

Differential Revision: D18265179

Pulled By: jackm321

fbshipit-source-id: 03404ccde041a0801d7eeef7960e1388f9ebb008
  • Loading branch information
jackm321 authored and facebook-github-bot committed Nov 5, 2019
1 parent 5b89f3d commit 3d1caf9
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 16 deletions.
59 changes: 46 additions & 13 deletions torch_glow/src/PyTorchCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,60 @@ namespace glow {
bool GlowCompilePyTorchModule = false;

namespace {
/// Builds and \returns a HostManager instance.
std::unique_ptr<runtime::HostManager> buildHostManager() {
constexpr size_t numGlowDevices = 1;

std::vector<std::unique_ptr<runtime::DeviceConfig>> deviceConfigs;
for (int i = 0; i < numGlowDevices; i++) {
deviceConfigs.push_back(llvm::make_unique<runtime::DeviceConfig>(
getPyTorchLoaderSettings().glowBackendName));
}

return llvm::make_unique<runtime::HostManager>(std::move(deviceConfigs));
/// GlowBackendState stores the currently active Glow HostManager that will
/// be used to run the subgraphs lowered to Glow. It also contains information
/// about the number and type of backend devices owned by the HostManager.
struct GlowBackendState {
std::shared_ptr<runtime::HostManager> hostManager;
std::string backendName;
size_t numDevices = 0;
};

/// Meyers singleton for GlowBackendState.
GlowBackendState *getGlowBackendState() {
static GlowBackendState state_;
return &state_;
}

} // namespace

/// \returns the HostManager singleton used to run all PyTorch graphs in Glow.
std::shared_ptr<runtime::HostManager> getHostManager() {
static std::shared_ptr<runtime::HostManager> hostManager = buildHostManager();
auto hostManager = getGlowBackendState()->hostManager;
// If no HostManager has been set, use Glow's Interpreter.
if (!hostManager) {
setHostManager("Interpreter");
hostManager = getGlowBackendState()->hostManager;
}
return hostManager;
}

const std::string &getBackendName() {
return getGlowBackendState()->backendName;
}

size_t getBackendNumDevices() { return getGlowBackendState()->numDevices; }

void setHostManager(const std::string &backendName, size_t numDevices) {
auto *state = getGlowBackendState();

// Don't create a new identical HostManager.
if (state->backendName == backendName && state->numDevices == numDevices) {
return;
}

state->backendName = backendName;
state->numDevices = numDevices;

std::vector<std::unique_ptr<runtime::DeviceConfig>> deviceConfigs;
for (int i = 0; i < numDevices; i++) {
deviceConfigs.push_back(
llvm::make_unique<runtime::DeviceConfig>(backendName));
}

state->hostManager =
std::make_shared<runtime::HostManager>(std::move(deviceConfigs));
}

/// Given a Glow ElemKind \p ty, \returns a matching PyTorch ScalarType.
c10::ScalarType elemKindToScalarType(glow::ElemKind ty) {
switch (ty) {
Expand Down
13 changes: 10 additions & 3 deletions torch_glow/src/PyTorchCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ struct PyTorchLoaderSettings {
/// A list of symbols for nodes that will be ignored by the Glow fuser and
/// thus will not be fused to Glow.
std::unordered_set<torch::jit::Symbol> opBlacklist;

/// Name of the Glow backend to use with CachingGraphRunner's HostManager.
std::string glowBackendName = "Interpreter";
};

/// Given a PyTorch ScalarType \p ty, \returns a matching Glow ElemKind.
Expand All @@ -62,6 +59,16 @@ PyTorchLoaderSettings &getPyTorchLoaderSettings();
/// \returns the HostManager singleton used to run all PyTorch graphs in Glow.
std::shared_ptr<runtime::HostManager> getHostManager();

/// Set the active HostManager to one that owns \p numDevices of type
/// \p backendName.
void setHostManager(const std::string &backendName, size_t numDevices = 1);

/// \returns the name of the device backend used by the active HostManager.
const std::string &getBackendName();

/// \returns the quantity of the device backends used by the active HostManager.
size_t getBackendNumDevices();

/// \returns the PyTorch symbol to be used for the PyTorch node which represents
/// the subgraph that Glow will compile and run.
const c10::Symbol &getGlowSymbol();
Expand Down
19 changes: 19 additions & 0 deletions torch_glow/src/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,25 @@ PYBIND11_MODULE(_torch_glow, m) {
m.def("clearFusionBlacklist",
[]() { getPyTorchLoaderSettings().opBlacklist.clear(); });

/// Set the active HostManager to one that owns 1 of type \p backendName.
m.def("setGlowBackend", [](const std::string &glowBackendName) {
setHostManager(glowBackendName);
});

/// Set the active HostManager to one that owns \p numDevices of type
/// \p backendName.
m.def("setGlowBackend",
[](const std::string &glowBackendName, size_t numDevices) {
setHostManager(glowBackendName, numDevices);
});

/// \returns the name of the device backend used by the active HostManager.
m.def("getGlowBackendName", []() { return getBackendName(); });

/// \returns the quantity of the device backends used by the active
/// HostManager.
m.def("getGlowBackendNumDevices", []() { return getBackendNumDevices(); });

/// Binding wrapper class for TorchGlowTraining and its settings.
py::class_<TorchGlowTrainingWrapper>(m, "TorchGlowTrainingWrapper")
.def(py::init())
Expand Down
18 changes: 18 additions & 0 deletions torch_glow/tests/functionality/set_glow_backend_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import torch_glow


def test_set_glow_backend():
"""Test setting the Glow backend type"""

backend_name_before = torch_glow.getGlowBackendName()
backend_num_devices_before = torch_glow.getGlowBackendNumDevices()

torch_glow.setGlowBackend("CPU", 4)

assert(torch_glow.getGlowBackendName() == "CPU")
assert(torch_glow.getGlowBackendNumDevices() == 4)

# reset everything
torch_glow.setGlowBackend(backend_name_before, backend_num_devices_before)

0 comments on commit 3d1caf9

Please sign in to comment.