From 8bd3627acab960aef015243f5efde1b6e65367dd Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Tue, 28 Dec 2021 19:34:18 -0500 Subject: [PATCH] [DataPipe] Enforcing single iterator per IterableWrapperIterDataPipe [ghstack-poisoned] --- test/test_datapipe.py | 153 +++++++++++++++++++---- torch/utils/data/datapipes/iter/utils.py | 32 ++++- 2 files changed, 155 insertions(+), 30 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 069960150baf9a..f811ed422210a3 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -20,6 +20,7 @@ from typing import ( Any, Awaitable, + Callable, Dict, Generic, Iterator, @@ -113,7 +114,8 @@ def create_temp_dir_and_files(): def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], - n: int) -> Tuple[List[T_co], List[T_co]]: + n: int, + custom_reset: Optional[Callable] = None) -> Tuple[List[T_co], List[T_co]]: """ Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list Then, reset the DataPipe and return a tuple of two lists @@ -124,6 +126,24 @@ def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_c res_before_reset = [] for _ in range(n): res_before_reset.append(next(it)) + if custom_reset is not None: + custom_reset() + return res_before_reset, list(datapipe) + + +def reset_after_n_next_calls_with_method( + datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_co]], + n: int) -> Tuple[List[T_co], List[T_co]]: + """ + Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list + Then, reset the DataPipe and return a tuple of two lists + 1. A list of elements yielded before the reset + 2. A list of all elements of the DataPipe after the reset + """ + res_before_reset = [] + for _ in range(n): + res_before_reset.append(next(datapipe)) + datapipe.reset() return res_before_reset, list(datapipe) @@ -429,6 +449,7 @@ def test_demux_mux_datapipe(self): n1, n2 = numbers_dp.demux(2, lambda x: x % 2) self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1)) self.assertEqual([1, 3, 5, 7, 9], list(n2)) + numbers_dp.reset() n = n1.mux(n2) self.assertEqual(source_numbers, list(n)) @@ -713,8 +734,10 @@ def test_iterable_wrapper_datapipe(self): self.assertEqual(input_ls, list(input_dp)) # Functional Test: deep copy by default when an iterator is initialized (first element is read) + input_dp = dp.iter.IterableWrapper(input_ls) it = iter(input_dp) - self.assertEqual(0, next(it)) # The deep copy only happens when the first element is read + self.assertEqual(0, next(it)) # The deep copy only happens + # when the first element is read input_ls.append(50) self.assertEqual(list(range(1, 10)), list(it)) @@ -728,7 +751,7 @@ def test_iterable_wrapper_datapipe(self): input_ls = list(range(10)) input_dp = dp.iter.IterableWrapper(input_ls) n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(input_dp, n_elements_before_reset) + res_before_reset, res_after_reset = reset_after_n_next_calls_with_method(input_dp, n_elements_before_reset) self.assertEqual(input_ls[:n_elements_before_reset], res_before_reset) self.assertEqual(input_ls, res_after_reset) @@ -748,13 +771,19 @@ def test_concat_iterdatapipe(self): dp.iter.Concater(input_dp1, ()) # type: ignore[arg-type] # Functional Test: Concatenate DataPipes as expected + concat_dp = input_dp1.concat(input_dp2) self.assertEqual(len(concat_dp), 15) self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) # Reset Test: reset the DataPipe + def reset_dps(): + input_dp1.reset() + input_dp2.reset() + + reset_dps() n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(concat_dp, n_elements_before_reset) + res_before_reset, res_after_reset = reset_after_n_next_calls(concat_dp, n_elements_before_reset, reset_dps) self.assertEqual(list(range(5)), res_before_reset) self.assertEqual(list(range(10)) + list(range(5)), res_after_reset) @@ -763,7 +792,7 @@ def test_concat_iterdatapipe(self): concat_dp = input_dp1.concat(input_dp_nl) with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): len(concat_dp) - + reset_dps() self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) def test_fork_iterdatapipe(self): @@ -773,19 +802,22 @@ def test_fork_iterdatapipe(self): input_dp.fork(num_instances=0) dp0 = input_dp.fork(num_instances=1) - self.assertEqual(dp0, input_dp) + self.assertEqual(list(range(10)), list(dp0)) # Test Case: making sure all child DataPipe shares the same reference + input_dp.reset() dp1, dp2, dp3 = input_dp.fork(num_instances=3) self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3))) # Test Case: one child DataPipe yields all value at a time + input_dp.reset() output1, output2, output3 = list(dp1), list(dp2), list(dp3) self.assertEqual(list(range(10)), output1) self.assertEqual(list(range(10)), output2) self.assertEqual(list(range(10)), output3) # Test Case: two child DataPipes yield value together + input_dp.reset() dp1, dp2 = input_dp.fork(num_instances=2) output = [] for n1, n2 in zip(dp1, dp2): @@ -793,6 +825,7 @@ def test_fork_iterdatapipe(self): self.assertEqual([(i, i) for i in range(10)], output) # Test Case: one child DataPipe yields all value first, but buffer_size = 5 being too small + input_dp.reset() dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5) it1 = iter(dp1) for _ in range(5): @@ -807,11 +840,13 @@ def test_fork_iterdatapipe(self): dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=-1) self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Unlimited buffer size is set") + input_dp.reset() l1, l2 = list(dp1), list(dp2) for d1, d2 in zip(l1, l2): self.assertEqual(d1, d2) # Test Case: two child DataPipes yield value together with buffer size 1 + input_dp.reset() dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1) output = [] for n1, n2 in zip(dp1, dp2): @@ -819,6 +854,7 @@ def test_fork_iterdatapipe(self): self.assertEqual([(i, i) for i in range(10)], output) # Test Case: make sure logic related to slowest_ptr is working properly + input_dp.reset() dp1, dp2, dp3 = input_dp.fork(num_instances=3) output1, output2 , output3 = [], [], [] for i, (n1, n2) in enumerate(zip(dp1, dp2)): @@ -832,8 +868,10 @@ def test_fork_iterdatapipe(self): self.assertEqual(list(range(10)), output3) # Test Case: DataPipe doesn't reset if this pipe hasn't been read + input_dp.reset() dp1, dp2 = input_dp.fork(num_instances=2) i1, i2 = iter(dp1), iter(dp2) + input_dp.reset() output2 = [] for i, n2 in enumerate(i2): output2.append(n2) @@ -842,14 +880,17 @@ def test_fork_iterdatapipe(self): self.assertEqual(list(range(10)), output2) # Test Case: DataPipe reset when some of it have been read + input_dp.reset() dp1, dp2 = input_dp.fork(num_instances=2) i1, i2 = iter(dp1), iter(dp2) output1, output2 = [], [] + input_dp.reset() for i, (n1, n2) in enumerate(zip(i1, i2)): output1.append(n1) output2.append(n2) if i == 4: with warnings.catch_warnings(record=True) as wa: + input_dp.reset() i1 = iter(dp1) # Reset both all child DataPipe self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") @@ -857,11 +898,13 @@ def test_fork_iterdatapipe(self): self.assertEqual(list(range(5)) + list(range(10)), output2) # Test Case: DataPipe reset, even when some other child DataPipes are not read + input_dp.reset() dp1, dp2, dp3 = input_dp.fork(num_instances=3) output1, output2 = list(dp1), list(dp2) self.assertEqual(list(range(10)), output1) self.assertEqual(list(range(10)), output2) with warnings.catch_warnings(record=True) as wa: + input_dp.reset() self.assertEqual(list(range(10)), list(dp1)) # Resets even though dp3 has not been read self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") @@ -870,6 +913,7 @@ def test_fork_iterdatapipe(self): output3.append(n3) if i == 4: with warnings.catch_warnings(record=True) as wa: + input_dp.reset() output1 = list(dp1) # Resets even though dp3 is only partially read self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") @@ -879,12 +923,14 @@ def test_fork_iterdatapipe(self): self.assertEqual(list(range(10)), list(dp3)) # dp3 has to read from the start again # Test Case: Each DataPipe inherits the source datapipe's length + input_dp.reset() dp1, dp2, dp3 = input_dp.fork(num_instances=3) self.assertEqual(len(input_dp), len(dp1)) self.assertEqual(len(input_dp), len(dp2)) self.assertEqual(len(input_dp), len(dp3)) # Pickle Test: + input_dp.reset() dp1, dp2, dp3 = input_dp.fork(num_instances=3) traverse(dp1) # This should not raise any error for _ in zip(dp1, dp2, dp3): @@ -916,7 +962,7 @@ def test_mux_iterdatapipe(self): input_dp2 = dp.iter.IterableWrapper([]) output_dp = input_dp1.mux(input_dp2) self.assertEqual(len(input_dp1), len(output_dp)) - self.assertEqual(list(input_dp1), list(output_dp)) + self.assertEqual([0, 1, 2, 3], list(output_dp)) # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__ input_dp1 = dp.iter.IterableWrapper(range(10)) @@ -938,6 +984,7 @@ def test_demux_iterdatapipe(self): self.assertEqual(list(range(1, 10, 2)), output2) # Test Case: split into 2 DataPipes and output them together + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) output = [] for n1, n2 in zip(dp1, dp2): @@ -945,6 +992,7 @@ def test_demux_iterdatapipe(self): self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output) # Test Case: values of the same classification are lumped together, and buffer_size = 3 being too small + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4) it1 = iter(dp1) with self.assertRaises(BufferError): @@ -953,12 +1001,14 @@ def test_demux_iterdatapipe(self): list(dp2) # Test Case: values of the same classification are lumped together, and buffer_size = 5 is just enough + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5) output1, output2 = list(dp1), list(dp2) self.assertEqual(list(range(5, 10)), output1) self.assertEqual(list(range(0, 5)), output2) # Test Case: values of the same classification are lumped together, and unlimited buffer + input_dp.reset() with warnings.catch_warnings(record=True) as wa: dp1, dp2 = input_dp.demux( num_instances=2, @@ -972,6 +1022,7 @@ def test_demux_iterdatapipe(self): self.assertEqual(list(range(0, 5)), output2) # Test Case: classifer returns a value outside of [0, num_instance - 1] + input_dp.reset() dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2) it = iter(dp0[0]) with self.assertRaises(ValueError): @@ -979,6 +1030,7 @@ def test_demux_iterdatapipe(self): next(it) # Test Case: DataPipe doesn't reset when it has not been read + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) i1 = iter(dp1) output2 = [] @@ -990,6 +1042,7 @@ def test_demux_iterdatapipe(self): self.assertEqual(list(range(1, 10, 2)), output2) # Test Case: DataPipe reset when some of it has been read + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) output1, output2 = [], [] for n1, n2 in zip(dp1, dp2): @@ -997,6 +1050,7 @@ def test_demux_iterdatapipe(self): output2.append(n2) if n1 == 4: break + input_dp.reset() with warnings.catch_warnings(record=True) as wa: i1 = iter(dp1) # Reset all child DataPipes self.assertEqual(len(wa), 1) @@ -1008,9 +1062,11 @@ def test_demux_iterdatapipe(self): self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2) # Test Case: DataPipe reset, even when not all child DataPipes are exhausted + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) output1 = list(dp1) self.assertEqual(list(range(0, 10, 2)), output1) + input_dp.reset() with warnings.catch_warnings(record=True) as wa: self.assertEqual(list(range(0, 10, 2)), list(dp1)) # Reset even when dp2 is not read self.assertEqual(len(wa), 1) @@ -1021,6 +1077,7 @@ def test_demux_iterdatapipe(self): if i == 1: self.assertEqual(list(range(1, 5, 2)), output2) with warnings.catch_warnings(record=True) as wa: + input_dp.reset() self.assertEqual(list(range(0, 10, 2)), list(dp1)) # Can reset even when dp2 is partially read self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") @@ -1029,12 +1086,14 @@ def test_demux_iterdatapipe(self): self.assertEqual(list(range(1, 10, 2)), output2) # Test Case: drop_none = True + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, drop_none=True) self.assertEqual([2, 4, 6, 8], list(dp1)) self.assertEqual([1, 3, 7, 9], list(dp2)) # Test Case: drop_none = False + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None, drop_none=False) it1 = iter(dp1) @@ -1042,6 +1101,7 @@ def test_demux_iterdatapipe(self): next(it1) # Test Case: __len__ not implemented + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) with self.assertRaises(TypeError): len(dp1) # It is not implemented as we do not know length for each child in advance @@ -1049,6 +1109,7 @@ def test_demux_iterdatapipe(self): len(dp2) # Pickle Test: + input_dp.reset() dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=odd_or_even) traverse(dp1) # This should not raise any error for _ in zip(dp1, dp2): @@ -1065,16 +1126,18 @@ def fn(item, dtype=torch.float, *, sum=False): # Functional Test: apply to each element correctly map_dp = input_dp.map(fn) self.assertEqual(len(input_dp), len(map_dp)) - for x, y in zip(map_dp, input_dp): + for x, y in zip(map_dp, range(10)): self.assertEqual(x, torch.tensor(y, dtype=torch.float)) # Functional Test: works with partial function + input_dp.reset() map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True)) - for x, y in zip(map_dp, input_dp): + for x, y in zip(map_dp, range(10)): self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) # __len__ Test: inherits length from source DataPipe - self.assertEqual(len(input_dp), len(map_dp)) + input_dp.reset() + self.assertEqual(10, len(map_dp)) input_dp_nl = IDP_NoLen(range(10)) map_dp_nl = input_dp_nl.map(lambda x: x) @@ -1086,8 +1149,9 @@ def fn(item, dtype=torch.float, *, sum=False): len(map_dp_nl) # Reset Test: DataPipe resets properly + input_dp.reset() n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset) + res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset, input_dp.reset) self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) self.assertEqual(list(range(10)), res_after_reset) @@ -1110,9 +1174,14 @@ def _helper(ref_fn, fn, input_col=None, output_col=None): datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) res_dp = datapipe.map(fn, input_col, output_col) ref_dp = datapipe.map(ref_fn) - self.assertEqual(list(res_dp), list(ref_dp)) + expected_ls = list(res_dp) + datapipe.reset() + self.assertEqual(expected_ls, list(ref_dp)) # Reset - self.assertEqual(list(res_dp), list(ref_dp)) + datapipe.reset() + expected_ls = list(res_dp) + datapipe.reset() + self.assertEqual(expected_ls, list(ref_dp)) # Replacing with one input column and default output column _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) @@ -1181,9 +1250,14 @@ def _helper(ref_fn, fn, input_col=None, output_col=None): ) res_dp = datapipe.map(fn, input_col, output_col) ref_dp = datapipe.map(ref_fn) - self.assertEqual(list(res_dp), list(ref_dp)) + expected_ls = list(res_dp) + datapipe.reset() + self.assertEqual(expected_ls, list(ref_dp)) # Reset - self.assertEqual(list(res_dp), list(ref_dp)) + datapipe.reset() + expected_ls = list(res_dp) + datapipe.reset() + self.assertEqual(expected_ls, list(ref_dp)) # Replacing with one input column and default output column _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") @@ -1227,24 +1301,28 @@ def _collate_fn(batch, default_type=torch.float): # Functional Test: defaults to the default collate function when a custom one is not specified collate_dp = input_dp.collate() - for x, y in zip(input_dp, collate_dp): + for x, y in zip(arrs, collate_dp): self.assertEqual(torch.tensor(x), y) # Functional Test: custom collate function + input_dp.reset() collate_dp = input_dp.collate(collate_fn=_collate_fn) - for x, y in zip(input_dp, collate_dp): + for x, y in zip(arrs, collate_dp): self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y) # Functional Test: custom, partial collate function + input_dp.reset() collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int)) - for x, y in zip(input_dp, collate_dp): + for x, y in zip(arrs, collate_dp): self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y) # Reset Test: reset the DataPipe and results are still correct + input_dp.reset() n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(collate_dp, n_elements_before_reset) + res_before_reset, res_after_reset = \ + reset_after_n_next_calls(collate_dp, n_elements_before_reset, input_dp.reset) self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset) - for x, y in zip(input_dp, res_after_reset): + for x, y in zip(arrs, res_after_reset): self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y) # __len__ Test: __len__ is inherited @@ -1273,16 +1351,19 @@ def test_batch_iterdatapipe(self): self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)]) # Functional Test: Drop the last batch when specified + input_dp.reset() bs = 4 batch_dp = input_dp.batch(batch_size=bs, drop_last=True) for i, batch in enumerate(batch_dp): self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)]) # __len__ test: verifying that the overall length and of each batch is correct + input_dp.reset() for i, batch in enumerate(batch_dp): self.assertEqual(len(batch), bs) # __len__ Test: the length is missing if the source DataPipe doesn't have length + input_dp.reset() self.assertEqual(len(batch_dp), 2) input_dp_nl = IDP_NoLen(range(10)) batch_dp_nl = input_dp_nl.batch(batch_size=2) @@ -1290,8 +1371,9 @@ def test_batch_iterdatapipe(self): len(batch_dp_nl) # Reset Test: Ensures that the DataPipe can properly reset + input_dp.reset() n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(batch_dp, n_elements_before_reset) + res_before_reset, res_after_reset = reset_after_n_next_calls(batch_dp, n_elements_before_reset, input_dp.reset) self.assertEqual([[0, 1, 2, 3]], res_before_reset) self.assertEqual([[0, 1, 2, 3], [4, 5, 6, 7]], res_after_reset) @@ -1301,32 +1383,38 @@ def test_unbatch_iterdatapipe(self): input_dp = prebatch_dp.batch(3) unbatch_dp = input_dp.unbatch() - self.assertEqual(len(list(unbatch_dp)), target_length) - for i, res in zip(prebatch_dp, unbatch_dp): + self.assertEqual(target_length, len(list(unbatch_dp))) + prebatch_dp.reset() + for i, res in zip(range(target_length), unbatch_dp): self.assertEqual(i, res) input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]]) unbatch_dp = input_dp.unbatch() self.assertEqual(len(list(unbatch_dp)), target_length) - for i, res in zip(prebatch_dp, unbatch_dp): + input_dp.reset() + for i, res in zip(range(target_length), unbatch_dp): self.assertEqual(i, res) input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) - unbatch_dp = input_dp.unbatch() expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]] self.assertEqual(len(list(unbatch_dp)), 4) + input_dp.reset() for i, res in zip(expected_dp, unbatch_dp): self.assertEqual(i, res) + input_dp.reset() unbatch_dp = input_dp.unbatch(unbatch_level=2) expected_dp2 = [0, 1, 2, 3, 4, 5, 6, 7] self.assertEqual(len(list(unbatch_dp)), 8) + input_dp.reset() for i, res in zip(expected_dp2, unbatch_dp): self.assertEqual(i, res) + input_dp.reset() unbatch_dp = input_dp.unbatch(unbatch_level=-1) self.assertEqual(len(list(unbatch_dp)), 8) + input_dp.reset() for i, res in zip(expected_dp2, unbatch_dp): self.assertEqual(i, res) @@ -1338,6 +1426,7 @@ def test_unbatch_iterdatapipe(self): with self.assertRaises(IndexError): unbatch_dp = input_dp.unbatch(unbatch_level=5) + input_dp.reset() for i in unbatch_dp: print(i) @@ -1400,6 +1489,7 @@ def _filter_fn(data, val, clip=False): self.assertEqual(data, exp) # Functional Test: filter works with partial function with keyword args + input_ds.reset() filter_dp = input_ds.filter(partial(_filter_fn, val=5, clip=True)) for data, exp in zip(filter_dp, range(5, 10)): self.assertEqual(data, exp) @@ -1408,6 +1498,7 @@ def _non_bool_fn(data): return 1 # Functional Test: filter function must return bool + input_ds.reset() filter_dp = input_ds.filter(filter_fn=_non_bool_fn) with self.assertRaises(ValueError): temp = list(filter_dp) @@ -1417,9 +1508,10 @@ def _non_bool_fn(data): len(filter_dp) # Reset Test: DataPipe resets correctly + input_ds.reset() filter_dp = input_ds.filter(partial(_filter_fn, val=5, clip=True)) n_elements_before_reset = 3 - res_before_reset, res_after_reset = reset_after_n_next_calls(filter_dp, n_elements_before_reset) + res_before_reset, res_after_reset = reset_after_n_next_calls(filter_dp, n_elements_before_reset, input_ds.reset) self.assertEqual(list(range(5, 10))[:n_elements_before_reset], res_before_reset) self.assertEqual(list(range(5, 10)), res_after_reset) @@ -1446,10 +1538,12 @@ def test_shuffle_iterdatapipe(self): with self.assertRaises(AssertionError): shuffle_dp = input_ds.shuffle(buffer_size=0) + input_ds.reset() for bs in (5, 20, 25): shuffle_dp = input_ds.shuffle(buffer_size=bs) self.assertEqual(len(shuffle_dp), len(input_ds)) + input_ds.reset() random.seed(123) res = list(shuffle_dp) self.assertEqual(sorted(res), exp) @@ -1457,6 +1551,7 @@ def test_shuffle_iterdatapipe(self): # Test Deterministic for num_workers in (0, 1): random.seed(123) + input_ds.reset() dl = DataLoader(shuffle_dp, num_workers=num_workers, worker_init_fn=_worker_init_fn, shuffle=True) dl_res = list(dl) self.assertEqual(res, dl_res) @@ -1475,10 +1570,14 @@ def test_zip_iterdatapipe(self): exp = list((i, i) for i in range(5)) self.assertEqual(list(zipped_dp), exp) - zipped_dp = dp.iter.Zipper(dp.iter.IterableWrapper(range(10)), dp.iter.IterableWrapper(range(5))) + dp1 = dp.iter.IterableWrapper(range(10)) + dp2 = dp.iter.IterableWrapper(range(5)) + zipped_dp = dp.iter.Zipper(dp1, dp2) self.assertEqual(len(zipped_dp), 5) self.assertEqual(list(zipped_dp), exp) # Reset + dp1.reset() + dp2.reset() self.assertEqual(list(zipped_dp), exp) diff --git a/torch/utils/data/datapipes/iter/utils.py b/torch/utils/data/datapipes/iter/utils.py index 657d4934e509be..283a3564d77344 100644 --- a/torch/utils/data/datapipes/iter/utils.py +++ b/torch/utils/data/datapipes/iter/utils.py @@ -18,12 +18,28 @@ class IterableWrapperIterDataPipe(IterDataPipe): that data pipeline doesn't contain any in-place operations over the iterable instance, in order to prevent data inconsistency across iterations. + + .. note: + DataLoader always materialize iterable objects when performing serialization (e.g. when __getstate__) is called. """ def __init__(self, iterable, deepcopy=True): self.iterable = iterable self.deepcopy = deepcopy + self.state_counter = 0 + self.iter = None + self.been_reset = False def __iter__(self): + if self.been_reset is True: + self.been_reset = False + return self.iter + elif self.iter is None: + self._create_iterator() + return self.iter + else: + raise RuntimeError(f"Only one iterator can exist for each {type(self).__name__} at a time.") + + def _create_iterator(self) -> None: source_data = self.iterable if self.deepcopy: try: @@ -34,11 +50,21 @@ def __iter__(self): # yield modified items. except TypeError: warnings.warn( - "The input iterable can not be deepcopied, " + "The input iterable cannot be deep copied," "please be aware of in-place modification would affect source data." ) - for data in source_data: - yield data + self.iter = iter(source_data) + + def __next__(self): + if self.iter is None: + self._create_iterator() + self.state_counter += 1 + return next(self.iter) + + def reset(self) -> None: + self.iter = None + self._create_iterator() + self.been_reset = True def __len__(self): return len(self.iterable)