8000 Fix DTensor handling of conjugate bit. · pytorch/pytorch@89cd4c9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 89cd4c9

Browse files
committed
Fix DTensor handling of conjugate bit.
Fixes #130646 specifically for DTensor Fixes pytorch/torchtitan#267 Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 5b5124e Pull-Request: #158030
1 parent aab949a commit 89cd4c9

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torch/distributed/tensor/_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,13 @@ def __new__(
269269
# new method instruct wrapper tensor from local_tensor and add
270270
# placement spec, it does not do actual distribution
271271
assert spec.tensor_meta is not None, "TensorMeta should not be None!"
272+
extra_dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(0)
273+
if torch._C._dispatch_keys(local_tensor).has(torch._C.DispatchKey.Conjugate):
274+
extra_dispatch_keys = extra_dispatch_keys.add(
275+
torch._C.DispatchKey.Conjugate
276+
)
277+
if torch._C._dispatch_keys(local_tensor).has(torch._C.DispatchKey.Negative):
278+
extra_dispatch_keys = extra_dispatch_keys.add(torch._C.DispatchKey.Negative)
272279
r = torch.Tensor._make_wrapper_subclass(
273280
cls,
274281
spec.tensor_meta.shape,
@@ -277,6 +284,7 @@ def __new__(
277284
device=local_tensor.device,
278285
layout=local_tensor.layout,
279286
requires_grad=requires_grad,
287+
_extra_dispatch_keys=extra_dispatch_keys,
280288
)
281289

282290
r._spec = spec

0 commit comments

Comments
 (0)
0