@@ -22,6 +22,7 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig,
22
22
super ().__init__ (model , model_config , cache_config , backend_config ,
23
23
device )
24
24
25
+ self .supported_model = ['Llama3-8B' , 'Llama2-7B' , 'Qwen2-7B' ]
25
26
self .enable_graph = self .check_enable_graph ()
26
27
if self .enable_graph :
27
28
import dlinfer .graph
@@ -44,21 +45,20 @@ def check_enable_graph(self):
44
45
"Graph mode of device_type 'ascend' only supports tp=1 "
45
46
'for now, fallback to eager mode' , RuntimeWarning )
46
47
return False
47
- # model support
48
- self .supported_model = {
49
- 'Llama2' : 'LlamaConfig' ,
50
- 'InternLM2' : 'InternLM2Config' ,
51
- 'Qwen2' : 'Qwen2Config' ,
52
- }
53
- is_model_support = True
54
- model_config_name = str (type (self .model_config .hf_config ).__name__ )
55
- if model_config_name not in self .supported_model .values ():
56
- is_model_support = False
57
- if not is_model_support :
58
- warnings .warn (
59
- "Graph mode of device_type 'ascend' only supports models: "
60
- f"{ ', ' .join (self .supported_model .keys ())} when tp=1 for now" ,
61
- RuntimeWarning )
48
+
49
+ warnings .warn (
50
+ '\n \n '
51
+ '**********************************************************\n '
52
+ ' The following models were tested in graph mode of\n '
53
+ " device_type 'ascend' when tp=1:\n "
54
+ f" { ', ' .join (self .supported_model )} \n "
55
+ ' Other LLaMa-like models may work in graph mode, please\n '
56
+ ' check the result yourself!\n '
57
+ ' If graph mode does not work correctly with your model,\n '
58
+ ' please use eager mode instead.\n '
59
+ '**********************************************************\n \n ' ,
60
+ RuntimeWarning )
61
+
62
62
return True
63
63
64
64
def patch_kernels_custom_op (self ):
0 commit comments