8000 [DataPipe] Enforcing single valid iterator for IterDataPipes with single DataPipe as output by NivekT · Pull Request #70479 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DataPipe] Enforcing single valid iterator for IterDataPipes with single DataPipe as output #70479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8bd3627
[DataPipe] Enforcing single iterator per IterableWrapperIterDataPipe
NivekT Dec 29, 2021
4dddb35
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Dec 29, 2021
57b5c09
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Jan 18, 2022
289fe02
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Jan 19, 2022
5961fa3
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Apr 15, 2022
f97581a
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Apr 18, 2022
aa01181
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT Apr 21, 2022
8b1e0d5
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT Apr 28, 2022
c9f3462
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 3, 2022
54cc729
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 3, 2022
dbf298e
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 5, 2022
2c5e0e7
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 5, 2022
91512c4
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 6, 2022
87bba0f
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 9, 2022
3c39423
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 12, 2022
7b75f5e
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 13, 2022
c77df3e
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 16, 2022
9824b8a
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 17, 2022
db9862f
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 193 additions & 31 deletions test/test_datapipe.py
F591
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,16 @@ def test_api(self):
self.assertTrue(fd.closed)

def test_pickle(self):
f = tempfile.TemporaryFile()
with self.assertRaises(TypeError) as ctx1:
pickle.dumps(f)
with tempfile.TemporaryFile() as f:
with self.assertRaises(TypeError) as ctx1:
pickle.dumps(f)

wrap_f = StreamWrapper(f)
with self.assertRaises(TypeError) as ctx2:
pickle.dumps(wrap_f)
wrap_f = StreamWrapper(f)
with self.assertRaises(TypeError) as ctx2:
pickle.dumps(wrap_f)

# Same exception when pickle
self.assertEqual(str(ctx1.exception), str(ctx2.exception))
# Same exception when pickle
self.assertEqual(str(ctx1.exception), str(ctx2.exception))

fd = TestStreamWrapper._FakeFD("")
wrap_fd = StreamWrapper(fd)
Expand All @@ -255,9 +255,9 @@ def test_repr(self):
wrap_fd = StreamWrapper(fd)
self.assertEqual(str(wrap_fd), "StreamWrapper<FakeFD>")

f = tempfile.TemporaryFile()
wrap_f = StreamWrapper(f)
self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")
with tempfile.TemporaryFile() as f:
wrap_f = StreamWrapper(f)
self.assertEqual(str(wrap_f), "StreamWrapper<" + str(f) + ">")


class TestIterableDataPipeBasic(TestCase):
Expand Down Expand Up @@ -568,7 +568,7 @@ def _fake_add(constant, data):


def _fake_filter_fn(data):
return data >= 5
return True


def _fake_filter_fn_constant(constant, data):
Expand Down Expand Up @@ -610,6 +610,7 @@ def _serialization_test_for_single_dp(self, dp, use_dill=False):
_ = next(it)
self._serialization_test_helper(dp, use_dill)
# 3. Testing for serialization after DataPipe is fully read
it = iter(dp)
_ = list(it)
self._serialization_test_helper(dp, use_dill)

