8000 [MPS][TYPE_PROMOTION] Fix Clamp (#133260) · pytorch/pytorch@26735e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 26735e7

Browse files
pytorchbotqqaatw
andauthored
[MPS][TYPE_PROMOTION] Fix Clamp (#133260)
[MPS][TYPE_PROMOTION] Fix Clamp (#130226) Summary: 1. Fixed #130201 by adding type promotion. 2. Added proper tests. 3. Found torch's type promotion is different from numpy as follows: ```python import torch import numpy as np np.clip(np.array([1], dtype=np.float32), np.array([1], dtype=np.int32), None).dtype # dtype('float64') torch.clamp(torch.tensor([1], dtype=torch.float32), torch.tensor([1], dtype=torch.int32)).dtype # torch.float32 ``` ~Not sure the proper way to handle it, it causes numpy ref tests to fail.~ Reason here, so think I'm gonna xfail it: https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264 Pull Request resolved: #130226 Approved by: https://github.com/malfet (cherry picked from commit 99967e1) Co-authored-by: Li-Huai (Allan) Lin <qqaatw@gmail.com>
1 parent f6fb80b commit 26735e7

File tree

3 files changed

+44
-29
lines changed

3 files changed

+44
-29
lines changed

aten/src/ATen/native/mps/operations/TensorCompare.mm

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,45 +29,42 @@
2929

3030
static void clamp_mps_graph(CachedGraph* cachedGraph,
3131
const Tensor& input_tensor,
32-
const Tensor& min_tensor,
33-
const Tensor& max_tensor) {
34-
auto input_dtype = input_tensor.scalar_type();
35-
auto min_dtype = cachedGraph->minTensor ? min_tensor.scalar_type() : input_dtype;
36-
auto max_dtype = cachedGraph->maxTensor ? max_tensor.scalar_type() : input_dtype;
37-
32+
const at::ScalarType min_type,
33+
const at::ScalarType max_type,
34+
const at::ScalarType result_type) {
3835
MPSGraph* mpsGraph = cachedGraph->graph();
3936

4037
cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
4138

4239
auto minTensor = cachedGraph->minTensor;
4340
auto maxTensor = cachedGraph->maxTensor;
41+
auto inputTensor = cachedGraph->inputTensor;
4442

45-
if (input_dtype != min_dtype) {
46-
minTensor = castMPSTensor(mpsGraph, cachedGraph->minTensor, input_dtype);
43+
if (minTensor && min_type != result_type) {
44+
minTensor = castMPSTensor(mpsGraph, minTensor, result_type);
45+
}
46+
if (maxTensor && max_type != result_type) {
47+
maxTensor = castMPSTensor(mpsGraph, maxTensor, result_type);
4748
}
48-
if (input_dtype != max_dtype) {
49-
maxTensor = castMPSTensor(mpsGraph, cachedGraph->maxTensor, input_dtype);
49+
if (input_tensor.scalar_type() != result_type) {
50+
inputTensor = castMPSTensor(mpsGraph, inputTensor, result_type);
5051
}
51-
if (c10::isIntegralType(input_dtype, /*includeBool=*/true)) {
52+
if (c10::isIntegralType(result_type, /*includeBool=*/true)) {
5253
if (minTensor && maxTensor) {
53-
cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor
54+
cachedGraph->outputTensor = [mpsGraph clampWithTensor:inputTensor
5455
minValueTensor:minTensor
5556
maxValueTensor:maxTensor
5657
name:nil];
5758
} else if (maxTensor) {
58-
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor
59-
secondaryTensor:maxTensor
60-
name:nil];
59+
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:inputTensor secondaryTensor:maxTensor name:nil];
6160
} else if (minTensor) {
62-
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor
63-
secondaryTensor:minTensor
64-
name:nil];
61+
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:inputTensor secondaryTensor:minTensor name:nil];
6562
}
6663
return;
6764
}
6865
// clampWithTensor doesn't propagate NaN through so simulate it as composition of
6966
// maximumWithNaNPropagationWithPrimaryTensor and minimumWithNaNPropagationWithPrimaryTensor
70-
auto outputTensor = cachedGraph->inputTensor;
67+
auto outputTensor = inputTensor;
7168
if (minTensor) {
7269
outputTensor = [mpsGraph maximumWithNaNPropagationWithPrimaryTensor:outputTensor
7370
secondaryTensor:minTensor
@@ -134,6 +131,8 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
134131
if (output_t.numel() == 0)
135132
return;
136133

134+
auto result_type = output_t.scalar_type();
135+
137136
IntArrayRef new_min_shape;
138137
IntArrayRef new_max_shape;
139138

@@ -182,7 +181,7 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
182181
;
183182
}
184183

185-
clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor, max_opt_tensor);
184+
clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor.scalar_type(), max_opt_tensor.scalar_type(), result_type);
186185
});
187186

