1
1
import copy
2
+ import functools
2
3
import io
3
4
from typing import List , Union
4
- import functools
5
5
6
6
import torch
7
7
8
+
8
9
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
9
- @functools .cache
10
+ @functools .lru_cache ( None )
10
11
def has_safetensors_and_transformers ():
11
12
try :
12
13
# safetensors is not an exporter requirement, but needed for some huggingface models
13
14
import safetensors # type: ignore[import] # noqa: F401
14
- import transformers # type: ignore[import]
15
+ import transformers # type: ignore[import] # noqa: F401
15
16
16
17
from safetensors import torch as safetensors_torch # noqa: F401
17
18
@@ -66,6 +67,8 @@ def torch_load_wrapper(f, *args, **kwargs):
66
67
self .torch_load_wrapper = torch_load_wrapper
67
68
68
69
if has_safetensors_and_transformers ():
70
+ import safetensors
71
+ import transformers
69
72
70
73
def safetensors_load_file_wrapper (filename , device = "cpu" ):
71
74
# Record path for later serialization into ONNX proto
@@ -114,6 +117,9 @@ def __enter__(self):
114
117
torch .fx ._symbolic_trace ._wrapped_methods_to_patch = desired_wrapped_methods
115
118
116
119
if has_safetensors_and_transformers ():
120
+ import safetensors
121
+ import transformers
122
+
117
123
safetensors .torch .load_file = self .safetensors_torch_load_file_wrapper
118
124
transformers .modeling_utils .safe_load_file = (
119
125
self .safetensors_torch_load_file_wrapper
@@ -125,6 +131,9 @@ def __exit__(self, exc_type, exc_value, traceback):
125
131
self .torch_fx__symbolic_trace__wrapped_methods_to_patch
126
132
)
127
133
if has_safetensors_and_transformers ():
134
+ import safetensors
135
+ import transformers
136
+
128
137
safetensors .torch .load_file = self .safetensors_torch_load_file
129
138
transformers .modeling_utils .safe_load_file = (
130
139
self .transformers_modeling_utils_safe_load_file
0 commit comments