Skip to content

Commit

Permalink
[jit] ClassType hashing: hash on compilation_unit as well (pytorch#12…
Browse files Browse the repository at this point in the history
…1928)

Following up on pytorch#121874 - it turns out that in our case, we're seeing repeated class names that are from different compilation units.  Our previous hash function wasn't considering the compilation unit, leading to hash collisions (and then exponential memory usage in the number of copies of this class name)

Differential Revision: [D54916455](https://our.internmc.facebook.com/intern/diff/D54916455)
Pull Request resolved: pytorch#121928
Approved by: https://github.com/eellison
ghstack dependencies: pytorch#121874
  • Loading branch information
davidberard98 authored and pytorchmergebot committed Mar 14, 2024
1 parent 2d9cee2 commit cceabe8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
52 changes: 52 additions & 0 deletions test/jit/test_alias_analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["oncall: jit"]

from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.jit_utils import JitTestCase
from torch._C import parse_ir
import torch
Expand Down Expand Up @@ -91,3 +92,54 @@ def foo2(self, x, y):
inps = list(node.inputs())
self.assertTrue(alias_db.has_writers(inps[1]))
self.assertFalse(alias_db.has_writers(inps[2]))

def test_multiple_compilation_units(self):
# This is a repro of an internal issue we saw.
# Here, we have a large number (40) of modules each with the same name (MyModuleCUTest).
# AliasDB uses some hash tables that hash on types; each of these 40 modules are not
# identical because they have different compilation units, but they have the same name.
# Therefore, if we hash only on the module name (which we previously did), we will have
# hash collisions for all of these module types.
#
# flat_hash_map has very bad performance (exponential) for this hash collision behavior.
# This OOMs prior to the fix.
N = 40

class MultiTmpFile:
def __init__(self, N):
self.N = N
self.ctxs = [TemporaryFileName(mode="w", suffix=".py") for _ in range(N)]

def __enter__(self):
return [x.__enter__() for x in self.ctxs]

def __exit__(self, exc_type, exc_value, traceback):
return [x.__exit__(exc_type, exc_value, traceback) for x in self.ctxs]

class ModuleWrapper(torch.nn.Module):
def __init__(self, module_list):
super().__init__()
self.module_list = module_list

def forward(self, x):
for mod in self.module_list:
x = mod(x)
return x

with MultiTmpFile(N) as fnames:
module_list = torch.nn.ModuleList()
global MyModuleCUTest

class MyModuleCUTest(torch.nn.Module):
def forward(self, x):
return x + 2

for _, fname in enumerate(fnames):
mod = torch.jit.script(MyModuleCUTest())
torch.jit.save(mod, fname)
loaded_mod = torch.jit.load(fname)
module_list.append(loaded_mod)

mod = ModuleWrapper(module_list)
mod = torch.jit.script(mod)
mod(torch.zeros((2, 2)))
3 changes: 2 additions & 1 deletion torch/csrc/jit/ir/type_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace torch::jit {
namespace {
size_t hashType(const Type& type) {
if (auto named_type = type.castRaw<ClassType>()) {
return get_hash(named_type->name().value());
return c10::get_hash(
named_type->name().value(), named_type->compilation_unit());
}
size_t hash = 0;
for (const auto& containedType : type.containedTypes()) {
Expand Down

0 comments on commit cceabe8

Please sign in to comment.