8000 [DataPipe] Count number of successful yields for IterDataPipe by NivekT · Pull Request #79657 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DataPipe] Count number of successful yields for IterDataPipe #79657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def tearDown(self):

def test_listdirfiles_iterable_datapipe(self):
temp_dir = self.temp_dir.name
datapipe = dp.iter.FileLister(temp_dir, '')
datapipe: IterDataPipe = dp.iter.FileLister(temp_dir, '')

count = 0
for pathname in datapipe:
Expand Down Expand Up @@ -2640,5 +2640,120 @@ def test_iterdatapipe_singleton_constraint_multiple_outputs(self):
next(it1)
self.assertEqual(1, next(it3))

class TestIterDataPipeCountSampleYielded(TestCase):

def _yield_count_test_helper(self, datapipe, n_expected_samples):

# Functional Test: Check if number of samples yielded is as expected
res = list(datapipe)
self.assertEqual(len(res), datapipe._number_of_samples_yielded)

# Functional Test: Check if the count is correct when DataPipe is partially read
it = iter(datapipe)
res = []
for i, value in enumerate(it):
res.append(value)
if i == n_expected_samples - 1:
break
self.assertEqual(n_expected_samples, datapipe._number_of_samples_yielded)

# Functional Test: Check for reset behavior and if iterator also works
it = iter(datapipe) # reset the DataPipe
res = list(it)
self.assertEqual(len(res), datapipe._number_of_samples_yielded)

def test_iterdatapipe_sample_yielded_generator_function(self):
# Functional Test: `__iter__` is a generator function
datapipe: IterDataPipe = dp.iter.IterableWrapper(range(10))
self._yield_count_test_helper(datapipe, n_expected_samples=5)

def test_iterdatapipe_sample_yielded_generator_function_exception(self):
# Functional Test: `__iter__` is a custom generator function with exception
class _CustomGeneratorFnDataPipe(IterDataPipe):
# This class's `__iter__` has a Runtime Error
def __iter__(self):
yield 0
yield 1
yield 2
raise RuntimeError("Custom test error after yielding 3 elements")
yield 3

# Functional Test: Ensure the count is correct even when exception is raised
datapipe: IterDataPipe = _CustomGeneratorFnDataPipe()
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(datapipe)
self.assertEqual(3, datapipe._number_of_samples_yielded)

# Functional Test: Check for reset behavior and if iterator also works
it = iter(datapipe) # reset the DataPipe
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(it)
self.assertEqual(3, datapipe._number_of_samples_yielded)

def test_iterdatapipe_sample_yielded_return_self(self):
class _CustomGeneratorDataPipe(IterDataPipe):
# This class's `__iter__` is not a generator function
def __init__(self):
self.source = iter(range(10))

def __iter__(self):
return self.source

def reset(self):
self.source = iter(range(10))

datapipe: IterDataPipe = _CustomGeneratorDataPipe()
self._yield_count_test_helper(datapipe, n_expected_samples=5)

def test_iterdatapipe_sample_yielded_next(self):
class _CustomNextDataPipe(IterDataPipe):
# This class's `__iter__` returns `self` and has a `__next__`
def __init__(self):
self.source = iter(range(10))

def __iter__(self):
return self

def __next__(self):
return next(self.source)

def reset(self):
self.source = iter(range(10))

datapipe: IterDataPipe = _CustomNextDataPipe()
self._yield_count_test_helper(datapipe, n_expected_samples=5)

def test_iterdatapipe_sample_yielded_next_exception(self):
class _CustomNextDataPipe(IterDataPipe):
# This class's `__iter__` returns `self` and has a `__next__`
def __init__(self):
self.source = iter(range(10))
self.count = 0

def __iter__(self):
return self

def __next__(self):
if self.count == 3:
raise RuntimeError("Custom test error after yielding 3 elements")
self.count += 1
return next(self.source)

def reset(self):
self.count = 0
self.source = iter(range(10))

# Functional Test: Ensure the count is correct even when exception is raised
datapipe: IterDataPipe = _CustomNextDataPipe()
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(datapipe)
self.assertEqual(3, datapipe._number_of_samples_yielded)

# Functional Test: Check for reset behavior and if iterator also works
it = iter(datapipe) # reset the DataPipe
with self.assertRaisesRegex(RuntimeError, "Custom test error after yielding 3 elements"):
list(it)
self.assertEqual(3, datapipe._number_of_samples_yielded)

