-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Fix lerp weight type promotion #141117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix lerp weight type promotion #141117
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141117
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit af28e9e with merge base b5655d9 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
0f5ea29
to
8b29276
Compare
@zeshengzong thanks for taking a stab at this! lmk when the PR is ready for review--i've triggered the CI based on your current changes. |
Sorry for late reply, still have some small inconsistent behavior need to confirm, I will open for review after all works fine. About the promotion I think currently only works for
Thanks! @janeyx99 |
Ah great question--I am noticing now that this PR would not fix the issue at hand. The issue is referring to the Scalar weight overload (not the Tensor overload). So I would anticipate changes to the CUDA and CPP impls relating to Thanks for asking for clarification! |
23a00fc
to
2b89afe
Compare
@janeyx99 Hello, please help me trigger CI and review the change when available, thanks! |
@janeyx99 Hello, please review new updates, thanks! |
CI failures are real |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
784dd96
to
b4c2e5c
Compare
b4c2e5c
to
a715af2
Compare
After setting pytorch/aten/src/ATen/TensorIterator.cpp Lines 543 to 547 in aaf5615
Seems no more extra logic needed in kernel as lerp to have CPU scalar tensors with CUDA Hi @janeyx99, please check whether current implement works, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thnaks for looking into this! This approach looks better :)
One nit, and waiting on CI
aten/src/ATen/native/Lerp.cpp
Outdated
bool promote_weight = weight.dim() == 0 && self.dtype() != weight.dtype(); | ||
if (!promote_weight) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool promote_weight = weight.dim() == 0 && self.dtype() != weight.dtype(); | |
if (!promote_weight) { | |
bool promote_weight = weight.dim() == 0 | |
if (!promote_weight) { |
what if we don't duplicate the dtype check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I write like this
if (weight.dim() != 0) {
TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
" for `weight` but got dtype ", weight.dtype());
}
build(at::TensorIteratorConfig()
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(weight.dim() == 0 && self.dtype() != weight.dtype())
Introduce a variable promote_weight
, I think it might be easier for others to get lerp
promote standard when reading this code, their reading experience like
promote standard is weight.dim() == 0 && self.dtype() != weight.dtype()
if not promote
need check dtype equal
do weight promote if match promote standard
In each version dtype
check twice, cause we need to avoid promotion in other cases, will change out
behavior if set .promote_inputs_to_common_dtype(true)
directly.
If original version is better(withtout promote_weight
variable), I can change it back. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But why don’t we just promote inputs to common dtype whenever weight is a scalar (no dtype check)?
the previous checks should prevent the other inputs from getting promoted right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, after some tests find out that if we have code like this, without check dtype when set promote_inputs_to_common_dtype
if (weight.dim() != 0) {
TORCH_CHECK(self.dtype() == weight.dtype(), "expected dtype ", self.dtype(),
" for `weight` but got dtype ", weight.dtype());
}
build(at::TensorIteratorConfig()
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(weight.dim() == 0)
.add_output(maybe_get_output())
.add_const_input(self)
.add_const_input(end)
.add_const_input(weight));
}
In the case all inputs are same type, but not out
param, will cause lerp
has different behavior as before
import torch
a=torch.tensor([[ 0.5385, 8.4653, -8.5042, -6.7041, 0.9973],
[-1.5006, 5.3119, -7.8279, -8.0691, 0.9812],
[ 2.6690, 1.3635, -3.8211, 1.0685, 5.1207],
[-8.9332, -5.6855, 4.2723, 8.9549, -6.6269],
[ 8.5845, -4.1670, 6.6996, -2.9766, -7.7093]], device='cuda:0')
b=torch.tensor([ 8.2217, 3.2300, 7.5432, -5.6094, -0.2661], device='cuda:0')
c=torch.rand((), dtype=torch.float, device='cuda:0')
d=torch.tensor([[-1, 5, 7, -4, -7],
[-5, 3, -5, -5, -4],
[ 9, 2, -3, 4, -9],
[ 7, -6, 7, 1, 5],
[ 2, -5, -2, -1, -8]], dtype=torch.long, device='cuda:0')
# Before change will raise error about out=d
>>> torch.lerp(a,b,c,out=d)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Found dtype Long but expected Float
# After change it will pass, and get the result seems wrong
>>> torch.lerp(a,b,c,out=d)
tensor([[ 7, 3, 6, -5, 0],
[ 7, 3, 7, -5, 0],
[ 8, 3, 7, -5, 0],
[ 7, 2, 7, -5, 0],
[ 8, 2, 7, -5, 0]], device='cuda:0')
I think it better to keep check out
param behavior not change when dim==0
, so I add check dtype
when set promote_inputs_to_common_dtype
. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thanks for the clarification, and I totally agree with you, but the code currently does not fix the concern completely! For example, the new code would NOT error for the following:
a=torch.tensor([[ 0.5385, 8.4653, -8.5042, -6.7041, 0.9973],
[-1.5006, 5.3119, -7.8279, -8.0691, 0.9812],
[ 2.6690, 1.3635, -3.8211, 1.0685, 5.1207],
[-8.9332, -5.6855, 4.2723, 8.9549, -6.6269],
[ 8.5845, -4.1670, 6.6996, -2.9766, -7.7093]], device='cuda:0')
b=torch.tensor([ 8.2217, 3.2300, 7.5432, -5.6094, -0.2661], device='cuda:0')
c=torch.rand((), dtype=torch.double, device='cuda:0') # changed this line
d=torch.tensor([[-1, 5, 7, -4, -7],
[-5, 3, -5, -5, -4],
[ 9, 2, -3, 4, -9],
[ 7, -6, 7, 1, 5],
[ 2, -5, -2, -1, -8]], dtype=torch.long, device='cuda:0')
even if it would beforehand.
So it'd be more important to figure out where the code is promoting out + be more precise even still.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, I've removed repeat weight.dtype
check, found enforce_safe_casting_to_output
flag will guard out.dtype
avoid wrong output in above use case. Check code is here:
pytorch/aten/src/ATen/TensorIterator.cpp
Lines 508 to 512 in 1ce5338
if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) { | |
TORCH_CHECK(canCast(common_dtype_, op.current_dtype), | |
"result type ", common_dtype_, " can't be cast to the " | |
"desired output type ", op.current_dtype); | |
} |
Will raise error:
import torch
a=torch.tensor([[ 0.5385, 8.4653, -8.5042, -6.7041, 0.9973],
[-1.5006, 5.3119, -7.8279, -8.0691, 0.9812],
[ 2.6690, 1.3635, -3.8211, 1.0685, 5.1207],
[-8.9332, -5.6855, 4.2723, 8.9549, -6.6269],
[ 8.5845, -4.1670, 6.6996, -2.9766, -7.7093]], device='cuda:0')
b=torch.tensor([ 8.2217, 3.2300, 7.5432, -5.6094, -0.2661], device='cuda:0')
c=torch.rand((), dtype=torch.double, device='cuda:0') # changed this line
d=torch.tensor([[-1, 5, 7, -4, -7],
[-5, 3, -5, -5, -4],
[ 9, 2, -3, 4, -9],
[ 7, -6, 7, 1, 5],
[ 2, -5, -2, -1, -8]], dtype=torch.long, device='cuda:0')
torch.lerp(a,b,c,out=d)
print(d.dtype)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: result type Float can't be cast to the desired output type Long
But in the case out
can be safely promote, the behavior is different as before:
import torch
a=torch.tensor([[ 0.5385, 8.4653, -8.5042, -6.7041, 0.9973],
[-1.5006, 5.3119, -7.8279, -8.0691, 0.9812],
[ 2.6690, 1.3635, -3.8211, 1.0685, 5.1207],
[-8.9332, -5.6855, 4.2723, 8.9549, -6.6269],
[ 8.5845, -4.1670, 6.6996, -2.9766, -7.7093]], device='cpu', dtype=torch.double)
b=torch.tensor([ 8.2217, 3.2300, 7.5432, -5.6094, -0.2661], device='cpu', dtype=torch.double)
c=torch.rand((), dtype=torch.double, device='cpu')
d=torch.tensor([[-1, 5, 7, -4, -7],
[-5, 3, -5, -5, -4],
[ 9, 2, -3, 4, -9],
[ 7, -6, 7, 1, 5],
[ 2, -5, -2, -1, -8]], dtype=torch.float, device='cpu')
torch.lerp(a,b,c,out=d)
print(d.dtype)
# Before, raise error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Found dtype Float but expected Double
# After, no error
torch.float32
As out
usage description in here
A "safe copy" is different from PyTorch's regular copy. For operations that do not participate in type promotion the device and dtype of the source and destination tensors must match. For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source. PyTorch has four type kinds: boolean, integer, float, and complex, in that order. So, for example, an operation like add (which participates in type promotion) will throw a runtime error if given float inputs but an integer out= tensor.
Since weight.dim==0
will do the promotion now, I think such change in out
behavior is consistent with the description.
Please check whether this works, thanks!
Ah it looks like our CI is going through some infra issues, could you rebase as well? |
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
c9d483b
to
af28e9e
Compare
build(at::TensorIteratorConfig() | ||
.allow_cpu_scalars(true) | ||
.promote_inputs_to_common_dtype(promote_weight) | ||
.enforce_safe_casting_to_output(promote_weight) | ||
.cast_common_dtype_to_outputs(promote_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello,
-
promote_inputs_to_common_dtype
enable type promotion instead of raise error directly, main function used to fix the issue. -
enforce_safe_casting_to_output
will do check onout.dtype
with others, make surecanCast
toout.dtype
(which guard use-case likeout.dtype=torch.long
inconsistent with other param like above)
For operations that do participate in type promotion the copy can be to a different dtype, but the destination of the copy cannot be a lower "type kind" than the source. PyTorch has four type kinds: boolean, integer, float, and complex, in that order.
pytorch/aten/src/ATen/TensorIterator.cpp
Lines 508 to 512 in d95a6ba
if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) { | |
TORCH_CHECK(canCast(common_dtype_, op.current_dtype), | |
"result type ", common_dtype_, " can't be cast to the " | |
"desired output type ", op.current_dtype); | |
} |
cast_common_dtype_to_outputs
in cpu device for creating temp tensor to cast output,
pytorch/aten/src/ATen/TensorIterator.cpp
Lines 516 to 540 in d95a6ba
if (common_device == kCPU) { | |
// Casts to outputs by creating temporaries of the correct dtype (if needed) | |
// NB: we skip this on is_meta_, because the temporary allocation here is | |
// unnecessary if we aren't going to actually do the compute | |
if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) { | |
TORCH_INTERNAL_ASSERT(op.tensor_base().defined()); | |
// Marker [Output original_tensor is set] | |
// NB: do NOT use set_output here, as the temporary is NOT a true output; | |
// op.tensor is the true output and it was pre-provided for us. | |
// TODO: The logic for cast_outputs will need to be handled by the | |
// structured kernels implementation. What probably should happen | |
// is that we pass in the inferred dtype into the out kernel, and | |
// then after calling the out kernel, do the conversion (which | |
// is cast_outputs here), but integrating this with existing | |
// TensorIterator will take a little doing | |
op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned( | |
at::empty_like(op.tensor(), | |
op.tensor_base().options().dtype(common_dtype_), | |
LEGACY_CONTIGUOUS_MEMORY_FORMAT))); | |
if (!names_.empty()) { | |
namedinference::propagate_names(op.tensor_base(), names_); | |
} | |
op.current_dtype = common_dtype_; | |
op.target_dtype = common_dtype_; | |
} |
For other ops support promotion by setting all these flags to true
, like lerp_Scalar
(add
, sub
, mul
, ... as well) use macros in here
pytorch/aten/src/ATen/TensorIterator.cpp
Lines 979 to 1000 in d95a6ba
#define BINARY_OP_CONFIG() \ | |
TensorIteratorConfig() \ | |
.set_check_mem_overlap(true) \ | |
.allow_cpu_scalars(true) \ | |
.promote_inputs_to_common_dtype(true) \ | |
.cast_common_dtype_to_outputs(true) \ | |
.enforce_safe_casting_to_output(true) \ | |
void TensorIteratorBase::build_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b) { | |
build(BINARY_OP_CONFIG() | |
.add_owned_output(out) | |
.add_owned_const_input(a) | |
.add_owned_const_input(b)); | |
} | |
void TensorIteratorBase::build_borrowing_binary_op( | |
const TensorBase& out, const TensorBase& a, const TensorBase& b) { | |
build(BINARY_OP_CONFIG() | |
.add_output(out) | |
.add_const_input(a) | |
.add_const_input(b)); | |
} |
The default value of flags is False
, currently only make them work when weight.dim == 0
. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! This approach def looks the best of all! Just one more q
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! This approach def looks the best of all! Just one more q
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! This approach def looks the best of all! Just one more q
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! This approach def looks the best of all! Just one more q
Approving workflows to see what CI thinks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the diligence!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@janeyx99 Thank you for your time and patience! 🎉 |
Fixes pytorch#140601 Enable `promote_inputs_to_common_dtype` when tensors not same dtype when invoke `lerp` function. For `lerp_Tensor` - Check whether same `dtype` of tensors, enable promote if not - Remove type check assert For `lerp_Scalar` - Seems already enable `promote_inputs_to_common_dtype` by default, just remove the type check. Make sure promote behavior consistent with `lerp_Tensor` `lerp_Scalar` get TensorIteratorConfig from here https://github.com/pytorch/pytorch/blob/c37185c76ae4068899869e48a8388e78437508e8/aten/src/ATen/TensorIterator.cpp#L979-L985 **Test Result** Test case in issue passed ```python >>> import torch >>> >>> x = torch.ones(2, 2, dtype=torch.float64) >>> w = torch.ones(2, 2, dtype=torch.float64) >>> s = torch.tensor(2.2) >>> x.lerp_(w, s) tensor([[1., 1.], [1., 1.]], dtype=torch.float64) >>> x = torch.ones(2, 2, dtype=torch.float16) >>> w = torch.ones(2, 2, dtype=torch.float16) >>> s = torch.tensor(2.2) >>> x.lerp_(w, s) tensor([[1., 1.], [1., 1.]], dtype=torch.float16) ``` ```bash $ pytest test/test_binary_ufuncs.py -k 'test_lerp_tensor_type_promotion or test_lerp_scalar_type_promotion' ```  ```bash $ lintrunner ```  Pull Request resolved: pytorch#141117 Approved by: https://github.com/janeyx99 Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
@janeyx99 Does this then fix this? Or does it only fix weight type promotion without start/end input type promotion? Could it somehow reuse the kernel / type promotion from which asked for Also, does lerp support BoolTensor weight? then I guess it could decay to doing torch.where It's also useful to allow broadcasting between start / end |
@vadimkantorov Your observation is correct that this PR does not allow type promotion between start and end, as they're not scalar tensors. I would think torch.where promotion would be different (as the true vs false branches do not interact with each other and the start and end here do), so further discussion should go in the original issue. And yes, lerp supports bool tensor weights, which does reduce it to a torch.where in essence, though most lerp use cases I'd imagine would not have such a binary weight. |
Sometimes supporting constant python scalar as start/end is also useful.
True, but given the amount of corner cases in promotion, I wonder if the torch.where/torch.lerp kernels could be unified or made templated... |
How come? Why wouldn't someone just use the python scalars as is and escape a kernel launch? Is it mostly for compile capturing? Unifying |
For python scalars I mean usage like so: foreground_mask = torch.rand(16, 16)
image = torch.randint(0, 256, (16, 16), dtype = torch.uint8)
torch.lerp(image, 255, foreground_mask)
# TypeError: lerp() received an invalid combination of arguments - got (Tensor, int, Tensor), but expected one of:
# * (Tensor input, Tensor end, Tensor weight, *, Tensor out = None)
# * (Tensor input, Tensor end, Number weight, *, Tensor out = None)
image1 = torch.randint(0, 256, (16, 16), dtype = torch.uint8)
image2 = torch.randint(0, 256, (16, 16), dtype = torch.uint8)
torch.lerp(image1, image2, 0.5)
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# RuntimeError: "lerp_kernel_scalar" not implemented for 'Byte'
Oh, I see :( I had assumed they are similar because they both have to read from both passed arguments (named |
@janeyx99 pasted these examples also into: currently python scalar as
also, seems there are no kernels for integral tensors (useful for streamlining lerp / blending for uint8 images, int16 audio, uint16 images) |
Fixes #140601
Enable
promote_inputs_to_common_dtype
when tensors not same dtype when invokelerp
function.For
lerp_Tensor
dtype
of tensors, enable promote if notFor
lerp_Scalar
promote_inputs_to_common_dtype
by default, just remove the type check. Make sure promote behavior consistent withlerp_Tensor
lerp_Scalar
get TensorIteratorConfig from herepytorch/aten/src/ATen/TensorIterator.cpp
Lines 979 to 985 in c37185c
Test Result
Test case in issue passed
$ pytest test/test_binary_ufuncs.py -k 'test_lerp_tensor_type_promotion or test_lerp_scalar_type_promotion'
cc @janeyx99