8000 [DataPipe] Snapshotting prototype · pytorch/pytorch@43588ad · GitHub
[go: up one dir, main page]

Skip to content

Commit 43588ad

Browse files
committed
[DataPipe] Snapshotting prototype
ghstack-source-id: e8973a1 Pull Request resolved: #70216
1 parent 03b32d0 commit 43588ad

File tree

9 files changed

+140
-3
lines changed

9 files changed

+140
-3
lines changed

test/test_datapipe.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,16 @@ def test_concat_iterdatapipe(self):
881881

882882
self.assertEqual(list(concat_dp), list(range(10)) + list(range(5)))
883883

884+
# Snapshot Test:
885+
reset_dps()
886+
concat_dp = input_dp1.concat(input_dp2)
887+
blank_dp = dp.iter.IterableWrapper([])
888+
n_elements_to_advance = 5
889+
new_dp = blank_dp.concat(blank_dp)
890+
snapshot_test_helper(concat_dp, new_dp, n_elements_to_advance)
891+
for old_ele, new_ele in zip(concat_dp, new_dp):
892+
self.assertEqual(old_ele, new_ele)
893+
884894
def test_fork_iterdatapipe(self):
885895
input_dp = dp.iter.IterableWrapper(range(10))
886896

@@ -1227,6 +1237,13 @@ def fn(item, dtype=torch.float, *, sum=False):
12271237
self.assertEqual(list(range(n_elements_before_reset)), res_before_reset)
12281238
self.assertEqual(list(range(10)), res_after_reset)
12291239

1240+
# Snapshot Test: DataPipe restores properly
1241+
n_elements_to_advance = 5
1242+
new_dp = dp.iter.Mapper(datapipe=dp.iter.IterableWrapper([]), fn=partial(fn, dtype=torch.int, sum=True))
1243+
snapshot_test_helper(map_dp, new_dp, n_elements_to_advance)
1244+
for old_ele, new_ele in zip(map_dp, new_dp):
1245+
self.assertEqual(old_ele, new_ele)
1246+
12301247
@suppress_warnings # Suppress warning for lambda fn
12311248
def test_map_tuple_list_with_col_iterdatapipe(self):
12321249
def fn_11(d):
@@ -1399,6 +1416,14 @@ def _collate_fn(batch, default_type=torch.float):
13991416
for x, y in zip(arrs, collate_dp_nl):
14001417
self.assertEqual(torch.tensor(x), y)
14011418

1419+
# Snapshot Test: DataPipe restores properly
1420+
collate_dp = input_dp.collate(collate_fn=_collate_fn)
1421+
n_elements_to_advance = 1
1422+
new_dp = dp.iter.Collator(datapipe=dp.iter.IterableWrapper([]), collate_fn=_collate_fn)
1423+
snapshot_test_helper(collate_dp, new_dp, n_elements_to_advance)
1424+
for old_ele, new_ele in zip(collate_dp, new_dp):
1425+
self.assertEqual(old_ele, new_ele)
1426+
14021427
def test_batch_iterdatapipe(self):
14031428
arrs = list(range(10))
14041429
input_dp = dp.iter.IterableWrapper(arrs)
@@ -1542,6 +1567,13 @@ def _mul_filter_fn(a, b):
15421567
self.assertEqual(list(range(5, 10))[:n_elements_before_reset], res_before_reset)
15431568
self.assertEqual(list(range(5, 10)), res_after_reset)
15441569

1570+
# Snapshot Test: DataPipe restores properly
1571+
n_elements_to_advance = 5
1572+
new_dp = dp.iter.Filter(datapipe=dp.iter.IterableWrapper([]), filter_fn=partial(_filter_fn, val=5, clip=True))
1573+
snapshot_test_helper(filter_dp, new_dp, n_elements_to_advance)
1574+
for old_ele, new_ele in zip(filter_dp, new_dp):
1575+
self.assertEqual(old_ele, new_ele)
1576+
15451577
def test_sampler_iterdatapipe(self):
15461578
input_dp = dp.iter.IterableWrapper(range(10))
15471579
# Default SequentialSampler

torch/utils/data/datapipes/iter/combinatorics.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ def __len__(self) -> int:
5050
return len(self.sampler)
5151
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
5252

53+
def save_snapshot(self):
54+
# TODO: Do poor man's snapshotting by default, and show a warning
55+
# Unless the sampler has a save_snapshot and restore_snapshot attribute
56+
# TODO: 1. Should this DataPipe have a buffer?
57+
# TODO: 2. It really depends on the specific sampler (potentially need poor man's sampling
58+
pass
59+
60+
def restore_snapshot(self, snapshot=None):
61+
pass
62+
5363

5464
@functional_datapipe('shuffle')
5565
class ShufflerIterDataPipe(IterDataPipe[T_co]):

torch/utils/data/datapipes/iter/combining.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,31 @@ def __len__(self) -> int:
6363
self.length = -1
6464
return len(self)
6565

66+
def save_snapshot(self):
67+
# TODO: This should be a pass if you assume the previous DPs' snapshots are correct
68+
# Otherwise, you need a buffer to save the outputs from `for data in dp:` in case it crashes
69+
# halfway through (e.g. 4 DPs, first 2 succeed, 3rd fails). Code might be like:
70+
# try:
71+
# for data in dp:
72+
# temp.append(data)
73+
# except SomeException:
74+
# # Remember which failed, start again from that one
75+
# # If no exception, proceed as normal
76+
pass
77+
78+
def restore_snapshot(self, snapshot=None):
79+
pass
80+
81+
# def __getstate__(self):
82+
# if IterDataPipe.getstate_hook is not None:
83+
# return IterDataPipe.getstate_hook(self)
84+
# state = (self.datapipes, self.length)
85+
# return state
86+
#
87+
# def __setstate__(self, state):
88+
# (self.datapipes, self.length) = state
89+
90+
6691