Expand Down Expand Up @@ -637,12 +638,12 @@ def test_serializable(self):
(dp.iter.Batcher, None, (3, True,), {}),
(dp.iter.Collator, None, (_fake_fn,), {}),
(dp.iter.Concater, None, (dp.iter.IterableWrapper(range(5)),), {}),
(dp.iter.Demultiplexer, None, (2, _fake_filter_fn), {}),
# (dp.iter.Demultiplexer, None, (2, _fake_filter_fn), {}), # Temporarily disabled until next PR
(dp.iter.FileLister, ".", (), {}),
(dp.iter.FileOpener, None, (), {}),
(dp.iter.Filter, None, (_fake_filter_fn,), {}),
(dp.iter.Filter, None, (partial(_fake_filter_fn_constant, 5),), {}),
(dp.iter.Forker, None, (2,), {}),
# (dp.iter.Forker, None, (2,), {}), # Temporarily disabled until next PR
(dp.iter.Grouper, None, (_fake_filter_fn,), {"group_size": 2}),
(dp.iter.IterableWrapper, range(10), (), {}),
(dp.iter.Mapper, None, (_fake_fn,), {}),
Expand Down Expand Up @@ -678,7 +679,7 @@ def test_serializable_with_dill(self):
input_dp = dp.iter.IterableWrapper(range(10))
unpicklable_datapipes: List[Tuple[Type[IterDataPipe], Tuple, Dict[str, Any]]] = [
(dp.iter.Collator, (lambda x: x,), {}),
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
# (dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}), # Temporarily disabled until next PR
(dp.iter.Filter, (lambda x: x >= 5,), {}),
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
(dp.iter.Mapper, (lambda x: x,), {}),
Expand Down Expand Up @@ -850,6 +851,10 @@ def test_fork_iterdatapipe(self):
i1 = iter(dp1) # Reset both all child DataPipe
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
break
for i, (n1, n2) in enumerate(zip(i1, i2)):
output1.append(n1)
output2.append(n2)
self.assertEqual(list(range(5)) + list(range(10)), output1)
self.assertEqual(list(range(5)) + list(range(10)), output2)

Expand Down Expand Up @@ -1054,29 +1059,30 @@ def test_demux_iterdatapipe(self):
traverse(dp2) # This should not raise any error either

def test_map_iterdatapipe(self):
input_dp = dp.iter.IterableWrapper(range(10))
target_length = 10
input_dp = dp.iter.IterableWrapper(range(target_length))

def fn(item, dtype=torch.float, *, sum=False):
data = torch.tensor(item, dtype=dtype)
return data if not sum else data.sum()

# Functional Test: apply to each element correctly
map_dp = input_dp.map(fn)
self.assertEqual(len(input_dp), len(map_dp))
for x, y in zip(map_dp, input_dp):
self.assertEqual(target_length, len(map_dp))
for x, y in zip(map_dp, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

# Functional Test: works with partial function
map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
for x, y in zip(map_dp, input_dp):
for x, y in zip(map_dp, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

# __len__ Test: inherits length from source DataPipe
self.assertEqual(len(input_dp), len(map_dp))
self.assertEqual(target_length, len(map_dp))

input_dp_nl = IDP_NoLen(range(10))
input_dp_nl = IDP_NoLen(range(target_length))
map_dp_nl = input_dp_nl.map(lambda x: x)
for x, y in zip(map_dp_nl, input_dp_nl):
for x, y in zip(map_dp_nl, range(target_length)):
self.assertEqual(x, torch.tensor(y, dtype=torch.float))

# __len__ Test: inherits length from source DataPipe - raises error when invalid
Expand Down Expand Up @@ -1228,24 +1234,24 @@ def _collate_fn(batch, default_type=torch.float):

# Functional Test: defaults to the default collate function when a custom one is not specified
collate_dp = input_dp.collate()
for x, y in zip(input_dp, collate_dp):
for x, y in zip(arrs, collate_dp):
self.assertEqual(torch.tensor(x), y)

# Functional Test: custom collate function
collate_dp = input_dp.collate(collate_fn=_collate_fn)
for x, y in zip(input_dp, collate_dp):
for x, y in zip(arrs, collate_dp):
self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y)

# Functional Test: custom, partial collate function
collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int))
for x, y in zip(input_dp, collate_dp):
for x, y in zip(arrs, collate_dp):
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)

# Reset Test: reset the DataPipe and results are still correct
n_elements_before_reset = 1
res_before_reset, res_after_reset = reset_after_n_next_calls(collate_dp, n_elements_before_reset)
self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset)
for x, y in zip(input_dp, res_after_reset):
for x, y in zip(arrs, res_after_reset):
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)

# __len__ Test: __len__ is inherited
Expand All @@ -1256,7 +1262,7 @@ def _collate_fn(batch, default_type=torch.float):
collate_dp_nl = input_dp_nl.collate()
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(collate_dp_nl)
for x, y in zip(input_dp_nl, collate_dp_nl):
for x, y in zip(arrs, collate_dp_nl):
self.assertEqual(torch.tensor(x), y)

