8000 [MPSHooks] Release pending command encoder · pytorch/pytorch@1e83dd9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1e83dd9

Browse files
committed
[MPSHooks] Release pending command encoder
Before returning a comand buffer, as subsequent calle are very likely to allocate their own encoder, which results in the following runtime error ``` tryCoalescingPreviousComputeCommandEncoderWithConfig:nextEncoderClass:]:1090: failed assertion `A command encoder is already encoding to this command buffer' ``` Fixes #163721 ghstack-source-id: b214852 Pull Request resolved: #164093
1 parent 6ba83e0 commit 1e83dd9

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

aten/src/ATen/mps/MPSHooks.mm

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@
7070
}
7171

7272
void* MPSHooks::getCommandBuffer() const {
73-
return at::mps::getDefaultMPSStream()->commandBuffer();
73+
auto stream = at::mps::getDefaultMPSStream();
74+
// Release pending computeCommandEncoder, as extensions is likely to allocate new one
75+
stream->endKernelCoalescing();
76+
return stream->commandBuffer();
7477
}
7578

7679
void* MPSHooks::getDispatchQueue() const {

test/cpp_extensions/mps_extension.mm

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ kernel void add_arrays(device const float* inA,
1313
{
1414
result[index] = inA[index] + inB[index];
1515
}
16+
17+
kernel void add_one(device float* data,
18+
uint index [[thread_position_in_grid]]) {
19+
data[index] += 1.0;
20+
}
1621
)MPS_ADD_ARRAYS");
1722

1823
at::Tensor get_cpu_add_output(at::Tensor & cpu_input1, at::Tensor & cpu_input2) {
@@ -50,7 +55,31 @@ kernel void add_arrays(device const float* inA,
5055
return mps_output;
5156
}
5257

58+
void mps_add_one_new_encoder(const at::Tensor& input) {
59+
using namespace at::native::mps;
60+
TORCH_CHECK(input.is_mps());
61+
TORCH_CHECK(input.numel() > 0);
62+
63+
@autoreleasepool {
64+
auto kernelPSO = lib.getPipelineStateForFunc("add_one");
65+
auto serialQueue = torch::mps::get_dispatch_queue();
66+
67+
dispatch_sync(serialQueue, ^(){
68+
auto commandBuffer = torch::mps::get_command_buffer();
69+
// Start a compute pass.
70+
auto computeEncoder = [commandBuffer computeCommandEncoder];
71+
TORCH_CHECK(computeEncoder, "Failed to create compute command encoder");
72+
[computeEncoder setComputePipelineState: kernelPSO];
73+
mtl_setArgs(computeEncoder, input);
74+
mtl_dispatch1DJob(computeEncoder, kernelPSO, input.numel());
75+
[computeEncoder endEncoding];
76+
torch::mps::commit();
77+
});
78+
}
79+
}
80+
5381
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5482
m.def("get_cpu_add_output", &get_cpu_add_output);
5583
m.def("get_mps_add_output", &get_mps_add_output);
84+
m.def("mps_add_one_new_context", &mps_add_one_new_encoder);
5685
}

test/test_cpp_extensions_jit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,12 @@ def test_mps_extension(self):
220220

221221
self.assertEqual(cpu_output, mps_output.to("cpu"))
222222

223+
# Regression test 6761 for https://github.com/pytorch/pytorch/issues/163721
224+
lib = torch.mps.compile_shader("void kernel noop(device float *x) {}")
225+
lib.noop(mps_output)
226+
module.mps_add_one_new_context(mps_output)
227+
self.assertEqual(cpu_output + 1.0, mps_output.to("cpu"))
228+
223229
def _run_jit_cuda_archflags(self, flags, expected):
224230
# Compile an extension with given `flags`
225231
def _check_cuobjdump_output(expected_values, is_ptx=False):

0 commit comments

Comments
 (0)
0