8000 Update base for Update on "[NJT] Allow construction of NJT within gra… · pytorch/pytorch@2668048 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2668048

Browse files
committed
Update base for Update on "[NJT] Allow construction of NJT within graph using offsets from inputs"
Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time. See #118446 Known gaps: - creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs) - when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when the sizes are compare ("s0 cannot be compared with u0") cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
1 parent 53d8ecf commit 2668048

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

torch/onnx/_internal/fx/patcher.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
import copy
22
import io
33
from typing import List, Union
4+
import functools
45

56
import torch
67

78
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
8-
try:
9-
# safetensors is not an exporter requirement, but needed for some huggingface models
10-
import safetensors # type: ignore[import] # noqa: F401
11-
import transformers # type: ignore[import]
12-
from safetensors import torch as safetensors_torch # noqa: F401
9+
@functools.cache
10+
def has_safetensors_and_transformers():
11+
try:
12+
# safetensors is not an exporter requirement, but needed for some huggingface models
13+
import safetensors # type: ignore[import] # noqa: F401
14+
import transformers # type: ignore[import]
1315

14-
has_safetensors_and_transformers = True
15-
except ImportError:
16-
has_safetensors_and_transformers = False
16+
from safetensors import torch as safetensors_torch # noqa: F401
17+
18+
return True
19+
except ImportError:
20+
return False
1721

1822

1923
class ONNXTorchPatcher:
@@ -61,7 +65,7 @@ def torch_load_wrapper(f, *args, **kwargs):
6165
# Wrapper or modified version of torch functions.
6266
self.torch_load_wrapper = torch_load_wrapper
6367

64-
if has_safetensors_and_transformers:
68+
if has_safetensors_and_transformers():
6569

6670
def safetensors_load_file_wrapper(filename, device="cpu"):
6771
# Record path for later serialization into ONNX proto
@@ -109,7 +113,7 @@ def __enter__(self):
109113
desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
110114
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
111115

112-
if has_safetensors_and_transformers:
116+
if has_safetensors_and_transformers():
113117
safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
114118
transformers.modeling_utils.safe_load_file = (
115119
self.safetensors_torch_load_file_wrapper
@@ -120,7 +124,7 @@ def __exit__(self, exc_type, exc_value, traceback):
120124
torch.fx._symbolic_trace._wrapped_methods_to_patch = (
121125
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
122126
)
123-
if has_safetensors_and_transformers:
127+
if has_safetensors_and_transformers():
124128
safetensors.torch.load_file = self.safetensors_torch_load_file
125129
transformers.modeling_utils.safe_load_file = (
126130
self.transformers_modeling_utils_safe_load_file

0 commit comments

Comments
 (0)
0