8000 Fix lerp weight type promotion by zeshengzong · Pull Request #141117 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix lerp weight type promotion #141117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from

Conversation

zeshengzong
Copy link
Contributor
@zeshengzong zeshengzong commented Nov 20, 2024

Fixes #140601

Enable promote_inputs_to_common_dtype when tensors not same dtype when invoke lerp function.

For lerp_Tensor

  • Check whether same dtype of tensors, enable promote if not
  • Remove type check assert

For lerp_Scalar

  • Seems already enable promote_inputs_to_common_dtype by default, just remove the type check. Make sure promote behavior consistent with lerp_Tensor

lerp_Scalar get TensorIteratorConfig from here

#define BINARY_OP_CONFIG() \
TensorIteratorConfig() \
.set_check_mem_overlap(true) \
.allow_cpu_scalars(true) \
.promote_inputs_to_common_dtype(true) \
.cast_common_dtype_to_outputs(true) \
.enforce_safe_casting_to_output(true) \

Test Result
Test case in issue passed

>>> import torch
>>> 
>>> x = torch.ones(2, 2, dtype=torch.float64)
>>> w = torch.ones(2, 2, dtype=torch.float64)
>>> s = torch.tensor(2.2) 
>>> x.lerp_(w, s)
tensor([[1., 1.],
        [1., 1.]], dtype=torch.float64)

>>> x = torch.ones(2, 2, dtype=torch.float16)
>>> w = torch.ones(2, 2, dtype=torch.float16)
>>> s = torch.tensor(2.2) 
>>> x.lerp_(w, s)
tensor([[1., 1.],
        [1., 1.]], dtype=torch.float16)
$ pytest test/test_binary_ufuncs.py -k 'test_lerp_tensor_type_promotion or test_lerp_scalar_type_promotion'

image

$ lintrunner

image

cc @janeyx99

Copy link
pytorch-bot bot commented Nov 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141117

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit af28e9e with merge base b5655d9 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janeyx99 janeyx99 self-requested a review December 6, 2024 19:58
@janeyx99 janeyx99 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module release notes: python_frontend python frontend release notes category topic: improvements topic category labels Dec 6, 2024
@janeyx99
Copy link
Contributor
janeyx99 commented Dec 6, 2024

@zeshengzong thanks for taking a stab at this! lmk when the PR is ready for review--i've triggered the CI based on your current changes.

@zeshengzong
Copy link
Contributor Author

Sorry for late reply, still have some small inconsistent behavior need to confirm, I will open for review after all works fine.

About the promotion I think currently only works for weight param promote to same type as input, but not the other way around. Does this make sense? Or all tensors should promote to same type?

torch.lerp(input, end, weight, *, out=None)

Thanks! @janeyx99

@janeyx99
Copy link
Contributor

Ah great question--I am noticing now that this PR would not fix the issue at hand. The issue is referring to the Scalar weight overload (not the Tensor overload). So I would anticipate changes to the CUDA and CPP impls relating to lerp_kernel_scalar_weight in the codebase and no changes to the Tensor weight signatures.

Thanks for asking for clarification!

@zeshengzong zeshengzong marked this pull request as ready for review December 13, 2024 07:25
@zeshengzong
Copy link
Contributor Author

@janeyx99 Hello, please help me trigger CI and review the change when available, thanks!

< 8000 /details-collapsible>
@zeshengzong
Copy link
Contributor Author

@janeyx99 Hello, please review new updates, thanks!

@janeyx99
Copy link
Contributor
janeyx99 commented Jan 2, 2025

CI failures are real

@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix/aten/lerp onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix/aten/lerp && git pull --rebase)

@zeshengzong
Copy link
Contributor Author
zeshengzong commented Jan 8, 2025

There were recent changes to enable lerp to have CPU scalar tensors with CUDA. I would expect this change to reuse that path as much as possible.

After setting promote_inputs_to_common_dtype, weight scalar tensor promotion will be handled by TensorIteratorBase::compute_types logic.

if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(op.tensor().to(common_dtype_)));
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}

Seems no more extra logic needed in kernel as lerp to have CPU scalar tensors with CUDA

Hi @janeyx99, please check whether current implement works, thanks!