def test_batch_iterdatapipe(self):
Expand Down Expand Up @@ -1306,14 +1312,14 @@ def test_unbatch_iterdatapipe(self):
input_dp = prebatch_dp.batch(3)
unbatch_dp = input_dp.unbatch()
self.assertEqual(len(list(unbatch_dp)), target_length) # __len__ is as expected
for i, res in zip(prebatch_dp, unbatch_dp):
for i, res in zip(range(target_length), unbatch_dp):
self.assertEqual(i, res)

# Functional Test: unbatch works for an input with nested levels
input_dp = dp.iter.IterableWrapper([[0, 1, 2], [3, 4, 5]])
unbatch_dp = input_dp.unbatch()
self.assertEqual(len(list(unbatch_dp)), target_length)
for i, res in zip(prebatch_dp, unbatch_dp):
for i, res in zip(range(target_length), unbatch_dp):
self.assertEqual(i, res)

input_dp = dp.iter.IterableWrapper([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])
Expand All @@ -1322,8 +1328,8 @@ def test_unbatch_iterdatapipe(self):
unbatch_dp = input_dp.unbatch()
expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
self.assertEqual(len(list(unbatch_dp)), 4)
for i, res in zip(expected_dp, unbatch_dp):
self.assertEqual(i, res)
for j, res in zip(expected_dp, unbatch_dp):
self.assertEqual(j, res)

# Functional Test: unbatching multiple levels at the same time
unbatch_dp = input_dp.unbatch(unbatch_level=2)
Expand Down Expand Up @@ -2290,5 +2296,161 @@ def test_old_dataloader(self):
self.assertEqual(sorted(expected), sorted(items))


class TestIterDataPipeSingletonConstraint(TestCase):

r"""
Each `IterDataPipe` can only have one active iterator. Whenever a new iterator is created, older
iterators are invalidated. These tests aim to ensure `IterDataPipe` follows this behavior.
"""

def _check_single_iterator_invalidation_logic(self, source_dp: IterDataPipe):
r"""
Given a IterDataPipe, verifies that the iterator can be read, reset, and the creation of
a second iterator invalidates the first one.
"""
it1 = iter(source_dp)
self.assertEqual(list(range(10)), list(it1))
it1 = iter(source_dp)
self.assertEqual(list(range(10)), list(it1)) # A fresh iterator can be read in full again
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
it2 = iter(source_dp) # This should invalidate `it1`
self.assertEqual(0, next(it2)) # Should read from the beginning again
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
next(it1)


def test_iterdatapipe_singleton_generator(self):
r"""
Testing for the case where IterDataPipe's `__iter__` is a generator function.
"""

# Functional Test: Check if invalidation logic is correct
source_dp: IterDataPipe = dp.iter.IterableWrapper(range(10))
self._check_single_iterator_invalidation_logic(source_dp)

# Functional Test: extend the test to a pipeline
dps = source_dp.map(_fake_fn).filter(_fake_filter_fn)
self._check_single_iterator_invalidation_logic(dps)

# Functional Test: multiple simultaneous references to the same DataPipe fails
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
for _ in zip(source_dp, source_dp):
pass

# Function Test: sequential references work
for _ in zip(list(source_dp), list(source_dp)):
pass

def test_iterdatapipe_singleton_self_next(self):
r"""
Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method
Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`).
"""
class _CustomIterDP_Self(IterDataPipe):
def __init__(self, iterable):
self.source = iterable
self.iterable = iter(iterable)

def __iter__(self):
self.reset()
return self

def __next__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need buggy case when it returns a new object, but also have next

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am adding a buggy case at the bottom. Let me know if it matches what you have in mind.

return next(self.iterable)

def reset(self):
self.iterable = iter(self.source)

# Functional Test: Check that every `__iter__` call returns the same object
source_dp = _CustomIterDP_Self(range(10))
res = list(source_dp)
it = iter(source_dp)
self.assertEqual(res, list(it))

