8000 Update · pytorch/pytorch@23a00fc · GitHub
[go: up one dir, main page]

Skip to content

Commit 23a00fc

Browse files
committed
Update
1 parent 8b29276 commit 23a00fc

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

aten/src/ATen/native/Lerp.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@ TORCH_META_FUNC(lerp_Tensor)(
1616
const Tensor& self, const Tensor& end, const Tensor& weight) {
1717
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(),
1818
" for `end` but got dtype ", end.dtype());
19-
auto weight_ = self.dtype() == weight.dtype() ? weight : weight.to(self.dtype());
19+
20+
auto weight_ = weight;
21+
if (self.dtype() != weight.dtype()) {
22+
auto promote_type = c10::promoteTypes(self.scalar_type(), weight.scalar_type());
23+
TORCH_CHECK(promote_type == self.scalar_type(), "Unable to promote `input` dtype to ", promote_type,
24+
", change `weight` dtype ", weight.dtype(), " same as `input` dtype ", self.dtype());
25+
weight_ = weight.to(promote_type);
26+
}
27+
2028
build(at::TensorIteratorConfig()
2129
.add_output(maybe_get_output())
2230
.add_const_input(self)

test/test_binary_ufuncs.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3518,15 +3518,22 @@ def test_lerp_lowp_cpu(self, device, dtype):
35183518

35193519
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
35203520
def test_lerp_weight_type_promotion(self, device, dtype):
3521-
shape = (5, 5)
3522-
start = torch.randn(shape, device=device, dtype=dtype)
3523-
end = torch.randn(shape, device=device, dtype=dtype)
3524-
weight = torch.randn(shape, device=device, dtype=torch.float)
3521+
start = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
3522+
end = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
3523+
weight = make_tensor((5, 5), dtype=dtype, device=device, low=1, high=100)
35253524

3526-
actual = torch.lerp(start, end, weight)
3527-
expected = start + weight * (end - start)
3525+
actual = torch.lerp(start, end, weight.to(torch.float))
3526+
expected = torch.lerp(start, end, weight)
35283527
self.assertEqual(expected, actual)
35293528

3529+
@dtypes(torch.int, torch.long, torch.bfloat16, torch.float16, torch.float)
3530+
def test_lerp_weight_type_error(self, device, dtype):
3531+
x = torch.ones(2, 2, device=device, dtype=dtype)
3532+
w = torch.ones(2, 2, device=device, dtype=dtype)
3533+
s = torch.tensor(2.2, device=device, dtype=torch.double)
3534+
3535+
with self.assertRaisesRegex(RuntimeError, "Unable to promote `input` dtype"):
3536+
torch.lerp(x, w, s)
35303537

35313538
def _test_logaddexp(self, device, dtype, base2):
35323539
if base2:

0 commit comments

Comments
 (0)
0