8000 [export] Sync aoti schema to schema.py (#148017) · pytorch/pytorch@915b9c8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 915b9c8

zhxchen17pytorchmergebot
authored andcommitted
[export] Sync aoti schema to schema.py (#148017)
Summary: Synchronizing internal AOTI schema to OSS schema.py Test Plan: CI Differential Revision: D70271151 Pull Request resolved: #148017 Approved by: https://github.com/yiming0416
1 parent 871b390 commit 915b9c8

File tree

4 files changed

+240
-5
lines changed

4 files changed

+240
-5
lines changed

torch/_export/serde/export_schema.thrift

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// @generated by update_schema.py
2-
// checksum<<735727a20b699856df15a522b7e0075dae34d441a021be640f447eb245b426c9>>
2+
// checksum<<f36968728ea96d9629b7c5269f5303e5cf23fba341d0221cb364aaf571b94dd6>>
33

44
namespace py3 torch._export
55
namespace cpp2 torch._export.schema
@@ -341,3 +341,21 @@ struct Model {
341341
60: map<string, string> deviceAllocationMap;
342342
70: map<string, string> constantPaths;
343343
}
344+
345+
struct AOTInductorModelPickleData {
346+
1: string library_basename;
347+
2: list<string> input_names;
348+
3: list<string> output_names;
349+
4: optional i64 floating_point_input_dtype;
350+
5: optional i64 floating_point_output_dtype;
351+
6: optional bool aot_inductor_model_is_cpu;
352+
}
353+
354+
struct ExternKernelNode {
355+
10: string name;
356+
20: Node node;
357+
}
358+
359+
struct ExternKernelNodes {
360+
10: list<ExternKernelNode> nodes;
361+
}

torch/_export/serde/schema.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch._export.serde.union import _Union
99

1010
# NOTE: Please update this value if any modifications are made to the schema
11-
SCHEMA_VERSION = (8, 6)
11+
SCHEMA_VERSION = (8, 7)
1212
TREESPEC_VERSION = 1
1313

1414

@@ -447,3 +447,40 @@ class Model:
447447
# key is the FQN of constant in exported program (constant tensor or torchbind objs)
448448
# value is the archive path of serialized constants
449449
constantPaths: Annotated[dict[str, str], 70]
450+
451+
#
452+
# The structure is used to serialize instances of AOTInductorModel to pass
453+
# them from the publishing pipeline to the predictor.
454+
#
455+
# All new fields should be marked as optional.
456+
#
457+
@dataclass
458+
class AOTInductorModelPickleData:
459+
# Base name of an associated .so AOTInductor library. Typically looks like:
460+
# "abc.so".
461+
library_basename: Annotated[str, 1]
462+
463+
# AOTInductor engine input names.
464+
input_names: Annotated[list[str], 2]
465+
466+
# AOTInductor engine output names.
467+
output_names: Annotated[list[str], 3]
468+
469+
# These fields tell whether floating point inputs/outputs should be converted to
470+
# a certain type. If None, the dtypes that the AOTInductor engine inferred from the sample
471+
# inputs are used.
472+
floating_point_input_dtype: Annotated[Optional[int], 4] = None
473+
floating_point_output_dtype: Annotated[Optional[int], 5] = None
474+
475+
# Whether AOTInductor runtime is for CPU.
476+
aot_inductor_model_is_cpu: Annotated[Optional[bool], 6] = None
477+
478+
@dataclass
479+
class ExternKernelNode:
480+
# name is not the unique identifier of the node
481+
name: Annotated[str, 10]
482+
node: Annotated[Node, 20]
483+
484+
@dataclass
485+
class ExternKernelNodes:
486+
nodes: Annotated[list[ExternKernelNode], 10]

torch/_export/serde/schema.yaml

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
# @generated by update_schema.py
2-
# checksum<<4710fa9fbec49ef1c60365f9eccade2b963e259d6aafc42ce1b32a4f2e20a858>>
2+
# checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>>
3+
AOTInductorModelPickleData:
4+
kind: struct
5+
fields:
6+
library_basename:
7+
type: str
8+
input_names:
9+
type: List[str]
10+
output_names:
11+
type: List[str]
12+
floating_point_input_dtype:
13+
type: Optional[int]
14+
default: None
15+
floating_point_output_dtype:
16+
type: Optional[int]
17+
default: None
18+
aot_inductor_model_is_cpu:
19+
type: Optional[bool]
20+
default: None
321
Argument:
422
kind: union
523
fields:
@@ -111,6 +129,18 @@ ExportedProgram:
111129
torch_version:
112130
type: str
113131
default: <=2.4
132+
ExternKernelNode:
133+
kind: struct
134+
fields:
135+
name:
136+
type: str
137+
node:
138+
type: Node
139+
ExternKernelNodes:
140+
kind: struct
141+
fields:
142+
nodes:
143+
type: List[ExternKernelNode]
114144
GradientToParameterSpec:
115145
kind: struct
116146
fields:
@@ -500,5 +530,5 @@ UserOutputSpec:
500530
type: Argument
501531
SCHEMA_VERSION:
502532
- 8
503-
- 6
533+
- 7
504534
TREESPEC_VERSION: 1

torch/csrc/utils/generated_serialization_types.h

Lines changed: 151 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
0