# Functional Test: Check if invalidation logic is correct
source_dp = _CustomIterDP_Self(range(10))
self._check_single_iterator_invalidation_logic(source_dp)
self.assertEqual(1, next(source_dp)) # `source_dp` is still valid and can be read

# Functional Test: extend the test to a pipeline
source_dp = _CustomIterDP_Self(dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn))
self._check_single_iterator_invalidation_logic(source_dp)
self.assertEqual(1, next(source_dp)) # `source_dp` is still valid and can be read

# Functional Test: multiple simultaneous references to the same DataPipe fails
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
for _ in zip(source_dp, source_dp):
pass

def test_iterdatapipe_singleton_new_object(self):
r"""
Testing for the case where IterDataPipe's `__iter__` isn't a generator nor returns `self`,
and there isn't a `__next__` method.
"""
class _CustomIterDP(IterDataPipe):
def __init__(self, iterable):
self.iterable = iter(iterable)

def __iter__(self): # Note that this doesn't reset
return self.iterable # Intentionally not returning `self`

# Functional Test: Check if invalidation logic is correct
source_dp = _CustomIterDP(range(10))
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
it2 = iter(source_dp)
self.assertEqual(1, next(it2))
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
next(it1)

# Functional Test: extend the test to a pipeline
source_dp = _CustomIterDP(dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn))
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
it2 = iter(source_dp)
self.assertEqual(1, next(it2))
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, why it is invalid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have at least one valid iterator after this?

Copy link
Contributor Author
@NivekT NivekT May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my original proposal, the users can are allowed to call __iter__ as many times as they want. But after our offline discussion, we decided to apply the constraint on __iter__ even when it may be returning self. This is because, having multiple active iterators referencing the same underlying object can lead to confusion and bugs. Plus the use case isn't really necessary when you can just reference dp. The code snippet below illustrates what the behavior would've been like if it is allowed:

dp = CustomDP(range(10))
next(dp)  # returns 0
it1 = iter(dp)  # returns `self`
next(it1)  # returns 1
it2 = iter(dp)  # returns `self`
next(it2)  # returns 2
next(it1)  # returns 3
next(dp)  # returns 4

In the current implementation, calling next(dp) raises an error if there is more than 1 iterator created from that DataPipe. This is because the ID that tracks whether the iterator is valid (i.e. dp._valid_iterator_id) is tied to the DataPipe (and because it returns self, the ID is tied to the iterator as well). I think it has to be the case, otherwise when next(dp) is called. It has no idea to what ID to check against.

source_dp = _CustomIterDP_Self(range(10))
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(source_dp))
it2 = iter(source_dp)
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
    next(it1)
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
    next(it2)
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
    next(source_dp)
# In this test case, there is no valid iterator at the end, because `iter(it1)` delegates to `iter(dp)` anyway.

The simplest workaround for users is to re-define dp via dp = CustomDP() or avoid creating more than one iterator from dp by keep using dp as the variable rather than calling iter(dp).

Copy link
Contributor
@ejguan ejguan May 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the proposal you have

  • For DataPipe returning self in __iter__: We disallow users to call iter multiple times.
  • For DataPipe using a generator function as __iter__: We are making sure a new iterator is created and previous iterator is invalidated whenever iter is called.

As we are making iterator of DataPiep singleton (the same behavior), we are introducing a new difference on iterator. I think this is not preferable.

I think as we have reset function to help DataPipe to reset iterator. Why don't we rely on reset to clean up all buffer, etc. for DataPipe returning self in __iter__. And, the IteratorDecorator and datapipe itself should be able to track the iterator_id to invalidate previous iterator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see how that can be confusing.

I think this proposal will be best, given a DataPipe dp with __next__ with iterators it1 and it2:

  1. Allows unlimited creation of iterators, but only the latest one is valid/active
  2. Doesn't place restriction on next(dp) (but check for invalidation when next(it1) is called
    • We can't really place restriction on next(dp), because any call to next(it2) gets delegated to next(dp)
      within the IteratorDecorator
source_dp = _CustomIterDP_Self(range(10))
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(source_dp))
# Only invalidates `it1`, and not `source_dp`. Since methods of `it2` depends on `source_dp` remaining valid
it2 = iter(source_dp)
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
    next(it1)
