diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index 24b45413e5c910..0d4213d051c13c 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -171,7 +171,8 @@ def wrap_generator(*args, **kwargs): @functools.wraps(next_func) def wrap_next(*args, **kwargs): if torch.autograd._profiler_enabled(): - return next_func(*args, **kwargs) + with profiler_record_fn_context(): + return next_func(*args, **kwargs) else: return next_func(*args, **kwargs)