|
33 | 33 | logger = logging.getLogger(__name__)
|
34 | 34 |
|
35 | 35 |
|
36 |
| -@dataclasses.dataclass(frozen=True) |
| 36 | +@dataclasses.dataclass |
37 | 37 | class OnnxDecompMeta:
|
38 | 38 | """A wrapper of onnx-script function with additional metadata.
|
39 | 39 |
|
40 | 40 | onnx_function: The onnx-script function from torchlib.
|
41 | 41 | fx_target: The PyTorch node callable target.
|
| 42 | + signature: The ONNX signature of the function. When None, the signature is inferred. |
42 | 43 | is_custom: Whether the function is a custom function.
|
43 | 44 | is_complex: Whether the function is a function that handles complex valued inputs.
|
44 | 45 | device: The device the function is registered to. If None, it is registered to all devices.
|
| 46 | + skip_signature_inference: Whether to skip signature inference for the function. |
45 | 47 | """
|
46 | 48 |
|
47 | 49 | onnx_function: Callable
|
48 | 50 | fx_target: TorchOp
|
| 51 | + signature: _schemas.OpSignature | None |
49 | 52 | is_custom: bool = False
|
50 | 53 | is_complex: bool = False
|
51 | 54 | device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
|
| 55 | + skip_signature_inference: bool = False |
| 56 | + |
| 57 | + def __post_init__(self) -> None: |
| 58 | + if self.signature is None and not self.skip_signature_inference: |
| 59 | + try: |
| 60 | + if isinstance(self.onnx_function, onnxscript.OnnxFunction): |
| 61 | + signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] |
| 62 | + self.onnx_function, |
| 63 | + self.onnx_function.function_ir.domain, |
| 64 | + self.onnx_function.name, |
| 65 | + opset_version=self.onnx_function.opset.version, |
| 66 | + ) |
| 67 | + else: |
| 68 | + signature = _schemas.OpSignature.from_function( |
| 69 | + self.onnx_function, "__traced", self.onnx_function.__name__ |
| 70 | + ) |
| 71 | + except Exception as e: |
| 72 | + # Log an warning if the op is custom. Raise exception for builtin ops. |
| 73 | + if not self.is_custom: |
| 74 | + raise |
| 75 | + else: |
| 76 | + # When the function is targeting an HOP, for example, it will accept |
| 77 | + # functions as arguments and fail to generate an ONNX signature. |
| 78 | + # In this case we set signature to None and dispatch to this function always. |
| 79 | + logger.warning( |
| 80 | + "Failed to infer the signature for function '%s' because '%s'" |
| 81 | + "All nodes targeting `%s` will be dispatched to this function", |
| 82 | + self.onnx_function, |
| 83 | + e, |
| 84 | + self.fx_target, |
| 85 | + ) |
| 86 | + else: |
| 87 | + self.signature = signature |
| 88 | + self.onnx_function._pt_onnx_signature = signature # type: ignore[attr-defined] |
52 | 89 |
|
53 | 90 |
|
54 | 91 | def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
|
@@ -120,57 +157,34 @@ def from_torchlib(cls) -> ONNXRegistry:
|
120 | 157 | torchlib_registry: The torchlib registry to use for populating the registry.
|
121 | 158 | """
|
122 | 159 | registry = cls()
|
| 160 | + for meta in _torchlib_registry.get_torchlib_ops(): |
| 161 | + registry._register(meta.fx_target, meta) |
123 | 162 |
|
| 163 | + # TODO(justinchuby): Remove this once torchlib is migrated to PyTorch |
124 | 164 | torchlib_ops = onnxscript_apis.get_torchlib_ops()
|
125 | 165 |
|
126 |
| - for meta in torchlib_ops: |
127 |
| - qualified_name = meta.qualified_name |
128 |
| - overload_func = meta.function |
129 |
| - domain = meta.domain |
130 |
| - name = meta.name |
| 166 | + for torchlib_meta in torchlib_ops: |
| 167 | + qualified_name = torchlib_meta.qualified_name |
| 168 | + overload_func = torchlib_meta.function |
131 | 169 | try:
|
132 | 170 | # NOTE: This is heavily guarded with try-except because we don't want
|
133 | 171 | # to fail the entire registry population if one function fails.
|
134 | 172 | target = _get_overload(qualified_name)
|
135 | 173 | if target is None:
|
136 | 174 | continue
|
137 | 175 |
|
138 |
| - if isinstance(overload_func, onnxscript.OnnxFunction): |
139 |
| - opset_version = overload_func.opset.version |
140 |
| - else: |
141 |
| - opset_version = 1 |
142 |
| - |
143 |
| - overload_func.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] |
144 |
| - overload_func, |
145 |
| - domain, |
146 |
| - name, |
147 |
| - opset_version=opset_version, |
148 |
| - ) |
149 |
| - onnx_decomposition = OnnxDecompMeta( |
| 176 | + meta = OnnxDecompMeta( |
150 | 177 | onnx_function=overload_func,
|
151 | 178 | fx_target=target,
|
| 179 | + signature=None, |
152 | 180 | is_custom=False,
|
153 |
| - is_complex=meta.is_complex, |
| 181 | + is_complex=torchlib_meta.is_complex, |
154 | 182 | )
|
155 |
| - registry._register(target, onnx_decomposition) |
| 183 | + registry._register(target, meta) |
156 | 184 | except Exception:
|
157 | 185 | logger.exception("Failed to register '%s'. Skipped", qualified_name)
|
158 | 186 | continue
|
159 | 187 |
|
160 |
| - # Gather ops from the internal torchlib registry |
161 |
| - # TODO(justinchuby): Make this the main registry after torchlib is migrated to PyTorch |
162 |
| - # Trigger registration |
163 |
| - from torch.onnx._internal.exporter._torchlib import ops |
164 |
| - |
165 |
| - del ops |
166 |
| - for target, implementations in _torchlib_registry.registry.items(): # type: ignore[assignment] |
167 |
| - for impl in implementations: |
168 |
| - onnx_decomposition = OnnxDecompMeta( |
169 |
| - onnx_function=impl, |
170 |
| - fx_target=target, # type: ignore[arg-type] |
171 |
| - ) |
172 |
| - registry._register(target, onnx_decomposition) # type: ignore[arg-type] |
173 |
| - |
174 | 188 | return registry
|
175 | 189 |
|
176 | 190 | def _register(
|
@@ -209,32 +223,23 @@ def register_op(
|
209 | 223 | function: The onnx-script function to register.
|
210 | 224 | is_complex: Whether the function is a function that handles complex valued inputs.
|
211 | 225 | """
|
212 |
| - if not hasattr(function, "signature"): |
213 |
| - try: |
214 |
| - # TODO(justinchuby): Use the op_signature attribute when onnxscript is updated in CI |
215 |
| - if isinstance(function, onnxscript.OnnxFunction): |
216 |
| - function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] |
217 |
| - function, |
218 |
| - function.function_ir.domain, |
219 |
| - function.name, |
220 |
| - opset_version=function.opset.version, |
221 |
| - ) |
222 |
| - else: |
223 |
| - function.signature = _schemas.OpSignature.from_function( # type: ignore[attr-defined] |
224 |
| - function, "__custom", function.__name__ |
225 |
| - ) |
226 |
| - except Exception: |
227 |
| - logger.exception( |
228 |
| - "Failed to infer the signature for function '%s'", function |
229 |
| - ) |
| 226 | + if isinstance(target, torch._ops.OpOverloadPacket): |
| 227 | + raise TypeError( |
| 228 | + f"Target '{target}' should be provided as an OpOverload instead of an " |
| 229 | + "OpOverloadPacket. You can get the default overload with " |
| 230 | + "<op>.default" |
| 231 | + ) |
230 | 232 |
|
231 |
| - onnx_decomposition = OnnxDecompMeta( |
232 |
| - onnx_function=function, |
233 |
| - fx_target=target, |
234 |
| - is_custom=True, |
235 |
| - is_complex=is_complex, |
| 233 | + self._register( |
| 234 | + target, |
| 235 | + OnnxDecompMeta( |
| 236 | + onnx_function=function, |
| 237 | + fx_target=target, |
| 238 | + signature=None, |
| 239 | + is_custom=True, |
| 240 | + is_complex=is_complex, |
| 241 | + ), |
236 | 242 | )
|
237 |
| - self._register(target, onnx_decomposition) |
238 | 243 |
|
239 | 244 | def get_decomps(self, target: TorchOp) -> list[OnnxDecompMeta]:
|
240 | 245 | """Returns a list of OnnxDecompMeta for the given op: torch.ops.<namespace>.<op_name>.<overload>.
|
|
0 commit comments