Copy link
Contributor
@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thnaks for looking into this! This approach looks better :)

One nit, and waiting on CI

Comment on lines 19 to 20
bool promote_weight = weight.dim() == 0 && self.dtype() != weight.dtype();
if (!promote_weight) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool promote_weight = weight.dim() == 0 && self.dtype() != weight.dtype();
if (!promote_weight) {
bool promote_weight = weight.dim() == 0
if (!promote_weight) {

what if we don't duplicate the dtype check?

Copy link
Contributor Author
@zeshengzong zeshengzong Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I write like this

  if (weight.dim() != 0) {
      TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
                " for `weight` but got dtype ", weight.dtype());
  }
  build(at::TensorIteratorConfig()
        .allow_cpu_scalars(true)
        .promote_inputs_to_common_dtype(weight.dim() == 0 && self.dtype() != weight.dtype())

Introduce a variable promote_weight, I think it might be easier for others to get lerp promote standard when reading this code, their reading experience like

promote standard is weight.dim() == 0 && self.dtype() != weight.dtype()
if not promote
    need check dtype equal
do weight promote if match promote standard 

In each version dtype check twice, cause we need to avoid promotion in other cases, will change out behavior if set .promote_inputs_to_common_dtype(true) directly.

If original version is better(withtout promote_weight variable), I can change it back. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why don’t we just promote inputs to common dtype whenever weight is a scalar (no dtype check)?

the previous checks should prevent the other inputs from getting promoted right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, after some tests find out that if we have code like this, without check dtype when set promote_inputs_to_common_dtype

  if (weight.dim() != 0) {
    TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
                " for `weight` but got dtype ", weight.dtype());
  }
  build(at::TensorIteratorConfig()
        .allow_cpu_scalars(true)
        .promote_inputs_to_common_dtype(weight.dim() == 0)
        .add_output(maybe_get_output())
        .add_const_input(self)
        .add_const_input(end)
        .add_const_input(weight));
}

In the case all inputs are same type, but not out param, will cause lerp has different behavior as before

import torch
a=torch.tensor([[ 0.5385,  8.4653, -8.5042, -6.7041,  0.9973],
        [-1.5006,  5.3119, -7.8279, -8.0691,  0.9812],
        [ 2.6690,  1.3635, -3.8211,  1.0685,  5.1207],
        [-8.9332, -5.6855,  4.2723,  8.9549, -6.6269],
        [ 8.5845, -4.1670,  6.6996, -2.9766, -7.7093]], device='cuda:0')
b=torch.tensor([ 8.2217,  3.2300,  7.5432, -5.6094, -0.2661], device='cuda:0')
c=torch.rand((), dtype=torch.float, device='cuda:0')

d=torch.tensor([[-1,  5,  7, -4, -7],
        [-5,  3, -5, -5, -4],
        [ 9,  2, -3,  4, -9],
        [ 7, -6,  7,  1,  5],
        [ 2, -5, -2, -1, -8]], dtype=torch.long, device='cuda:0')
        
# Before change will raise error about out=d
>>> torch.lerp(a,b,c,out=d)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Found dtype Long but expected Float

# After change it will pass, and get the result seems wrong
>>> torch.lerp(a,b,c,out=d)
tensor([[ 7,  3,  6, -5,  0],
        [ 7,  3,  7, -5,  0],
        [ 8,  3,  7, -5,  0],
        [ 7,  2,  7, -5,  0],
        [ 8,  2,  7, -5,  0]], device='cuda:0')

I think it better to keep check out param behavior not change when dim==0, so I add check dtype when set promote_inputs_to_common_dtype. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for the clarification, and I totally agree with you, but the code currently does not fix the concern completely! For example, the new code would NOT error for the following:

a=torch.tensor([[ 0.5385,  8.4653, -8.5042, -6.7041,  0.9973],
        [-1.5006,  5.3119, -7.8279, -8.0691,  0.9812],
        [ 2.6690,  1.3635, -3.8211,  1.0685,  5.1207],
        [-8.9332, -5.6855,  4.2723,  8.9549, -6.6269],
        [ 8.5845, -4.1670,  6.6996, -2.9766, -7.7093]], device='cuda:0')
