@@ -2600,7 +2600,6 @@ def is_aligned(x):
2600
2600
# WIP
2601
2601
make_fallback (aten ._adaptive_avg_pool3d ) # @isuruf
2602
2602
make_fallback (aten .adaptive_max_pool3d ) # @isuruf
2603
- make_fallback (aten .fractional_max_pool3d ) # @isuruf
2604
2603
make_fallback (aten ._scaled_dot_product_attention_math_for_mps ) # @malfet
2605
2604
2606
2605
@@ -4472,34 +4471,27 @@ def _low_memory_max_pool_with_offsets(
4472
4471
return result , to_dtype (offsets , torch .int8 )
4473
4472
4474
4473
4475
- @register_lowering (
4476
- prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4477
- )
4478
- def _low_memory_max_pool_offsets_to_indices (
4479
- offsets , kernel_size , input_size , stride , padding , dilation
4480
- ):
4481
- # TODO: Generalize to other max pooling flavors
4474
+ def _pool_offsets_to_indices (
4475
+ offsets : TensorBox ,
4476
+ kernel_size : Sequence [Union [int , torch .SymInt ]],
4477
+ input_size : Sequence [Union [int , torch .SymInt ]],
4478
+ increments_to_index : Callable [
4479
+ [Sequence [Union [int , torch .SymInt ]], Sequence [Union [int , torch .SymInt ]]],
4480
+ torch ._inductor .virtualized .OpsValue ,
4481
+ ],
4482
+ ) -> TensorBox :
4482
4483
n_dim = len (kernel_size )
4483
4484
offsets_loader = offsets .make_loader ()
4484
-
4485
- def increments_to_index (dhw_inc , bh ):
4486
- w_in = [ops .index_expr (input_size [d ], torch .int64 ) for d in range (n_dim )]
4487
- hbase = [
4488
- ops .index_expr (bh [d ] * stride [d ] - padding [d ], torch .int64 )
4489
- for d in range (n_dim )
4490
- ]
4491
- idhw = [
4492
- hbase [d ] + dhw_inc [d ] * ops .constant (dilation [d ], torch .int64 )
4493
- for d in range (n_dim )
4494
- ]
4495
- return inductor_prims ._flatten_index (idhw , w_in )
4485
+ window_size = sympy .sympify (functools .reduce (operator .mul , kernel_size ))
4496
4486
4497
4487
def offsets_to_indices (idx ):
4498
- bh = idx [- n_dim :]
4499
4488
offset = offsets_loader (idx )
4500
- k_const = [ops .constant (kernel_size [d ], torch .int32 ) for d in range (n_dim )]
4501
- dhw_inc = inductor_prims ._flattened_index_to_nd (offset , k_const )
4502
- return increments_to_index (dhw_inc , bh )
4489
+ offset_sympy = ops .indirect_indexing (offset , window_size )
4490
+ reduction_idx = inductor_prims ._flattened_index_to_nd (offset_sympy , kernel_size )
4491
+ idhw = increments_to_index (idx , reduction_idx )
4492
+ return ops .index_expr (
4493
+ inductor_prims ._flatten_index (idhw , input_size [- n_dim :]), torch .int64
4494
+ )
4503
4495
4504
4496
indices = Pointwise .create (
4505
4497
device = offsets .get_device (),
@@ -4510,6 +4502,27 @@ def offsets_to_indices(idx):
4510
4502
return indices
4511
4503
4512
4504
4505
+ @register_lowering (
4506
+ prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4507
+ )
4508
+ def _low_memory_max_pool_offsets_to_indices (
4509
+ offsets , kernel_size , input_size , stride , padding , dilation
4510
+ ):
4511
+ # TODO: Generalize to other max pooling flavors
4512
+ n_dim = len (kernel_size )
4513
+
4514
+ def increments_to_index (idx , reduction_idx ):
4515
+ bh = idx [- n_dim :]
4516
+ return [
4517
+ (bh [i ] * stride [i ]) + (reduction_idx [i ] * dilation [i ]) - padding [i ]
4518
+ for i in range (n_dim )
4519
+ ]
4520
+
4521
+ return _pool_offsets_to_indices (
4522
+ offsets , kernel_size , input_size , increments_to_index
4523
+ )
4524
+
4525
+
4513
4526
def _max_pool_with_indices (
4514
4527
x ,
4515
4528
kernel_size ,
@@ -5013,11 +5026,6 @@ def inner_fn_max_idx(idx):
5013
5026
return rv , ri
5014
5027
5015
5028
5016
- fallback_fractional_max_pool2d = fallback_handler (
5017
- aten .fractional_max_pool2d .default , add_to_fallback_set = False
5018
- )
5019
-
5020
-
5021
5029
def _fractional_pooling_offsets (samples , in_sz , out_sz , kernel_sz , dim , ndims ):
5022
5030
out_sz = out_sz [dim ]
5023
5031
in_sz = in_sz [dim ]
@@ -5036,80 +5044,85 @@ def load(prefix, i):
5036
5044
seq_i = ops .trunc ((i_expr + sample ) * alpha ) - ops .trunc (sample * alpha )
5037
5045
seq_i = ops .to_dtype (seq_i , torch .int64 )
5038
5046
mask = ops .lt (i_expr , out_sz_expr )
5039
- return ops .where (mask , seq_i , diff )
5047
+ return ops .indirect_indexing ( ops . where (mask , seq_i , diff ), sympy . sympify ( in_sz ) )
5040
5048
5041
5049
return load
5042
5050
5043
5051
5044
5052
@register_lowering (aten .fractional_max_pool2d )
5045
5053
def fractional_max_pool2d (x , kernel_size , output_size , random_samples ):
5046
- x .realize_hint ()
5047
- * batch , inp_h , inp_w = x .get_size ()
5048
- kernel_h , kernel_w = kernel_size
5049
- h_out , w_out = output_size
5054
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim = 2 )
5050
5055
5051
- if kernel_h * kernel_w >= 25 :
5052
- return fallback_fractional_max_pool2d (
5053
- x , kernel_size , output_size , random_samples
5054
- )
5055
5056
5056
- gen_offsets_for_dim = functools .partial (
5057
- _fractional_pooling_offsets ,
5058
- samples = random_samples ,
5059
- in_sz = [inp_h , inp_w ],
5060
- out_sz = output_size ,
5061
- kernel_sz = kernel_size ,
5062
- ndims = 2 ,
5063
- )
5057
+ @register_lowering (aten .fractional_max_pool3d )
5058
+ def fractional_max_pool3d (x , kernel_size , output_size , random_samples ):
5059
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim = 3 )
5064
5060
5065
- h_index_fn = gen_offsets_for_dim (dim = 0 )
5066
- w_index_fn = gen_offsets_for_dim (dim = 1 )
5067
- x_loader = x .make_loader ()
5068
5061
5069
- def fn (idx , return_index ):
5070
- * prefix , bh , bw = idx
5062
+ def _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim ):
5063
+ x .realize_hint ()
5064
+ batch , inp_dhw = x .shape [:- n_dim ], x .shape [- n_dim :]
5071
5065
5072
- h_start_index = ops .indirect_indexing (h_index_fn (prefix ,
8000
bh ), inp_h )
5073
- w_start_index = ops .indirect_indexing (w_index_fn (prefix , bw ), inp_w )
5066
+ with config .patch (unroll_reductions_threshold = 25 ):
5067
+ dhw_index_fn = [
5068
+ _fractional_pooling_offsets (
5069
+ samples = random_samples ,
5070
+ in_sz = inp_dhw ,
5071
+ out_sz = output_size ,
5072
+ kernel_sz = kernel_size ,
5073
+ ndims = n_dim ,
5074
+ dim = d ,
5075
+ )
5076
+ for d in range (n_dim )
5077
+ ]
5074
5078
5075
- maxval = None
5076
- maxindex = None
5077
- for ih , iw in itertools .product (range (kernel_size [0 ]), range (kernel_size [1 ])):
5078
-
8000
val = x_loader ([* prefix , h_start_index + ih , w_start_index + iw ])
5079
- if return_index :
5080
- index = ops .index_expr (
5081
- (h_start_index + ih ) * inp_w + w_start_index + iw , torch .int64
5082
- )
5083
- if maxindex is None :
5084
- maxindex = index
5085
- else :
5086
- maxindex = ops .where (
5087
- ops .or_ (ops .gt (val , maxval ), ops .isnan (val )), index , maxindex
5088
- )
5089
- if maxval is None :
5090
- maxval = val
5091
- else :
5092
- maxval = ops .maximum (val , maxval )
5093
- if return_index :
5094
- return maxindex
5095
- else :
5096
- return maxval
5079
+ x_loader = x .make_loader ()
5097
5080
5098
- new_size = list (batch ) + [h_out , w_out ]
5099
- rv = Pointwise .create (
5100
- device = x .get_device (),
5101
- dtype = x .get_dtype (),
5102
- inner_fn = functools .partial (fn , return_index = False ),
5103
- ranges = new_size ,
5104
- )
5081
+ def fn_inner (idx , reduction_idx ):
5082
+ prefix = idx [:- n_dim ]
5083
+ return x_loader ([* prefix , * increments_to_index (idx , reduction_idx )])
5105
5084
5106
- ri = Pointwise .create (
5107
- device = x .get_device (),
5108
- dtype = torch .int64 ,
5109
- inner_fn = functools .partial (fn , return_index = True ),
5110
- ranges = new_size ,
5111
- )
5112
- return rv , ri
5085
+ def increments_to_index (idx , reduction_idx ):
5086
+ prefix = idx [:- n_dim ]
5087
+ bdhw = idx [- n_dim :]
5088
+ return [
5089
+ dhw_index_fn [d ](prefix , bdhw [d ]) + reduction_idx [d ]
5090
+ for d in range (n_dim )
5091
+ ]
5092
+
5093
+ new_size = list (batch ) + list (output_size )
5094
+ dtype = x .get_dtype ()
5095
+ result = Reduction .create (
5096
+ reduction_type = "max" ,
5097
+ input_node = x ,
5098
+ device = x .get_device (),
5099
+ dst_dtype = dtype ,
5100
+ src_dtype = dtype ,
5101
+ inner_fn = fn_inner ,
5102
+ ranges = new_size ,
5103
+ reduction_ranges = kernel_size ,
5104
+ )
5105
+ offsets = Reduction .create (
5106
+ reduction_type = "argmax" ,
5107
+ input_node = x ,
5108
+ device = x .get_device (),
5109
+ dst_dtype = torch .int64 ,
5110
+ src_dtype = dtype ,
5111
+ inner_fn = fn_inner ,
5112
+ ranges = new_size ,
5113
+ reduction_ranges = kernel_size ,
5114
+ )
5115
+ if isinstance (result .data .data , Reduction ): # type: ignore[attr-defined]
5116
+ # Only realize if reduction isn't unrolled
5117
+ result .realize ()
5118
+ if isinstance (offsets .data .data , Reduction ): # type: ignore[attr-defined]
5119
+ # Only realize if reduction isn't unrolled
5120
+ offsets .realize ()
5121
+
5122
+ indices = _pool_offsets_to_indices (
5123
+ offsets , kernel_size , x .shape , increments_to_index
5124
+ )
5125
+ return result , indices
5113
5126
5114
5127
5115
5128
@register_lowering (aten .upsample_nearest2d_backward .default )
0 commit comments