@@ -187,6 +187,55 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
187
187
f .write (f"{ kernel_name } | { args_str } | { grid !r} \n " )
188
188
189
189
190
+ def check_autotune_cache (
191
+ configs , filename , inductor_meta
192
+ ) -> tuple [list [Config ], Optional [AutotuneCache ], dict [str , Any ]]:
193
+ """
194
+ Given a list of configs, checks autotune cache and return metadata
195
+ """
196
+ autotune_cache = None
197
+ autotune_cache_info = {}
198
+ disabled = inductor_meta .get ("force_disable_caches" , False )
199
+ if (
200
+ not disabled
201
+ and filename is not None
202
+ and (len (configs ) > 1 or inductor_meta .get ("coordinate_descent_tuning" ))
203
+ and not os .environ .get ("TRITON_INTERPRET" , "0" ) == "1"
204
+ ):
205
+ configs_hash = hash_configs (configs )
206
+
207
+ autotune_cache = AutotuneCache .create (inductor_meta , filename , configs_hash )
208
+ if autotune_cache :
209
+ if best_config := autotune_cache .read_best (inductor_meta , configs ):
210
+ configs = [best_config ]
211
+ autotune_cache_info ["best_config" ] = triton_config_to_hashable (
212
+ best_config
213
+ )
214
+ autotune_cache_info ["autotune_cache_state" ] = "hit"
215
+
216
+ else :
217
+ autotune_cache_info ["autotune_cache_state" ] = "miss"
218
+ autotune_cache_info ["num_configs" ] = len (configs )
219
+ if inductor_meta .get ("coordinate_descent_tuning" ):
220
+ autotune_cache_info ["coordesc_tuning" ] = True
221
+ if len (configs ) == 1 :
222
+ # This is the config that coordinate descent tuning started at, which
223
+ # is not the same as the final config chosen (i.e. only_config, best_config)
224
+ autotune_cache_info ["coordesc_tuning_start_config" ] = (
225
+ triton_config_to_hashable (configs [0 ])
226
+ )
227
+ else :
228
+ if len (configs ) == 1 :
229
+ autotune_cache_info ["autotune_cache_state" ] = "only 1 config"
230
+ autotune_cache_info ["only_config" ] = triton_config_to_hashable (configs [0 ])
231
+
232
+ if disabled :
233
+ autotune_cache_info ["autotune_cache_state" ] = "force_disabled"
234
+ log .debug ("autotune caching is disabled by config.force_disable_caches" )
235
+
236
+ return configs , autotune_cache , autotune_cache_info
237
+
238
+
190
239
class CachingAutotuner (KernelInterface ):
191
240
"""
192
241
Simplified version of Triton autotuner that has no invalidation
@@ -298,6 +347,39 @@ def is_statically_launchable(self):
298
347
isinstance (x , StaticTritonCompileResult ) for x in self .compile_results
299
348
)
300
349
350
+ def recheck_autotune_cache (self , reload_kernel_from_src ) -> None :
351
+ """
352
+ On cache load on static autotuner, we need to recheck the autotune cache, since
353
+ a best config could have been found from a previous run
354
+ """
355
+ assert self .is_statically_launchable ()
356
+
357
+ configs = [result .config for result in self .compile_results ]
358
+
359
+ (cached_configs , _ , autotune_cache_info ) = check_autotune_cache (
360
+ configs , self .filename , self .inductor_meta
361
+ )
362
+ self .autotune_cache_info = autotune_cache_info
363
+ # I.e. there was an autotune cache hit
364
+ if len (cached_configs ) == 1 and len (configs ) > 1 :
365
+ best_config = cached_configs [0 ]
366
+ # Grab the best compiled config, if it's in the list of available ones
367
+ best_config_hash = triton_config_to_hashable (best_config )
368
+
369
+ for compile_result in self .compile_results :
370
+ if triton_config_to_hashable (compile_result .config ) == best_config_hash :
371
+ self .compile_results = [compile_result ]
372
+ return
373
+
374
+ # If the best config isn't in our list of compile results,
375
+ # it's likely because it was found by coordesc after the cache
376
+ # already saved
377
+ if best_config .found_by_coordesc :
378
+ with dynamo_timed ("CachingAutotuner.slow_precompile_config" ):
379
+ if self .fn .fn is None :
380
+ self .fn = reload_kernel_from_src ().fn
381
+ self .compile_results = [self ._precompile_config (best_config )]
382
+
301
383
def set_compile_info (
302
384
self , compile_id : Optional [CompileId ], is_backward : bool
303
385
) -> None :
@@ -1713,47 +1795,9 @@ def cached_autotune(
1713
1795
assert len (configs ) == 1 or filename
1714
1796
inductor_meta = {} if inductor_meta is None else inductor_meta
1715
1797
1716
- disabled = inductor_meta .get ("force_disable_caches" , False )
1717
-
1718
- # on disk caching logic and/or remote caching
1719
- autotune_cache = None
1720
- autotune_cache_info = {}
1721
- if (
1722
- not disabled
1723
- and filename is not None
1724
- and (len (configs ) > 1 or inductor_meta .get ("coordinate_descent_tuning" ))
1725
- and not os .environ .get ("TRITON_INTERPRET" , "0" ) == "1"
1726
- ):
1727
- configs_hash = hash_configs (configs )
1728
-
1729
- autotune_cache = AutotuneCache .create (inductor_meta , filename , configs_hash )
1730
- if autotune_cache :
1731
- if best_config := autotune_cache .read_best (inductor_meta , configs ):
1732
- configs = [best_config ]
1733
- autotune_cache_info ["best_config" ] = triton_config_to_hashable (
1734
- best_config
1735
- )
1736
- autotune_cache_info ["autotune_cache_state" ] = "hit"
1737
- else :
1738
- autotune_cache_info ["autotune_cache_state" ] = "miss"
1739
- autotune_cache_info ["num_configs" ] = len (configs )
1740
- if inductor_meta .get ("coordinate_descent_tuning" ):
1741
- autotune_cache_info ["coordesc_tuning" ] = True
1742
- if len (configs ) == 1 :
1743
- # This is the config that coordinate descent tuning started at, which
1744
- # is not the same as the final config chosen (i.e. only_config, best_config)
1745
- autotune_cache_info ["coor
B216
desc_tuning_start_config" ] = (
1746
- triton_config_to_hashable (configs [0 ])
1747
- )
1748
- else :
1749
- if len (configs ) == 1 :
1750
- autotune_cache_info ["autotune_cache_state" ] = "only 1 config"
1751
- autotune_cache_info ["only_config" ] = triton_config_to_hashable (configs [0 ])
1752
-
1753
- if disabled :
1754
- autotune_cache_info ["autotune_cache_state" ] = "force_disabled"
1755
- log .debug ("autotune caching is disabled by config.force_disable_caches" )
1756
-
1798
+ configs , autotune_cache , autotune_cache_info = check_autotune_cache (
1799
+ configs , filename , inductor_meta
1800
+ )
1757
1801
mutated_arg_names = inductor_meta .pop ("mutated_arg_names" , ())
1758
1802
optimize_mem = inductor_meta .pop ("optimize_mem" , True )
1759
1803
0 commit comments