b=torch.tensor([ 8.2217,  3.2300,  7.5432, -5.6094, -0.2661], device='cuda:0')
c=torch.rand((), dtype=torch.double, device='cuda:0')  # changed this line

d=torch.tensor([[-1,  5,  7, -4, -7],
        [-5,  3, -5, -5, -4],
        [ 9,  2, -3,  4, -9],
        [ 7, -6,  7,  1,  5],
        [ 2, -5, -2, -1, -8]], dtype=torch.long, device='cuda:0')

even if it would beforehand.

So it'd be more important to figure out where the code is promoting out + be more precise even still.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I've removed repeat weight.dtype check, found enforce_safe_casting_to_output flag will guard out.dtype avoid wrong output in above use case. Check code is here:

if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) {
TORCH_CHECK(canCast(common_dtype_, op.current_dtype),
"result type ", common_dtype_, " can't be cast to the "
"desired output type ", op.current_dtype);
}

Will raise error:

import torch
a=torch.tensor([[ 0.5385,  8.4653, -8.5042, -6.7041,  0.9973],
        [-1.5006,  5.3119, -7.8279, -8.0691,  0.9812],
        [ 2.6690,  1.3635, -3.8211,  1.0685,  5.1207],
        [-8.9332, -5.6855,  4.2723,  8.9549, -6.6269],
        [ 8.5845, -4.1670,  6.6996, -2.9766, -7.7093]], device='cuda:0')
b=torch.tensor([ 8.2217,  3.2300,  7.5432, -5.6094, -0.2661], device='cuda:0')
c=torch.rand((), dtype=torch.double, device='cuda:0')  # changed this line

d=torch.tensor([[-1,  5,  7, -4, -7],
        [-5,  3, -5, -5, -4],
        [ 9,  2, -3,  4, -9],
        [ 7, -6,  7,  1,  5],
        [ 2, -5, -2, -1, -8]], dtype=torch.long, device='cuda:0')

torch.lerp(a,b,c,out=d)
print(d.dtype)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: result type Float can't be cast to the desired output type Long

But in the case out can be safely promote, the behavior is different as before:

import torch
a=torch.tensor([[ 0.5385,  8.4653, -8.5042, -6.7041,  0.9973],
        [-1.5006,  5.3119, -7.8279, -8.0691,  0.9812],
        [ 2.6690,  1.3635, -3.8211,  1.0685,  5.1207],
        [-8.9332, -5.6855,  4.2723,  8.9549, -6.6269],
        [ 8.5845, -4.1670,  6.6996, -2.9766, -7.7093]], device='cpu', dtype=torch.double)
b=torch.tensor([ 8.2217,  3.2300,  7.5432, -5.6094, -0.2661], device='cpu', dtype=torch.double)
c=torch.rand((), dtype=torch.double, device='cpu')


d=torch.tensor([[-1,  5,  7, -4, -7],
        [-5,  3, -5, -5, -4],
        [ 9,  2, -3,  4, -9],
        [ 7, -6,  7,  1,  5],
        [ 2, -5, -2, -1, -8]], dtype=torch.float, device='cpu')

torch.lerp(a,b,c,out=d)
print(d.dtype)

# Before, raise error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Found dtype Float but expected Double

# After, no error
torch.float32

As out usage description in here

A "safe copy" is different from PyTorch's regular copy. For operations that do not participate in type promotion the device and dtype of the source and destination tensors must match. For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source. PyTorch has four type kinds: boolean, integer, float, and complex, in that order. So, for example, an operation like add (which participates in type promotion) will throw a runtime error if given float inputs but an integer out= tensor.

Since weight.dim==0 will do the promotion now, I think such change in out behavior is consistent with the description.

Please check whether this works, thanks!

@janeyx99
Copy link
Contributor
janeyx99 commented Jan 9, 2025

Ah it looks like our CI is going through some infra issues, could you rebase as well?

@zeshengzong
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

