8000 Merge branch 'master' into mps-binops-dtype-precedence · pytorch/pytorch@8592fa9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8592fa9

Browse files
committed
Merge branch 'master' into mps-binops-dtype-precedence
2 parents f86d924 + 089203f commit 8592fa9

31 files changed

+657
-263
lines changed

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ TORCH_META_FUNC(special_hermite_polynomial_h) (const Tensor& self, const Tensor&
102102
build_borrowing_binary_float_op(maybe_get_output(), self, n);
103103
}
104104

105+
TORCH_META_FUNC(special_hermite_polynomial_he) (const Tensor& self, const Tensor& n) {
106+
build_borrowing_binary_float_op(maybe_get_output(), self, n);
107+
}
108+
105109
TORCH_META_FUNC2(copysign, Tensor) (
106110
const Tensor& self, const Tensor& other
107111
) {
@@ -291,6 +295,7 @@ DEFINE_DISPATCH(zeta_stub);
291295
DEFINE_DISPATCH(chebyshev_polynomial_t_stub);
292296
DEFINE_DISPATCH(chebyshev_polynomial_u_stub);
293297
DEFINE_DISPATCH(hermite_polynomial_h_stub);
298+
DEFINE_DISPATCH(hermite_polynomial_he_stub);
294299

295300
TORCH_IMPL_FUNC(sub_out) (
296301
const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& result
@@ -349,6 +354,10 @@ TORCH_IMPL_FUNC(special_hermite_polynomial_h_out) (const Tensor& self, const Ten
349354
hermite_polynomial_h_stub(device_type(), *this);
350355
}
351356

357+
TORCH_IMPL_FUNC(special_hermite_polynomial_he_out) (const Tensor& self, const Tensor& n, const Tensor& result) {
358+
hermite_polynomial_he_stub(device_type(), *this);
359+
}
360+
352361
TORCH_IMPL_FUNC(tanh_backward_out) (const Tensor& grad_output, const Tensor& output, const Tensor& result) {
353362
tanh_backward_stub(device_type(), *this);
354363
}
@@ -457,6 +466,22 @@ Tensor& special_hermite_polynomial_h_out(const Tensor& self, const Scalar& n, Te
457466
return at::special_hermite_polynomial_h_out(result, self, wrapped_scalar_tensor(n));
458467
}
459468

469+
Tensor special_hermite_polynomial_he(const Scalar& x, const Tensor& n) {
470+
return at::special_hermite_polynomial_he(wrapped_scalar_tensor(x), n);
471+
}
472+
473+
Tensor special_hermite_polynomial_he(const Tensor& x, const Scalar& n) {
474+
return at::special_hermite_polynomial_he(x, wrapped_scalar_tensor(n));
475+
}
476+
477+
Tensor& special_hermite_polynomial_he_out(const Scalar& self, const Tensor& n, Tensor& result) {
478+
return at::special_hermite_polynomial_he_out(result, wrapped_scalar_tensor(self), n);
479+
}
480+
481+
Tensor& special_hermite_polynomial_he_out(const Tensor& self, const Scalar& n, Tensor& result) {
482+
return at::special_hermite_polynomial_he_out(result, self, wrapped_scalar_tensor(n));
483+
}
484+
460485
Tensor& special_gammainc_out(const Tensor& self, const Tensor& other, Tensor& result) {
461486
return at::igamma_out(result, self, other);
462487
}
@@ -649,34 +674,18 @@ Tensor& true_divide_(Tensor& self, const Scalar& divisor) {
649674
}
650675

651676
Tensor& floor_divide_out(const Tensor& self, const Tensor& other, Tensor& result) {
652-
TORCH_WARN_ONCE(
653-
"floor_divide is deprecated, and will be removed in a future version of pytorch. "
654-
"It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
655-
"This results in incorrect rounding for negative values.\n"
656-
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
657-
"or for actual floor division, use torch.div(a, b, rounding_mode='floor')."
658-
);
659-
// FIXME: Not actually doing floor division (#43874)
660677
auto iter = TensorIterator::binary_op(result, self, other);
661-
div_trunc_stub(iter.device_type(), iter);
678+
div_floor_stub(iter.device_type(), iter);
662679
if (!result.defined()) {
663680
result = iter.output();
664681
}
665682
return result;
666683
}
667684

668685
Tensor floor_divide(const Tensor& self, const Tensor& other) {
669-
TORCH_WARN_ONCE(
670-
"floor_divide is deprecated, and will be removed in a future version of pytorch. "
671-
"It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). "
672-
"This results in incorrect rounding for negative values.\n"
673-
"To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), "
674-
"or for actual floor division, use torch.div(a, b, rounding_mode='floor')."
675-
);
676-
// FIXME: Not actually doing floor division (#43874)
677686
Tensor result;
678687
auto iter = TensorIterator::binary_op(result, self, other);
679-
div_trunc_stub(iter.device_type(), iter);
688+
div_floor_stub(iter.device_type(), iter);
680689
return iter.output();
681690
}
682691

aten/src/ATen/native/BinaryOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,6 @@ DECLARE_DISPATCH(structured_binary_fn, zeta_stub);
104104
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub);
105105
DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub);
106106
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub);
107+
DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub);
107108

108109
}} // namespace at::native

aten/src/ATen/native/Math.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2301,4 +2301,36 @@ static inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
23012301
return hermite_polynomial_h_forward(x, static_cast<int64_t>(n));
23022302
} // hermite_polynomial_h_forward(T x, T n)
23032303

