@@ -706,16 +706,6 @@ def triton_compute_type(dtype: torch.dtype) -> str:
706
706
return triton_type (upcast_compute_type (dtype ))
707
707
708
708
709
- def _get_primitive_bitwidth (dtype : torch .dtype ) -> int :
710
- """Number of bits of triton_compute_type()"""
711
- dtype = upcast_compute_type (dtype )
712
- itemsize = getattr (dtype , "itemsize" , None )
713
- if itemsize :
714
- return itemsize * 8
715
- else :
716
- return - 1
717
-
718
-
719
709
def triton_store_type (dtype : torch .dtype ) -> str :
720
710
"""Convert torch.dtype to triton type, with fix for storing tl.bool"""
721
711
if dtype == torch .bool :
@@ -887,30 +877,20 @@ def _get_min_elements_per_thread(
887
877
888
878
@staticmethod
889
879
def to_dtype_bitcast (x , dtype : torch .dtype , src_dtype : torch .dtype ):
890
- triton_dtype = triton_compute_type ( dtype )
880
+ assert src_dtype . itemsize == dtype . itemsize
891
881
# We may promote float16 or bfloat16 to float32 and cause the
892
882
# bitwidth of dtype to be different from the input tensor (i.e. float32).
893
883
# In such as case, we will have to convert the input tensor to
894
884
# its src_type, perform bitcast, and then convert the bit-casted
895
885
# tensor back to float to ensure we use values with the right precision.
896
- if (
897
- src_dtype in (torch .float16 , torch .bfloat16 )
898
- and config .triton .codegen_upcast_to_fp32
899
- ):
900
- triton_src_dtype = str (src_dtype ).split ("." )[- 1 ]
901
- cast_x = f"{ x } .to(tl.{ triton_src_dtype } )"
902
- if dtype in (torch .float16 , torch .bfloat16 ):
903
- triton_type_name = str (dtype ).split ("." )[- 1 ]
904
- triton_dtype = f"tl.{ triton_type_name } "
905
- cast_x = f"{ cast_x } .to({ triton_dtype } , bitcast=True)"
906
- if dtype in (torch .float16 , torch .bfloat16 ):
907
- return f"{ cast_x } .to(tl.float32)"
908
- return cast_x
909
- else :
910
- src_dtype_bitwidth = _get_primitive_bitwidth (src_dtype )
911
- target_dtype_bitwidth = _get_primitive_bitwidth (dtype )
912
- bitcast = "True" if src_dtype_bitwidth == target_dtype_bitwidth else "False"
913
- return f"{ x } .to({ triton_dtype } , bitcast={ bitcast } )"
886
+ if x .dtype != src_dtype :
887
+ x = f"{ x } .to({ triton_type (src_dtype )} )"
888
+
889
+ out = f"{ x } .to({ triton_type (dtype )} , bitcast=True)"
890
+ if upcast_compute_type (dtype ) != dtype :
891
+ out = f"{ out } .to({ triton_type (upcast_compute_type (dtype ))} )"
892
+
893
+ return out
914
894
915
895
@staticmethod
916
896
def _shaped_constant (value , dtype , shape ):
0 commit comments