8000 [DataPipe] Enforcing single valid iterator for IterDataPipes without … · pytorch/pytorch@7c52f20 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7c52f20

Browse files
NivekTpytorchmergebot
authored andcommitted
[DataPipe] Enforcing single valid iterator for IterDataPipes without multiple outputs
Pull Request resolved: #70479 Approved by: https://github.com/ejguan
1 parent e0451d8 commit 7c52f20

File tree

4 files changed

+307
-65
lines changed

4 files changed

+307
-65
lines changed

test/test_datapipe.py

Lines changed: 193 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,16 @@ def test_api(self):
235235
self.assertTrue(fd.closed)
236236

237237
def test_pickle(self):
238-
f = tempfile.TemporaryFile()
239-
with self.assertRaises(TypeError) as ctx1:
240-
pickle.dumps(f)
238+
with tempfile.TemporaryFile() as f:
239+
with self.assertRaises(TypeError) as ctx1:
240+
pickle.dumps(f)
241241

242-
wrap_f = StreamWrapper(f)
243-
with self.assertRaises(TypeError) as ctx2:
244-
pickle.dumps(wrap_f)
242+
wrap_f = StreamWrapper(f)
243+
with self.assertRaises(TypeError) as ctx2:
244+
pickle.dumps(wrap_f)
245245

246-
# Same exception when pickle
247-
self.assertEqual(str(ctx1.exception), str(ctx2.exception))
246+
# Same exception when pickle
247+
self.assertEqual(str(ctx1.exception), str(ctx2.exception))
248248

249249
fd = TestStreamWrapper._FakeFD("")
250250
wrap_fd = StreamWrapper(fd)
@@ -255,9 +255,9 @@ def test_repr(self):
255255
wrap_fd = StreamWrapper(fd)
256256
self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>")
257257

258-
f = tempfile.TemporaryFile()
259-
wrap_f = StreamWrapper(f)
260-
self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")
258+
with tempfile.TemporaryFile() as f:
259+
wrap_f = StreamWrapper(f)
260+
self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")
261261

262262

263263
class TestIterableDataPipeBasic(TestCase):
@@ -568,7 +568,7 @@ def _fake_add(constant, data):
568568

569569

570570
def _fake_filter_fn(data):
571-
return data >= 5
571+
return True
572572

573573

574574
def _fake_filter_fn_constant(constant, data):
@@ -610,6 +610,7 @@ def _serialization_test_for_single_dp(self, dp, use_dill=False):
610610
_ = next(it)
611611
self._serialization_test_helper(dp, use_dill)
612612
# 3. Testing for serialization after DataPipe is fully read
613+
it = iter(dp)
613614
_ = list(it)
614615
self._serialization_test_helper(dp, use_dill)
615616

@@ -637,12 +638,12 @@ def test_serializable(self):
637638
(dp.iter.Batcher, None, (3, True,), {}),
638639
(dp.iter.Collator, None, (_fake_fn,), {}),
639640
(dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}),
640-
(dp.iter.Demultiplexer, None, (2, _fake_filter_fn), {}),
641+
# (dp.iter.Demultiplexer, None, (2, _fake_filter_fn), {}), # Temporarily disabled until next PR
641642
(dp.iter.FileLister, ".", (), {}),
642643
(dp.iter.FileOpener, None, (), {}),
643644
(dp.iter.Filter, None, (_fake_filter_fn,), {}),
644645
(dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}),
645-
(dp.iter.Forker, None, (2,), {}),
646+
# (dp.iter.Forker, None, (2,), {}), # Temporarily disabled until next PR
646647
(dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}),
647648
(dp.iter.IterableWrapper, range(10), (), {}),
648649
(dp.iter.Mapper, None, (_fake_fn,), {}),
@@ -678,7 +679,7 @@ def test_serializable_with_dill(self):
678679
input_dp = dp.iter.IterableWrapper(range(10))
679680
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
680681
(dp.iter.Collator, (lambda x: x,), {}),
681-
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
682+
# (dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}), # Temporarily disabled until next PR
682683
(dp.iter.Filter, (lambda x: x >= 5,), {}),
683684
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
684685
(dp.iter.Mapper, (lambda x: x,), {}),
@@ -850,6 +851,10 @@ def test_fork_iterdatapipe(self):
850851
i1 = iter(dp1) # Reset both all child DataPipe
851852
self.assertEqual(len(wa), 1)
852853
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
854+
break
855+
for i, (n1, n2) in enumerate(zip(i1, i2)):
856+
output1.append(n1)
857+
output2.append(n2)
853858
self.assertEqual(list(range(5)) + list(range(10)), output1)
854859
self.assertEqual(list(range(5)) + list(range(10)), output2)
855860

