8000 [BE][MPS] Pass `alpha` by reference (#152737) · pytorch/pytorch@792736f · GitHub
[go: up one dir, main page]

Skip to content

Commit 792736f

Browse files
malfetpytorchmergebot
authored andcommitted
[BE][MPS] Pass alpha by reference (#152737)
As it's always a scalar Pull Request resolved: #152737 Approved by: https://github.com/dcci ghstack dependencies: #152663, #152515
1 parent cc28b43 commit 792736f

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

c10/metal/indexing.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ kernel void alpha_binary_strided(
196196
device void* output [[buffer(0)]],
197197
constant void* input [[buffer(1)]],
198198
constant void* other [[buffer(2)]],
199-
constant T* alpha [[buffer(3)]],
199+
constant T& alpha [[buffer(3)]],
200200
constant long* sizes [[buffer(4)]],
201201
constant long* output_strides [[buffer(5)]],
202202
constant long* input_strides [[buffer(6)]],
@@ -211,7 +211,7 @@ kernel void alpha_binary_strided(
211211
const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
212212
const auto a = val_at_offs<T>(input, input_offs);
213213
const auto b = val_at_offs<T>(other, other_offs);
214-
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, *alpha);
214+
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, alpha);
215215
}
216216

217217
template <typename T, typename F, typename om_t = opmath_t<T>>
@@ -244,7 +244,7 @@ kernel void alpha_binary_strided_cast(
244244
device void* output [[buffer(0)]],
245245
constant void* input [[buffer(1)]],
246246
constant void* other [[buffer(2)]],
247-
constant T* alpha [[buffer(3)]],
247+
constant T& alpha [[buffer(3)]],
248248
constant long* sizes [[buffer(4)]],
249249
constant long* output_strides [[buffer(5)]],
250250
constant long* input_strides [[buffer(6)]],
@@ -261,7 +261,7 @@ kernel void alpha_binary_strided_cast(
261261
val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
262262
const auto b =
263263
val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
264-
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, *alpha);
264+
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, alpha);
265265
}
266266

267267
template <typename T, typename F, typename om_t = opmath_t<T>>
@@ -280,10 +280,10 @@ kernel void alpha_binary_dense(
280280
device result_of<F, T, T, T>* out [[buffer(0)]],
281281
constant T* input [[buffer(1)]],
282282
constant T* other [[buffer(2)]],
283-
constant T* alpha [[buffer(3)]],
283+
constant T& alpha [[buffer(3)]],
284284
uint tid [[thread_position_in_grid]]) {
285285
F f;
286-
out[tid] = f(input[tid], other[tid], *alpha);
286+
out[tid] = f(input[tid], other[tid], alpha);
287287
}
288288

289289
template <typename T, typename F, typename om_t = T>
8000 @@ -307,15 +307,15 @@ kernel void alpha_binary_dense_cast(
307307
device result_of<F, T, T, T>* out [[buffer(0)]],
308308
constant void* input [[buffer(1)]],
309309
constant void* other [[buffer(2)]],
310-
constant T* alpha [[buffer(3)]],
310+
constant T& alpha [[buffer(3)]],
311311
constant uint4& sizes_types [[buffer(4)]],
312312
uint tid [[thread_position_in_grid]]) {
313313
F f;
314314
const auto a = val_at_offs<T>(
315315
input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
316316
const auto b = val_at_offs<T>(
317317
other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
318-
out[tid] = f(a, b, *alpha);
318+
out[tid] = f(a, b, alpha);
319319
}
320320

321321
#define REGISTER_BINARY_OP_(NAME, DTYPEI, DTYPEO, OMT) \
@@ -380,7 +380,7 @@ kernel void alpha_binary_dense_cast(
380380
device void* out, \
381381
constant void* input, \
382382
constant void* other, \
383-
constant DTYPEI* alpha, \
383+
constant DTYPEI& alpha, \
384384
constant long* sizes, \
385385
constant long* output_strides, \
386386
constant long* input_strides, \
@@ -392,7 +392,7 @@ kernel void alpha_binary_dense_cast(
392392
device void* out, \
393393
constant void* input, \
394394
constant void* other, \
395-
constant DTYPEI* alpha, \
395+
constant DTYPEI& alpha, \
396396
constant long* sizes, \
397397
constant long* output_strides, \
398398
constant long* input_strides, \
@@ -406,7 +406,7 @@ kernel void alpha_binary_dense_cast(
406406
out_, \
407407
constant DTYPEI * input_, \
408408
constant DTYPEI * other_, \
409-
constant DTYPEI * alpha, \
409+
constant DTYPEI & alpha, \
410410
uint tid); \
411411
template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
412412
metal::alpha_binary_dense_cast<DTYPEI, NAME##_functor>( \
@@ -415,7 +415,7 @@ kernel void alpha_binary_dense_cast(
415415
out_, \
416416
constant void* input, \
417417
constant void* other, \
418-
constant DTYPEI* alpha, \
418+
constant DTYPEI& alpha, \
419419
constant uint4& sizes_types, \
420420
uint tid)
421421
} // namespace metal

0 commit comments

Comments
 (0)
0