8000 Update base for Update on "[NJT] Allow construction of NJT within gra… · pytorch/pytorch@cd52984 · GitHub
[go: up one dir, main page]

Skip to content

Commit cd52984

Browse files
committed
Update base for Update on "[NJT] Allow construction of NJT within graph using offsets from inputs"
Creating symbolic nested ints within the graph is difficult. Using unbacked symints should solve the most important(?) cases in the mean time. See #118446 Known gaps: - creating NJT from intermediate offsets (offsets created within the graph, as opposed to being offsets passed in as inputs) - when the same offsets is also passed in as a input to the graph. We are not smart enough to realize that the offsets from that input is the same and therefore would fail when the sizes are compare ("s0 cannot be compared with u0") cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang [ghstack-poisoned]
1 parent 2668048 commit cd52984

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

torch/onnx/_internal/fx/patcher.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import copy
2+
import functools
23
import io
34
from typing import List, Union
4-
import functools
55

66
import torch
77

8+
89
# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
9-
@functools.cache
10+
@functools.lru_cache(None)
1011
def has_safetensors_and_transformers():
1112
try:
1213
# safetensors is not an exporter requirement, but needed for some huggingface models
1314
import safetensors # type: ignore[import] # noqa: F401
14-
import transformers # type: ignore[import]
15+
import transformers # type: ignore[import] # noqa: F401
1516

1617
from safetensors import torch as safetensors_torch # noqa: F401
1718

@@ -66,6 +67,8 @@ def torch_load_wrapper(f, *args, **kwargs):
6667
self.torch_load_wrapper = torch_load_wrapper
6768

6869
if has_safetensors_and_transformers():
70+
import safetensors
71+
import transformers
6972

7073
def safetensors_load_file_wrapper(filename, device="cpu"):
7174
# Record path for later serialization into ONNX proto
@@ -114,6 +117,9 @@ def __enter__(self):
114117
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
115118

116119
if has_safetensors_and_transformers():
120+
import safetensors
121+
import transformers
122+
117123
safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
118124
transformers.modeling_utils.safe_load_file = (
119125
self.safetensors_torch_load_file_wrapper
@@ -125,6 +131,9 @@ def __exit__(self, exc_type, exc_value, traceback):
125131
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
126132
)
127133
if has_safetensors_and_transformers():
134+
import safetensors
135+
import transformers
136+
128137
safetensors.torch.load_file = self.safetensors_torch_load_file
129138
transformers.modeling_utils.safe_load_file = (
130139
self.transformers_modeling_utils_safe_load_file

0 commit comments

Comments
 (0)
0