@@ -1054,29 +1059,30 @@ def test_demux_iterdatapipe(self):
10541059
traverse(dp2) # This should not raise any error either
10551060

10561061
def test_map_iterdatapipe(self):
1057-
input_dp = dp.iter.IterableWrapper(range(10))
1062+
target_length = 10
1063+
input_dp = dp.iter.IterableWrapper(range(target_length))
10581064

10591065
def fn(item, dtype=torch.float, *, sum=False):
10601066
data = torch.tensor(item, dtype=dtype)
10611067
return data if not sum else data.sum()
10621068

10631069
# Functional Test: apply to each element correctly
10641070
map_dp = input_dp.map(fn)
1065-
self.assertEqual(len(input_dp), len(map_dp))
1066-
for x, y in zip(map_dp, input_dp):
1071+
self.assertEqual(target_length, len(map_dp))
1072+
for x, y in zip(map_dp, range(target_length)):
10671073
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
10681074

1069 EF56 1075
# Functional Test: works with partial function
10701076
map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
1071-
for x, y in zip(map_dp, input_dp):
1077+
for x, y in zip(map_dp, range(target_length)):
10721078
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
10731079

10741080
# __len__ Test: inherits length from source DataPipe
1075-
self.assertEqual(len(input_dp), len(map_dp))
1081+
self.assertEqual(target_length, len(map_dp))
10761082

1077-
input_dp_nl = IDP_NoLen(range(10))
1083+
input_dp_nl = IDP_NoLen(range(target_length))
10781084
map_dp_nl = input_dp_nl.map(lambda x: x)
1079-
for x, y in zip(map_dp_nl, input_dp_nl):
1085+
for x, y in zip(map_dp_nl, range(target_length)):
10801086
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
10811087

10821088
# __len__ Test: inherits length from source DataPipe - raises error when invalid
@@ -1228,24 +1234,24 @@ def _collate_fn(batch, default_type=torch.float):
12281234

12291235
# Functional Test: defaults to the default collate function when a custom one is not specified
12301236
collate_dp = input_dp.collate()
1231-
for x, y in zip(input_dp, collate_dp):
1237+
for x, y in zip(arrs, collate_dp):
12321238
self.assertEqual(torch.tensor(x), y)
12331239

12341240
# Functional Test: custom collate function
12351241
collate_dp = input_dp.collate(collate_fn=_collate_fn)
1236-
for x, y in zip(input_dp, collate_dp):
1242+
for x, y in zip(arrs, collate_dp):
12371243
self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y)
12381244

12391245
# Functional Test: custom, partial collate function
12401246
collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int))
1241-
for x, y in zip(input_dp, collate_dp):
1247+
for x, y in zip(arrs, collate_dp):
12421248
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
12431249

12441250
# Reset Test: reset the DataPipe and results are still correct
12451251
n_elements_before_reset = 1
12461252
res_before_reset, res_after_reset = reset_after_n_next_calls(collate_dp, n_elements_before_reset)
12471253
self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset)
1248-
for x, y in zip(input_dp, res_after_reset):
1254+
for x, y in zip(arrs, res_after_reset):
12491255
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
12501256

12511257
# __len__ Test: __len__ is inherited
@@ -1256,7 +1262,7 @@ def _collate_fn(batch, default_type=torch.float):
12561262
collate_dp_nl = input_dp_nl.collate()
12571263
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
12581264
len(collate_dp_nl)
1259-
for x, y in zip(input_dp_nl, collate_dp_nl):
1265+
for x, y in zip(arrs, collate_dp_nl):
12601266
self.assertEqual(torch.tensor(x), y)
12611267

