@@ -333,6 +333,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
333
333
tensor_name = tensor_names .get (tensor )
334
334
if tensor_name is None :
335
335
continue
336
+ mapping [tensor_name ] = (tensor , tensor_name )
336
337
for key in keys :
337
338
mapping [key ] = (tensor , tensor_name )
338
339
for bid in range (n_blocks ):
@@ -341,11 +342,12 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
341
342
if tensor_name is None :
342
343
continue
343
344
tensor_name = tensor_name .format (bid = bid )
345
+ mapping [tensor_name ] = (tensor , tensor_name )
344
346
for key in keys :
345
347
key = key .format (bid = bid )
346
348
mapping [key ] = (tensor , tensor_name )
347
349
348
- def get_type_and_name (self , key : str , try_suffixes : Sequence [str ]) -> tuple [MODEL_TENSOR , str ] | None :
350
+ def get_type_and_name (self , key : str , try_suffixes : Sequence [str ] = () ) -> tuple [MODEL_TENSOR , str ] | None :
349
351
result = self .mapping .get (key )
350
352
if result is not None :
351
353
return result
@@ -356,13 +358,13 @@ def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODE
356
358
return (result [0 ], result [1 ] + suffix )
357
359
return None
358
360
359
- def get_name (self , key : str , try_suffixes : Sequence [str ]) -> str | None :
361
+ def get_name (self , key : str , try_suffixes : Sequence [str ] = () ) -> str | None :
360
362
result = self .get_type_and_name (key , try_suffixes = try_suffixes )
361
363
if result is None :
362
364
return None
363
365
return result [1 ]
364
366
365
- def get_type (self , key : str , try_suffixes : Sequence [str ]) -> MODEL_TENSOR | None :
367
+ def get_type (self , key : str , try_suffixes : Sequence [str ] = () ) -> MODEL_TENSOR | None :
366
368
result = self .get_type_and_name (key , try_suffixes = try_suffixes )
367
369
if result is None :
368
370
return None
0 commit comments