@@ -1888,15 +1888,10 @@ def _warn_complex_not_supported():
18881888
18891889# There are some types (CPU) which we accept as input but not as
18901890# output.
1891- def unsupported_input_tensor (t : torch .Tensor , parent = None , node = None ):
1891+ def unsupported_input_tensor (t : torch .Tensor , node = None ):
18921892 "Do not support reading or writing to this tensor"
18931893 if t .is_complex ():
18941894 # Complex views are supported with IR ComplexView
1895- if parent and parent .target in (
1896- torch .ops .aten .view .dtype ,
1897- torch .ops .prims .convert_element_type .default ,
1898- ):
1899- return False
19001895 _warn_complex_not_supported ()
19011896 return True
19021897
@@ -1910,11 +1905,12 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
19101905 # allow bitcast, views, memory movement, but not arithmetic
19111906 # TODO: delete once triton adds native support
19121907 return not (
1913- isinstance (parent .target , torch ._ops .OpOverload )
1914- and parent .target
1908+ isinstance (node .target , torch ._ops .OpOverload )
1909+ and node .target
19151910 in (
19161911 aten .view .dtype ,
19171912 aten .cat .default ,
1913+ aten .clone .default ,
19181914 aten ._scaled_mm .default ,
19191915 )
19201916 or (isinstance (node .target , torch ._ops .OpOverload ) and is_view (node .target ))
@@ -1923,9 +1919,15 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
19231919 return False
19241920
19251921
1926- def unsupported_output_tensor (t : torch .Tensor , parent = None , node = None ):
1922+ def unsupported_output_tensor (t : torch .Tensor , node = None ):
19271923 "Do not support writing tensor but can read from it"
1928- if unsupported_input_tensor (t , parent ):
1924+ supported_complex_views = (
1925+ aten .view .dtype ,
1926+ torch .ops .prims .convert_element_type .default ,
1927+ )
1928+ if node is not None and node .target in supported_complex_views and t .is_complex ():
1929+ return False
1930+ if unsupported_input_tensor (t , node ):
19291931 return True
19301932 return t .is_cpu and config .disable_cpp_codegen
19311933
@@ -1935,36 +1937,39 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
19351937 if node .target is aten .view_as_complex .default :
19361938 return False
19371939
1940+ if node .op == "placeholder" :
1941+ return False
1942+
19381943 # We should be able to remove this special case once `disable_cpp_codegen` is killed.
19391944 if node .target is aten .lift_fresh_copy .default :
19401945 return False
19411946
1942- def check_skip_condition (node , parent , is_output ):
1943- if not isinstance (node , torch .fx .Node ):
1947+ def check_skip_condition (inp_out_node , is_output ):
1948+ if not isinstance (inp_out_node , torch .fx .Node ):
19441949 return False
19451950
1946- if "val" not in node .meta :
1951+ if "val" not in inp_out_node .meta :
19471952 return False
19481953
1949- for meta in pytree .tree_leaves (node .meta ["val" ]):
1954+ for meta in pytree .tree_leaves (inp_out_node .meta ["val" ]):
19501955 if not isinstance (meta , torch ._subclasses .FakeTensor ):
19511956 continue
19521957
19531958 if is_output :
1954- if unsupported_output_tensor (meta , parent , node ):
1959+ if unsupported_output_tensor (meta , node ):
19551960 return True
19561961 else :
1957- if unsupported_input_tensor (meta , parent , node ):
1962+ if unsupported_input_tensor (meta , node ):
19581963 return True
19591964
19601965 return False
19611966
19621967 # only skip codegen if there is a cpu output, not input
19631968 for arg in pytree .arg_tree_leaves (* node .args , ** node .kwargs ):
1964- if check_skip_condition (arg , node , is_output = False ):
1969+ if check_skip_condition (arg , is_output = False ):
19651970 return True
19661971
1967- return check_skip_condition (node , node , is_output = True )
1972+ return check_skip_condition (node , is_output = True )
19681973
19691974
19701975def make_fallback (op , layout_constraint = None , warn = True , override_decomp = False ):
0 commit comments