8000 Update base for Update on "Add unsigned integer dtypes to PyTorch" · pytorch/pytorch@56648ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 56648ba

Browse files
committed
Update base for Update on "Add unsigned integer dtypes to PyTorch"
The dtypes are very useless right now (not even fill works), but it makes torch.uint16, uint32 and uint64 available as a dtype. Towards #58734 Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
1 parent 94d2239 commit 56648ba

File tree

6 files changed

+216
-41
lines changed

6 files changed

+216
-41
lines changed

aten/src/ATen/Dispatch_v2.h

Lines changed: 170 additions & 0 deletions
Large diffs are not rendered by default.

aten/src/ATen/native/Scalar.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
22
#include <ATen/core/Tensor.h>
3-
#include <ATen/Dispatch.h>
3+
#include <ATen/Dispatch_v2.h>
44

55
#ifndef AT_PER_OPERATOR_HEADERS
66
#include <ATen/Functions.h>
@@ -27,33 +27,24 @@ Scalar item(const Tensor& self) {
2727
}
2828
}
2929

30+
#define AT_SD_BASE_TYPES AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), kComplexHalf, kHalf, kBool, kBFloat16
3031
#if !defined(C10_MOBILE)
31-
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
32-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
33-
kComplexHalf, \
34-
kHalf, \
35-
kBool, \
36-
kBFloat16, \
37-
kFloat8_e5m2, \
38-
kFloat8_e5m2fnuz, \
39-
kFloat8_e4m3fn, \
40-
kFloat8_e4m3fnuz, \
41-
TYPE, \
42-
NAME, \
43-
__VA_ARGS__)
32+
#define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES), AT_EXPAND(AT_FLOAT8_TYPES)
4433
#else
45-
#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...) \
46-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
47-
kComplexHalf, kHalf, kBool, kBFloat16, \
48-
TYPE, NAME, __VA_ARGS__)
34+
#define AT_SD_TYPES AT_EXPAND(AT_SD_BASE_TYPES)
4935
#endif
5036

5137
Scalar _local_scalar_dense_cpu(const Tensor& self) {
5238
Scalar r;
53-
_AT_DISPATCH_SD_TYPES(self.scalar_type(), "_local_scalar_dense_cpu", [&] {
54-
scalar_t value = *self.data_ptr<scalar_t>();
55-
r = Scalar(value);
56-
});
39+
AT_DISPATCH_V2(
40+
self.scalar_type(),
41+
"_local_scalar_dense_cpu",
42+
AT_WRAP([&] {
43+
scalar_t value = *self.data_ptr<scalar_t>();
44+
r = Scalar(value);
45+
}),
46+
AT_EXPAND(AT_SD_TYPES)
47+
);
5748
return r;
5849
}
5950

aten/src/ATen/native/cpu/FillKernel.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#define TORCH_ASSERT_NO_OPERATORS
2-
#include <ATen/Dispatch.h>
2+
#include <ATen/Dispatch_v2.h>
33
#include <ATen/Parallel.h>
44
#include <ATen/cpu/vec/vec.h>
55
#include <ATen/cpu/vec/functional.h>
@@ -44,13 +44,16 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
4444
} else if (iter.dtype() == ScalarType::ComplexHalf) {
4545
fill_non_native_type<c10::complex<at::Half>>(iter, value_scalar);
4646
} else {
47-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, iter.dtype(), "fill_cpu", [&]() {
48-
scalar_t value = value_scalar.to<scalar_t>();
49-
cpu_kernel_vec(
50-
iter,
51-
[=]() -> scalar_t { return value; },
52-
[=]() { return Vectorized<scalar_t>(value); });
53-
});
47+
AT_DISPATCH_V2(
48+
iter.dtype(), "fill_cpu", AT_WRAP([&]() {
49+
scalar_t value = value_scalar.to<scalar_t>();
50+
cpu_kernel_vec(
51+
iter,
52+
[=]() -> scalar_t { return value; },
53+
[=]() { return Vectorized<scalar_t>(value); });
54+
}),
55+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kBool
56+
);
5457
}
5558
}
5659

aten/src/ATen/native/transformers/sdp_utils_cpp.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <c10/core/SymInt.h>
1616
#include <c10/core/SymFloat.h>
1717
#include <c10/util/string_view.h>
18+
#include <c10/util/Array.h>
1819
#include <cmath>
1920
#include <cstdint>
2021
#include <functional>
@@ -58,12 +59,7 @@ inline c10::SymFloat calculate_scale(
5859
return c10::SymFloat(softmax_scale);
5960
}
6061

61-
// This helper function creates a constexpr std::array
62-
// From a compile time list of values
63-
template <typename V, typename... T>
64-
inline constexpr auto array_of(T&&... t) -> std::array<V, sizeof...(T)> {
65-
return {{std::forward<T>(t)...}};
66-
}
62+
using c10::array_of;
6763

6864
inline bool input_requires_grad(sdp_params const& params) {
6965
const bool any_inputs_require_grad = params.query.requires_grad() ||

c10/core/ScalarType.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <c10/core/ScalarType.h>
2+
#include <c10/util/Array.h>
23
#include <array>
34

45
namespace c10 {
@@ -20,10 +21,8 @@ constexpr auto b1 = ScalarType::Bool;
2021
constexpr auto bf = ScalarType::BFloat16;
2122
constexpr auto ud = ScalarType::Undefined;
2223

23-
constexpr int64_t NUM_PROMOTE_TYPES = 20;
24-
25-
constexpr std::array<ScalarType, NUM_PROMOTE_TYPES> index2dtype =
26-
{u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf};
24+
constexpr auto index2dtype = array_of<
25+
c10::ScalarType>(u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf);
2726

2827
constexpr std::array<int64_t, static_cast<size_t>(ScalarType::NumOptions)>
2928
calculate_dtype2index() {
@@ -83,7 +82,7 @@ ScalarType promoteTypes(ScalarType a, ScalarType b) {
8382
// This table axes must be consistent with index2dtype
8483
// clang-format off
8584
static constexpr std::
86-
array<std::array<ScalarType, NUM_PROMOTE_TYPES>, NUM_PROMOTE_TYPES>
85+
array<std::array<ScalarType, index2dtype.size()>, index2dtype.size()>
8786
_promoteTypesLookup = {{
8887
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/
8988
/* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf},

c10/util/Array.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#include <array>
2+
#include <utility>
3+
4+
namespace c10 {
5+
6+
// This helper function creates a constexpr std::array
7+
// From a compile time list of values, without requiring you to explicitly
8+
// write out the length.
9+
//
10+
// See also https://stackoverflow.com/a/26351760/23845
11+
template <typename V, typename... T>
12+
inline constexpr auto array_of(T&&... t) -> std::array<V, sizeof...(T)> {
13+
return {{std::forward<T>(t)...}};
14+
}
15+
16+
} // namespace c10

0 commit comments

Comments
 (0)
0