|
1 | 1 | #pragma once
|
2 | 2 |
|
3 |
| -// This is directly synchronized with caffe2/proto/caffe2.proto, but |
4 |
| -// doesn't require me to figure out how to get Protobuf headers into |
5 |
| -// ATen/core (which would require a lot more build system hacking.) |
6 |
| -// If you modify me, keep me synchronized with that file. |
7 |
| - |
8 |
| -#include <c10/macros/Export.h> |
9 |
| - |
10 |
| -#include <cstddef> |
11 |
| -#include <cstdint> |
12 |
| -#include <functional> |
13 |
| -#include <ostream> |
14 |
| -#include
10000
<string> |
| 3 | +#include <torch/standalone/header_only/core/DeviceType.h> |
15 | 4 |
|
16 | 5 | namespace c10 {
|
17 |
| - |
18 |
| -// These contains all device types that also have a BackendComponent |
19 |
| -// and therefore participate in per-backend functionality dispatch keys. |
20 |
| -// This is most backends except PrivateUse2 and PrivateUse3 |
21 |
| -#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \ |
22 |
| - _(CPU, extra) \ |
23 |
| - _(CUDA, extra) \ |
24 |
| - _(HIP, extra) \ |
25 |
| - _(XLA, extra) \ |
26 |
| - _(MPS, extra) \ |
27 |
| - _(IPU, extra) \ |
28 |
| - _(XPU, extra) \ |
29 |
| - _(HPU, extra) \ |
30 |
| - _(VE, extra) \ |
31 |
| - _(Lazy, extra) \ |
32 |
| - _(Meta, extra) \ |
33 |
| - _(MTIA, extra) \ |
34 |
| - _(PrivateUse1, extra) |
35 |
| - |
36 |
| -enum class DeviceType : int8_t { |
37 |
| - CPU = 0, |
38 |
| - CUDA = 1, // CUDA. |
39 |
| - MKLDNN = 2, // Reserved for explicit MKLDNN |
40 |
| - OPENGL = 3, // OpenGL |
41 |
| - OPENCL = 4, // OpenCL |
42 |
| - IDEEP = 5, // IDEEP. |
43 |
| - HIP = 6, // AMD HIP |
44 |
| - FPGA = 7, // FPGA |
45 |
| - MAIA = 8, // ONNX Runtime / Microsoft |
46 |
| - XLA = 9, // XLA / TPU |
47 |
| - Vulkan = 10, // Vulkan |
48 |
| - Metal = 11, // Metal |
49 |
| - XPU = 12, // XPU |
50 |
| - MPS = 13, // MPS |
51 |
| - Meta = 14, // Meta (tensors with no data) |
52 |
| - HPU = 15, // HPU / HABANA |
53 |
| - VE = 16, // SX-Aurora / NEC |
54 |
| - Lazy = 17, // Lazy Tensors |
55 |
| - IPU = 18, // Graphcore IPU |
56 |
| - MTIA = 19, // Meta training and inference devices |
57 |
| - PrivateUse1 = 20, // PrivateUse1 device |
58 |
| - // NB: If you add more devices: |
59 |
| - // - Change the implementations of DeviceTypeName and isValidDeviceType |
60 |
| - // in DeviceType.cpp |
61 |
| - // - Change the number below |
62 |
| - COMPILE_TIME_MAX_DEVICE_TYPES = 21, |
63 |
| -}; |
64 |
| - |
65 |
| -constexpr DeviceType kCPU = DeviceType::CPU; |
66 |
| -constexpr DeviceType kCUDA = DeviceType::CUDA; |
67 |
| -constexpr DeviceType kHIP = DeviceType::HIP; |
68 |
| -constexpr DeviceType kFPGA = DeviceType::FPGA; |
69 |
| -constexpr DeviceType kMAIA = DeviceType::MAIA; |
70 |
| -constexpr DeviceType kXLA = DeviceType::XLA; |
71 |
| -constexpr DeviceType kMPS = DeviceType::MPS; |
72 |
| -constexpr DeviceType kMeta = DeviceType::Meta; |
73 |
| -constexpr DeviceType kVulkan = DeviceType::Vulkan; |
74 |
| -constexpr DeviceType kMetal = DeviceType::Metal; |
75 |
| -constexpr DeviceType kXPU = DeviceType::XPU; |
76 |
| -constexpr DeviceType kHPU = DeviceType::HPU; |
77 |
| -constexpr DeviceType kVE = DeviceType::VE; |
78 |
| -constexpr DeviceType kLazy = DeviceType::Lazy; |
79 |
| -constexpr DeviceType kIPU = DeviceType::IPU; |
80 |
| -constexpr DeviceType kMTIA = DeviceType::MTIA; |
81 |
| -constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1; |
82 |
| - |
83 |
| -// define explicit int constant |
84 |
| -constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = |
85 |
| - static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); |
86 |
| - |
87 |
| -static_assert( |
88 |
| - COMPILE_TIME_MAX_DEVICE_TYPES <= 21, |
89 |
| - "Hey! You seem to be adding a lot of new DeviceTypes. The intent was " |
90 |
| - "for this constant to reflect the actual number of DeviceTypes we support " |
91 |
| - "in PyTorch; it's important that this number is not too large as we " |
92 |
| - "use this to allocate stack arrays in some places in our code. If you " |
93 |
| - "are indeed just adding the 20th device type, feel free to change " |
94 |
| - "the check to 32; but if you are adding some sort of extensible device " |
95 |
| - "types registration, please be aware that you are affecting code that " |
96 |
| - "this number is small. Try auditing uses of this constant."); |
97 |
| - |
98 |
| -C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false); |
99 |
| - |
100 |
| -C10_API bool isValidDeviceType(DeviceType d); |
101 |
| - |
102 |
| -C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type); |
103 |
| - |
104 |
| -C10_API void register_privateuse1_backend(const std::string& backend_name); |
105 |
| -C10_API std::string get_privateuse1_backend(bool lower_case = true); |
106 |
| - |
107 |
| -C10_API bool is_privateuse1_backend_registered(); |
108 |
| - |
| 6 | +using namespace torch::standalone; |
| 7 | +using torch::standalone::DeviceType; |
109 | 8 | } // namespace c10
|
110 |
| - |
111 |
| -namespace std { |
112 |
| -template <> |
113 |
| -struct hash<c10::DeviceType> { |
114 |
| - std::size_t operator()(c10::DeviceType k) const { |
115 |
| - return std::hash<int>()(static_cast<int>(k)); |
116 |
| - } |
117 |
| -}; |
118 |
| -} // namespace std |
119 |
| - |
120 |
| -namespace torch { |
121 |
| -// NOLINTNEXTLINE(misc-unused-using-decls) |
122 |
| -using c10::DeviceType; |
123 |
| -} // namespace torch |
0 commit comments