@@ -759,6 +759,58 @@ def test_cumsum(self):
759759 self .assertTrue (output_dtensor .placements [0 ].is_shard (shard_dim ))
760760 self .assertEqual (output_dtensor .full_tensor (), output )
761761
762+ @with_comms
763+ def test_conj_complex_dtensor (self ):
764+ mesh = self .build_device_mesh ()
765+ comm_mode = CommDebugMode ()
766+
767+ freqs_cis = torch .randn (
768+ 1 , 1 , dtype = torch .complex64 , requires_grad = False , device = self .device_type
769+ )
770+ freqs_cis_dt = distribute_tensor (
771+ freqs_cis , device_mesh = mesh , placements = [Replicate ()]
772+ )
773+
774+ local_result = freqs_cis .conj () + 1
775+ with comm_mode :
776+ dtensor_result = freqs_cis_dt .conj () + 1
777+ self .assertEqual (comm_mode .get_total_counts (), 0 )
778+
779+ self .assertEqual (local_result , dtensor_result .full_tensor ())
780+
781+ @with_comms
782+ def test_rotary_embedding_complex_ops (self ):
783+ mesh = self .build_device_mesh ()
784+ comm_mode = CommDebugMode ()
785+
786+ def apply_rotary_emb (xq , freqs_cis ):
787+ xq_ = torch .view_as_complex (xq )
788+ xq_out = torch .view_as_real (xq_ * freqs_cis )
<
8000
/code>789+ return xq_out
790+
791+ xq = torch .randn (1 , 1 , 2 , requires_grad = True , device = self .device_type )
792+ freqs_cis = torch .randn (
793+ 1 , 1 , dtype = torch .complex64 , requires_grad = False , device = self .device_type
794+ )
795+
796+ xq_dt = distribute_tensor (xq , device_mesh = mesh , placements = [Replicate ()])
797+ freqs_cis_dt = distribute_tensor (
798+ freqs_cis , device_mesh = mesh , placements = [Replicate ()]
799+ )
800+
801+ with comm_mode :
802+ xq_out_dt = apply_rotary_emb (xq_dt , freqs_cis_dt )
803+ xq_out_dt .sum ().backward ()
804+ self .assertEqual (comm_mode .get_total_counts (), 0 )
805+
806+ dtensor_grad = xq_dt .grad .full_tensor ()
807+
808+ xq .grad = None
809+ xq_out = apply_rotary_emb (xq , freqs_cis )
810+ xq_out .sum ().backward ()
811+
812+ self .assertEqual (dtensor_grad , xq .grad )
813+
762814
763815if __name__ == "__main__" :
764816 run_tests ()
0 commit comments