8000 Update on "Fix loading sparse tensors with pinning check in fork cont… · pytorch/pytorch@6f719f8 · GitHub
[go: up one dir, main page]

Skip to content < 8000 link crossorigin="anonymous" media="all" rel="stylesheet" href="https://github.githubassets.com/assets/keyboard-shortcuts-dialog.f8fba3bd67fe74f9227b.module.css" />

Commit 6f719f8

Browse files
committed
Update on "Fix loading sparse tensors with pinning check in fork context."
As in the title. Fixes #153143 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer jcaip andrewkho divyanshk SsnL VitalyFedyunin dzhulgakov [ghstack-poisoned]
1 parent d8d956e commit 6f719f8

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

test/test_dataloader.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,22 @@
135135

136136
# The following collate functions are defined globally here for pickle purposes.
137137

138+
138139
# collate_fn that returns the batch cloned
139140
def _clone_collate(b):
140141
return [x.clone() for x in b]
141142

143+
142144
# collate_fn that returns the batch of sparse coo tensors re-constructed & cloned
143145
def _sparse_coo_collate(b):
144146
# we'll use constructor prior clone to force sparse tensor
145147
# invariant checks, required to reproduce gh-153143
146-
return [torch.sparse_coo_tensor(x._indices(), x._values(), x.shape, check_invariants=True).clone()
147-
for x in b]
148+
return [
149+
torch.sparse_coo_tensor(
150+
x._indices(), x._values(), x.shape, check_invariants=True
151+
).clone()
152+
for x in b
153+
]
148154

149155

150156
@unittest.skipIf(
@@ -2902,7 +2908,9 @@ class TestDataLoaderDeviceType(TestCase):
29022908
def test_nested_tensor_multiprocessing(self, device, context):
29032909
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
29042910
if "cuda" in device and context == "fork":
2905-
self.skipTest(f"{context} multiprocessing context not supported for {device}")
2911+
self.skipTest(
2912+
f"{context} multiprocessing context not supported for {device}"
2913+
)
29062914

29072915
dataset = [
29082916
torch.nested.nested_tensor([torch.randn(5)], device=device)
@@ -2948,11 +2956,11 @@ def test_nested_tensor_multiprocessing(self, device, context):
29482956
def test_sparse_tensor_multiprocessing(self, device, context):
29492957
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
29502958
if "cuda" in device and context == "fork":
2951-
self.skipTest(f"{context} multiprocessing context not supported for {device}")
2959+
self.skipTest(
2960+
f"{context} multiprocessing context not supported for {device}"
2961+
)
29522962

2953-
dataset = [
2954-
torch.randn(5, 5).to_sparse().to(device) for _ in range(10)
2955-
]
2963+
dataset = [torch.randn(5, 5).to_sparse().to(device) for _ in range(10)]
29562964

29572965
pin_memory_settings = [False]
29582966
if device == "cpu" and torch.cuda.is_available():

0 commit comments

Comments
 (0)
0