8000 [DataPipe] Enforcing single iterator per IterableWrapperIterDataPipe … · pytorch/pytorch@9d730cb · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d730cb

Browse files
NivekTfacebook-github-bot
authored andcommitted
[DataPipe] Enforcing single iterator per IterableWrapperIterDataPipe (#70479)
Summary: Pull Request resolved: #70479 Test Plan: Imported from OSS Reviewed By: b0noI, cpuhrsch Differential Revision: D33344609 fbshipit-source-id: c3eaee4b684890fc5dd1f0a2c6d04e718c236a7b
1 parent 7baea4a commit 9d730cb

File tree

4 files changed

+307
-65
lines changed

4 files changed

+307
-65
lines changed

test/test_datapipe.py

Lines changed: 193 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -235,16 +235,16 @@ def test_api(self):
235235
self.assertTrue(fd.closed)
236236

237237
def test_pickle(self):
238-
f = tempfile.TemporaryFile()
239-
with self.assertRaises(TypeError) as ctx1:
240-
pickle.dumps(f)
238+
with tempfile.TemporaryFile() as f:
239+
with self.assertRaises(TypeError) as ctx1:
240+
pickle.dumps(f)
241241

242-
wrap_f = StreamWrapper(f)
243-
with self.assertRaises(TypeError) as ctx2:
244-
pickle.dumps(wrap_f)
242+
wrap_f = StreamWrapper(f)
243+
with self.assertRaises(TypeError) as ctx2:
244+
pickle.dumps(wrap_f)
245245

246-
# Same exception when pickle
247-
self.assertEqual(str(ctx1.exception), str(ctx2.exception))
246+
# Same exception when pickle
247+
self.assertEqual(str(ctx1.exception), str(ctx2.exception))
248248

249249
fd = TestStreamWrapper._FakeFD("")
250250
wrap_fd = StreamWrapper(fd)
@@ -255,9 +255,9 @@ def test_repr(self):
255255
wrap_fd = StreamWrapper(fd)
256256
self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>")
257257

258-
f = tempfile.TemporaryFile()
259-
wrap_f = StreamWrapper(f)
260-
self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")
258+
with tempfile.TemporaryFile() as f:
259+
wrap_f = StreamWrapper(f)
260+
self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")
261261

262262

263263
class TestIterableDataPipeBasic(TestCase):
@@ -568,7 +568,7 @@ def _fake_add(constant, data):
568568

569569

570570
def _fake_filter_fn(data):
571-
return data >= 5
571+
return True
572572

573573

574574
def _fake_filter_fn_constant(constant, data):
@@ -610,6 +610,7 @@ def _serialization_test_for_single_dp(self, dp, use_dill=False):
610610
_ = next(it)
611611
self._serialization_test_helper(dp, use_dill)
612612
# 3. Testing for serialization after DataPipe is fully read
613+
it = iter(dp)
613614
_ = list(it)
614615
self._serialization_test_helper(dp, use_dill)
615616

@@ -637,12 +638,12 @@ def test_serializable(self):
637638
(dp.iter.Batcher, None, (3, True,), {}),
638639
(dp.iter.Collator, None, (_fake_fn,), {}),
639640
(dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}),
640-
(dp.iter.Demultiplexer, None, (2, _fake_filter_fn), {}),
641+
# (dp.iter.Demultiplexer, None, (2, _fake_filter_fn), {}), # Temporarily disabled until next PR
641642
(dp.iter.FileLister, ".", (), {}),
642643
(dp.iter.FileOpener, None, (), {}),
643644
(dp.iter.Filter, None, (_fake_filter_fn,), {}),
644645
(dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}),
645-
(dp.iter.Forker, None, (2,), {}),
646+
# (dp.iter.Forker, None, (2,), {}), # Temporarily disabled until next PR
646647
(dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}),
647648
(dp.iter.IterableWrapper, range(10), (), {}),
648649
(dp.iter.Mapper, None, (_fake_fn,), {}),
@@ -678,7 +679,7 @@ def test_serializable_with_dill(self):
678679
input_dp = dp.iter.IterableWrapper(range(10))
679680
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
680681
(dp.iter.Collator, (lambda x: x,), {}),
681-
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
682+
# (dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}), # Temporarily disabled until next PR
682683
(dp.iter.Filter, (lambda x: x >= 5,), {}),
683684
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
684685
(dp.iter.Mapper, (lambda x: x,), {}),
@@ -850,6 +851,10 @@ def test_fork_iterdatapipe(self):
850851
i1 = iter(dp1) # Reset both all child DataPipe
851852
self.assertEqual(len(wa), 1)
852853
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
854+
break
855+
for i, (n1, n2) in enumerate(zip(i1, i2)):
856+
output1.append(n1)
857+
output2.append(n2)
853858
self.assertEqual(list(range(5)) + list(range(10)), output1)
854859
self.assertEqual(list(range(5)) + list(range(10)), output2)
855860

