@@ -2602,7 +2602,6 @@ def is_aligned(x):
2602
2602
# WIP
2603
2603
make_fallback (aten ._adaptive_avg_pool3d ) # @isuruf
2604
2604
make_fallback (aten .adaptive_max_pool3d ) # @isuruf
2605
- make_fallback (aten .fractional_max_pool3d ) # @isuruf
2606
2605
make_fallback (aten ._scaled_dot_product_attention_math_for_mps ) # @malfet
2607
2606
2608
2607
@@ -4475,34 +4474,19 @@ def _low_memory_max_pool_with_offsets(
4475
4474
return result , to_dtype (offsets , torch .int8 )
4476
4475
4477
4476
4478
- @register_lowering (
4479
- prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4480
- )
4481
- def _low_memory_max_pool_offsets_to_indices (
4482
- offsets , kernel_size , input_size , stride , padding , dilation
4483
- ):
4484
- # TODO: Generalize to other max pooling flavors
4477
+ def _pool_offsets_to_indices (offsets , kernel_size , input_size , increments_to_index ):
4485
4478
n_dim = len (kernel_size )
4486
4479
offsets_loader = offsets .make_loader ()
4487
-
4488
- def increments_to_index (dhw_inc , bh ):
4489
- w_in = [ops .index_expr (input_size [d ], torch .int64 ) for d in range (n_dim )]
4490
- hbase = [
4491
- ops .index_expr (bh [d ] * stride [d ] - padding [d ], torch .int64 )
4492
- for d in range (n_dim )
4493
- ]
4494
- idhw = [
4495
- hbase [d ] + dhw_inc [d ] * ops .constant (dilation [d ], torch .int64 )
4496
- for d in range (n_dim )
4497
- ]
4498
- return inductor_prims ._flatten_index (idhw , w_in )
4480
+ window_size = sympy .sympify (functools .reduce (operator .mul , kernel_size ))
4499
4481
4500
4482
def offsets_to_indices (idx ):
4501
- bh = idx [- n_dim :]
4502
4483
offset = offsets_loader (idx )
4503
- k_const = [ops .constant (kernel_size [d ], torch .int32 ) for d in range (n_dim )]
4504
- dhw_inc = inductor_prims ._flattened_index_to_nd (offset , k_const )
4505
- return increments_to_index (dhw_inc , bh )
4484
+ offset_sympy = ops .indirect_indexing (offset , window_size )
4485
+ reduction_idx = inductor_prims ._flattened_index_to_nd (offset_sympy , kernel_size )
4486
+ idhw = increments_to_index (idx , reduction_idx )
4487
+ return ops .index_expr (
4488
+ inductor_prims ._flatten_index (idhw , input_size [- n_dim :]), torch .int64
4489
+ )
4506
4490
4507
4491
indices = Pointwise .create (
4508
4492
device = offsets .get_device (),
@@ -4513,6 +4497,27 @@ def offsets_to_indices(idx):
4513
4497
return indices
4514
4498
4515
4499
4500
+ @register_lowering (
4501
+ prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4502
+ )
4503
+ def _low_memory_max_pool_offsets_to_indices (
4504
+ offsets , kernel_size , input_size , stride , padding , dilation
4505
+ ):
4506
+ # TODO: Generalize to other max pooling flavors
4507
+ n_dim = len (kernel_size )
4508
+
4509
+ def increments_to_index (idx , reduction_idx ):
4510
+ bh = idx [- n_dim :]
4511
+ return [
4512
+ (bh [i ] * stride [i ]) + (reduction_idx [i ] * dilation [i ]) - padding [i ]
4513
+ for i in range (n_dim )
4514
+ ]
4515
+
4516
+ return _pool_offsets_to_indices (
4517
+ offsets , kernel_size , input_size , increments_to_index
4518
+ )
4519
+
4520
+
4516
4521
def _max_pool_with_indices (
4517
4522
x ,
4518
4523
kernel_size ,
@@ -5016,11 +5021,6 @@ def inner_fn_max_idx(idx):
5016
5021
return rv , ri
5017
5022
5018
5023
5019
- fallback_fractional_max_pool2d = fallback_handler (
5020
- aten .fractional_max_pool2d .default , add_to_fallback_set = False
5021
- )
5022
-
5023
-
5024
5024
def _fractional_pooling_offsets (samples , in_sz , out_sz , kernel_sz , dim , ndims ):
5025
5025
out_sz = out_sz [dim ]
5026
5026
in_sz = in_sz [dim ]
@@ -5039,80 +5039,85 @@ def load(prefix, i):
5039
5039
seq_i = ops .trunc ((i_expr + sample ) * alpha ) - ops .trunc (sample * alpha )
5040
5040
seq_i = ops .to_dtype (seq_i , torch .int64 )
5041
5041
mask = ops .lt (i_expr , out_sz_expr )
5042
- return ops .where (mask , seq_i , diff )
5042
+ return ops .indirect_indexing ( ops . where (mask , seq_i , diff ), sympy . sympify ( in_sz ) )
5043
5043
5044
5044
return load
5045
5045
5046
5046
5047
5047
@register_lowering (aten .fractional_max_pool2d )
5048
5048
def fractional_max_pool2d (x , kernel_size , output_size , random_samples ):
5049
- x .realize_hint ()
5050
- * batch , inp_h , inp_w = x .get_size ()
5051
- kernel_h , kernel_w = kernel_size
5052
- h_out , w_out = output_size
5049
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim = 2 )
5053
5050
5054
- if kernel_h * kernel_w >= 25 :
5055
- return fallback_fractional_max_pool2d (
5056
- x , kernel_size , output_size , random_samples
5057
- )
5058
5051
5059
- gen_offsets_for_dim = functools .partial (
5060
- _fractional_pooling_offsets ,
5061
- samples = random_samples ,
5062
- in_sz = [inp_h , inp_w ],
5063
- out_sz = output_size ,
5064
- kernel_sz = kernel_size ,
5065
- ndims = 2 ,
5066
- )
5052
+ @register_lowering (aten .fractional_max_pool3d )
5053
+ def fractional_max_pool3d (x , kernel_size , output_size , random_samples ):
5054
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim = 3 )
5067
5055
5068
- h_index_fn = gen_offsets_for_dim (dim = 0 )
5069
- w_index_fn = gen_offsets_for_dim (dim = 1 )
5070
- x_loader = x .make_loader ()
5071
5056
5072
- def fn (idx , return_index ):
5073
- * prefix , bh , bw = idx
5057
+ def _fractional_max_pool (x , kernel_size , output_size , random_samples , n_dim ):
5058
+ x .realize_hint ()
5059
+ batch , inp_dhw = x .shape [:- n_dim ], x .shape [- n_dim :]
5074
5060
5075
- h_start_index = ops .indirect_indexing (h_index_fn (prefix , bh ), inp_h )
5076
- w_start_index = ops .indirect_indexing (w_index_fn (prefix , bw ), inp_w )
5061
+ with config .patch (unroll_reductions_threshold = 25 ):
5062
+ dhw_index_fn = [
5063
+ _fractional_pooling_offsets (
5064
+ samples = random_samples ,
5065
+ in_sz = inp_dhw ,
5066
+ out_sz = output_size ,
5067
+ kernel_sz = kernel_size ,
5068
+ ndims = n_dim ,
5069
+ dim = d ,
5070
+ )
5071
+ for d in range (n_dim )
5072
+ ]
5077
5073
5078
- maxval = None
5079
- maxindex = None
5080
- for ih , iw in itertools .product (range (kernel_size [0 ]), range (kernel_size [1 ])):
5081
- val = x_loader ([* prefix , h_start_index + ih , w_start_index + iw ])
5082
- if return_index :
5083
- index = ops .index_expr (
5084
- (h_start_index + ih ) * inp_w + w_start_index + iw , torch .int64
5085
- )
5086
- if maxindex is None :
5087
- maxindex = index
5088
- else :
5089
- maxindex = ops .where (
5090
- ops .or_ (ops .gt (val , maxval ), ops .isnan (val )), index , maxindex
5091
- )
5092
- if maxval is None :
5093
- maxval = val
5094
- else :
5095
- maxval = ops .maximum (val , maxval )
5096
- if return_index :
5097
- return maxindex
5098
- else :
5099
- return maxval
5074
+ x_loader = x .make_loader ()
5100
5075
5101
- new_size = list (batch ) + [h_out , w_out ]
5102
- rv = Pointwise .create (
5103
- device = x .get_device (),
5104
- dtype = x .get_dtype (),
5105
- inner_fn = functools .partial (fn , return_index = False ),
5106
- ranges = new_size ,
5107
- )
5076
+ def fn_inner (idx , reduction_idx ):
5077
+ prefix = idx [:- n_dim ]
5078
+ return x_loader ([* prefix , * increments_to_index (idx , reduction_idx )])
5108
5079
5109
- ri = Pointwise .create (
5110
- device = x .get_device (),
5111
- dtype = torch .int64 ,
5112
- inner_fn = functools .partial (fn , return_index = True ),
5113
- ranges = new_size ,
5114
- )
5115
- return rv , ri
5080
+ def increments_to_index (idx , reduction_idx ):
5081
+ prefix = idx [:- n_dim ]
5082
+ bdhw = idx [- n_dim :]
5083
+ return [
5084
+ dhw_index_fn [d ](prefix , bdhw [d ]) + reduction_idx [d ]
5085
+ for d in range (n_dim )
508
10000
6
+ ]
5087
+
5088
+ new_size = list (batch ) + list (output_size )
5089
+ dtype = x .get_dtype ()
5090
+ result = Reduction .create (
5091
+ reduction_type = "max" ,
5092
+ input_node = x ,
5093
+ device = x .get_device (),
5094
+ dst_dtype = dtype ,
5095
+ src_dtype = dtype ,
5096
+ inner_fn = fn_inner ,
5097
+ ranges = new_size ,
5098
+ reduction_ranges = kernel_size ,
5099
+ )
5100
+ offsets = Reduction .create (
5101
+ reduction_type = "argmax" ,
5102
+ input_node = x ,
5103
+ device = x .get_device (),
5104
+ dst_dtype = torch .int64 ,
5105
+ src_dtype = dtype ,
5106
+ inner_fn = fn_inner ,
5107
+ ranges = new_size ,
5108
+ reduction_ranges = kernel_size ,
5109
+ )
5110
+ if isinstance (result .data .data , Reduction ): # type: ignore[attr-defined]
5111
+ # Only realize if reduction isn't unrolled
5112
+ result .realize ()
5113
+ if isinstance (offsets .data .data , Reduction ): # type: ignore[attr-defined]
5114
+ # Only realize if reduction isn't unrolled
5115
+ offsets .realize ()
5116
+
5117
+ indices = _pool_offsets_to_indices (
5118
+ offsets , kernel_size , x .shape , increments_to_index
5119
+ )
5120
+ return result , indices
5116
5121
5117
5122
5118
5123
@register_lowering (aten .upsample_nearest2d_backward .default )
0 commit comments