Skip to content

Commit

Permalink
Fix all RuntimeErrors during weights_only load from being erroneously…
Browse files Browse the repository at this point in the history
… reported with the weights_only message (pytorch#132349)

Caught in above PR pytorch#127627

Pull Request resolved: pytorch#132349
Approved by: https://github.com/albanD
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed Aug 16, 2024
1 parent 0d2be06 commit c8ad5e3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
11 changes: 6 additions & 5 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,12 @@ def __reduce_ex__(self, proto):

with tempfile.NamedTemporaryFile() as f:
torch.save({"spoofed": TensorSerializationSpoofer(x)}, f)
f.seek(0)
with self.assertRaisesRegex(
RuntimeError,
"size is inconsistent with indices"):
y = torch.load(f)
for weights_only in (False, True):
f.seek(0)
with self.assertRaisesRegex(
RuntimeError,
"size is inconsistent with indices"):
y = torch.load(f, weights_only=weights_only)

def _test_serialization_sparse_compressed_invalid(self,
conversion,
Expand Down
24 changes: 13 additions & 11 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ def load(self):
module = IMPORT_MAPPING[module]
full_path = f"{module}.{name}"
if module in _blocklisted_modules:
raise RuntimeError(
raise UnpicklingError(
f"Trying to load unsupported GLOBAL {full_path} whose module {module} is blocked."
)
if full_path in _get_allowed_globals():
self.append(_get_allowed_globals()[full_path])
elif full_path in _get_user_allowed_globals():
self.append(_get_user_allowed_globals()[full_path])
else:
raise RuntimeError(
raise UnpicklingError(
f"Unsupported global: GLOBAL {full_path} was not an allowed global by default. "
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
"this global if you trust this class/function."
Expand All @@ -256,15 +256,17 @@ def load(self):
elif cls in _get_user_allowed_globals().values():
self.append(cls.__new__(cls, *args))
else:
raise RuntimeError(f"Trying to instantiate unsupported class {cls}")
raise UnpicklingError(
f"Trying to instantiate unsupported class {cls}"
)
elif key[0] == REDUCE[0]:
args = self.stack.pop()
func = self.stack[-1]
if (
func not in _get_allowed_globals().values()
and func not in _get_user_allowed_globals().values()
):
raise RuntimeError(
raise UnpicklingError(
f"Trying to call reduce for unrecognized function {func}"
)
self.stack[-1] = func(*args)
Expand All @@ -284,23 +286,23 @@ def load(self):
else:
inst.__dict__.update(state)
else:
raise RuntimeError(
raise UnpicklingError(
f"Can only build Tensor, parameter or OrderedDict objects, but got {type(inst)}"
)
# Stack manipulation
elif key[0] == APPEND[0]:
item = self.stack.pop()
list_obj = self.stack[-1]
if type(list_obj) is not list:
raise RuntimeError(
raise UnpicklingError(
f"Can only append to lists, but got {type(list_obj)}"
)
list_obj.append(item)
elif key[0] == APPENDS[0]:
items = self.pop_mark()
list_obj = self.stack[-1]
if type(list_obj) is not list:
raise RuntimeError(
raise UnpicklingError(
f"Can only extend lists, but got {type(list_obj)}"
)
list_obj.extend(items)
Expand Down Expand Up @@ -350,7 +352,7 @@ def load(self):
elif key[0] == BINUNICODE[0]:
strlen = unpack("<I", read(4))[0]
if strlen > maxsize:
raise RuntimeError("String is too long")
raise UnpicklingError("String is too long")
strval = str(read(strlen), "utf-8", "surrogatepass")
self.append(strval)
elif key[0] == SHORT_BINSTRING[0]:
Expand All @@ -363,15 +365,15 @@ def load(self):
pid = self.stack.pop()
# Only allow persistent load of storage
if type(pid) is not tuple and not type(pid) is not int:
raise RuntimeError(
raise UnpicklingError(
f"persistent_load id must be tuple or int, but got {type(pid)}"
)
if (
type(pid) is tuple
and len(pid) > 0
and torch.serialization._maybe_decode_ascii(pid[0]) != "storage"
):
raise RuntimeError(
raise UnpicklingError(
f"Only persistent_load of storage is allowed, but got {pid[0]}"
)
self.append(self.persistent_load(pid))
Expand Down Expand Up @@ -401,7 +403,7 @@ def load(self):
rc = self.stack.pop()
return rc
else:
raise RuntimeError(f"Unsupported operand {key[0]}")
raise UnpicklingError(f"Unsupported operand {key[0]}")

# Return a list of items pushed in the stack after last MARK instruction.
def pop_mark(self):
Expand Down
4 changes: 2 additions & 2 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,7 +1258,7 @@ def _get_wo_message(message: str) -> str:
overall_storage=overall_storage,
**pickle_load_args,
)
except RuntimeError as e:
except pickle.UnpicklingError as e:
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
return _load(
opened_zipfile,
Expand All @@ -1282,7 +1282,7 @@ def _get_wo_message(message: str) -> str:
_weights_only_unpickler,
**pickle_load_args,
)
except RuntimeError as e:
except pickle.UnpicklingError as e:
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
return _legacy_load(
opened_file, map_location, pickle_module, **pickle_load_args
Expand Down

0 comments on commit c8ad5e3

Please sign in to comment.