8000 Update on "Do not import transformers when import torch._dynamo" · pytorch/pytorch@27d40a8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 27d40a8

Browse files
committed
Update on "Do not import transformers when import torch._dynamo"
Fixes #123954 [ghstack-poisoned]
2 parents fe887a0 + 8c7d415 commit 27d40a8

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

torch/onnx/_internal/fx/patcher.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import copy
2+
import functools
23
import io
34
from typing import List, Union
4-
import functools
55

66
import torch
77

8+
89
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
9-
@functools.cache
10+
@functools.lru_cache(None)
1011
def has_safetensors_and_transformers():
1112
try:
1213
# safetensors is not an exporter requirement, but needed for some huggingface models
1314
import safetensors # type: ignore[import] # noqa: F401
14-
import transformers # type: ignore[import]
15+
import transformers # type: ignore[import] # noqa: F401
1516

1617
from safetensors import torch as safetensors_torch # noqa: F401
1718

@@ -66,6 +67,8 @@ def torch_load_wrapper(f, *args, **kwargs):
6667
self.torch_load_wrapper = torch_load_wrapper
6768

6869
if has_safetensors_and_transformers():
70+
import safetensors
71+
import transformers
6972

7073
def safetensors_load_file_wrapper(filename, device="cpu"):
7174
# Record path for later serialization into ONNX proto
@@ -114,6 +117,9 @@ def __enter__(self):
114117
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
115118

116119
if has_safetensors_and_transformers():
120+
import safetensors
121+
import transformers
122+
117123
safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
118124
transformers.modeling_utils.safe_load_file = (
119125
self.safetensors_torch_load_file_wrapper
@@ -125,6 +131,9 @@ def __exit__(self, exc_type, exc_value, traceback):
125131
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
126132
)
127133
if has_safetensors_and_transformers():
134+
import safetensors
135+
import transformers
136+
128137
safetensors.torch.load_file = self.safetensors_torch_load_file
129138
transformers.modeling_utils.safe_load_file = (
130139
self.transformers_modeling_utils_safe_load_file

0 commit comments

Comments
 (0)
0