8000 Tests for #158030 · pytorch/pytorch@b4c684b · GitHub
[go: up one dir, main page]

Skip to content

Commit b4c684b

Browse files
committed
Tests for #158030
Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: ab9957e Pull-Request: #158033
1 parent 89cd4c9 commit b4c684b

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

test/distributed/tensor/test_math_ops.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,58 @@ def test_cumsum(self):
759759
self.assertTrue(output_dtensor.placements[0].is_shard(shard_dim))
760760
self.assertEqual(output_dtensor.full_tensor(), output)
761761

762+
@with_comms
763+
def test_conj_complex_dtensor(self):
764+
mesh = self.build_device_mesh()
765+
comm_mode = CommDebugMode()
766+
767+
freqs_cis = torch.randn(
768+
1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type
769+
)
770+
freqs_cis_dt = distribute_tensor(
771+
freqs_cis, device_mesh=mesh, placements=[Replicate()]
772+
)
773+
774+
local_result = freqs_cis.conj() + 1
775+
with comm_mode:
776+
dtensor_result = freqs_cis_dt.conj() + 1
777+
self.assertEqual(comm_mode.get_total_counts(), 0)
778+
779+
self.assertEqual(local_result, dtensor_result.full_tensor())
780+
781+
@with_comms
782+
def test_rotary_embedding_complex_ops(self):
783+
mesh = self.build_device_mesh()
784+
comm_mode = CommDebugMode()
785+
786+
def apply_rotary_emb(xq, freqs_cis):
787+
xq_ = torch.view_as_complex(xq)
788+
xq_out = torch.view_as_real(xq_ * freqs_cis)
< 8000 /code>
789+
return xq_out
790+
791+
xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
792+
freqs_cis = torch.randn(
793+
1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type
794+
)
795+
796+
xq_dt = distribute_tensor(xq, device_mesh=mesh, placements=[Replicate()])
797+
freqs_cis_dt = distribute_tensor(
798+
freqs_cis, device_mesh=mesh, placements=[Replicate()]
799+
)
800+
801+
with comm_mode:
802+
xq_out_dt = apply_rotary_emb(xq_dt, freqs_cis_dt)
803+
xq_out_dt.sum().backward()
804+
self.assertEqual(comm_mode.get_total_counts(), 0)
805+
806+
dtensor_grad = xq_dt.grad.full_tensor()
807+
808+
xq.grad = None
809+
xq_out = apply_rotary_emb(xq, freqs_cis)
810+
xq_out.sum().backward()
811+
812+
self.assertEqual(dtensor_grad, xq.grad)
813+
762814

763815
if __name__ == "__main__":
764816
run_tests()

0 commit comments

Comments
 (0)
0