8000 add set/restore context in CUDA provider free() by bratpiorka · Pull Request #1049 · oneapi-src/unified-memory-framework · GitHub
[go: up one dir, main page]

Skip to content

add set/restore context in CUDA provider free() #1049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/provider/provider_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,14 @@ static umf_result_t cu_memory_provider_free(void *provider, void *ptr,

cu_memory_provider_t *cu_provider = (cu_memory_provider_t *)provider;

// Remember current context and set the one from the provider
CUcontext restore_ctx = NULL;
umf_result_t umf_result = set_context(cu_provider->context, &restore_ctx);
if (umf_result != UMF_RESULT_SUCCESS) {
LOG_ERR("Failed to set CUDA context, ret = %d", umf_result);
return umf_result;
}

CUresult cu_result = CUDA_SUCCESS;
switch (cu_provider->memory_type) {
case UMF_MEMORY_TYPE_HOST: {
Expand All @@ -451,6 +459,11 @@ static umf_result_t cu_memory_provider_free(void *provider, void *ptr,
return UMF_RESULT_ERROR_UNKNOWN;
}

umf_result = set_context(restore_ctx, &restore_ctx);
if (umf_result != UMF_RESULT_SUCCESS) {
LOG_ERR("Failed to restore CUDA context, ret = %d", umf_result);
}

return cu2umf_result(cu_result);
}

Expand Down
9 changes: 6 additions & 3 deletions test/providers/cuda_helpers.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2024 Intel Corporation
* Copyright (C) 2024-2025 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand Down Expand Up @@ -251,15 +251,18 @@ int InitCUDAOps() {
}
#endif // USE_DLOPEN

static CUresult set_context(CUcontext required_ctx, CUcontext *restore_ctx) {
CUresult set_context(CUcontext required_ctx, CUcontext *restore_ctx) {
CUcontext current_ctx = NULL;
CUresult cu_result = libcu_ops.cuCtxGetCurrent(&current_ctx);
if (cu_result != CUDA_SUCCESS) {
fprintf(stderr, "cuCtxGetCurrent() failed.\n");
return cu_result;
}

*restore_ctx = current_ctx;
if (restore_ctx != NULL) {
*restore_ctx = current_ctx;
}

if (current_ctx != required_ctx) {
cu_result = libcu_ops.cuCtxSetCurrent(required_ctx);
if (cu_result != CUDA_SUCCESS) {
Expand Down
4 changes: 3 additions & 1 deletion test/providers/cuda_helpers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2024 Intel Corporation
* Copyright (C) 2024-2025 Intel Corporation
*
* Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand Down Expand Up @@ -30,6 +30,8 @@ int get_cuda_device(CUdevice *device);

int create_context(CUdevice device, CUcontext *context);

CUresult set_context(CUcontext required_ctx, CUcontext *restore_ctx);

int destroy_context(CUcontext context);

int cuda_fill(CUcontext context, CUdevice device, void *ptr, size_t size,
Expand Down
67 changes: 66 additions & 1 deletion test/providers/provider_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024 Intel Corporation
// Copyright (C) 2024-2025 Intel Corporation
// Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

Expand Down Expand Up @@ -315,6 +315,71 @@ TEST_P(umfCUDAProviderTest, cudaProviderNullParams) {
EXPECT_EQ(res, UMF_RESULT_ERROR_INVALID_ARGUMENT);
}

TEST_P(umfCUDAProviderTest, multiContext) {
CUdevice device;
int ret = get_cuda_device(&device);
ASSERT_EQ(ret, 0);

// create two CUDA contexts and two providers
CUcontext ctx1, ctx2;
ret = create_context(device, &ctx1);
ASSERT_EQ(ret, 0);
ret = create_context(device, &ctx2);
ASSERT_EQ(ret, 0);

cuda_params_unique_handle_t params1 =
create_cuda_prov_params(ctx1, device, UMF_MEMORY_TYPE_HOST);
ASSERT_NE(params1, nullptr);
umf_memory_provider_handle_t provider1;
umf_result_t umf_result = umfMemoryProviderCreate(
umfCUDAMemoryProviderOps(), params1.get(), &provider1);
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
ASSERT_NE(provider1, nullptr);

cuda_params_unique_handle_t params2 =
create_cuda_prov_params(ctx2, device, UMF_MEMORY_TYPE_HOST);
ASSERT_NE(params2, nullptr);
umf_memory_provider_handle_t provider2;
umf_result = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(),
params2.get(), &provider2);
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
ASSERT_NE(provider2, nullptr);

// use the providers
// allocate from 1, then from 2, then free 1, then free 2
void *ptr1, *ptr2;
const int size = 128;
// NOTE: we use ctx1 here
umf_result = umfMemoryProviderAlloc(provider1, size, 0, &ptr1);
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
ASSERT_NE(ptr1, nullptr);

// NOTE: we use ctx2 here
umf_result = umfMemoryProviderAlloc(provider2, size, 0, &ptr2);
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);
ASSERT_NE(ptr2, nullptr);

// even if we change the context, we should be able to free the memory
ret = set_context(ctx2, NULL);
ASSERT_EQ(ret, 0);
// free memory from ctx1
umf_result = umfMemoryProviderFree(provider1, ptr1, size);
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);

ret = set_context(ctx1, NULL);
ASSERT_EQ(ret, 0);
umf_result = umfMemoryProviderFree(provider2, ptr2, size);
ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS);

// cleanup
umfMemoryProviderDestroy(provider2);
umfMemoryProviderDestroy(provider1);
ret = destroy_context(ctx1);
ASSERT_EQ(ret, 0);
ret = destroy_context(ctx2);
ASSERT_EQ(ret, 0);
}

// TODO add tests that mixes CUDA Memory Provider and Disjoint Pool

CUDATestHelper cudaTestHelper;
Expand Down
Loading
0