8000 [MPS][BE] Use conveinence methods to set args (#145736) · pytorch/pytorch@30dea84 · GitHub
[go: up one dir, main page]

Skip to content

Commit 30dea84

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS][BE] Use conveinence methods to set args (#145736)
It's better to call `mtl_setArgs` rather than set arguments one by one with the risk of making a typo Also, all interactions with MTLCommandBuffer must be serialized, which is commonly done using dispatch queues Pull Request resolved: #145736 Approved by: https://github.com/Skylion007
1 parent 7db20ff commit 30dea84

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

aten/src/ATen/native/mps/operations/LinearAlgebra.mm

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -815,21 +815,16 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
815815
int64_t numBlocks = (N + NB - 1) / NB;
816816

817817
Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1);
818-
id<MTLBuffer> successBuffer = getMTLBufferStorage(success);
819818

820819
MTLSize threadGroupSize = MTLSizeMake(256, 1, 1);
821-
id<MTLBuffer> outBuffer = getMTLBufferStorage(out);
822-
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
823-
[computeEncoder setBuffer:outBuffer offset:0 atIndex:0];
824-
[computeEncoder setBytes:&N length:sizeof(int64_t) atIndex:2];
825-
[computeEncoder setBytes:&NB length:sizeof(int64_t) atIndex:3];
826820

827821
@autoreleasepool {
828822
dispatch_sync_with_rethrow(stream->queue(), ^() {
823+
auto computeEncoder = stream->commandEncoder();
824+
mtl_setArgs(computeEncoder, out, success, N, NB);
829825
for (int64_t k = 0; k < numBlocks; k++) {
830826
[computeEncoder setComputePipelineState:factorDiagonalPSO];
831-
[computeEncoder setBuffer:successBuffer offset:0 atIndex:1];
832-
[computeEncoder setBytes:&k length:sizeof(int64_t) atIndex:4];
827+
mtl_setBytes(computeEncoder, k, 4);
833828
MTLSize gridSize = MTLSizeMake(B, 1, 1);
834829
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
835830

0 commit comments

Comments
 (0)
0