|
29 | 29 |
|
30 | 30 | static void clamp_mps_graph(CachedGraph* cachedGraph,
|
31 | 31 | const Tensor& input_tensor,
|
32 |
| - const Tensor& min_tensor, |
33 |
| - const Tensor& max_tensor) { |
34 |
| - auto input_dtype = input_tensor.scalar_type(); |
35 |
| - auto min_dtype = cachedGraph->minTensor ? min_tensor.scalar_type() : input_dtype; |
36 |
| - auto max_dtype = cachedGraph->maxTensor ? max_tensor.scalar_type() : input_dtype; |
37 |
| - |
| 32 | + const at::ScalarType min_type, |
| 33 | + const at::ScalarType max_type, |
| 34 | + const at::ScalarType result_type) { |
38 | 35 | MPSGraph* mpsGraph = cachedGraph->graph();
|
39 | 36 |
|
40 | 37 | cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
|
41 | 38 |
|
42 | 39 | auto minTensor = cachedGraph->minTensor;
|
43 | 40 | auto maxTensor = cachedGraph->maxTensor;
|
| 41 | + auto inputTensor = cachedGraph->inputTensor; |
44 | 42 |
|
45 |
| - if (input_dtype != min_dtype) { |
46 |
| - minTensor = castMPSTensor(mpsGraph, cachedGraph->minTensor, input_dtype); |
| 43 | + if (minTensor && min_type != result_type) { |
| 44 | + minTensor = castMPSTensor(mpsGraph, minTensor, result_type); |
| 45 | + } |
| 46 | + if (maxTensor && max_type != result_type) { |
| 47 | + maxTensor = castMPSTensor(mpsGraph, maxTensor, result_type); |
47 | 48 | }
|
48 |
| - if (input_dtype != max_dtype) { |
49 |
| - maxTensor = castMPSTensor(mpsGraph, cachedGraph->maxTensor, input_dtype); |
| 49 | + if (input_tensor.scalar_type() != result_type) { |
| 50 | + inputTensor = castMPSTensor(mpsGraph, inputTensor, result_type); |
50 | 51 | }
|
51 |
| - if (c10::isIntegralType(input_dtype, /*includeBool=*/true)) { |
| 52 | + if (c10::isIntegralType(result_type, /*includeBool=*/true)) { |
52 | 53 | if (minTensor && maxTensor) {
|
53 |
| - cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor |
| 54 | + cachedGraph->outputTensor = [mpsGraph clampWithTensor:inputTensor |
54 | 55 | minValueTensor:minTensor
|
55 | 56 | maxValueTensor:maxTensor
|
56 | 57 | name:nil];
|
57 | 58 | } else if (maxTensor) {
|
58 |
| - cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor |
59 |
| - secondaryTensor:maxTensor |
60 |
| - name:nil]; |
| 59 | + cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:inputTensor secondaryTensor:maxTensor name:nil]; |
61 | 60 | } else if (minTensor) {
|
62 |
| - cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor |
63 |
| - secondaryTensor:minTensor |
64 |
| - name:nil]; |
| 61 | + cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:inputTensor secondaryTensor:minTensor name:nil]; |
65 | 62 | }
|
66 | 63 | return;
|
67 | 64 | }
|
68 | 65 | // clampWithTensor doesn't propagate NaN through so simulate it as composition of
|
69 | 66 | // maximumWithNaNPropagationWithPrimaryTensor and minimumWithNaNPropagationWithPrimaryTensor
|
70 |
| - auto outputTensor = cachedGraph->inputTensor; |
| 67 | + auto outputTensor = inputTensor; |
71 | 68 | if (minTensor) {
|
72 | 69 | outputTensor = [mpsGraph maximumWithNaNPropagationWithPrimaryTensor:outputTensor
|
73 | 70 | secondaryTensor:minTensor
|
@@ -134,6 +131,8 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
134 | 131 | if (output_t.numel() == 0)
|
135 | 132 | return;
|
136 | 133 |
|
| 134 | + auto result_type = output_t.scalar_type(); |
| 135 | + |
137 | 136 | IntArrayRef new_min_shape;
|
138 | 137 | IntArrayRef new_max_shape;
|
139 | 138 |
|
@@ -182,7 +181,7 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
182 | 181 | ;
|
183 | 182 | }
|
184 | 183 |
|
185 |
| - clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor, max_opt_tensor); |
| 184 | + clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor.scalar_type(), max_opt_tensor.scalar_type(), result_type); |
186 | 185 | });
|
187 | 186 |
|
188 | 187 | bool gatherTensorData = true;
|
@@ -238,21 +237,23 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
238 | 237 | if (output_t.numel() == 0)
|
239 | 238 | return;
|
240 | 239 |
|
| 240 | + auto result_type = output_t.scalar_type(); |
| 241 | + |
241 | 242 | @autoreleasepool {
|
242 | 243 | // the optional min/max refs could affect how we build the cached graph
|
243 | 244 | string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
|
244 | 245 | (has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
245 | 246 | auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
246 | 247 | if (has_min)
|
247 |
| - newCachedGraph->minTensor = [mpsGraph |
248 |
| - constantWithScalar:min_scalar |
249 |
| - shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))]; |
| 248 | + newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar |
| 249 | + shape:mps::getMPSShape(input_t) |
| 250 | + dataType:mps::getMPSScalarType(result_type)]; |
250 | 251 | if (has_max)
|
251 |
| - newCachedGraph->maxTensor = [mpsGraph |
252 |
| - constantWithScalar:max_scalar |
253 |
| - shape:(mps::getMPSShape(input_t))dataType:(mps::getMPSScalarType(input_t.scalar_type()))]; |
| 252 | + newCachedGraph->maxTensor = [mpsGraph constantWithScalar:max_scalar |
| 253 | + shape:mps::getMPSShape(input_t) |
| 254 | + dataType:mps::getMPSScalarType(result_type)]; |
254 | 255 |
|
255 |
| - clamp_mps_graph(newCachedGraph, input_t, input_t, input_t); |
| 256 | + clamp_mps_graph(newCachedGraph, input_t, result_type, result_type, result_type); |
256 | 257 | });
|
257 | 258 |
|
258 | 259 | bool gatherTensorData = true;
|
|
0 commit comments