@@ -1054,29 +1059,30 @@ def test_demux_iterdatapipe(self):
10541059
traverse(dp2) # This should not raise any error either
10551060

10561061
def test_map_iterdatapipe(self):
1057-
input_dp = dp.iter.IterableWrapper(range(10))
1062+
target_length = 10
1063+
input_dp = dp.iter.IterableWrapper(range(target_length))
10581064

10591065
def fn(item, dtype=torch.float, *, sum=False):
10601066
data = torch.tensor(item, dtype=dtype)
10611067
return data if not sum else data.sum()
10621068

10631069
# Functional Test: apply to each element correctly
10641070
map_dp = input_dp.map(fn)
1065-
self.assertEqual(len(input_dp), len(map_dp))
1066-
for x, y in zip(map_dp, input_dp):
1071+
self.assertEqual(target_length, len(map_dp))
1072+
for x, y in zip(map_dp, range(target_length)):
10671073
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
10681074

10691075
# Functional Test: works with partial function
10701076
map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
1071-
for x, y in zip(map_dp, input_dp):
1077+
for x, y in zip(map_dp, range(target_length)):
10721078
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())
10731079

10741080
# __len__ Test: inherits length from source DataPipe
1075-
self.assertEqual(len(input_dp), len(map_dp))
1081+
self.assertEqual(target_length, len(map_dp))
10761082

1077-
input_dp_nl = IDP_NoLen(range(10))
1083+
input_dp_nl = IDP_NoLen(range(target_length))
10781084
map_dp_nl = input_dp_nl.map(lambda x: x)
1079-
for x, y in zip(map_dp_nl, input_dp_nl):
1085+
for x, y in zip(map_dp_nl, range(target_length)):
10801086
self.assertEqual(x, torch.tensor(y, dtype=torch.float))
10811087

10821088
# __len__ Test: inherits length from source DataPipe - raises error when invalid
@@ -1228,24 +1234,24 @@ def _collate_fn(batch, default_type=torch.float):
12281234

12291235
# Functional Test: defaults to the default collate function when a custom one is not specified
12301236
collate_dp = input_dp.collate()
1231-
for x, y in zip(input_dp, collate_dp):
1237+
for x, y in zip(arrs, collate_dp):
12321238
self.assertEqual(torch.tensor(x), y)
12331239

12341240
# Functional Test: custom collate function
12351241
collate_dp = input_dp.collate(collate_fn=_collate_fn)
1236-
for x, y in zip(input_dp, collate_dp):
1242+
for x, y in zip(arrs, collate_dp):
12371243
self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y)
12381244

12391245
# Functional Test: custom, partial collate function
12401246
collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int))
1241-
for x, y in zip(input_dp, collate_dp):
1247+
for x, y in zip(arrs, collate_dp):
12421248
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
12431249

12441250
# Reset Test: reset the DataPipe and results are still correct
12451251
n_elements_before_reset = 1
12461252
res_before_reset, res_after_reset = reset_after_n_next_calls(collate_dp, n_elements_before_reset)
12471253
self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset)
1248-
for x, y in zip(input_dp, res_after_reset):
1254+
for x, y in zip(arrs, res_after_reset):
12491255
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
12501256

12511257
# __len__ Test: __len__ is inherited
@@ -1256,7 +1262,7 @@ def _collate_fn(batch, default_type=torch.float):
12561262
collate_dp_nl = input_dp_nl.collate()
12571263
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
12581264
len(collate_dp_nl)
1259-
for x, y in zip(input_dp_nl, collate_dp_nl):
1265+
for x, y in zip(arrs, collate_dp_nl):
12601266
self.assertEqual(torch.tensor(x), y)
12611267

