8000 Implement DeviceType.h as header-only · pytorch/pytorch@d725b04 · GitHub
[go: up one dir, main page]

Skip to content

Commit d725b04

Browse files
committed
Implement DeviceType.h as header-only
Summary: Move c10/core/DeviceType.h to a separate torch/csrc/header_only directory. Still keep a copy of c10/core/DeviceType.h for backwrad compatibility. More header files will be moved as follow-up. CI to guard "header-only-ness" will be added later. ghstack-source-id: f415bf8 Pull Request resolved: #152787
1 parent 2e70eb1 commit d725b04

File tree

7 files changed

+310
-288
lines changed

7 files changed

+310
-288
lines changed

BUILD.bazel

+8
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,14 @@ flatbuffer_cc_library(
670670
out_prefix = "torch/csrc/jit/serialization/",
671671
)
672672

673+
cc_library(
674+
name = "torch_standalone_headers",
675+
hdrs = glob([
676+
"torch/standalone/header_only/**/*.h"
677+
]),
678+
visibility = ["//visibility:public"],
679+
)
680+
673681
cc_library(
674682
name = "torch_headers",
675683
hdrs = if_cuda(

c10/core/DeviceType.cpp

-168
This file was deleted.

c10/core/DeviceType.h

+3-118
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,8 @@
11
#pragma once
22

3-
// This is directly synchronized with caffe2/proto/caffe2.proto, but
4-
// doesn't require me to figure out how to get Protobuf headers into
5-
// ATen/core (which would require a lot more build system hacking.)
6-
// If you modify me, keep me synchronized with that file.
7-
8-
#include <c10/macros/Export.h>
9-
10-
#include <cstddef>
11-
#include <cstdint>
12-
#include <functional>
13-
#include <ostream>
14-
#include 10000 <string>
3+
#include <torch/standalone/header_only/core/DeviceType.h>
154

165
namespace c10 {
17-
18-
// These contains all device types that also have a BackendComponent
19-
// and therefore participate in per-backend functionality dispatch keys.
20-
// This is most backends except PrivateUse2 and PrivateUse3
21-
#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
22-
_(CPU, extra) \
23-
_(CUDA, extra) \
24-
_(HIP, extra) \
25-
_(XLA, extra) \
26-
_(MPS, extra) \
27-
_(IPU, extra) \
28-
_(XPU, extra) \
29-
_(HPU, extra) \
30-
_(VE, extra) \
31-
_(Lazy, extra) \
32-
_(Meta, extra) \
33-
_(MTIA, extra) \
34-
_(PrivateUse1, extra)
35-
36-
enum class DeviceType : int8_t {
37-
CPU = 0,
38-
CUDA = 1, // CUDA.
39-
MKLDNN = 2, // Reserved for explicit MKLDNN
40-
OPENGL = 3, // OpenGL
41-
OPENCL = 4, // OpenCL
42-
IDEEP = 5, // IDEEP.
43-
HIP = 6, // AMD HIP
44-
FPGA = 7, // FPGA
45-
MAIA = 8, // ONNX Runtime / Microsoft
46-
XLA = 9, // XLA / TPU
47-
Vulkan = 10, // Vulkan
48-
Metal = 11, // Metal
49-
XPU = 12, // XPU
50-
MPS = 13, // MPS
51-
Meta = 14, // Meta (tensors with no data)
52-
HPU = 15, // HPU / HABANA
53-
VE = 16, // SX-Aurora / NEC
54-
Lazy = 17, // Lazy Tensors
55-
IPU = 18, // Graphcore IPU
56-
MTIA = 19, // Meta training and inference devices
57-
PrivateUse1 = 20, // PrivateUse1 device
58-
// NB: If you add more devices:
59-
// - Change the implementations of DeviceTypeName and isValidDeviceType
60-
// in DeviceType.cpp
61-
// - Change the number below
62-
COMPILE_TIME_MAX_DEVICE_TYPES = 21,
63-
};
64-
65-
constexpr DeviceType kCPU = DeviceType::CPU;
66-
constexpr DeviceType kCUDA = DeviceType::CUDA;
67-
constexpr DeviceType kHIP = DeviceType::HIP;
68-
constexpr DeviceType kFPGA = DeviceType::FPGA;
69-
constexpr DeviceType kMAIA = DeviceType::MAIA;
70-
constexpr DeviceType kXLA = DeviceType::XLA;
71-
constexpr DeviceType kMPS = DeviceType::MPS;
72-
constexpr DeviceType kMeta = DeviceType::Meta;
73-
constexpr DeviceType kVulkan = DeviceType::Vulkan;
74-
constexpr DeviceType kMetal = DeviceType::Metal;
75-
constexpr DeviceType kXPU = DeviceType::XPU;
76-
constexpr DeviceType kHPU = DeviceType::HPU;
77-
constexpr DeviceType kVE = DeviceType::VE;
78-
constexpr DeviceType kLazy = DeviceType::Lazy;
79-
constexpr DeviceType kIPU = DeviceType::IPU;
80-
constexpr DeviceType kMTIA = DeviceType::MTIA;
81-
constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
82-
83-
// define explicit int constant
84-
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
85-
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
86-
87-
static_assert(
88-
COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
89-
"Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
90-
"for this constant to reflect the actual number of DeviceTypes we support "
91-
"in PyTorch; it's important that this number is not too large as we "
92-
"use this to allocate stack arrays in some places in our code. If you "
93-
"are indeed just adding the 20th device type, feel free to change "
94-
"the check to 32; but if you are adding some sort of extensible device "
95-
"types registration, please be aware that you are affecting code that "
96-
"this number is small. Try auditing uses of this constant.");
97-
98-
C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
99-
100-
C10_API bool isValidDeviceType(DeviceType d);
101-
102-
C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
103-
104-
C10_API void register_privateuse1_backend(const std::string& backend_name);
105-
C10_API std::string get_privateuse1_backend(bool lower_case = true);
106-
107-
C10_API bool is_privateuse1_backend_registered();
108-
6+
using namespace torch::standalone;
7+
using torch::standalone::DeviceType;
1098
} // namespace c10
110-
111-
namespace std {
112-
template <>
113-
struct hash<c10::DeviceType> {
114-
std::size_t operator()(c10::DeviceType k) const {
115-
return std::hash<int>()(static_cast<int>(k));
116-
}
117-
};
118-
} // namespace std
119-
120-
namespace torch {
121-
// NOLINTNEXTLINE(misc-unused-using-decls)
122-
using c10::DeviceType;
123-
} // namespace torch

c10/core/build.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def define_targets(rules):
8080
deps = [
8181
":ScalarType",
8282
"//third_party/cpuinfo",
83+
"//:torch_standalone_headers",
8384
"//c10/macros",
8485
"//c10/util:TypeCast",
8586
"//c10/util:base",

caffe2/CMakeLists.txt

+6-2
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,8 @@ endif()
12821282
target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE})
12831283

12841284
target_include_directories(torch_cpu PRIVATE
1285-
${TORCH_SRC_DIR}/csrc)
1285+
${TORCH_SRC_DIR}/csrc
1286+
${TORCH_SRC_DIR}/standalone)
12861287

12871288
target_include_directories(torch_cpu PRIVATE
12881289
${TORCH_ROOT}/third_party/miniz-3.0.2)
@@ -1301,9 +1302,12 @@ target_include_directories(torch_cpu PRIVATE
13011302
target_include_directories(torch_cpu PRIVATE
13021303
${TORCH_ROOT}/third_party/nlohmann/include)
13031304

1304-
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
1305+
install(DIRECTORY
1306+
"${TORCH_SRC_DIR}/csrc"
1307+
"${TORCH_SRC_DIR}/standalone"
13051308
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
13061309
FILES_MATCHING PATTERN "*.h" PATTERN "*.hpp")
1310+
13071311
install(FILES
13081312
"${TORCH_SRC_DIR}/script.h"
13091313
"${TORCH_SRC_DIR}/extension.h"

torch/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ set(TORCH_PYTHON_INCLUDE_DIRECTORIES
7474
${TORCH_SRC_DIR}/csrc
7575
${TORCH_SRC_DIR}/csrc/api/include
7676
${TORCH_SRC_DIR}/lib
77+
${TORCH_SRC_DIR}/standalone
7778
)
7879

7980
list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${LIBSHM_SRCDIR})

0 commit comments

Comments
 (0)
0