2304+
template<typename T>
2305+
static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
2306+
if (n < 0) {
2307+
return T(0.0);
2308+
}
2309+
2310+
if (n == 0) {
2311+
return T(1.0);
2312+
}
2313+
2314+
if (n == 1) {
2315+
return x;
2316+
}
2317+
2318+
T p = T(1.0);
2319+
T q = x;
2320+
T r;
2321+
2322+
for (int64_t k = 1; k < n; k++) {
2323+
r = x * q - k * p;
2324+
p = q;
2325+
q = r;
2326+
}
2327+
2328+
return r;
2329+
} // hermite_polynomial_he_forward(T x, int64_t n)
2330+
2331+
template<typename T, bool is_cuda=false>
2332+
static inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) {
2333+
return hermite_polynomial_he_forward(x, static_cast<std::int64_t>(n));
2334+
} // hermite_polynomial_he_forward(T x, T n)
2335+
23042336
C10_CLANG_DIAGNOSTIC_POP()

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 9 additions & 0 deletions
1146
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,14 @@ void hermite_polynomial_h_kernel(TensorIteratorBase& iterator) {
11341134
});
11351135
} // hermite_polynomial_h_kernel(TensorIteratorBase& iterator)
11361136

1137+
void hermite_polynomial_he_kernel(TensorIteratorBase& iterator) {
1138+
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "hermite_polynomial_he_cpu", [&]() {
1139+
cpu_kernel(iterator, [](scalar_t x, scalar_t n) -> scalar_t {
1140+
return hermite_polynomial_he_forward(x, n);
1141+
});
1142+
});
1143+
} // hermite_polynomial_he_kernel
1144+
11371145
} // namespace
1138

11391147
REGISTER_DISPATCH(add_clamp_stub, &add_clamp_kernel);
@@ -1184,6 +1192,7 @@ REGISTER_DISPATCH(zeta_stub, &zeta_kernel);
11841192
REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_kernel);
11851193
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel);
11861194
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel);
1195+
REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_kernel);
11871196

11881197
} // namespace native
11891198
} // namespace at

aten/src/ATen/native/cuda/Math.cuh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,40 @@ const auto hermite_polynomial_h_string = jiterator_stringify(
13951395
} // hermite_polynomial_h_forward(T x, T n)
13961396
); // hermite_polynomial_h_string
13971397

1398+
const auto hermite_polynomial_he_string = jiterator_stringify(
1399+
template<typename T>
1400+
T hermite_polynomial_he_forward(T x, int64_t n) {
1401+
if (n < 0) {
1402+
return T(0.0);
1403+
}
1404+
1405+
if (n == 0) {
1406+
return T(1.0);
1407+
}
1408+
1409+
if (n == 1) {
1410+
return x;
1411+
}
1412+
1413+
T p = T(1.0);
1414+
T q = x;
1415+
T r;
1416+
1417+
for (int64_t k = 1; k < n; k++) {
1418+
r = x * q - k * p;
1419+
p = q;
1420+
q = r;
1421+
}
1422+
1423+
return r;
1424+
} // hermite_polynomial_he_forward(T x, int64_t n)
1425+
1426+
template<typename T>
1427+
T hermite_polynomial_he_forward(T x, T n) {
1428+
return hermite_polynomial_he_forward(x, static_cast<int64_t>(n));
1429+
} // hermite_polynomial_he_forward(T x, T n)
1430+
); // hermite_polynomial_he_string
1431+
13981432
#else // !AT_USE_JITERATOR() -- kernels must be precompiled
13991433

14001434
template <typename scalar_t>
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#define TORCH_ASSERT_NO_OPERATORS
2+
3+
#include <ATen/Dispatch.h>
4+
#include <ATen/native/cuda/JitLoops.cuh>
5+
#include <ATen/native/cuda/Loops.cuh>
6+
#include <ATen/native/BinaryOps.h>
7+
#include <ATen/native/Math.h>
8+
#include <ATen/native/cuda/Math.cuh>
9+
#include <ATen/native/cuda/jit_utils.h>
10+
11+
namespace at {
12+
namespace native {
13+
namespace {
14+
const char hermite_polynomial_he_name[] = "hermite_polynomial_he_forward";
15+
16+
void hermite_polynomial_he_kernel_cuda(TensorIteratorBase& iterator) {
17+
#if AT_USE_JITERATOR()
18+
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "hermite_polynomial_he_cuda", [&]() {
19+
opmath_jitted_gpu_kernel_with_scalars<hermite_polynomial_he_name, scalar_t, scalar_t>(iterator, hermite_polynomial_he_string);
20+
});
21+
#else
22+
AT_DISPATCH_FLOATING_TYPES(iterator.common_dtype(), "hermite_polynomial_he_cuda", [&]() {
23+
gpu_kernel_with_scalars(iterator, []GPU_LAMBDA(scalar_t x, scalar_t n) -> scalar_t {
24+
return hermite_polynomial_he_forward<scalar_t, true>(x, n);
25+
});
26+
});
27+
#endif
28+
} // hermite_polynomial_he_kernel_cuda
29+
} // namespace (anonymous)
30+
31+
REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_kernel_cuda);
32+
} // namespace native
33+
} // namespace at

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,22 @@ class Placeholder {
7474
return _value == nullptr;
7575
}
7676

77+
void allocateViewTensor(const at::Tensor& src)
78+
{
79+
assert (!_viewOutput.numel());
80+
_viewOutput = at::native::empty_mps(
81+
src.sizes(),
82+
src.scalar_type(),
83+
c10::nullopt,
84+
kMPS,
85+
c10::nullopt,
86+
c10::nullopt);
87+
}
88+
7789
private:
7890
MPSGraphTensor* _placeholder;
7991
MPSGraphTensorData* _value;
92+
Tensor _viewOutput;
8093
};
8194

8295
void resize_tensor(Tensor* output);

0 commit comments

Comments
 (0)
0