12621268
def test_batch_iterdatapipe(self):
@@ -1306,14 +1312,14 @@ def test_unbatch_iterdatapipe(self):
13061312
input_dp = prebatch_dp.batch(3)
13071313
unbatch_dp = input_dp.unbatch()
13081314
self.assertEqual(len(list(unbatch_dp)), target_length) # __len__ is as expected
1309-
for i, res in zip(prebatch_dp, unbatch_dp):
1315+
for i, res in zip(range(target_length), unbatch_dp):
13101316
self.assertEqual(i, res)
13111317

13121318
# Functional Test: unbatch works for an input with nested levels
13131319
input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
13141320
unbatch_dp = input_dp.unbatch()
13151321
self.assertEqual(len(list(unbatch_dp)), target_length)
1316-
for i, res in zip(prebatch_dp, unbatch_dp):
1322+
for i, res in zip(range(target_length), unbatch_dp):
13171323
self.assertEqual(i, res)
13181324

13191325
input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
@@ -1322,8 +1328,8 @@ def test_unbatch_iterdatapipe(self):
13221328
unbatch_dp = input_dp.unbatch()
13231329
expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
13241330
self.assertEqual(len(list(unbatch_dp)), 4)
1325-
for i, res in zip(expected_dp, unbatch_dp):
1326-
self.assertEqual(i, res)
1331+
for j, res in zip(expected_dp, unbatch_dp):
1332+
self.assertEqual(j, res)
13271333

13281334
# Functional Test: unbatching multiple levels at the same time
13291335
unbatch_dp = input_dp.unbatch(unbatch_level=2)
@@ -2289,5 +2295,161 @@ def test_old_dataloader(self):
22892295
self.assertEqual(sorted(expected), sorted(items))
22902296

22912297

