8000 [DataPipe] Improve inline doc and testing for CollatorIterDataPipe (#… · pytorch/pytorch@ad0cd8a · GitHub
[go: up one dir, main page]

Skip to content

Commit ad0cd8a

Browse files
NivekTfacebook-github-bot
authored andcommitted
[DataPipe] Improve inline doc and testing for CollatorIterDataPipe (#70139)
Summary: Pull Request resolved: #70139 cc VitalyFedyunin ejguan NivekT Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D33199107 Pulled By: NivekT fbshipit-source-id: f96d77490998ac9bc3da8d4ff1a9caa08e9e7f27
1 parent 8a91201 commit ad0cd8a

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

test/test_datapipe.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,20 +1200,41 @@ def test_collate_datapipe(self):
12001200
arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
12011201
input_dp = dp.iter.IterableWrapper(arrs)
12021202

1203-
def _collate_fn(batch):
1204-
return torch.tensor(sum(batch), dtype=torch.float)
1203+
def _collate_fn(batch, default_type=torch.float):
1204+
return torch.tensor(sum(batch), dtype=default_type)
12051205

1206+
# Functional Test: defaults to the default collate function when a custom one is not specified
1207+
collate_dp = input_dp.collate()
1208+
for x, y in zip(input_dp, collate_dp):
1209+
self.assertEqual(torch.tensor(x), y)
1210+
1211+
# Functional Test: custom collate function
12061212
collate_dp = input_dp.collate(collate_fn=_collate_fn)
1213+
for x, y in zip(input_dp, collate_dp):
1214+
self.assertEqual(torch.tensor(sum(x), dtype=torch.float), y)
1215+
1216+
# Functional Test: custom, partial collate function
1217+
collate_dp = input_dp.collate(partial(_collate_fn, default_type=torch.int))
1218+
for x, y in zip(input_dp, collate_dp):
1219+
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
1220+
1221+
# Reset Test: reset the DataPipe and results are still correct
1222+
n_elements_before_reset = 1
1223+
res_before_reset, res_after_reset = reset_after_n_next_calls(collate_dp, n_elements_before_reset)
1224+
self.assertEqual([torch.tensor(6, dtype=torch.int)], res_before_reset)
1225+
for x, y in zip(input_dp, res_after_reset):
1226+
self.assertEqual(torch.tensor(sum(x), dtype=torch.int), y)
1227+
1228+
# __len__ Test: __len__ is inherited
12071229
self.assertEqual(len(input_dp), len(collate_dp))
1208-
for x, y in zip(collate_dp, input_dp):
1209-
self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float))
12101230

1231+
# __len__ Test: verify that it has no valid __len__ when the source doesn't have it
12111232
input_dp_nl = IDP_NoLen(arrs)
12121233
collate_dp_nl = input_dp_nl.collate()
12131234
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
12141235
len(collate_dp_nl)
1215-
for x, y in zip(collate_dp_nl, input_dp_nl):
1216-
self.assertEqual(x, torch.tensor(y))
1236+
for x, y in zip(input_dp_nl, collate_dp_nl):
1237+
self.assertEqual(torch.tensor(x), y)
12171238

12181239
def test_batch_datapipe(self):
12191240
arrs = list(range(10))

torch/utils/data/datapipes/iter/callable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,12 @@ def __setstate__(self, state):
145145
class CollatorIterDataPipe(MapperIterDataPipe):
146146
r""":class:`CollatorIterDataPipe`.
147147
148-
Iterable DataPipe to collate samples from datapipe to Tensor(s) by `util_.collate.default_collate`,
149-
or customized Data Structure by collate_fn.
148+
Iterable DataPipe to collate samples from DataPipe to Tensor(s) by a custom collate function,
149+
which defaults to `torch.utils.data.default_collate` if it is not specified.
150+
151+
.. note::
152+
While writing a custom collate function, you can impor A041 t `torch.utils.data.default_collate` for the
153+
default behavior and `functools.partial` to specify any additional arguments.
150154
151155
Args:
152156
datapipe: Iterable DataPipe being collated

torch/utils/data/dataset.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class IterableDataset(Dataset[T_co], metaclass=_DataPipeMeta):
8585
# Functional form of 'ShardingFilterIterDataPipe'
8686
def sharding_filter(self) -> IterDataPipe: ...
8787
# Functional form of 'ShufflerIterDataPipe'
88-
def shuffle(self, *, buffer_size: int = 10000, unbatch_level: int = 0) -> IterDataPipe: ...
88+
def shuffle(self, *, default: bool = True, buffer_size: int = 10000, unbatch_level: int = 0) -> IterDataPipe: ...
8989
# Functional form of 'UnBatcherIterDataPipe'
9090
def unbatch(self, unbatch_level: int = 1) -> IterDataPipe: ...
9191
# Functional form of 'ZipperIterDataPipe'

0 commit comments

Comments
 (0)
0