12621268
def test_batch_iterdatapipe(self):
@@ -1306,14 +1312,14 @@ def test_unbatch_iterdatapipe(self):
13061312
input_dp = prebatch_dp.batch(3)
13071313
unbatch_dp = input_dp.unbatch()
13081314
self.assertEqual(len(list(unbatch_dp)), target_length) # __len__ is as expected
1309-
for i, res in zip(prebatch_dp, unbatch_dp):
1315+
for i, res in zip(range(target_length), unbatch_dp):
13101316
self.assertEqual(i, res)
13111317

13121318
# Functional Test: unbatch works for an input with nested levels
13131319
input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
13141320
unbatch_dp = input_dp.unbatch()
13151321
self.assertEqual(len(list(unbatch_dp)), target_length)
1316-
for i, res in zip(prebatch_dp, unbatch_dp):
1322+
for i, res in zip(range(target_length), unbatch_dp):
13171323
self.assertEqual(i, res)
13181324

13191325
input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
@@ -1322,8 +1328,8 @@ def test_unbatch_iterdatapipe(self):
13221328
unbatch_dp = input_dp.unbatch()
13231329
expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
13241330
self.assertEqual(len(list(unbatch_dp)), 4)
1325-
for i, res in zip(expected_dp, unbatch_dp):
1326-
self.assertEqual(i, res)
1331+
for j, res in zip(expected_dp, unbatch_dp):
1332+
self.assertEqual(j, res)
13271333

13281334
# Functional Test: unbatching multiple levels at the same time
13291335
unbatch_dp = input_dp.unbatch(unbatch_level=2)
@@ -2290,5 +2296,161 @@ def test_old_dataloader(self):
22902296
10000 self.assertEqual(sorted(expected), sorted(items))
22912297

22922298

