@@ -196,7 +196,7 @@ kernel void alpha_binary_strided(
196
196
device void * output [[buffer(0 )]],
197
197
constant void* input [[buffer(1 )]],
198
198
constant void* other [[buffer(2 )]],
199
- constant T* alpha [[buffer(3 )]],
199
+ constant T& alpha [[buffer(3 )]],
200
200
constant long* sizes [[buffer(4 )]],
201
201
constant long* output_strides [[buffer(5 )]],
202
202
constant long* input_strides [[buffer(6 )]],
@@ -211,7 +211,7 @@ kernel void alpha_binary_strided(
211
211
const auto output_offs = offset_from_coord (pos, output_strides, ndim.x );
212
212
const auto a = val_at_offs<T>(input, input_offs);
213
213
const auto b = val_at_offs<T>(other, other_offs);
214
- ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f (a, b, * alpha);
214
+ ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f (a, b, alpha);
215
215
}
216
216
217
217
template <typename T, typename F, typename om_t = opmath_t <T>>
@@ -244,7 +244,7 @@ kernel void alpha_binary_strided_cast(
244
244
device void * output [[buffer(0 )]],
245
245
constant void* input [[buffer(1 )]],
246
246
constant void* other [[buffer(2 )]],
247
- constant T* alpha [[buffer(3 )]],
247
+ constant T& alpha [[buffer(3 )]],
248
248
constant long* sizes [[buffer(4 )]],
249
249
constant long* output_strides [[buffer(5 )]],
250
250
constant long* input_strides [[buffer(6 )]],
@@ -261,7 +261,7 @@ kernel void alpha_binary_strided_cast(
261
261
val_at_offs<T>(input, input_offs, static_cast <ScalarType>(ndim_types.y ));
262
262
const auto b =
263
263
val_at_offs<T>(other, other_offs, static_cast <ScalarType>(ndim_types.z ));
264
- ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f (a, b, * alpha);
264
+ ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f (a, b, alpha);
265
265
}
266
266
267
267
template <typename T, typename F, typename om_t = opmath_t <T>>
@@ -280,10 +280,10 @@ kernel void alpha_binary_dense(
280
280
device result_of<F, T, T, T>* out [[buffer(0 )]],
281
281
constant T* input [[buffer(1 )]],
282
282
constant T* other [[buffer(2 )]],
283
- constant T* alpha [[buffer(3 )]],
283
+ constant T& alpha [[buffer(3 )]],
284
284
uint tid [[thread_position_in_grid]]) {
285
285
F f;
286
- out[tid] = f (input[tid], other[tid], * alpha);
286
+ out[tid] = f (input[tid], other[tid], alpha);
287
287
}
288
288
289
289
template <typename T, typename F, typename om_t = T>
8000
@@ -307,15 +307,15 @@ kernel void alpha_binary_dense_cast(
307
307
device result_of<F, T, T, T>* out [[buffer(0 )]],
308
308
constant void* input [[buffer(1 )]],
309
309
constant void* other [[buffer(2 )]],
310
- constant T* alpha [[buffer(3 )]],
310
+ constant T& alpha [[buffer(3 )]],
311
311
constant uint4& sizes_types [[buffer(4 )]],
312
312
uint tid [[thread_position_in_grid]]) {
313
313
F f;
314
314
const auto a = val_at_offs<T>(
315
315
input, tid * sizes_types.x , static_cast <ScalarType>(sizes_types.z ));
316
316
const auto b = val_at_offs<T>(
317
317
other, tid * sizes_types.y , static_cast <ScalarType>(sizes_types.w ));
318
- out[tid] = f (a, b, * alpha);
318
+ out[tid] = f (a, b, alpha);
319
319
}
320
320
321
321
#define REGISTER_BINARY_OP_ (NAME, DTYPEI, DTYPEO, OMT ) \
@@ -380,7 +380,7 @@ kernel void alpha_binary_dense_cast(
380
380
device void * out, \
381
381
constant void * input, \
382
382
constant void * other, \
383
- constant DTYPEI* alpha, \
383
+ constant DTYPEI& alpha, \
384
384
constant long * sizes, \
385
385
constant long * output_strides, \
386
386
constant long * input_strides, \
@@ -392,7 +392,7 @@ kernel void alpha_binary_dense_cast(
392
392
device void * out, \
393
393
constant void * input, \
394
394
constant void * other, \
395
- constant DTYPEI* alpha, \
395
+ constant DTYPEI& alpha, \
396
396
constant long * sizes, \
397
397
constant long * output_strides, \
398
398
constant long * input_strides, \
@@ -406,7 +406,7 @@ kernel void alpha_binary_dense_cast(
406
406
out_, \
407
407
constant DTYPEI * input_, \
408
408
constant DTYPEI * other_, \
409
- constant DTYPEI * alpha, \
409
+ constant DTYPEI & alpha, \
410
410
uint tid); \
411
411
template [[host_name(#NAME " _dense_cast_" #DTYPEI)]] kernel void ::c10:: \
412
412
metal::alpha_binary_dense_cast<DTYPEI, NAME##_functor>( \
@@ -415,7 +415,7 @@ kernel void alpha_binary_dense_cast(
415
415
out_, \
416
416
constant void * input, \
417
417
constant void * other, \
418
- constant DTYPEI* alpha, \
418
+ constant DTYPEI& alpha, \
419
419
constant uint4& sizes_types, \
420
420
uint tid)
421
421
} // namespace metal
0 commit comments