|
1 | 1 | from collections import defaultdict
|
2 | 2 |
|
3 | 3 | from torch.utils.data import IterDataPipe, functional_datapipe, DataChunk
|
| 4 | +from torch.utils.data.datapipes.utils.common import DILL_AVAILABLE, check_lambda_fn |
4 | 5 | from typing import Any, Callable, DefaultDict, Iterator, List, Optional, Sized, TypeVar
|
5 | 6 |
|
| 7 | +if DILL_AVAILABLE: |
| 8 | + import dill |
| 9 | + dill.extend(use_dill=False) |
| 10 | + |
6 | 11 | T_co = TypeVar('T_co', covariant=True)
|
7 | 12 |
|
8 | 13 |
|
@@ -157,6 +162,7 @@ def __init__(self,
|
157 | 162 | group_size: Optional[int] = None,
|
158 | 163 | guaranteed_group_size: Optional[int] = None,
|
159 | 164 | drop_remaining: bool = False):
|
| 165 | + check_lambda_fn(group_key_fn) |
160 | 166 | self.datapipe = datapipe
|
161 | 167 | self.group_key_fn = group_key_fn
|
162 | 168 | self.buffer_size = buffer_size
|
@@ -214,3 +220,36 @@ def __iter__(self):
|
214 | 220 | res = buffer_elements.pop(key)
|
215 | 221 | buffer_size -= len(res)
|
216 | 222 | yield self.wrapper_class(res)
|
| 223 | + |
| 224 | + def __getstate__(self): |
| 225 | + if IterDataPipe.getstate_hook is not None: |
| 226 | + return IterDataPipe.getstate_hook(self) |
| 227 | + |
| 228 | + if DILL_AVAILABLE: |
| 229 | + dill_function = dill.dumps(self.group_key_fn) |
| 230 | + else: |
| 231 | + dill_function = self.group_key_fn |
| 232 | + state = ( |
| 233 | + self.datapipe, |
| 234 | + dill_function, |
| 235 | + self.buffer_size, |
| 236 | + self.group_size, |
| 237 | + self.guaranteed_group_size, |
| 238 | + self.drop_remaining, |
| 239 | + ) |
| 240 | + return state |
| 241 | + |
| 242 | + def __setstate__(self, state): |
| 243 | + ( |
| 244 | + self.datapipe, |
| 245 | + dill_function, |
| 246 | + self.buffer_size, |
| 247 | + self.group_size, |
| 248 | + self.guaranteed_group_size, |
| 249 | + self.drop_remaining, |
| 250 | + ) = state |
| 251 | + if DILL_AVAILABLE: |
| 252 | + self.group_key_fn = dill.loads(dill_function) # type: ignore[assignment] |
| 253 | + else: |
| 254 | + self.group_key_fn = dill_function # type: ignore[assignment] |
| 255 | + self.wrapper_class = DataChunk |
0 commit comments