8000 [DataPipe] Make GroupBy serializable with lambda function · pytorch/pytorch@f118532 · GitHub
[go: up one dir, main page]

Skip to content

Commit f118532

Browse files
committed
[DataPipe] Make GroupBy serializable with lambda function
[ghstack-poisoned]
1 parent 7d28578 commit f118532

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

test/test_datapipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def test_serializable_with_dill(self):
468468
(dp.iter.Collator, (lambda x: x,), {}),
469469
(dp.iter.Demultiplexer, (2, lambda x: x % 2,), {}),
470470
(dp.iter.Filter, (lambda x: x >= 5,), {}),
471-
# (dp.iter.Grouper, (lambda x: x >= 5,), {}), # TODO: Need custom __getstate__ for Grouper
471+
(dp.iter.Grouper, (lambda x: x >= 5,), {}),
472472
(dp.iter.Mapper, (lambda x: x, ), {}),
473473
]
474474
if HAS_DILL:

torch/utils/data/datapipes/iter/grouping.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from collections import defaultdict
22

33
from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
4+
from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn
45
from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
56

7+
if DILL_AVAILABLE:
8+
import dill
9+
dill.extend(use_dill=False)
10+
611
T_co = TypeVar('T_co', covariant=True)
712

813

@@ -157,6 +162,7 @@ def __init__(self,
157162
group_size: Optional[int] = None,
158163
guaranteed_group_size: Optional[int] = None,
159164
drop_remaining: bool = False):
165+
check_lambda_fn(group_key_fn)
160166
self.datapipe = datapipe
161167
self.group_key_fn = group_key_fn
162168
self.buffer_size = buffer_size
@@ -214,3 +220,36 @@ def __iter__(self):
214220
res = buffer_elements.pop(key)
215221
buffer_size -= len(res)
216222
yield self.wrapper_class(res)
223+
224+
def __getstate__(self):
225+
if IterDataPipe.getstate_hook is not None:
226+
return IterDataPipe.getstate_hook(self)
227+
228+
if DILL_AVAILABLE:
229+
dill_function = dill.dumps(self.group_key_fn)
230+
else:
231+
dill_function = self.group_key_fn
232+
state = (
233+
self.datapipe,
234+
dill_function,
235+
self.buffer_size,
236+
self.group_size,
237+
self.guaranteed_group_size,
238+
self.drop_remaining,
239+
)
240+
return state
241+
242+
def __setstate__(self, state):
243+
(
244+
self.datapipe,
245+
dill_function,
246+
self.buffer_size,
247+
self.group_size,
248+
self.guaranteed_group_size,
249+
self.drop_remaining,
250+
) = state
251+
if DILL_AVAILABLE:
252+
self.group_key_fn = dill.loads(dill_function) # type: ignore[assignment]
253+
else:
254+
self.group_key_fn = dill_function # type: ignore[assignment]
255+
self.wrapper_class = DataChunk

0 commit comments

Comments
 (0)
0