zeshengzong and others added 3 commits January 21, 2025 15:19
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
build(at::TensorIteratorConfig()
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(promote_weight)
.enforce_safe_casting_to_output(promote_weight)
.cast_common_dtype_to_outputs(promote_weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello,

  • promote_inputs_to_common_dtype enable type promotion instead of raise error directly, main function used to fix the issue.

  • enforce_safe_casting_to_output will do check on out.dtype with others, make sure canCast to out.dtype (which guard use-case like out.dtype=torch.long inconsistent with other param like above)

For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source. PyTorch has four type kinds: boolean, integer, float, and complex, in that order.

if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) {
TORCH_CHECK(canCast(common_dtype_, op.current_dtype),
"result type ", common_dtype_, " can't be cast to the "
"desired output type ", op.current_dtype);
}

  • cast_common_dtype_to_outputs in cpu device for creating temp tensor to cast output,

if (common_device == kCPU) {
// Casts to outputs by creating temporaries of the correct dtype (if needed)
// NB: we skip this on is_meta_, because the temporary allocation here is
// unnecessary if we aren't going to actually do the compute
if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) {
TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
// Marker [Output original_tensor is set]
// NB: do NOT use set_output here, as the temporary is NOT a true output;
// op.tensor is the true output and it was pre-provided for us.
// TODO: The logic for cast_outputs will need to be handled by the
// structured kernels implementation. What probably should happen
// is that we pass in the inferred dtype into the out kernel, and
// then after calling the out kernel, do the conversion (which
// is cast_outputs here), but integrating this with existing
// TensorIterator will take a little doing
op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(
at::empty_like(op.tensor(),
op.tensor_base().options().dtype(common_dtype_),
LEGACY_CONTIGUOUS_MEMORY_FORMAT)));
if (!names_.empty()) {
namedinference::propagate_names(op.tensor_base(), names_);
}
op.current_dtype = common_dtype_;
op.target_dtype = common_dtype_;
}

For other ops support promotion by setting all these flags to true, like lerp_Scalar (add, sub, mul, ... as well) use macros in here

#define BINARY_OP_CONFIG() \
TensorIteratorConfig() \
.set_check_mem_overlap(true) \
.allow_cpu_scalars(true) \
.promote_inputs_to_common_dtype(true) \
.cast_common_dtype_to_outputs(true) \
.enforce_safe_casting_to_output(true) \
void TensorIteratorBase::build_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b) {
build(BINARY_OP_CONFIG()
.add_owned_output(out)
.add_owned_const_input(a)
.add_owned_const_input(b));
}
void TensorIteratorBase::build_borrowing_binary_op(
const TensorBase& out, const TensorBase& a, const TensorBase& b) {
build(BINARY_OP_CONFIG()
.add_output(out)
.add_const_input(a)
.add_const_input(b));
}

The default value of flags is False, currently only make them work when weight.dim == 0. Thanks!

Copy link
Contributor
@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This approach def looks the best of all! Just one more q

Copy link
Contributor
@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This approach def looks the best of all! Just one more q

Copy link
Contributor
@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This approach def looks the best of all! Just one more q

Copy link
Contributor
@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! This approach def looks the best of all! Just one more q

@janeyx99
Copy link
Contributor

Approving workflows to see what CI thinks

Copy link
Contributor
@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the diligence!

@janeyx99
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 23, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 23, 2025 20:03 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 23, 2025 20:03 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 23, 2025 20:03 Inactive
@zeshengzong
Copy link
Contributor Author

@janeyx99 Thank you for your time and patience! 🎉

pytorchmergebot pushed a commit to AnantGulati/pytorch that referenced this pull request Jan 24, 2025
Fixes pytorch#140601

Enable `promote_inputs_to_common_dtype` when tensors not same dtype when invoke `lerp` function.

For `lerp_Tensor`
- Check whether same `dtype` of tensors, enable promote if not
- Remove type check assert

For `lerp_Scalar`
- Seems already enable `promote_inputs_to_common_dtype` by default, just remove the type check. Make sure promote behavior consistent with `lerp_Tensor`

`lerp_Scalar` get TensorIteratorConfig from here
https://github.com/pytorch/pytorch/blob/c37185c76ae4068899869e48a8388e78437508e8/aten/src/ATen/TensorIterator.cpp#L979-L985

**Test Result**
Test case in issue passed