188187
bool gatherTensorData = true;
@@ -238,21 +237,23 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
238237
if (output_t.numel() == 0)
239238
return;
240239

240+
auto result_type = output_t.scalar_type();
241+
241242
@autoreleasepool {
242243
// the optional min/max refs could affect how we build the cached graph
243244
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
244245
(has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
245246
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
246247
if (has_min)
247-
newCachedGraph->minTensor = [mpsGraph
248-
constantWithScalar:min_scalar
249-
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
248+
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar
249+
shape:mps::getMPSShape(input_t)
250+
dataType:mps::getMPSScalarType(result_type)];
250251
if (has_max)
251-
newCachedGraph->maxTensor = [mpsGraph
252-
constantWithScalar:max_scalar
253-
shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))];
252+
newCachedGraph->maxTensor = [mpsGraph constantWithScalar:max_scalar
253+
shape:mps::getMPSShape(input_t)
254+
dataType:mps::getMPSScalarType(result_type)];
254255

255-
clamp_mps_graph(newCachedGraph, input_t, input_t, input_t);
256+
clamp_mps_graph(newCachedGraph, input_t, result_type, result_type, result_type);
256257
});
257258

258259
bool gatherTensorData = true;

test/test_mps.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12042,8 +12042,13 @@ def test_numpy_ref_mps(self, device, dtype, op):
1204212042
# does not support float64 Tensors.
1204312043
# A few ops are currently broken on their reference inputs, but not their sample inputs. These should
1204412044
# get patched up and this workaround removed.
12045-
broken_on_ref_inputs = op.name in ['clamp', 'where']
12046-
inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype)
12045+
broken_on_ref_inputs = op.name in ('where',)
12046+
12047+
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12048+
inputs = (
12049+
op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs
12050+
else op.sample_inputs(device, dtype, set_seed=False)
12051+
)
1204712052
for sample_input in inputs:
1204812053
self.compare_with_reference(op, op.ref, sample_input)
1204912054

torch/testing/_internal/common_methods_invocations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6223,13 +6223,17 @@ def error_inputs_flipud(op, device, **kwargs):
62236223

62246224
def sample_inputs_clamp(op_info, device, dtype, requires_grad, **kwargs):
62256225
make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad)
6226+
make_integral_arg = partial(make_tensor, dtype=torch.int32, device=device, low=None, high=None, requires_grad=False)
62266227
shape = (S, M, S)
62276228

62286229
yield SampleInput(make_arg(shape), args=(make_arg(shape), make_arg(shape)))
62296230
yield SampleInput(make_arg(shape), args=(make_arg(shape[1:]), make_arg(shape[1:])))
62306231
yield SampleInput(make_arg(shape), args=(make_arg((S, 1, S)),))
62316232
yield SampleInput(make_arg(shape), args=(None, make_arg(shape)))
62326233
yield SampleInput(make_arg(shape), args=(make_arg(shape), None))
6234+
# test type promotion
6235+
yield SampleInput(make_arg(shape), args=(make_integral_arg(shape), None))
6236+
yield SampleInput(make_arg(shape), args=(make_arg(shape), make_integral_arg(shape)))
62336237

62346238
def reference_inputs_elementwise_ternary(op, device, dtype, requires_grad, *, sample_inputs_func, supports_scalars=False, **kwargs):
62356239
yield from sample_inputs_func(op, device, dtype, requires_grad, **kwargs)
@@ -12666,6 +12670,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1266612670
'TestNNCOpInfo',
1266712671
'test_nnc_correctness',
1266812672
dtypes=(torch.bool,)),
12673+
# MPS does not support float64, while numpy does internal computations in float64.
12674+
# See https://github.com/pytorch/pytorch/blob/3c1cf03fde145bdbe1f5ffb81765d076c10b4c04/test/test_ops.py#L260-L264
12675+
DecorateInfo(unittest.expectedFailure,
12676+
'TestCommon',
12677+
'test_numpy_ref_mps'),
1266912678
)),
1267012679
UnaryUfuncInfo('positive',
1267112680
ref=np.positive,

0 commit comments

Comments
 (0)
0