@@ -2595,7 +2595,6 @@ def is_aligned(x):
2595
2595
# WIP
2596
2596
make_fallback (aten ._adaptive_avg_pool3d ) # @isuruf
2597
2597
make_fallback (aten .adaptive_max_pool3d ) # @isuruf
2598
- make_fallback (aten .fractional_max_pool3d ) # @isuruf
2599
2598
2600
2599
2601
2600
# 1) Easy
@@ -4472,34 +4471,19 @@ def _low_memory_max_pool_with_offsets(
4472
4471
return result , to_dtype (offsets , torch .int8 )
4473
4472
4474
4473
4475
- @register_lowering (
4476
- prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4477
- )
4478
- def _low_memory_max_pool_offsets_to_indices (
4479
- offsets , kernel_size , input_size , stride , padding , dilation
4480
- ):
4481
- # TODO: Generalize to other max pooling flavors
4474
+ def _pool_offsets_to_indices (offsets , kernel_size , input_size , increments_to_index ):
4482
4475
dim = len (kernel_size )
4483
4476
offsets_loader = offsets .make_loader ()
4484
-
4485
- def increments_to_index (dhw_inc , bh ):
4486
- w_in = [ops .index_expr (input_size [d ], torch .int64 ) for d in range (dim )]
4487
- hbase = [
4488
- ops .index_expr (bh [d ] * stride [d ] - padding [d ], torch .int64 )
4489
- for d in range (dim )
4490
- ]
4491
- idhw = [
4492
- hbase [d ] + dhw_inc [d ] * ops .index_expr (dilation [d ], torch .int64 )
4493
- for d in range (dim )
4494
- ]
4495
- return inductor_prims ._flatten_index (idhw , w_in )
4477
+ window_size = sympy .sympify (functools .reduce (operator .mul , kernel_size ))
4496
4478
4497
4479
def offsets_to_indices (idx ):
4498
- bh = idx [- dim :]
4499
4480
offset = offsets_loader (idx )
4500
- k_const = [ops .constant (kernel_size [d ], torch .int32 ) for d in range (dim )]
4501
- dhw_inc = inductor_prims ._flattened_index_to_nd (offset , k_const )
4502
- return increments_to_index (dhw_inc , bh )
4481
+ offset_sympy = ops .indirect_indexing (offset , window_size )
4482
+ reduction_idx = inductor_prims ._flattened_index_to_nd (offset_sympy , kernel_size )
4483
+ idhw = increments_to_index (idx , reduction_idx )
4484
+ return ops .index_expr (
4485
+ inductor_prims ._flatten_index (idhw , input_size [- dim :]), torch .int64
4486
+ )
4503
4487
4504
4488
indices = Pointwise .create (
4505
4489
device = offsets .get_device (),
@@ -4510,6 +4494,27 @@ def offsets_to_indices(idx):
4510
4494
return indices
4511
4495
4512
4496
4497
+ @register_lowering (
4498
+ prims ._low_memory_max_pool_offsets_to_indices , type_promotion_kind = None
4499
+ )
4500
+ def _low_memory_max_pool_offsets_to_indices (
4501
+ offsets , kernel_size , input_size , stride , padding , dilation
4502
+ ):
4503
+ # TODO: Generalize to other max pooling flavors
4504
+ dim = len (kernel_size )
4505
+
4506
+ def increments_to_index (idx , reduction_idx ):
4507
+ bh = idx [- dim :]
4508
+ return [
4509
+ bh [i ] * stride [i ] + reduction_idx [i ] * dilation [i ] - padding [i ]
4510
+ for i in range (dim )
4511
+ ]
4512
+
4513
+ return _pool_offsets_to_indices (
4514
+ offsets , kernel_size , input_size , increments_to_index
4515
+ )
4516
+
4517
+
4513
4518
def _max_pool_with_indices (
4514
4519
x ,
4515
4520
kernel_size ,
@@ -5013,11 +5018,6 @@ def inner_fn_max_idx(idx):
5013
5018
return rv , ri
5014
5019
5015
5020
5016
- fallback_fractional_max_pool2d = fallback_handler (
5017
- aten .fractional_max_pool2d .default , add_to_fallback_set = False
5018
- )
5019
-
5020
-
5021
5021
def _fractional_pooling_offsets (samples , in_sz , out_sz , kernel_sz , dim , ndims ):
5022
5022
out_sz = out_sz [dim ]
5023
5023
in_sz = in_sz [dim ]
@@ -5038,80 +5038,87 @@ def load(prefix, i):
5038
5038
i_expr ,
5039
5039
ops .index_expr (out_sz - 1 , torch .int64 ),
5040
5040
)
5041
- return ops .where (mask , seq_i , ops .index_expr (in_sz - kernel_sz , torch .int64 ))
5041
+ return ops .indirect_indexing (
5042
+ ops .where (mask , seq_i , ops .index_expr (in_sz - kernel_sz , torch .int64 )),
5043
+ sympy .sympify (in_sz ),
5044
+ )
5042
5045
5043
5046
return load
5044
5047
5045
5048
5046
5049
@register_lowering (aten .fractional_max_pool2d )
5047
5050
def fractional_max_pool2d (x , kernel_size , output_size , random_samples ):
5048
- x .realize_hint ()
5049
- * batch , inp_h , inp_w = x .get_size ()
5050
- kernel_h , kernel_w = kernel_size
5051
- h_out , w_out = output_size
5051
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , dim = 2 )
5052
5052
5053
- if kernel_h * kernel_w >= 25 :
5054
- return fallback_fractional_max_pool2d (
5055
- x , kernel_size , output_size , random_samples
5056
- )
5057
5053
5058
- gen_offsets_for_dim = functools .partial (
5059
- _fractional_pooling_offsets ,
5060
- samples = random_samples ,
5061
- in_sz = [inp_h , inp_w ],
5062
- out_sz = output_size ,
5063
- kernel_sz = kernel_size ,
5064
- ndims = 2 ,
5065
- )
5054
+ @register_lowering (aten .fractional_max_pool3d )
5055
+ def fractional_max_pool3d (x , kernel_size , output_size , random_samples ):
5056
+ return _fractional_max_pool (x , kernel_size , output_size , random_samples , dim = 3 )
5066
5057
5067
- h_index_fn = gen_offsets_for_dim (dim = 0 )
5068
- w_index_fn = gen_offsets_for_dim (dim = 1 )
5069
- x_loader = x .make_loader ()
5070
5058
5071
- def fn (idx , return_index ):
5072
- * prefix , bh , bw = idx
5059
+ def _fractional_max_pool (x , kernel_size , output_size , random_samples , dim ):
5060
+ x .realize_hint ()
5061
+ batch , inp_dhw = x .shape [:- dim ], x .shape [- dim :]
5073
5062
5074
- h_start_index = ops .indirect_indexing (h_index_fn (prefix , bh ), inp_h )
5075
- w_start_index = ops .indirect_indexing (w_index_fn (prefix , bw ), inp_w )
5063
+ with config .patch (unroll_reductions_threshold = 25 ):
5064
+ dhw_index_fn = [
5065
+ _fractional_pooling_offsets (
5066
+ samples = random_samples ,
5067
+ in_sz = inp_dhw ,
5068
+ out_sz = output_size ,
5069
+ kernel_sz = kernel_size ,
5070
+ ndims = dim ,
5071
+ dim = d ,
5072
+ )
5073
+ for d in range (dim )
5074
+ ]
5076
5075
5077
- maxval = None
5078
- maxindex = None
5079
- for ih , iw in itertools .product (range (kernel_size [0 ]), range (kernel_size [1 ])):
5080
- val = x_loader ([* prefix , h_start_index + ih , w_start_index + iw ])
5081
- if return_index :
5082
- index = ops .index_expr (
5083
- (h_start_index + ih ) * inp_w + w_start_index + iw , torch .int64
5084
- )
5085
- if maxindex is None :
5086
- maxindex = index
5087
- else :
5088
- maxindex = ops .where (
5089
- ops .or_ (ops .gt (val , maxval ), ops .isnan (val )), index , maxindex
5090
- )
5091
- if maxval is None :
5092
- maxval = val
5093
- else :
5094
- maxval = ops .maximum (val , maxval )
5095
- if return_index :
5096
- return maxindex
5097
- else :
5098
- return maxval
5076
+ x_loader = x .make_loader ()
5099
5077
5100
- new_size = list (batch ) + [h_out , w_out ]
5101
- rv = Pointwise .create (
5102
- device = x .get_device (),
5103
- dtype = x .get_dtype (),
5104
- inner_fn = functools .partial (fn , return_index = False ),
5105
- ranges = new_size ,
5106
- )
5078
+ def fn_inner (idx , reduction_idx ):
5079
+ prefix = idx [:- dim ]
5080
+ return x_loader ([* prefix , * increments_to_index (idx , reduction_idx )])
5107
5081
5108
- ri = Pointwise .create (
5109
- device = x .get_device (),
5110
- dtype = torch .int64 ,
5111
- inner_fn = functools .partial (fn , return_index = True ),
5112
- ranges = new_size ,
5113
- )
5114
- return rv , ri
5082
+ def increments_to_index (idx , reduction_idx ):
5083
+ prefix = idx [:- dim ]
5084
+ bdhw = idx [- dim :]
5085
+ return [
5086
+ dhw_index_fn [d ](prefix , bdhw [d ]) + reduction_idx [d ] for d in range (dim )
5087
+ ]
5088
+
5089
+ new_size = list (batch ) + list (output_size )
5090
+ dtype = x .get_dtype ()
5091
+ result = Reduction .create (
5092
+ reduction_type = "max" ,
5093
+ input_node = x ,
5094
+ device = x .get_device (),
5095
+ dst_dtype = dtype ,
5096
+ src_dtype = dtype ,
5097
+ inner_fn = fn_inner ,
5098
+ ranges = new_size ,
5099
+ reduction_ranges = kernel_size ,
5100
+ )
5101
+ offsets = Reduction .create (
5102
+ reduction_type = "argmax" ,
5103
+ input_node = x ,
5104
+ device = x .get_device (),
5105
+ dst_dtype = torch .int64 ,
5106
+ src_dtype = dtype ,
5107
+ inner_fn = fn_inner ,
5108
+ ranges = new_size ,
5109
+ reduction_ranges = kernel_size ,
5110
+ )
5111
+ if isinstance (result .data .data , Reduction ): # type: ignore[attr-defined]
5112
+ # Only realize if reduction isn't unrolled
5113
+ result .realize ()
5114
+ if isinstance (offsets .data .data , Reduction ): # type: ignore[attr-defined]
5115
+ # Only realize if reduction isn't unrolled
5116
+ offsets .realize ()
5117
+
5118
+ indices = _pool_offsets_to_indices (
5119
+ offsets , kernel_size , x .shape , increments_to_index
5120
+ )
5121
+ return result , indices
5115
5122
5116
5123
5117
5124
@register_lowering (aten .upsample_nearest2d_backward .default )
0 commit comments