@@ -187,6 +187,55 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
187187 f .write (f"{ kernel_name } | { args_str } | { grid !r} \n " )
188188
189189
190+ def check_autotune_cache (
191+ configs , filename , inductor_meta
192+ ) -> tuple [list [Config ], Optional [AutotuneCache ], dict [str , Any ]]:
193+ """
194+ Given a list of configs, checks autotune cache and return metadata
195+ """
196+ autotune_cache = None
197+ autotune_cache_info = {}
198+ disabled = inductor_meta .get ("force_disable_caches" , False )
199+ if (
200+ not disabled
201+ and filename is not None
202+ and (len (configs ) > 1 or inductor_meta .get ("coordinate_descent_tuning" ))
203+ and not os .environ .get ("TRITON_INTERPRET" , "0" ) == "1"
204+ ):
205+ configs_hash = hash_configs (configs )
206+
207+ autotune_cache = AutotuneCache .create (inductor_meta , filename , configs_hash )
208+ if autotune_cache :
209+ if best_config := autotune_cache .read_best (inductor_meta , configs ):
210+ configs = [best_config ]
211+ autotune_cache_info ["best_config" ] = triton_config_to_hashable (
212+ best_config
213+ )
214+ autotune_cache_info ["autotune_cache_state" ] = "hit"
215+
216+ else :
217+ autotune_cache_info ["autotune_cache_state" ] = "miss"
218+ autotune_cache_info ["num_configs" ] = len (configs )
219+ if inductor_meta .get ("coordinate_descent_tuning" ):
220+ autotune_cache_info ["coordesc_tuning" ] = True
221+ if len (configs ) == 1 :
222+ # This is the config that coordinate descent tuning started at, which
223+ # is not the same as the final config chosen (i.e. only_config, best_config)
224+ autotune_cache_info ["coordesc_tuning_start_config" ] = (
225+ triton_config_to_hashable (configs [0 ])
226+ )
227+ else :
228+ if len (configs ) == 1 :
229+ autotune_cache_info ["autotune_cache_state" ] = "only 1 config"
230+ autotune_cache_info ["only_config" ] = triton_config_to_hashable (configs [0 ])
231+
232+ if disabled :
233+ autotune_cache_info ["autotune_cache_state" ] = "force_disabled"
234+ log .debug ("autotune caching is disabled by config.force_disable_caches" )
235+
236+ return configs , autotune_cache , autotune_cache_info
237+
238+
190239class CachingAutotuner (KernelInterface ):
191240 """
192241 Simplified version of Triton autotuner that has no invalidation
@@ -298,6 +347,39 @@ def is_statically_launchable(self):
298347 isinstance (x , StaticTritonCompileResult ) for x in self .compile_results
299348 )
300349
350+ def recheck_autotune_cache (self , reload_kernel_from_src ) -> None :
351+ """
352+ On cache load on static autotuner, we need to recheck the autotune cache, since
353+ a best config could have been found from a previous run
354+ """
355+ assert self .is_statically_launchable ()
356+
357+ configs = [result .config for result in self .compile_results ]
358+
359+ (cached_configs , _ , autotune_cache_info ) = check_autotune_cache (
360+ configs , self .filename , self .inductor_meta
361+ )
362+ self .autotune_cache_info = autotune_cache_info
363+ # I.e. there was an autotune cache hit
364+ if len (cached_configs ) == 1 and len (configs ) > 1 :
365+ best_config = cached_configs [0 ]
366+ # Grab the best compiled config, if it's in the list of available ones
367+ best_config_hash = triton_config_to_hashable (best_config )
368+
369+ for compile_result in self .compile_results :
370+ if triton_config_to_hashable (compile_result .config ) == best_config_hash :
371+ self .compile_results = [compile_result ]
372+ return
373+
374+ # If the best config isn't in our list of compile results,
375+ # it's likely because it was found by coordesc after the cache
376+ # already saved
377+ if best_config .found_by_coordesc :
378+ with dynamo_timed ("CachingAutotuner.slow_precompile_config" ):
379+ if self .fn .fn is None :
380+ self .fn = reload_kernel_from_src ().fn
381+ self .compile_results = [self ._precompile_config (best_config )]
382+
301383 def set_compile_info (
302384 self , compile_id : Optional [CompileId ], is_backward : bool
303385 ) -> None :
@@ -1713,47 +1795,9 @@ def cached_autotune(
17131795 assert len (configs ) == 1 or filename
17141796 inductor_meta = {} if inductor_meta is None else inductor_meta
17151797
1716- disabled = inductor_meta .get ("force_disable_caches" , False )
1717-
1718- # on disk caching logic and/or remote caching
1719- autotune_cache = None
1720- autotune_cache_info = {}
1721- if (
1722- not disabled
1723- and filename is not None
1724- and (len (configs ) > 1 or inductor_meta .get ("coordinate_descent_tuning" ))
1725- and not os .environ .get ("TRITON_INTERPRET" , "0" ) == "1"
1726- ):
1727- configs_hash = hash_configs (configs )
1728-
1729- autotune_cache = AutotuneCache .create (inductor_meta , filename , configs_hash )
1730- if autotune_cache :
1731- if best_config := autotune_cache .read_best (inductor_meta , configs ):
1732- configs = [best_config ]
1733- autotune_cache_info ["best_config" ] = triton_config_to_hashable (
1734- best_config
1735- )
1736- autotune_cache_info ["autotune_cache_state" ] = "hit"
1737- else :
1738- autotune_cache_info ["autotune_cache_state" ] = "miss"
1739- autotune_cache_info ["num_configs" ] = len (configs )
1740- if inductor_meta .get ("coordinate_descent_tuning" ):
1741- autotune_cache_info ["coordesc_tuning" ] = True
1742- if len (configs ) == 1 :
1743- # This is the config that coordinate descent tuning started at, which
1744- # is not the same as the final config chosen (i.e. only_config, best_config)
1745- autotune_cache_info ["coordesc_tuning_start_config" ] = (
1746- triton_config_to_hashable (configs [0 ])
1747- )
1748- else :
1749- if len (configs ) == 1 :
1750- autotune_cache_info ["autotune_cache_state" ] = "only 1 config"
1751- autotune_cache_info ["only_config" ] = triton_config_to_hashable (configs [0 ])
1752-
1753- if disabled :
1754- autotune_cache_info ["autotune_cache_state" ] = "force_disabled"
1755- log .debug ("autotune caching is disabled by config.force_disable_caches" )
1756-
1798+ configs , autotune_cache , autotune_cache_info = check_autotune_cache (
1799+ configs , filename , inductor_meta
1800+ )
17571801 mutated_arg_names = inductor_meta .pop ("mutated_arg_names" , ())
17581802 optimize_mem = inductor_meta .pop ("optimize_mem" , True )
17591803
0 commit comments