-
Notifications
You must be signed in to change notification settings - Fork 166
Description
🚀 The feature
Feature request: expose a pin_memory_map
parameter in the PinMemory node, which defaults to the current choice (pin_memory from torch):
from torch.utils.data._utils.pin_memory import pin_memory
class PinMemory(BaseNode[T]):
def __init__(
self,
source: BaseNode[T],
pin_memory_device: str = "",
snapshot_frequency: int = 1,
pin_memory_map: Callable[[T, DeviceType | None], T] = pin_memory
): ...
The same parameter and default needs to appear in _pin_memory_loop
and would override the function called here:
data/torchdata/nodes/pin_memory.py
Line 81 in dbf04a9
item = pin_memory(item, device) |
Motivation, pitch
The current pytorch implementation of pin_memory only partially allows custom objects to implement a 'pin memory interface':
https://github.com/pytorch/pytorch/blob/50d4698ac8c12ad8399773aa157d25316c7c345e/torch/utils/data/_utils/pin_memory.py#L108
Note that the device is not being passed when pin_memory
is called on the object. This would allow objects to implement their own def pin_memory(self, device: torch.device | None = None) -> None
which is then used by the PinMemory node. So one could pass a map e.g.:
@runtime_checkable
class SupportsPinMemory(Protocol):
def pin_memory(self, device: torch.device | None = None) -> Self: ...
def pin_memory_custom(data: Any, device: torch.device | None) -> Any:
if isinstance(data, SupportsPinMemory):
return data.pin_memory(device=device)
# Otherwise default to pytorch pin memory
return pin_memory(data, device)
node = PinMemory(source=other_node, pin_memory_map=pin_memory_custom)
Alternatives
Of course this can be done now with a custom Mapper, but my understanding is the reimplmented _pin_memory_loop
that the PinMemory node uses plays nicely with the rest of the nodes in a pipeline without consuming all CPU cores.
Additional context
No response