|
135 | 135 |
|
136 | 136 | # The following collate functions are defined globally here for pickle purposes.
|
137 | 137 |
|
| 138 | + |
138 | 139 | # collate_fn that returns the batch cloned
|
139 | 140 | def _clone_collate(b):
|
140 | 141 | return [x.clone() for x in b]
|
141 | 142 |
|
| 143 | + |
142 | 144 | # collate_fn that returns the batch of sparse coo tensors re-constructed & cloned
|
143 | 145 | def _sparse_coo_collate(b):
|
144 | 146 | # we'll use constructor prior clone to force sparse tensor
|
145 | 147 | # invariant checks, required to reproduce gh-153143
|
146 |
| - return [torch.sparse_coo_tensor(x._indices(), x._values(), x.shape, check_invariants=True).clone() |
147 |
| - for x in b] |
| 148 | + return [ |
| 149 | + torch.sparse_coo_tensor( |
| 150 | + x._indices(), x._values(), x.shape, check_invariants=True |
| 151 | + ).clone() |
| 152 | + for x in b |
| 153 | + ] |
148 | 154 |
|
149 | 155 |
|
150 | 156 | @unittest.skipIf(
|
@@ -2902,7 +2908,9 @@ class TestDataLoaderDeviceType(TestCase):
|
2902 | 2908 | def test_nested_tensor_multiprocessing(self, device, context):
|
2903 | 2909 | # The 'fork' multiprocessing context doesn't work for CUDA so skip it
|
2904 | 2910 | if "cuda" in device and context == "fork":
|
2905 |
| - self.skipTest(f"{context} multiprocessing context not supported for {device}") |
| 2911 | + self.skipTest( |
| 2912 | + f"{context} multiprocessing context not supported for {device}" |
| 2913 | + ) |
2906 | 2914 |
|
2907 | 2915 | dataset = [
|
2908 | 2916 | torch.nested.nested_tensor([torch.randn(5)], device=device)
|
@@ -2948,11 +2956,11 @@ def test_nested_tensor_multiprocessing(self, device, context):
|
2948 | 2956 | def test_sparse_tensor_multiprocessing(self, device, context):
|
2949 | 2957 | # The 'fork' multiprocessing context doesn't work for CUDA so skip it
|
2950 | 2958 | if "cuda" in device and context == "fork":
|
2951 |
| - self.skipTest(f"{context} multiprocessing context not supported for {device}") |
| 2959 | + self.skipTest( |
| 2960 | + f"{context} multiprocessing context not supported for {device}" |
| 2961 | + ) |
2952 | 2962 |
|
2953 |
| - dataset = [ |
2954 |
| - torch.randn(5, 5).to_sparse().to(device) for _ in range(10) |
2955 |
| - ] |
| 2963 | + dataset = [torch.randn(5, 5).to_sparse().to(device) for _ in range(10)] |
2956 | 2964 |
|
2957 | 2965 | pin_memory_settings = [False]
|
2958 | 2966 | if device == "cpu" and torch.cuda.is_available():
|
|
0 commit comments