if __name__ == '__main__':
run_tests()
42 changes: 31 additions & 11 deletions torch/utils/data/datapipes/_hook_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,40 @@ def profiler_record_fn_context():
return torch.autograd.profiler.record_function(profile_name)

class IteratorDecorator:
"""Wrap the iterator and modifying its `__next__` method"""
def __init__(self, iterator, source_dp, iterator_id):
r"""
Wrap the iterator and modifying its `__next__` method. This decorator is applied to
DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__`
method commonly returns `self` but not necessarily.
"""
def __init__(self, iterator, source_dp, iterator_id, has_next_method):
self.iterator = iterator
self.source_dp = source_dp
self.iterator_id = iterator_id
self._profiler_enabled = torch.autograd._profiler_enabled()
# Check if `__iter__` returns `self` and `DataPipe` has `__next__`
self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method

def __iter__(self):
return self

def _get_next(self):
r"""
Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
"""
_check_iterator_valid(self.source_dp, self.iterator_id)
result = next(self.iterator)
if not self.self_and_has_next_method:
self.source_dp._number_of_samples_yielded += 1
return result

def __next__(self):
# TODO: Add try-except to in-place reduce traceback from the Exception
# See: https://github.com/pytorch/data/issues/284
if self._profiler_enabled:
with profiler_record_fn_context():
_check_iterator_valid(self.source_dp, self.iterator_id)
return next(self.iterator)
return self._get_next()
else: # Decided against using `contextlib.nullcontext` for performance reasons
_check_iterator_valid(self.source_dp, self.iterator_id)
return next(self.iterator)
return self._get_next()

def __getattr__(self, name):
return getattr(self.iterator, name)
Expand All @@ -136,6 +150,7 @@ def wrap_generator(*args, **kwargs):
response = gen.send(None)

while True:
datapipe._number_of_samples_yielded += 1
request = yield response
# Pass through here every time `__next__` is called
if _profiler_enabled:
Expand Down Expand Up @@ -172,21 +187,26 @@ def wrap_generator(*args, **kwargs):
def wrap_next(*args, **kwargs):
if torch.autograd._profiler_enabled():
with profiler_record_fn_context():
return next_func(*args, **kwargs)
result = next_func(*args, **kwargs)
else:
return next_func(*args, **kwargs)
result = next_func(*args, **kwargs)
datapipe = args[0]
datapipe._number_of_samples_yielded += 1
return result

namespace['__next__'] = wrap_next

# Note that if the `__next__` and `__iter__` do something completely unrelated? It may cause issue but
# the user will be violating the iterator protocol
# Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
# the user will be violating the iterator protocol. Potential issue:
# 1. Valid iterator ID may not update or checked properly
# 2. The number of samples yielded will be miscounted

# Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
@functools.wraps(func)
def wrap_iter(*args, **kwargs):
iter_ret = func(*args, **kwargs)
datapipe = args[0]
iterator_id = _set_datapipe_valid_iterator_id(datapipe) # This ID is tied to each created iterator
return IteratorDecorator(iter_ret, datapipe, iterator_id)
return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace)

namespace['__iter__'] = wrap_iter
1 change: 1 addition & 0 deletions torch/utils/data/datapipes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def conditional_reset(*args, **kwargs):
if datapipe._restored is True:
datapipe._restored = False
else:
datapipe._number_of_samples_yielded = 0
reset_func(*args, **kwargs)

namespace['reset'] = conditional_reset
Expand Down
1 change: 1 addition & 0 deletions torch/utils/data/datapipes/datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
A8AF str_hook: Optional[Callable] = None
repr_hook: Optional[Callable] = None
_valid_iterator_id: Optional[int] = None
_number_of_samples_yielded: int = 0
_restored: bool = False

def __getattr__(self, attribute_name):
Expand Down
2 changes: 2 additions & 0 deletions torch/utils/data/datapipes/datapipe.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
getstate_hook: Optional[Callable] = ...
str_hook: Optional[Callable] = ...
repr_hook: Optional[Callable] = ...
_number_of_samples_yielded: int = ...
_restored: bool = False
def __getattr__(self, attribute_name: Any): ...
@classmethod
def register_function(cls, function_name: Any, function: Any) -> None: ...
Expand Down
0