2299+
class TestIterDataPipeSingletonConstraint(TestCase):
2300+
2301+
r"""
2302+
Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older
2303+
iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior.
2304+
"""
2305+
2306+
def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe):
2307+
r"""
2308+
Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of
2309+
a second iterator invalidates the first one.
2310+
"""
2311+
it1 = iter(source_dp)
2312+
self.assertEqual(list(range(10)), list(it1))
2313+
it1 = iter(source_dp)
2314+
self.assertEqual(list(range(10)), list(it1)) # A fresh iterator can be read in full again
2315+
it1 = iter(source_dp)
2316+
self.assertEqual(0, next(it1))
2317+
it2 = iter(source_dp) # This should invalidate `it1`
2318+
self.assertEqual(0, next(it2)) # Should read from the beginning again
2319+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2320+
next(it1)
2321+
2322+
2323+
def test_iterdatapipe_singleton_generator(self):
2324+
r"""
2325+
Testing for the case where IterDataPipe's `__iter__` is a generator function.
2326+
"""
2327+
2328+
# Functional Test: Check if invalidation logic is correct
2329+
source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
2330+
self._check_single_iterator_invalidation_logic(source_dp)
2331+
2332+
# Functional Test: extend the test to a pipeline
2333+
dps = source_dp.map(_fake_fn).filter(_fake_filter_fn)
2334+
self._check_single_iterator_invalidation_logic(dps)
2335+
2336+
# Functional Test: multiple simultaneous references to the same DataPipe fails
2337+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2338+
for _ in zip(source_dp, source_dp):
2339+
pass
2340+
2341+
# Function Test: sequential references work
2342+
for _ in zip(list(source_dp), list(source_dp)):
2343+
pass
2344+
2345+
def test_iterdatapipe_singleton_self_next(self):
2346+
r"""
2347+
Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method
2348+
Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`).
2349+
"""
2350+
class _CustomIterDP_Self(IterDataPipe):
2351+
def __init__(self, iterable):
2352+
self.source = iterable
2353+
self.iterable = iter(iterable)
2354+
2355+
def __iter__(self):
2356+
self.reset()
2357+
return self
2358+
2359+
def __next__(self):
2360+
return next(self.iterable)
2361+
2362+
def reset(self):
2363+
self.iterable = iter(self.source)
2364+
2365+
# Functional Test: Check that every `__iter__` call returns the same object
2366+
source_dp = _CustomIterDP_Self(range(10))
2367+
res = list(source_dp)
2368+
it = iter(source_dp)
2369+
self.assertEqual(res, list(it))
2370+
2371+
# Functional Test: Check if invalidation logic is correct
2372+
source_dp = _CustomIterDP_Self(range(10))
2373+
self._check_single_iterator_invalidation_logic(source_dp)
2374+
self.assertEqual(1, next(source_dp)) # `source_dp` is still valid and can be read
2375+
2376+
# Functional Test: extend the test to a pipeline
2377+
source_dp = _CustomIterDP_Self(dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn))
2378+
self._check_single_iterator_invalidation_logic(source_dp)
2379+
self.assertEqual(1, next(source_dp)) # `source_dp` is still valid and can be read
2380+
2381+
# Functional Test: multiple simultaneous references to the same DataPipe fails
2382+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2383+
for _ in zip(source_dp, source_dp):
2384+
pass
2385+
2386+
def test_iterdatapipe_singleton_new_object(self):
2387+
r"""
2388+
Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`,
2389+
and there isn't a `__next__` method.
2390+
"""
2391+
class _CustomIterDP(IterDataPipe):
2392+
def __init__(self, iterable):
2393+
self.iterable = iter(iterable)
2394+
2395+
def __iter__(self): # Note that this doesn't reset
2396+
return self.iterable # Intentionally not returning `self`
2397+
2398+
# Functional Test: Check if invalidation logic is correct
2399+
source_dp = _CustomIterDP(range(10))
2400+
it1 = iter(source_dp)
2401+
self.assertEqual(0, next(it1))
2402+
it2 = iter(source_dp)
2403+
self.assertEqual(1, next(it2))
2404+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2405+
next(it1)
2406+
2407+
# Functional Test: extend the test to a pipeline
2408+
source_dp = _CustomIterDP(dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn))
2409+
it1 = iter(source_dp)
2410+
self.assertEqual(0, next(it1))
2411+
it2 = iter(source_dp)
2412+
self.assertEqual(1, next(it2))
2413+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2414+
next(it1)
2415+
2416+
# Functional Test: multiple simultaneous references to the same DataPipe fails
2417+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2418+
for _ in zip(source_dp, source_dp):
2419+
pass
2420+
2421+
def test_iterdatapipe_singleton_buggy(self):
2422+
r"""
2423+
Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has
2424+
a `__next__` method.
2425+
"""
2426+
class _CustomIterDP(IterDataPipe):
2427+
def __init__(self, iterable):
2428+
self.source = iterable
2429+
self.iterable = iter(iterable)
2430+
2431+
def __iter__(self):
2432+
return iter(self.source) # Intentionally not returning `self`
2433+
2434+
def __next__(self):
2435+
return next(self.iterable)
2436+
2437+
# Functional Test: Check if invalidation logic is correct
2438+
source_dp = _CustomIterDP(range(10))
2439+
self._check_single_iterator_invalidation_logic(source_dp)
2440+
self.assertEqual(0, next(source_dp)) # `__next__` is unrelated with `__iter__`
2441+
2442+
# Functional Test: Special case to show `__next__` is unrelated with `__iter__`
2443+
source_dp = _CustomIterDP(range(10))
2444+
self.assertEqual(0, next(source_dp))
2445+
it1 = iter(source_dp)
2446+
self.assertEqual(0, next(it1))
2447+
self.assertEqual(1, next(source_dp))
2448+
it2 = iter(source_dp) # invalidates both `it1`
2449+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2450+
next(it1)
2451+
self.assertEqual(2, next(source_dp)) # not impacted by the creation of `it2`
2452+
self.assertEqual(list(range(10)), list(it2)) # `it2` still works because it is a new object
2453+
2454+
22932455
if __name__ == '__main__':
22942456
run_tests()

test/test_profiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ def __next__(self):
153153
def get_value(self, idx):
154154
return self.data[idx]
155155

156-
dp1 = IDPIterator()
156+
dp1 = IDPIterator() # The object itself is an iterator
157157
self.assertEqual(5, dp1.get_value(5))
158-
it_dp1 = iter(dp1)
159-
self.assertEqual(5, it_dp1.get_value(5))
158+
it_dp1 = iter(dp1) # This creates the 1st iterator
159+
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
160160
self.assertEqual(list(range(10)), list(it_dp1))
161161

162162
class IDPDelegator(torch.utils.data.IterDataPipe):

0 commit comments

Comments
 (0)
0