```python
>>> import torch
>>>
>>> x = torch.ones(2, 2, dtype=torch.float64)
>>> w = torch.ones(2, 2, dtype=torch.float64)
>>> s = torch.tensor(2.2)
>>> x.lerp_(w, s)
tensor([[1., 1.],
        [1., 1.]], dtype=torch.float64)

>>> x = torch.ones(2, 2, dtype=torch.float16)
>>> w = torch.ones(2, 2, dtype=torch.float16)
>>> s = torch.tensor(2.2)
>>> x.lerp_(w, s)
tensor([[1., 1.],
        [1., 1.]], dtype=torch.float16)

```

```bash
$ pytest test/test_binary_ufuncs.py -k 'test_lerp_tensor_type_promotion or test_lerp_scalar_type_promotion'
```
![image](https://github.com/user-attachments/assets/288a5294-a9ee-47f3-bbf7-d4ff986f3ba8)

```bash
$ lintrunner
```
![image](https://github.com/user-attachments/assets/d469836f-5c49-4d89-a2fd-379cad4db3af)

Pull Request resolved: pytorch#141117
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
@vadimkantorov
Copy link
Contributor
vadimkantorov commented Apr 25, 2025

@janeyx99 Does this then fix this?

Or does it only fix weight type promotion without start/end input type promotion?

Could it somehow reuse the kernel / type promotion from torch.where (or even uinfy the two)? The difference is that torch.where only accepts a bool mask and torch.lerp a float [0; 1] mask

which asked for torch.where-like type promotion, including accepting python scalars as one of bounds and also working with uint8/integral inputs (useful for interpolating between two uint8 images e.g. with a matting bool/float mask) and still producing uint8 image outputs.

Also, does lerp support BoolTensor weight? then I guess it could decay to doing torch.where

It's also useful to allow broadcasting between start / end

@janeyx99
Copy link
Contributor

@vadimkantorov Your observation is correct that this PR does not allow type promotion between start and end, as they're not scalar tensors. I would think torch.where promotion would be different (as the true vs false branches do not interact with each other and the start and end here do), so further discussion should go in the original issue.

And yes, lerp supports bool tensor weights, which does reduce it to a torch.where in essence, though most lerp use cases I'd imagine would not have such a binary weight.

@vadimkantorov
Copy link
Contributor

as they're not scalar tensors

Sometimes supporting constant python scalar as start/end is also useful.

I'd imagine would not have such a binary weight.

True, but given the amount of corner cases in promotion, I wonder if the torch.where/torch.lerp kernels could be unified or made templated...

@janeyx99
Copy link
Contributor

Sometimes supporting constant python scalar as start/end is also useful.

How come? Why wouldn't someone just use the python scalars as is and escape a kernel launch? Is it mostly for compile capturing?

Unifying lerp and where doesn't seem to make that much sense to me as their underlying implementations are quite different and their main similarity is just that they're don't fit under our classic binary or unary op buckets.

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Apr 25, 2025

For python scalars I mean usage like so:

foreground_mask = torch.rand(16, 16)
image = torch.randint(0, 256, (16, 16), dtype = torch.uint8)
torch.lerp(image, 255, foreground_mask)

# TypeError: lerp() received an invalid combination of arguments - got (Tensor, int, Tensor), but expected one of:
#  * (Tensor input, Tensor end, Tensor weight, *, Tensor out = None)
#  * (Tensor input, Tensor end, Number weight, *, Tensor out = None)

image1 = torch.randint(0, 256, (16, 16), dtype = torch.uint8)
image2 = torch.randint(0, 256, (16, 16), dtype = torch.uint8)
torch.lerp(image1, image2, 0.5)
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# RuntimeError: "lerp_kernel_scalar" not implemented for 'Byte'

as their underlying implementations are quite different

Oh, I see :( I had assumed they are similar because they both have to read from both passed arguments (named input/end for torch.lerp and input/other torch.where) and then modulate the output given the read mask (named weight for torch.lerp and condition for torch.where)

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Apr 26, 2025

@janeyx99 pasted these examples also into:

currently python scalar as input / end are not supported as the only overloads are

 * (Tensor input, Tensor end, Tensor weight, *, Tensor out = None)
 * (Tensor input, Tensor end, Number weight, *, Tensor out = None)

also, seems there are no kernels for integral tensors (useful for streamlining lerp / blending for uint8 images, int16 audio, uint16 images)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: python_frontend python frontend release notes category topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

lerp_ doesn't correctly type promote
5 participants
0