diff --git a/test/providers/provider_cuda.cpp b/test/providers/provider_cuda.cpp index a7e5dbe5a..00d8e9d8f 100644 --- a/test/providers/provider_cuda.cpp +++ b/test/providers/provider_cuda.cpp @@ -383,6 +383,48 @@ TEST_P(umfCUDAProviderTest, cudaProviderNullParams) { EXPECT_EQ(res, UMF_RESULT_ERROR_INVALID_ARGUMENT); } +TEST_P(umfCUDAProviderTest, cudaProviderInvalidCreate) { + CUdevice device; + int ret = get_cuda_device(&device); + ASSERT_EQ(ret, 0); + + CUcontext ctx; + ret = create_context(device, &ctx); + ASSERT_EQ(ret, 0); + + // wrong memory type + umf_cuda_memory_provider_params_handle_t params_wrong_memtype = + create_cuda_prov_params(ctx, device, + static_cast(0xFFFF), 0); + ASSERT_NE(params_wrong_memtype, nullptr); + umf_memory_provider_handle_t provider = nullptr; + umf_result_t umf_result = umfMemoryProviderCreate( + umfCUDAMemoryProviderOps(), params_wrong_memtype, &provider); + ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT); + umf_result = umfCUDAMemoryProviderParamsDestroy(params_wrong_memtype); + ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS); + + // wrong context + umf_cuda_memory_provider_params_handle_t params_wrong_ctx = + create_cuda_prov_params(nullptr, device, UMF_MEMORY_TYPE_HOST, 0); + ASSERT_NE(params_wrong_ctx, nullptr); + umf_result = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(), + params_wrong_ctx, &provider); + ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT); + umf_result = umfCUDAMemoryProviderParamsDestroy(params_wrong_ctx); + ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS); + + // wrong device + umf_cuda_memory_provider_params_handle_t params_wrong_device = + create_cuda_prov_params(ctx, (CUdevice)-1, UMF_MEMORY_TYPE_HOST, 0); + ASSERT_NE(params_wrong_device, nullptr); + umf_result = umfMemoryProviderCreate(umfCUDAMemoryProviderOps(), + params_wrong_device, &provider); + ASSERT_EQ(umf_result, UMF_RESULT_ERROR_INVALID_ARGUMENT); + umf_result = umfCUDAMemoryProviderParamsDestroy(params_wrong_device); + ASSERT_EQ(umf_result, UMF_RESULT_SUCCESS); +} + TEST_P(umfCUDAProviderTest, multiContext) { CUdevice device; int ret = get_cuda_device(&device);