8000 add submodules to sys.modules so their attributes can be pickled (#53… · pytorch/pytorch@fbf9745 · GitHub
[go: up one dir, main page]

Skip to content

Commit fbf9745

Browse files
mattipfacebook-github-bot
authored andcommitted
add submodules to sys.modules so their attributes can be pickled (#53107)
Summary: Fixes #38137 As mentioned in the issue, this is a workaround for [python issue 43367](https://bugs.python.org/issue43367). There are a number of other places where `sys.modules` is modified, if something changes in python perhaps those should be reviewed as well. Pull Request resolved: #53107 Reviewed By: zou3519 Differential Revision: D26753571 Pulled By: ezyang fbshipit-source-id: 2bda03bab39ff9ca58ce4bc13befe021da91b9c4
1 parent aa603cb commit fbf9745

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

test/test_nn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14926,6 +14926,14 @@ def test_invalid_functions(self):
1492614926
with self.assertRaisesRegex(ValueError, 'uninitialized parameter'):
1492714927
param + param
1492814928

14929+
class TestFunctionalPickle(TestCase):
14930+
14931+
# issue gh-38137
14932+
def test_pickle_softsign(self):
14933+
# Make sure it does not throw an exception
14934+
s = pickle.dumps(F.softsign)
14935+
14936+
1492914937
instantiate_device_type_tests(TestNNDeviceType, globals())
1493014938

1493114939
if __name__ == '__main__':

torch/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,19 @@ def _load_global_deps():
230230
if name[0] != '_' and
231231
not name.endswith('Base')]
232232

233+
if not TYPE_CHECKING:
234+
# issue 38137 and python issue 43367. Submodules of a C extension are
235+
# non-standard, and attributes of those submodules cannot be pickled since
236+
# pickle expect to be able to import them as "from _C.sub import attr"
237+
# which fails with "_C is not a package
238+
for attr in dir(_C):
239+
candidate = getattr(_C, attr)
240+
if type(candidate) is type(_C):
241+
# submodule
242+
if f'torch._C.{attr}' not in sys.modules:
243+
sys.modules[f'torch._C.{attr}'] = candidate
244+
245+
233246
################################################################################
234247
# Define basic utilities
235248
################################################################################

0 commit comments

Comments
 (0)
0