10000 [DataPipe] Enforcing single valid iterator for IterDataPipes with single DataPipe as output by NivekT · Pull Request #70479 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DataPipe] Enforcing single valid iterator for IterDataPipes with single DataPipe as output #70479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8bd3627
[DataPipe] Enforcing single iterator per IterableWrapperIterDataPipe
NivekT Dec 29, 2021
4dddb35
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Dec 29, 2021
57b5c09
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Jan 18, 2022
289fe02
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Jan 19, 2022
5961fa3
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Apr 15, 2022
f97581a
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
NivekT Apr 18, 2022
aa01181
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT Apr 21, 2022
8b1e0d5
Update on "[DataPipe] Enforcing single 10000 valid iterator for IterDataPip…
NivekT Apr 28, 2022
c9f3462
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 3, 2022
54cc729
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 3, 2022
dbf298e
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 5, 2022
2c5e0e7
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 5, 2022
91512c4
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 6, 2022
87bba0f
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 9, 2022
3c39423
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 12, 2022
7b75f5e
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 13, 2022
c77df3e
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 16, 2022
9824b8a
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 17, 2022
db9862f
Update on "[DataPipe] Enforcing single valid iterator for IterDataPip…
NivekT May 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update on "[DataPipe] Enforcing single iterator per IterableWrapperIt…
…erDataPipe"

Differential Revision: [D33344609](https://our.internmc.facebook.com/intern/diff/D33344609)

[ghstack-poisoned]
  • Loading branch information
NivekT committed Dec 29, 2021
commit 4dddb35d80bcd30800cf733ee9872087706907ed
49 changes: 24 additions & 25 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,27 +805,26 @@ def test_fork_iterdatapipe(self):
dp0 = input_dp.fork(num_instances=1)
self.assertEqual(list(range(10)), list(dp0))

# Test Case: making sure all child DataPipe shares the same reference
input_dp.reset()
# Functional Test: making sure all child DataPipe shares the same reference
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
self.assertTrue(all(n1 is n2 and n1 is n3 for n1, n2, n3 in zip(dp1, dp2, dp3)))

# Test Case: one child DataPipe yields all value at a time
# Functional Test: one child DataPipe yields all value at a time
input_dp.reset()
output1, output2, output3 = list(dp1), list(dp2), list(dp3)
self.assertEqual(list(range(10)), output1)
self.assertEqual(list(range(10)), output2)
self.assertEqual(list(range(10)), output3)

# Test Case: two child DataPipes yield value together
# Functional Test: two child DataPipes yield value together
input_dp.reset()
dp1, dp2 = input_dp.fork(num_instances=2)
output = []
for n1, n2 in zip(dp1, dp2):
output.append((n1, n2))
self.assertEqual([(i, i) for i in range(10)], output)

# Test Case: one child DataPipe yields all value first, but buffer_size = 5 being too small
# Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small
input_dp.reset()
dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=5)
it1 = iter(dp1)
Expand All @@ -846,15 +845,15 @@ def test_fork_iterdatapipe(self):
for d1, d2 in zip(l1, l2):
self.assertEqual(d1, d2)

# Test Case: two child DataPipes yield value together with buffer size 1
# Functional Test: two child DataPipes yield value together with buffer size 1
input_dp.reset()
dp1, dp2 = input_dp.fork(num_instances=2, buffer_size=1)
output = []
for n1, n2 in zip(dp1, dp2):
output.append((n1, n2))
self.assertEqual([(i, i) for i in range(10)], output)

# Test Case: make sure logic related to slowest_ptr is working properly
# Functional Test: make sure logic related to slowest_ptr is working properly
input_dp.reset()
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
output1, output2 , output3 = [], [], []
Expand All @@ -868,7 +867,7 @@ def test_fork_iterdatapipe(self):
self.assertEqual(list(range(5)), output2)
self.assertEqual(list(range(10)), output3)

# Test Case: DataPipe doesn't reset if this pipe hasn't been read
# Reset Test: DataPipe doesn't reset if this pipe hasn't been read
input_dp.reset()
dp1, dp2 = input_dp.fork(num_instances=2)
i1, i2 = iter(dp1), iter(dp2)
Expand All @@ -880,7 +879,7 @@ def test_fork_iterdatapipe(self):
i1 = iter(dp1) # Doesn't reset because i1 hasn't been read
self.assertEqual(list(range(10)), output2)

# Test Case: DataPipe reset when some of it have been read
# Reset Test: DataPipe reset when some of it have been read
input_dp.reset()
dp1, dp2 = input_dp.fork(num_instances=2)
i1, i2 = iter(dp1), iter(dp2)
Expand All @@ -898,7 +897,7 @@ def test_fork_iterdatapipe(self):
self.assertEqual(list(range(5)) + list(range(10)), output1)
self.assertEqual(list(range(5)) + list(range(10)), output2)

# Test Case: DataPipe reset, even when some other child DataPipes are not read
# Reset Test: DataPipe reset, even when some other child DataPipes are not read
input_dp.reset()
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
output1, output2 = list(dp1), list(dp2)
Expand All @@ -923,7 +922,7 @@ def test_fork_iterdatapipe(self):
break
self.assertEqual(list(range(10)), list(dp3)) # dp3 has to read from the start again