2298+
class TestIterDataPipeSingletonConstraint(TestCase):
2299+
2300+
r"""
2301+
Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older
2302+
iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior.
2303+
"""
2304+
2305+
def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe):
2306+
r"""
2307+
Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of
2308+
a second iterator invalidates the first one.
2309+
"""
2310+
it1 = iter(source_dp)
2311+
self.assertEqual(list(range(10)), list(it1))
2312+
it1 = iter(source_dp)
2313+
self.assertEqual(list(range(10)), list(it1)) # A fresh iterator can be read in full again
2314+
it1 = iter(source_dp)
2315+
self.assertEqual(0, next(it1))
2316+
it2 = iter(source_dp) # This should invalidate `it1`
2317+
self.assertEqual(0, next(it2)) # Should read from the beginning again
2318+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2319+
next(it1)
2320+
2321+
2322+
def test_iterdatapipe_singleton_generator(self):
2323+
r"""
2324+
Testing for the case where IterDataPipe's `__iter__` is a generator function.
2325+
"""
2326+
2327+
# Functional Test: Check if invalidation logic is correct
2328+
source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
2329+
self._check_single_iterator_invalidation_logic(source_dp)
2330+
2331+
# Functional Test: extend the test to a pipeline
2332+
dps = source_dp.map(_fake_fn).filter(_fake_filter_fn)
2333+
self._check_single_iterator_invalidation_logic(dps)
2334+
2335+
# Functional Test: multiple simultaneous references to the same DataPipe fails
2336+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2337+
for _ in zip(source_dp, source_dp):
2338+
pass
2339+
2340+
# Function Test: sequential references work
2341+
for _ in zip(list(source_dp), list(source_dp)):
2342+
pass
2343+
2344+
def test_iterdatapipe_singleton_self_next(self):
2345+
r"""
2346+
Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method
2347+
Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`).
2348+
"""
2349+
class _CustomIterDP_Self(IterDataPipe):
2350+
def __init__(self, iterable):
2351+
self.source = iterable
2352+
self.iterable = iter(iterable)
2353+
2354+
def __iter__(self):
2355+
self.reset()
2356+
return self
2357+
2358+
def __next__(self):
2359+
return next(self.iterable)
2360+
2361+
def reset(self):
2362+
self.iterable = iter(self.source)
2363+
2364+
# Functional Test: Check that every `__iter__` call returns the same object
2365+
source_dp = _CustomIterDP_Self(range(10))
2366+
res = list(source_dp)
2367+
it = iter(source_dp)
2368+
self.assertEqual(res, list(it))
2369+
2370+
# Functional Test: Check if invalidation logic is correct
2371+
source_dp = _CustomIterDP_Self(range(10))
2372+
self._check_single_iterator_invalidation_logic(source_dp)
2373+
self.assertEqual(1, next(source_dp)) # `source_dp` is still valid and can be read
2374+
2375+
# Functional Test: extend the test to a pipeline
2376+
source_dp = _CustomIterDP_Self(dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn))
2377+
self._check_single_iterator_invalidation_logic(source_dp)
2378+
self.assertEqual(1, next(source_dp)) # `source_dp` is still valid and can be read
2379+
2380+
# Functional Test: multiple simultaneous references to the same DataPipe fails
2381+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2382+
for _ in zip(source_dp, source_dp):
2383+
pass
2384+
2385+
def test_iterdatapipe_singleton_new_object(self):
2386+
r"""
2387+
Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`,
2388+
and there isn't a `__next__` method.
2389+
"""
2390+
class _CustomIterDP(IterDataPipe):
2391+
def __init__(self, iterable):
2392+
self.iterable = iter(iterable)
2393+
2394+
def __iter__(self): # Note that this doesn't reset
2395+
return self.iterable # Intentionally not returning `self`
2396+
2397+
# Functional Test: Check if invalidation logic is correct
2398+
source_dp = _CustomIterDP(range(10))
2399+
it1 = iter(source_dp)
2400+
self.assertEqual(0, next(it1))
2401+
it2 = iter(source_dp)
2402+
self.assertEqual(1, next(it2))
2403+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2404+
next(it1)
2405+
2406+
# Functional Test: extend the test to a pipeline
2407+
source_dp = _CustomIterDP(dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn))
2408+
it1 = iter(source_dp)
2409+
self.assertEqual(0, next(it1))
2410+
it2 = iter(source_dp)
2411+
self.assertEqual(1, next(it2))
2412+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2413+
next(it1)
2414+
2415+
# Functional Test: multiple simultaneous references to the same DataPipe fails
2416+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2417+
for _ in zip(source_dp, source_dp):
2418+
pass
2419+
2420+
def test_iterdatapipe_singleton_buggy(self):
2421+
r"""
2422+
Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has
2423+
a 10000 `__next__` method.
2424+
"""
2425+
class _CustomIterDP(IterDataPipe):
2426+
def __init__(self, iterable):
2427+
self.source = iterable
2428+
self.iterable = iter(iterable)
2429+
2430+
def __iter__(self):
2431+
return iter(self.source) # Intentionally not returning `self`
2432+
2433+
def __next__(self):
2434+
return next(self.iterable)
2435+
2436+
# Functional Test: Check if invalidation logic is correct
2437+
source_dp = _CustomIterDP(range(10))
2438+
self._check_single_iterator_invalidation_logic(source_dp)
2439+
self.assertEqual(0, next(source_dp)) # `__next__` is unrelated with `__iter__`
2440+
2441+
# Functional Test: Special case to show `__next__` is unrelated with `__iter__`
2442+
source_dp = _CustomIterDP(range(10))
2443+
self.assertEqual(0, next(source_dp))
2444+
it1 = iter(source_dp)
2445+
self.assertEqual(0, next(it1))
2446+
self.assertEqual(1, next(source_dp))
2447+
it2 = iter(source_dp) # invalidates both `it1`
2448+
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
2449+
next(it1)
2450+
self.assertEqual(2, next(source_dp)) # not impacted by the creation of `it2`
2451+
self.assertEqual(list(range(10)), list(it2)) # `it2` still works because it is a new object
2452+
2453+
22922454
if __name__ == '__main__':
22932455
run_tests()

test/test_profiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ def __next__(self):
153153
def get_value(self, idx):
154154
return self.data[idx]
155155

156-
dp1 = IDPIterator()
156+
dp1 = IDPIterator() # The object itself is an iterator
157157
self.assertEqual(5, dp1.get_value(5))
158-
it_dp1 = iter(dp1)
159-
self.assertEqual(5, it_dp1.get_value(5))
158+
it_dp1 = iter(dp1) # This creates the 1st iterator
159+
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
160160
self.assertEqual(list(range(10)), list(it_dp1))
161161

162162
class IDPDelegator(torch.utils.data.IterDataPipe):

0 commit comments

Comments
 (0)
0