6792
@functional_datapipe('fork')
6893
class ForkerIterDataPipe(IterDataPipe):
@@ -267,6 +292,13 @@ def _check_valid_iterator_id(self, iterator_id) -> bool:
267292
"""
268293
return iterator_id == self._valid_iterator_id and iterator_id == self.main_datapipe._valid_iterator_id
269294

295+
def save_snapshot(self):
296+
# TODO: Save buffer depending on self.main_datapipe type
297+
pass
298+
299+
def restore_snapshot(self, snapshot=None):
300+
pass
301+
270302

271303
@functional_datapipe('demux')
272304
class DemultiplexerIterDataPipe(IterDataPipe):
@@ -499,6 +531,12 @@ def __setstate__(self, state):
499531
def __del__(self):
500532
self.buffer.clear()
501533

534+
def save_snapshot(self):
535+
pass
536+
537+
def restore_snapshot(self, snapshot=None):
538+
pass
539+
502540

503541
@functional_datapipe('zip')
504542
class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]):
@@ -550,3 +588,18 @@ def __len__(self) -> int:
550588
else:
551589
self.length = -1
552590
return len(self)
591+
592+
def save_snapshot(self):
593+
pass
594+
595+
def restore_snapshot(self, snapshot=None):
596+
pass
597+
598+
# def __getstate__(self):
599+
# if IterDataPipe.getstate_hook is not None:
600+
# return IterDataPipe.getstate_hook(self)
601+
# state = (self.datapipes, self.length)
602+
# return state
603+
#
604+
# def __setstate__(self, state):
605+
# (self.datapipes, self.length) = state

torch/utils/data/datapipes/iter/fileopener.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def __len__(self):
7171
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
7272
return self.length
7373

74+
def save_snapshot(self):
75+
pass # Do nothing if previous DataPipe is properly restored
76+
77+
def restore_snapshot(self, snapshot=None):
78+
pass
79+
7480

7581
class FileLoaderIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
7682

torch/utils/data/datapipes/iter/grouping.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def __len__(self):
4848
(1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
4949
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
5050

51+
def save_snapshot(self):
52+
pass # Do nothing if previous DataPipe is properly restored
53+
54+
def restore_snapshot(self, snapshot=None):
55+
pass
56+
5157

5258
@functional_datapipe('batch')
5359
class BatcherIterDataPipe(IterDataPipe[DataChunk]):
@@ -111,6 +117,14 @@ def __len__(self) -> int:
111117
return self.length
112118
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
113119

120+
def save_snapshot(self):
121+
# TODO: We may want to potentially change `batch` to `current_batch` and save it
122+
# in case some error occurs while __iter__ is midway through?
123+
pass # Do nothing if previous DataPipe is properly restored
124+
125+
def restore_snapshot(self, snapshot=None):
126+
pass
127+
114128

115129
@functional_datapipe('unbatch')
116130
class UnBatcherIterDataPipe(IterDataPipe):
@@ -165,6 +179,14 @@ def _dive(self, element, unbatch_level):
165179
else:
166180
raise IndexError(f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe")
167181

182+
def save_snapshot(self):
183+
# TODO: Poor man's snapshotting
184+
# Need to know what the last known element is in __iter__ and how many have yielded since
185+
pass
186+
187+
def restore_snapshot(self, snapshot=None):
188+
pass
189+
168190

169191
@functional_datapipe('groupby')
170192
class GrouperIterDataPipe(IterDataPipe[DataChunk]):

torch/utils/data/datapipes/iter/routeddecoder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,15 @@ def __iter__(self) -> Iterator[Tuple[str, Any]]:
5757
for data in self.datapipe:
5858
pathname = data[0]
5959
result = self.decoder(data)
60-
yield (pathname, result[pathname])
60+
yield pathname, result[pathname]
6161

6262
def __len__(self) -> int:
6363
if isinstance(self.datapipe, Sized):
6464
return len(self.datapipe)
6565
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
66+
67+
def save_snapshot(self):
68+
pass
69+
70+
def restore_snapshot(self, snapshot=None):
71+
pass

torch/utils/data/datapipes/iter/streamreader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,11 @@ def __iter__(self):
3434
if not d:
3535
stream.close()
3636
break
37-
yield (furl, d)
37+
yield furl, d
38+
39+
def save_snapshot(self):
40+
# TODO: Remember last stream url and stream, and how many chunks you have read so far
41+
pass
42+
43+
def restore_snapshot(self, snapshot=None):
44+
pass

torch/utils/data/datapipes/utils/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def get_file_pathnames_from_root(
8686
non_deterministic: bool = False) -> Iterable[str]:
8787

8888
# print out an error message and raise the error out
89-
def onerror(err : OSError):
89+
def onerror(err: OSError):
9090
warnings.warn(err.filename + " : " + err.strerror)
9191
raise err
9292

torch/utils/data/sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
T_co = TypeVar('T_co', covariant=True)
1616

1717

18+
# TODO: Reimplement all Samplers as DataPipes
1819
class Sampler(Generic[T_co]):
1920
r"""Base class for all Samplers.
2021

0 commit comments

Comments
 (0)
0