8000 Update on "Fix loading sparse tensors with pinning check in fork cont… · pytorch/pytorch@6f719f8 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6f719f8

Browse files
committed
Update on "Fix loading sparse tensors with pinning check in fork context."
As in the title. Fixes #153143 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer jcaip andrewkho divyanshk SsnL VitalyFedyunin dzhulgakov [ghstack-poisoned]
1 parent d8d956e commit 6f719f8

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

test/test_dataloader.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,22 @@
135135

136136
# The following collate functions are defined globally here for pickle purposes.
137137

138+
138139
# collate_fn that returns the batch cloned
139140
def _clone_collate(b):
140141
return [x.clone() for x in b]
141142

143+
142144
# collate_fn that returns the batch of sparse coo tensors re-constructed & cloned
143145
def _sparse_coo_collate(b):
144146
# we'll use constructor prior clone to force sparse tensor
145147
# invariant checks, required to reproduce gh-153143
146-
return [torch.sparse_coo_tensor(x._indices(), x._values(), x.shape, check_invariants=True).clone()
147-
for x in b]
148+
return [
149+
torch.sparse_coo_tensor(
150+
x._indices(), x._values(), x.shape, check_invariants=True
151+
).clone()
152+
for x in b
153+
]
148154

149155

150156
@unittest.skipIf(
@@ -2902,7 +2908,9 @@ class TestDataLoaderDeviceType(TestCase):
29022908
def test_nested_tensor_multiprocessing(self, device, context):
29032909
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
29042910
if "cuda" in device and context == "fork":
2905-
self.skipTest(f"{context} multiprocessing context not supported for {device}")
2911+
self.skipTest(
2912+
f"{context} multiprocessing context not supported for {device}"
2913+
)
29062914

29072915
dataset = [
29082916
torch.nested.nested_tensor([torch.randn(5)], device=device)
@@ -2948,11 +2956,11 @@ def test_nested_tensor_multiprocessing(self, device, context):
29482956
def test_sparse_tensor_multiprocessing(self, device, context):
29492957
# The 'fork' multiprocessing context doesn't work for CUDA so skip it
29502958
if "cuda" in device and context == "fork":
2951-
self.skipTest(f"{context} multiprocessing context not supported for {device}")
2959+
self.skipTest(
2960+
f"{context} multiprocessing context not supported for {device}"
2961+
)
29522962

2953-
dataset = [
2954-
torch.randn(5, 5).to_sparse().to(device) for _ in range(10)
2955-
]
2963+
dataset = [torch.randn(5, 5).to_sparse().to(device) for _ in range(10)]
29562964

29572965
pin_memory_settings = [False]
29582966
if device == "cpu" and torch.cuda.is_available():

0 commit comments

Comments
 (0)
0