diff --git a/test/test_datapipe.py b/test/test_datapipe.py index c0cd54dae055ea..fb7a6dfb9618e4 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -129,6 +129,21 @@ def reset_after_n_next_calls(datapipe: Union[IterDataPipe[T_co], MapDataPipe[T_c return res_before_reset, list(datapipe) +def snapshot_test_helper(source_datapipe: IterDataPipe, n_elements_to_advance: int) -> IterDataPipe: + """ + Given an IterDataPipe and an integer `n`, advance the source_datapipe `n` times, + then get the state and snapshot from the DataPipe, recreates it by passing the + state and snapshot into the given new_datapipe and returns it. + """ + p = pickle.dumps(source_datapipe) + for _ in range(n_elements_to_advance): + next(source_datapipe) + snapshot = source_datapipe.save_snapshot() + new_datapipe = pickle.loads(p) + new_datapipe.restore_snapshot(snapshot) + return new_datapipe + + def odd_or_even(x: int) -> int: return x % 2 @@ -828,6 +843,13 @@ def test_iterable_wrapper_datapipe(self): # __len__ Test: inherits length from sequence self.assertEqual(len(input_ls), len(input_dp)) + # Snapshot Test: + input_dp = dp.iter.IterableWrapper(input_ls) + n_elements_to_advance = 5 + new_dp = snapshot_test_helper(input_dp, n_elements_to_advance) + for old_ele, new_ele in zip(input_dp, new_dp): + self.assertEqual(old_ele, new_ele) + def test_concat_iterdatapipe(self): input_dp1 = dp.iter.IterableWrapper(range(10)) input_dp2 = dp.iter.IterableWrapper(range(5))