Skip to content

Commit

Permalink
make ProcessException pickleable (pytorch#70118)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#70116

Happy to add tests if you let me know the best place to put them.

cc VitalyFedyunin

Pull Request resolved: pytorch#70118

Reviewed By: malfet

Differential Revision: D33255899

Pulled By: ejguan

fbshipit-source-id: 41d495374182eb28bb8bb421e890eca3bddc077b
  • Loading branch information
epwalsh authored and facebook-github-bot committed Dec 30, 2021
1 parent 9c742be commit 14d3d29
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/test_multiprocessing_spawn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["module: multiprocessing"]

import os
import pickle
import random
import signal
import sys
Expand Down Expand Up @@ -218,5 +219,15 @@ def test_process_exited(self):
class ForkTest(TestCase, _TestMultiProcessing):
start_method = 'fork'


class ErrorTest(TestCase):
def test_errors_pickleable(self):
for error in (
mp.ProcessRaisedException("Oh no!", 1, 1),
mp.ProcessExitedException("Oh no!", 1, 1, 1),
):
pickle.loads(pickle.dumps(error))


if __name__ == '__main__':
run_tests()
10 changes: 10 additions & 0 deletions torch/multiprocessing/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ class ProcessException(Exception):

def __init__(self, msg: str, error_index: int, pid: int):
super().__init__(msg)
self.msg = msg
self.error_index = error_index
self.pid = pid

def __reduce__(self):
return type(self), (self.msg, self.error_index, self.pid)


class ProcessRaisedException(ProcessException):
"""
Expand Down Expand Up @@ -47,6 +51,12 @@ def __init__(
self.exit_code = exit_code
self.signal_name = signal_name

def __reduce__(self):
return (
type(self),
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
)


def _wrap(fn, i, args, error_queue):
# prctl(2) is a Linux specific system call.
Expand Down

0 comments on commit 14d3d29

Please sign in to comment.