# Test Case: Each DataPipe inherits the source datapipe's length
# __len__ Test: Each DataPipe inherits the source datapipe's length
input_dp.reset()
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
self.assertEqual(len(input_dp), len(dp1))
Expand Down Expand Up @@ -984,15 +983,15 @@ def test_demux_iterdatapipe(self):
self.assertEqual(list(range(0, 10, 2)), output1)
self.assertEqual(list(range(1, 10, 2)), output2)

# Test Case: split into 2 DataPipes and output them together
# Functional Test: split into 2 DataPipes and output them together
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
output = []
for n1, n2 in zip(dp1, dp2):
output.append((n1, n2))
self.assertEqual([(i, i + 1) for i in range(0, 10, 2)], output)

# Test Case: values of the same classification are lumped together, and buffer_size = 3 being too small
# Functional Test: values of the same classification are lumped together, and buffer_size = 3 being too small
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=4)
it1 = iter(dp1)
Expand All @@ -1001,14 +1000,14 @@ def test_demux_iterdatapipe(self):
with self.assertRaises(BufferError):
list(dp2)

# Test Case: values of the same classification are lumped together, and buffer_size = 5 is just enough
# Functional Test: values of the same classification are lumped together, and buffer_size = 5 is just enough
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: 0 if x >= 5 else 1, buffer_size=5)
output1, output2 = list(dp1), list(dp2)
self.assertEqual(list(range(5, 10)), output1)
self.assertEqual(list(range(0, 5)), output2)

# Test Case: values of the same classification are lumped together, and unlimited buffer
# Functional Test: values of the same classification are lumped together, and unlimited buffer
input_dp.reset()
with warnings.catch_warnings(record=True) as wa:
dp1, dp2 = input_dp.demux(
Expand All @@ -1022,15 +1021,15 @@ def test_demux_iterdatapipe(self):
self.assertEqual(list(range(5, 10)), output1)
self.assertEqual(list(range(0, 5)), output2)

# Test Case: classifer returns a value outside of [0, num_instance - 1]
# Functional Test: classifier returns a value outside of [0, num_instance - 1]
input_dp.reset()
dp0 = input_dp.demux(num_instances=1, classifier_fn=lambda x: x % 2)
it = iter(dp0[0])
with self.assertRaises(ValueError):
next(it)
next(it)

# Test Case: DataPipe doesn't reset when it has not been read
# Reset Test: DataPipe doesn't reset when it has not been read
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
i1 = iter(dp1)
Expand All @@ -1042,7 +1041,7 @@ def test_demux_iterdatapipe(self):
i1 = iter(dp1)
self.assertEqual(list(range(1, 10, 2)), output2)

# Test Case: DataPipe reset when some of it has been read
# Reset Test: DataPipe reset when some of it has been read
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
output1, output2 = [], []
Expand All @@ -1062,7 +1061,7 @@ def test_demux_iterdatapipe(self):
self.assertEqual([0, 2, 4] + list(range(0, 10, 2)), output1)
self.assertEqual([1, 3, 5] + list(range(1, 10, 2)), output2)

# Test Case: DataPipe reset, even when not all child DataPipes are exhausted
# Reset Test: DataPipe reset, even when not all child DataPipes are exhausted
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
output1 = list(dp1)
Expand All @@ -1086,22 +1085,22 @@ def test_demux_iterdatapipe(self):
output2 = list(dp2) # output2 has to read from beginning again
self.assertEqual(list(range(1, 10, 2)), output2)

# Test Case: drop_none = True
# Functional Test: drop_none = True
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
drop_none=True)
self.assertEqual([2, 4, 6, 8], list(dp1))
self.assertEqual([1, 3, 7, 9], list(dp2))

# Test Case: drop_none = False
# Functional Test: drop_none = False
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2 if x % 5 != 0 else None,
drop_none=False)
it1 = iter(dp1)
with self.assertRaises(ValueError):
next(it1)

# Test Case: __len__ not implemented
# __len__ Test: __len__ not implemented
input_dp.reset()
dp1, dp2 = input_dp.demux(num_instances=2, classifier_fn=lambda x: x % 2)
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -1133,7 +1132,7 @@ def fn(item, dtype=torch.float, *, sum=False):
# Functional Test: works with partial function
input_dp.reset()
map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True))
for x, y in zip(map_dp, range(10)):
for x, y in zip(map_dp, range(10)):
self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum())

# __len__ Test: inherits length from source DataPipe
Expand Down Expand Up @@ -1401,8 +1400,8 @@ def test_unbatch_iterdatapipe(self):
expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]]
self.assertEqual(len(list(unbatch_dp)), 4)
input_dp.reset()
for i, res in zip(expected_dp, unbatch_dp):
self.assertEqual(i, res)
for i2, res2 in zip(expected_dp, unbatch_dp):
self.assertEqual(i2, res2)

input_dp.reset()
unbatch_dp = input_dp.unbatch(unbatch_level=2)
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/datapipes/iter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __next__(self):
if self.iter is None:
self._create_iterator()
self.state_counter += 1
return next(self.iter)
return next(self.iter) # type:ignore[arg-type]

def reset(self) -> None:
self.iter = None
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.
0