8000 [export] Move PT2 constants to torch::_export (#153206) · pytorch/pytorch@f6bdcf8 · GitHub
[go: up one dir, main page]

Skip to content

Commit f6bdcf8

Browse files
angelayifacebook-github-bot
authored andcommitted
[export] Move PT2 constants to torch::_export (#153206)
Summary: Pull Request resolved: #153206 before: `from sigmoid.core.package.pt2_archive_constants_pybind import ...` after: `from torch._C._export.pt2_archive_constants import ...` #buildall Test Plan: `buck2 test //sigmoid/...` https://www.internalfb.com/intern/testinfra/testrun/1970325119807758 Reviewed By: dolpm, zhxchen17 Differential Revision: D74417085
1 parent f66a159 commit f6bdcf8

File tree

10 files changed

+133
-29
lines changed

10 files changed

+133
-29
lines changed

docs/source/export.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,3 +859,5 @@ API Reference
859859
.. automodule:: torch.export.experimental
860860
.. automodule:: torch.export.passes
861861
.. autofunction:: torch.export.passes.move_to_device_pass
862+
.. automodule:: torch.export.pt2_archive
863+
.. automodule:: torch.export.pt2_archive.constants

torch/_C/_export.pyi renamed to torch/_C/_export/__init__.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Defined in torch/csrc/export/pybind.cpp
2-
32
class CppExportedProgram: ...
43

54
def deserialize_exported_program(
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Defined in torch/csrc/export/pt2_archive_constants.h
2+
3+
ARCHIVE_ROOT_NAME: str = ...
4+
ARCHIVE_FORMAT_PATH: str = ...
5+
ARCHIVE_FORMAT_VALUE: str = ...
6+
ARCHIVE_VERSION_PATH: str = ...
7+
ARCHIVE_VERSION_VALUE: str = ...
8+
MODELS_DIR: str = ...
9+
MODELS_FILENAME_FORMAT: str = ...
10+
AOTINDUCTOR_DIR: str = ...
11+
MTIA_DIR: str = ...
12+
WEIGHTS_DIR: str = ...
13+
WEIGHT_FILENAME_PREFIX: str = ...
14+
CONSTANTS_DIR: str = ...
15+
TENSOR_CONSTANT_FILENAME_PREFIX: str = ...
16+
CUSTOM_OBJ_FILENAME_PREFIX: str = ...
17+
SAMPLE_INPUTS_DIR: str = ...
18+
SAMPLE_INPUTS_FILENAME_FORMAT: str = ...
19+
EXTRA_DIR: str = ...
20+
MODULE_INFO_PATH: str = ...
21+
XL_MODEL_WEIGHTS_DIR: str = ...
22+
XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ...

torch/_inductor/codecache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@
8989
from torch._utils_internal import log_cache_bypass
9090
from torch.compiler import config as cconfig
9191
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
92+
from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX
9293
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
9394
from torch.utils._ordered_set import OrderedSet
9495

9596
from .output_code import CompiledFxGraph
96-
from .package.pt2_archive_constants import CUSTOM_OBJ_FILENAME_PREFIX
9797
from .remote_cache import create_cache
9898
from .runtime import autotune_cache
9999
from .runtime.autotune_cache import AutotuneCacheBundler

torch/_inductor/package/package.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414
from torch._inductor import config
1515
from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder
1616
from torch.export._tree_utils import reorder_kwargs
17-
from torch.types import FileLike
18-
19-
from .pt2_archive_constants import (
17+
from torch.export.pt2_archive.constants import (
2018
AOTINDUCTOR_DIR,
21-
ARCHIVE_VERSION,
19+
ARCHIVE_VERSION_VALUE,
2220
CONSTANTS_DIR,
2321
CUSTOM_OBJ_FILENAME_PREFIX,
2422
)
23+
from torch.types import FileLike
2524

2625

2726
log = logging.getLogger(__name__)
@@ -37,7 +36,7 @@ def __enter__(self) -> Self:
3736
self.archive_file = zipfile.ZipFile(
3837
self.archive_path, "w", compression=zipfile.ZIP_STORED
3938
)
40-
self.writestr("version", str(ARCHIVE_VERSION))
39+
self.writestr("version", str(ARCHIVE_VERSION_VALUE))
4140
self.writestr("archive_format", "pt2")
4241
return self
4342

torch/_inductor/package/pt2_archive_constants.py

Lines changed: 0 additions & 16 deletions
This file was deleted.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
3+
#include <array>
4+
#include <string_view>
5+
6+
namespace torch::_export::archive_spec {
7+
8+
#define FORALL_CONSTANTS(DO) \
9+
DO(ARCHIVE_ROOT_NAME, "package") \
10+
/* Archive format */ \
11+
DO(ARCHIVE_FORMAT_PATH, "archive_format") \
12+
DO(ARCHIVE_FORMAT_VALUE, "pt2") \
13+
/* Archive version */ \
14+
DO(ARCHIVE_VERSION_PATH, "archive_version") \
15+
DO(ARCHIVE_VERSION_VALUE, "0") /* Sep.4.2024: This is the initial version of \
16+
the PT2 Archive Spec */ \
17+
/* \
18+
* ######## Note on updating ARCHIVE_VERSION_VALUE ######## \
19+
* When there is a BC breaking change to the PT2 Archive Spec, \
20+
* e.g. deleting a folder, or changing the naming convention of the \
21+
* following fields it would require bumping the ARCHIVE_VERSION_VALUE \
22+
* Archive reader would need corresponding changes to support loading both \
23+
* the current and older versions of the PT2 Archive. \
24+
*/ \
25+
/* Model definitions */ \
26+
DO(MODELS_DIR, "models/") \
27+
DO(MODELS_FILENAME_FORMAT, "models/{}.json") /* {model_name} */ \
28+
/* AOTInductor artifacts */ \
29+
DO(AOTINDUCTOR_DIR, "data/aotinductor/") \
30+
/* MTIA artifacts */ \
31+
DO(MTIA_DIR, "data/mtia") \
32+
/* weights, including parameters and buffers */ \
33+
DO(WEIGHTS_DIR, "data/weights/") \
34+
DO(WEIGHT_FILENAME_PREFIX, "weight_") \
35+
/* constants, including tensor_constants, non-persistent buffers and script \
36+
* objects */ \
37+
DO(CONSTANTS_DIR, "data/constants/") \
38+
DO(TENSOR_CONSTANT_FILENAME_PREFIX, "tensor_") \
39+
DO(CUSTOM_OBJ_FILENAME_PREFIX, "custom_obj_") \
40+
/* example inputs */ \
41+
DO(SAMPLE_INPUTS_DIR, "data/sample_inputs/") \
42+
DO(SAMPLE_INPUTS_FILENAME_FORMAT, \
43+
"data/sample_inputs/{}.pt") /* {model_name} */ \
44+
/* extra folder */ \
45+
DO(EXTRA_DIR, "extra/") \
46+
DO(MODULE_INFO_PATH, "extra/module_info.json") \
47+
/* xl_model_weights, this folder is used for storing per-feature-weights for \
48+
* remote net data in this folder is consume by Predictor, and is not \
49+
* intended to be used by Sigmoid */ \
50+
DO(XL_MODEL_WEIGHTS_DIR, "xl_model_weights/") \
51+
DO(XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH, "xl_model_weights/model_param_config")
52+
53+
#define DEFINE_GLOBAL(NAME, VALUE) \
54+
inline constexpr std::string_view NAME = VALUE;
55+
FORALL_CONSTANTS(DEFINE_GLOBAL)
56+
#undef DEFINE_GLOBAL
57+
58+
#define DEFINE_ENTRY(NAME, VALUE) std::pair(#NAME, VALUE),
59+
inline constexpr std::array kAllConstants{FORALL_CONSTANTS(DEFINE_ENTRY)};
60+
#undef DEFINE_ENTRY
61+
62+
#undef FORALL_CONSTANTS
63+
} // namespace torch::_export::archive_spec

torch/csrc/export/pybind.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <torch/csrc/export/pt2_archive_constants.h>
12
#include <torch/csrc/export/pybind.h>
23
#include <torch/csrc/utils/generated_serialization_types.h>
34
#include <torch/csrc/utils/pybind.h>
@@ -6,17 +7,23 @@ namespace torch::_export {
67

78
void initExportBindings(PyObject* module) {
89
auto rootModule = py::handle(module).cast<py::module>();
9-
auto m = rootModule.def_submodule("_export");
10+
auto exportModule = rootModule.def_submodule("_export");
11+
auto pt2ArchiveModule = exportModule.def_submodule("pt2_archive_constants");
1012

1113
// NOLINTNEXTLINE(bugprone-unused-raii)
12-
py::class_<ExportedProgram>(m, "CppExportedProgram");
14+
py::class_<ExportedProgram>(exportModule, "CppExportedProgram");
1315

14-
m.def("deserialize_exported_program", [](const std::string& serialized) {
15-
return nlohmann::json::parse(serialized).get<ExportedProgram>();
16-
});
16+
exportModule.def(
17+
"deserialize_exported_program", [](const std::string& serialized) {
18+
return nlohmann::json::parse(serialized).get<ExportedProgram>();
19+
});
1720

18-
m.def("serialize_exported_program", [](const ExportedProgram& ep) {
21+
exportModule.def("serialize_exported_program", [](const ExportedProgram& ep) {
1922
return nlohmann::json(ep).dump();
2023
});
24+
25+
for (const auto& entry : torch::_export::archive_spec::kAllConstants) {
26+
pt2ArchiveModule.attr(entry.first) = entry.second;
27+
}
2128
}
2229
} // namespace torch::_export

torch/export/pt2_archive/__init__.py

Whitespace-only changes.

torch/export/pt2_archive/constants.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Defined in torch/csrc/export/pt2_archive_constants.h
2+
from torch._C._export import pt2_archive_constants
3+
4+
5+
AOTINDUCTOR_DIR: str = pt2_archive_constants.AOTINDUCTOR_DIR
6+
ARCHIVE_FORMAT_PATH: str = pt2_archive_constants.ARCHIVE_FORMAT_PATH
7+
ARCHIVE_FORMAT_VALUE: str = pt2_archive_constants.ARCHIVE_FORMAT_VALUE
8+
ARCHIVE_ROOT_NAME: str = pt2_archive_constants.ARCHIVE_ROOT_NAME
9+
ARCHIVE_VERSION_PATH: str = pt2_archive_constants.ARCHIVE_VERSION_PATH
10+
ARCHIVE_VERSION_VALUE: str = pt2_archive_constants.ARCHIVE_VERSION_VALUE
11+
CONSTANTS_DIR: str = pt2_archive_constants.CONSTANTS_DIR
12+
CUSTOM_OBJ_FILENAME_PREFIX: str = pt2_archive_constants.CUSTOM_OBJ_FILENAME_PREFIX
13+
EXTRA_DIR: str = pt2_archive_constants.EXTRA_DIR
14+
MODELS_DIR: str = pt2_archive_constants.MODELS_DIR
15+
MODELS_FILENAME_FORMAT: str = pt2_archive_constants.MODELS_FILENAME_FORMAT
16+
MODULE_INFO_PATH: str = pt2_archive_constants.MODULE_INFO_PATH
17+
MTIA_DIR: str = pt2_archive_constants.MTIA_DIR
18+
SAMPLE_INPUTS_DIR: str = pt2_archive_constants.SAMPLE_INPUTS_DIR
19+
SAMPLE_INPUTS_FILENAME_FORMAT: str = pt2_archive_constants.SAMPLE_INPUTS_FILENAME_FORMAT
20+
TENSOR_CONSTANT_FILENAME_PREFIX: str = (
21+
pt2_archive_constants.TENSOR_CONSTANT_FILENAME_PREFIX
22+
)
23+
WEIGHT_FILENAME_PREFIX: str = pt2_archive_constants.WEIGHT_FILENAME_PREFIX
24+
WEIGHTS_DIR: str = pt2_archive_constants.WEIGHTS_DIR
25+
XL_MODEL_WEIGHTS_DIR: str = pt2_archive_constants.XL_MODEL_WEIGHTS_DIR
26+
XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = (
27+
pt2_archive_constants.XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH
28+
)

0 commit comments

Comments
 (0)
0