self.assertEqual(0, next(it2))
self.assertEqual(1, next(source_dp))

Please examine this test for the full details:

pytorch/test/test_datapipe.py

Lines 2337 to 2390 in c77df3e

def test_iterdatapipe_singleton_self_next(self):
r"""
Testing for the case where IterDataPipe's `__iter__` returns `self` and there is a `__next__` method
Note that the following DataPipe by is singleton by default (because `__iter__` returns `self`).
"""
class _CustomIterDP_Self(IterDataPipe):
def __init__(self, iterable):
self.source = iterable
self.iterable = iter(iterable)
def __iter__(self):
self.reset()
return self
def __next__(self):
return next(self.iterable)
def reset(self):
self.iterable = iter(self.source)
# Functional Test: Check that every `__iter__` call returns the same object
source_dp = _CustomIterDP_Self(range(10))
res = list(source_dp)
it = iter(source_dp)
self.assertEqual(res, list(it))
# Functional Test: Check if invalidation logic is correct
source_dp = _CustomIterDP_Self(range(10))
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(source_dp))
# Only invalidates `it1`, and not `source_dp`. Since methods of `it2` depends on `source_dp` remaining valid
it2 = iter(source_dp)
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
next(it1)
self.assertEqual(0, next(it2))
self.assertEqual(1, next(source_dp))
# Functional Test: extend the test to a pipeline
source_dp = _CustomIterDP_Self(dp.iter.IterableWrapper(range(10)).map(_fake_fn).filter(_fake_filter_fn))
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(source_dp))
# Only invalidates `it1`, and not `source_dp`. Since methods of `it2` depends on `source_dp` remaining valid
it2 = iter(source_dp)
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
next(it1)
self.assertEqual(0, next(it2))
self.assertEqual(1, next(source_dp))
# Functional Test: multiple simultaneous references to the same DataPipe fails
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
for _ in zip(source_dp, source_dp):
pass

next(it1)

# Functional Test: multiple simultaneous references to the same DataPipe fails
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
for _ in zip(source_dp, source_dp):
pass

def test_iterdatapipe_singleton_buggy(self):
r"""
Buggy test case case where IterDataPipe's `__iter__` returns a new object, but also has
a `__next__` method.
"""
class _CustomIterDP(IterDataPipe):
def __init__(self, iterable):
self.source = iterable
self.iterable = iter(iterable)

def __iter__(self):
return iter(self.source) # Intentionally not returning `self`

def __next__(self):
return next(self.iterable)

# Functional Test: Check if invalidation logic is correct
source_dp = _CustomIterDP(range(10))
self._check_single_iterator_invalidation_logic(source_dp)
self.assertEqual(0, next(source_dp)) # `__next__` is unrelated with `__iter__`

# Functional Test: Special case to show `__next__` is unrelated with `__iter__`
source_dp = _CustomIterDP(range(10))
self.assertEqual(0, next(source_dp))
it1 = iter(source_dp)
self.assertEqual(0, next(it1))
self.assertEqual(1, next(source_dp))
it2 = iter(source_dp) # invalidates both `it1`
with self.assertRaisesRegex(RuntimeError, "This iterator has been invalidated"):
next(it1)
self.assertEqual(2, next(source_dp)) # not impacted by the creation of `it2`
self.assertEqual(list(range(10)), list(it2)) # `it2` still works because it is a new object


if __name__ == '__main__':
run_tests()
6 changes: 3 additions & 3 deletions test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def __next__(self):
def get_value(self, idx):
return self.data[idx]

dp1 = IDPIterator()
dp1 = IDPIterator() # The object itself is an iterator
self.assertEqual(5, dp1.get_value(5))
it_dp1 = iter(dp1)
self.assertEqual(5, it_dp1.get_value(5))
it_dp1 = iter(dp1) # This creates the 1st iterator
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
self.assertEqual(list(range(10)), list(it_dp1))

class IDPDelegator(torch.utils.data.IterDataPipe):
Expand Down
Loading
0