diff --git a/CMakeLists.txt b/CMakeLists.txt index c3330ce31b..a18510f7c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception cmake_minimum_required(VERSION 3.20.0 FATAL_ERROR) -project(unified-runtime VERSION 0.11.0) +project(unified-runtime VERSION 0.11.10) # Check if unified runtime is built as a standalone project. if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR UR_STANDALONE_BUILD) diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index b60d2944b1..353b0ebc41 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -52,7 +52,7 @@ commandHandleReleaseInternal(ur_exp_command_buffer_command_handle_t Command) { // of the `ur_event_t` object doesn't free the underlying CuEvent_t object and // we need to do it manually ourselves. if (Command->SignalNode) { - CUevent SignalEvent; + CUevent SignalEvent{}; UR_CHECK_ERROR( cuGraphEventRecordNodeGetEvent(Command->SignalNode, &SignalEvent)); UR_CHECK_ERROR(cuEventDestroy(SignalEvent)); @@ -85,7 +85,7 @@ ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_() { std::unique_ptr ur_exp_command_buffer_handle_t_::addSignalNode(CUgraphNode DepNode, CUgraphNode &SignalNode) { - CUevent Event; + CUevent Event{}; UR_CHECK_ERROR(cuEventCreate(&Event, CU_EVENT_DEFAULT)); UR_CHECK_ERROR( cuGraphAddEventRecordNode(&SignalNode, CudaGraph, &DepNode, 1, Event)); @@ -1433,7 +1433,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateSignalEventExp( return UR_RESULT_ERROR_INVALID_OPERATION; } - CUevent SignalEvent; + CUevent SignalEvent{}; UR_CHECK_ERROR(cuGraphEventRecordNodeGetEvent(SignalNode, &SignalEvent)); if (phEvent) { diff --git a/source/adapters/cuda/device.cpp b/source/adapters/cuda/device.cpp index d8916ccedd..ef306abf82 100644 --- a/source/adapters/cuda/device.cpp +++ b/source/adapters/cuda/device.cpp @@ -1238,7 +1238,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice, uint64_t *pDeviceTimestamp, uint64_t *pHostTimestamp) { - CUevent Event; + CUevent Event{}; ScopedContext Active(hDevice); if (pDeviceTimestamp) { diff --git a/source/adapters/cuda/platform.cpp b/source/adapters/cuda/platform.cpp index 7ce0bba9e7..cad809a1e5 100644 --- a/source/adapters/cuda/platform.cpp +++ b/source/adapters/cuda/platform.cpp @@ -84,7 +84,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, UR_CHECK_ERROR(cuDevicePrimaryCtxRetain(&Context, Device)); ScopedContext Active(Context); // Set native ctx as active - CUevent EvBase; + CUevent EvBase{}; UR_CHECK_ERROR(cuEventCreate(&EvBase, CU_EVENT_DEFAULT)); // Use default stream to record base event counter diff --git a/source/adapters/cuda/tensor_map.cpp b/source/adapters/cuda/tensor_map.cpp index da8e4f8f8c..1730b79d41 100644 --- a/source/adapters/cuda/tensor_map.cpp +++ b/source/adapters/cuda/tensor_map.cpp @@ -13,6 +13,24 @@ #include "context.hpp" +#if CUDA_VERSION < 12000 +UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeIm2ColExp( + ur_device_handle_t, ur_exp_tensor_map_data_type_flags_t, uint32_t, void *, + const uint64_t *, const uint64_t *, const int *, const int *, uint32_t, + uint32_t, const uint32_t *, ur_exp_tensor_map_interleave_flags_t, + ur_exp_tensor_map_swizzle_flags_t, ur_exp_tensor_map_l2_promotion_flags_t, + ur_exp_tensor_map_oob_fill_flags_t, ur_exp_tensor_map_handle_t *) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} +UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeTiledExp( + ur_device_handle_t, ur_exp_tensor_map_data_type_flags_t, uint32_t, void *, + const uint64_t *, const uint64_t *, const uint32_t *, const uint32_t *, + ur_exp_tensor_map_interleave_flags_t, ur_exp_tensor_map_swizzle_flags_t, + ur_exp_tensor_map_l2_promotion_flags_t, ur_exp_tensor_map_oob_fill_flags_t, + ur_exp_tensor_map_handle_t *) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} +#else struct ur_exp_tensor_map_handle_t_ { CUtensorMap Map; }; @@ -140,3 +158,4 @@ UR_APIEXPORT ur_result_t UR_APICALL urTensorMapEncodeTiledExp( } return UR_RESULT_SUCCESS; } +#endif diff --git a/source/adapters/level_zero/adapter.cpp b/source/adapters/level_zero/adapter.cpp index 7dff6bcf14..68aa852595 100644 --- a/source/adapters/level_zero/adapter.cpp +++ b/source/adapters/level_zero/adapter.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "adapter.hpp" +#include "common.hpp" #include "ur_level_zero.hpp" #include @@ -162,7 +163,7 @@ ur_result_t initPlatforms(PlatformVec &platforms, ZE2UR_CALL(zeDriverGet, (&ZeDriverGetCount, ZeDriverGetHandles.data())); } if (ZeDriverGetCount == 0 && GlobalAdapter->ZeInitDriversCount == 0) { - logger::debug("\nNo Valid L0 Drivers found.\n"); + logger::error("\nNo Valid L0 Drivers found.\n"); return UR_RESULT_SUCCESS; } @@ -376,7 +377,9 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() static_cast(L0InitFlags)); GlobalAdapter->ZeInitResult = ZE_CALL_NOCHECK(zeInit, (L0InitFlags)); if (GlobalAdapter->ZeInitResult != ZE_RESULT_SUCCESS) { - logger::debug("\nzeInit failed with {}\n", GlobalAdapter->ZeInitResult); + const char *ErrorString = "Unknown"; + zeParseError(GlobalAdapter->ZeInitResult, ErrorString); + logger::error("\nzeInit failed with {}\n", ErrorString); } bool useInitDrivers = false; @@ -422,8 +425,9 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() if (GlobalAdapter->ZeInitDriversResult == ZE_RESULT_SUCCESS) { GlobalAdapter->InitDriversSupported = true; } else { - logger::debug("\nzeInitDrivers failed with {}\n", - GlobalAdapter->ZeInitDriversResult); + const char *ErrorString = "Unknown"; + zeParseError(GlobalAdapter->ZeInitDriversResult, ErrorString); + logger::error("\nzeInitDrivers failed with {}\n", ErrorString); } } } @@ -441,6 +445,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() // Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms. if (*GlobalAdapter->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) { + logger::error("Level Zero Uninitialized\n"); result = std::move(platforms); return; } diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index 5ae19092a6..7065c8e167 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -26,14 +26,9 @@ namespace { // given Context and Device. bool checkImmediateAppendSupport(ur_context_handle_t Context, ur_device_handle_t Device) { - // TODO The L0 driver is not reporting this extension yet. Once it does, - // switch to using the variable zeDriverImmediateCommandListAppendFound. - // Minimum version that supports zeCommandListImmediateAppendCommandListsExp. - constexpr uint32_t MinDriverVersion = 30898; bool DriverSupportsImmediateAppend = - Context->getPlatform()->isDriverVersionNewerOrSimilar(1, 3, - MinDriverVersion); + Context->getPlatform()->ZeCommandListImmediateAppendExt.Supported; // If this environment variable is: // - Set to 1: the immediate append path will always be enabled as long the @@ -58,10 +53,8 @@ bool checkImmediateAppendSupport(ur_context_handle_t Context, if (EnableAppendPath && !DriverSupportsImmediateAppend) { logger::error("{} is set but " "the current driver does not support the " - "zeCommandListImmediateAppendCommandListsExp entrypoint. A " - "driver version of at least {} is required to use the " - "immediate append path.", - AppendEnvVarName, MinDriverVersion); + "zeCommandListImmediateAppendCommandListsExp entrypoint.", + AppendEnvVarName); std::abort(); } @@ -894,28 +887,31 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) { /** * Sets the kernel arguments for a kernel command that will be appended to the * command buffer. - * @param[in] CommandBuffer The CommandBuffer where the command will be + * @param[in] Device The Device associated with the command-buffer where the + * kernel command will be appended. + * @param[in,out] Arguments stored in the ur_kernel_handle_t object to be set + * on the /p ZeKernel object. + * @param[in] ZeKernel The handle to the Level-Zero kernel that will be * appended. - * @param[in] Kernel The handle to the kernel that will be appended. * @return UR_RESULT_SUCCESS or an error code on failure */ -ur_result_t -setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer, - ur_kernel_handle_t Kernel) { - +ur_result_t setKernelPendingArguments( + ur_device_handle_t Device, + std::vector &PendingArguments, + ze_kernel_handle_t ZeKernel) { // If there are any pending arguments set them now. - for (auto &Arg : Kernel->PendingArguments) { + for (auto &Arg : PendingArguments) { // The ArgValue may be a NULL pointer in which case a NULL value is used for // the kernel argument declared as a pointer to global or constant memory. char **ZeHandlePtr = nullptr; if (Arg.Value) { - UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, - CommandBuffer->Device, nullptr, 0u)); + UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode, Device, + nullptr, 0u)); } ZE2UR_CALL(zeKernelSetArgumentValue, - (Kernel->ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); + (ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr)); } - Kernel->PendingArguments.clear(); + PendingArguments.clear(); return UR_RESULT_SUCCESS; } @@ -951,34 +947,54 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer, ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET; auto Platform = CommandBuffer->Context->getPlatform(); + auto ZeDevice = CommandBuffer->Device->ZeDevice; + ze_command_list_handle_t ZeCommandList = + CommandBuffer->ZeComputeCommandListTranslated; + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + ZeCommandList = CommandBuffer->ZeComputeCommandList; + } + if (NumKernelAlternatives > 0) { ZeMutableCommandDesc.flags |= ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION; - std::vector TranslatedKernelHandles( - NumKernelAlternatives + 1, nullptr); + std::vector KernelHandles(NumKernelAlternatives + 1, + nullptr); - // Translate main kernel first - ZE2UR_CALL(zelLoaderTranslateHandle, - (ZEL_HANDLE_KERNEL, Kernel->ZeKernel, - (void **)&TranslatedKernelHandles[0])); + ze_kernel_handle_t ZeMainKernel{}; + UR_CALL(getZeKernel(ZeDevice, Kernel, &ZeMainKernel)); - for (size_t i = 0; i < NumKernelAlternatives; i++) { + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + KernelHandles[0] = ZeMainKernel; + } else { + // If the L0 loader is not aware of the MCL extension, the main kernel + // handle needs to be translated. ZE2UR_CALL(zelLoaderTranslateHandle, - (ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel, - (void **)&TranslatedKernelHandles[i + 1])); + (ZEL_HANDLE_KERNEL, ZeMainKernel, (void **)&KernelHandles[0])); + } + + for (size_t i = 0; i < NumKernelAlternatives; i++) { + ze_kernel_handle_t ZeAltKernel{}; + UR_CALL(getZeKernel(ZeDevice, KernelAlternatives[i], &ZeAltKernel)); + + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + KernelHandles[i + 1] = ZeAltKernel; + } else { + // If the L0 loader is not aware of the MCL extension, the kernel + // alternatives need to be translated. + ZE2UR_CALL(zelLoaderTranslateHandle, (ZEL_HANDLE_KERNEL, ZeAltKernel, + (void **)&KernelHandles[i + 1])); + } } ZE2UR_CALL(Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdWithKernelsExp, - (CommandBuffer->ZeComputeCommandListTranslated, - &ZeMutableCommandDesc, NumKernelAlternatives + 1, - TranslatedKernelHandles.data(), &CommandId)); + (ZeCommandList, &ZeMutableCommandDesc, NumKernelAlternatives + 1, + KernelHandles.data(), &CommandId)); } else { ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp, - (CommandBuffer->ZeComputeCommandListTranslated, - &ZeMutableCommandDesc, &CommandId)); + (ZeCommandList, &ZeMutableCommandDesc, &CommandId)); } DEBUG_LOG(CommandId); @@ -1022,23 +1038,28 @@ ur_result_t urCommandBufferAppendKernelLaunchExp( std::scoped_lock Lock( Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex); + auto Device = CommandBuffer->Device; + ze_kernel_handle_t ZeKernel{}; + UR_CALL(getZeKernel(Device->ZeDevice, Kernel, &ZeKernel)); + if (GlobalWorkOffset != NULL) { - UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, Kernel->ZeKernel, - WorkDim, GlobalWorkOffset)); + UR_CALL(setKernelGlobalOffset(CommandBuffer->Context, ZeKernel, WorkDim, + GlobalWorkOffset)); } // If there are any pending arguments set them now. if (!Kernel->PendingArguments.empty()) { - UR_CALL(setKernelPendingArguments(CommandBuffer, Kernel)); + UR_CALL( + setKernelPendingArguments(Device, Kernel->PendingArguments, ZeKernel)); } ze_group_count_t ZeThreadGroupDimensions{1, 1, 1}; uint32_t WG[3]; - UR_CALL(calculateKernelWorkDimensions(Kernel->ZeKernel, CommandBuffer->Device, + UR_CALL(calculateKernelWorkDimensions(ZeKernel, Device, ZeThreadGroupDimensions, WG, WorkDim, GlobalWorkSize, LocalWorkSize)); - ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2])); + ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2])); CommandBuffer->KernelsList.push_back(Kernel); for (size_t i = 0; i < NumKernelAlternatives; i++) { @@ -1063,7 +1084,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp( SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent)); ZE2UR_CALL(zeCommandListAppendLaunchKernel, - (CommandBuffer->ZeComputeCommandList, Kernel->ZeKernel, + (CommandBuffer->ZeComputeCommandList, ZeKernel, &ZeThreadGroupDimensions, ZeLaunchEvent, ZeEventList.size(), getPointerFromVector(ZeEventList))); @@ -1552,7 +1573,10 @@ ur_result_t enqueueImmediateAppendPath( ur_event_handle_t *Event, ur_command_list_ptr_t CommandListHelper, bool DoProfiling) { + ur_platform_handle_t Platform = CommandBuffer->Context->getPlatform(); + assert(CommandListHelper->second.IsImmediate); + assert(Platform->ZeCommandListImmediateAppendExt.Supported); _ur_ze_event_list_t UrZeEventList; if (NumEventsInWaitList) { @@ -1570,7 +1594,8 @@ ur_result_t enqueueImmediateAppendPath( nullptr /*ForcedCmdQueue*/)); assert(ZeCopyEngineImmediateListHelper->second.IsImmediate); - ZE2UR_CALL(zeCommandListImmediateAppendCommandListsExp, + ZE2UR_CALL(Platform->ZeCommandListImmediateAppendExt + .zeCommandListImmediateAppendCommandListsExp, (ZeCopyEngineImmediateListHelper->first, 1, &CommandBuffer->ZeCopyCommandList, nullptr, UrZeEventList.Length, UrZeEventList.ZeEventList)); @@ -1582,7 +1607,8 @@ ur_result_t enqueueImmediateAppendPath( ze_event_handle_t &EventToSignal = DoProfiling ? CommandBuffer->ComputeFinishedEvent->ZeEvent : (*Event)->ZeEvent; - ZE2UR_CALL(zeCommandListImmediateAppendCommandListsExp, + ZE2UR_CALL(Platform->ZeCommandListImmediateAppendExt + .zeCommandListImmediateAppendCommandListsExp, (CommandListHelper->first, 1, &CommandBuffer->ZeComputeCommandList, EventToSignal, WaitList.Length, WaitList.ZeEventList)); @@ -1599,7 +1625,8 @@ ur_result_t enqueueImmediateAppendPath( (CommandListHelper->first, CommandBuffer->ExecutionFinishedEvent->ZeEvent, 0, nullptr)); - ZE2UR_CALL(zeCommandListImmediateAppendCommandListsExp, + ZE2UR_CALL(Platform->ZeCommandListImmediateAppendExt + .zeCommandListImmediateAppendCommandListsExp, (CommandListHelper->first, 1, &CommandBuffer->ZeCommandListResetEvents, nullptr, 0, nullptr)); } @@ -1836,6 +1863,7 @@ ur_result_t updateKernelCommand( const auto CommandBuffer = Command->CommandBuffer; const void *NextDesc = nullptr; auto Platform = CommandBuffer->Context->getPlatform(); + auto ZeDevice = CommandBuffer->Device->ZeDevice; uint32_t Dim = CommandDesc->newWorkDim; size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset; @@ -1844,16 +1872,24 @@ ur_result_t updateKernelCommand( // Kernel handle must be updated first for a given CommandId if required ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel; + if (NewKernel && Command->Kernel != NewKernel) { - ze_kernel_handle_t ZeKernelTranslated = nullptr; - ZE2UR_CALL( - zelLoaderTranslateHandle, - (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated)); + ze_kernel_handle_t KernelHandle{}; + ze_kernel_handle_t ZeNewKernel{}; + UR_CALL(getZeKernel(ZeDevice, NewKernel, &ZeNewKernel)); + + ze_command_list_handle_t ZeCommandList = + CommandBuffer->ZeComputeCommandList; + KernelHandle = ZeNewKernel; + if (!Platform->ZeMutableCmdListExt.LoaderExtension) { + ZeCommandList = CommandBuffer->ZeComputeCommandListTranslated; + ZE2UR_CALL(zelLoaderTranslateHandle, + (ZEL_HANDLE_KERNEL, ZeNewKernel, (void **)&KernelHandle)); + } ZE2UR_CALL(Platform->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandKernelsExp, - (CommandBuffer->ZeComputeCommandListTranslated, 1, - &Command->CommandId, &ZeKernelTranslated)); + (ZeCommandList, 1, &Command->CommandId, &KernelHandle)); // Set current kernel to be the new kernel Command->Kernel = NewKernel; } @@ -1905,10 +1941,13 @@ ur_result_t updateKernelCommand( // by the driver for the kernel. bool UpdateWGSize = NewLocalWorkSize == nullptr; + ze_kernel_handle_t ZeKernel{}; + UR_CALL(getZeKernel(ZeDevice, Command->Kernel, &ZeKernel)); + uint32_t WG[3]; - UR_CALL(calculateKernelWorkDimensions( - Command->Kernel->ZeKernel, CommandBuffer->Device, - ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize)); + UR_CALL(calculateKernelWorkDimensions(ZeKernel, CommandBuffer->Device, + ZeThreadGroupDimensions, WG, Dim, + NewGlobalWorkSize, NewLocalWorkSize)); auto MutableGroupCountDesc = std::make_unique>(); @@ -2056,9 +2095,15 @@ ur_result_t updateKernelCommand( MutableCommandDesc.pNext = NextDesc; MutableCommandDesc.flags = 0; + ze_command_list_handle_t ZeCommandList = + CommandBuffer->ZeComputeCommandListTranslated; + if (Platform->ZeMutableCmdListExt.LoaderExtension) { + ZeCommandList = CommandBuffer->ZeComputeCommandList; + } + ZE2UR_CALL( Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp, - (CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc)); + (ZeCommandList, &MutableCommandDesc)); return UR_RESULT_SUCCESS; } diff --git a/source/adapters/level_zero/common.cpp b/source/adapters/level_zero/common.cpp index 3b3f59e055..e13afc179f 100644 --- a/source/adapters/level_zero/common.cpp +++ b/source/adapters/level_zero/common.cpp @@ -88,7 +88,7 @@ ZeUSMImportExtension ZeUSMImport; std::map *ZeCallCount = nullptr; -inline void zeParseError(ze_result_t ZeError, const char *&ErrorString) { +void zeParseError(ze_result_t ZeError, const char *&ErrorString) { switch (ZeError) { #define ZE_ERRCASE(ERR) \ case ERR: \ diff --git a/source/adapters/level_zero/common.hpp b/source/adapters/level_zero/common.hpp index 8a93993752..09d144df82 100644 --- a/source/adapters/level_zero/common.hpp +++ b/source/adapters/level_zero/common.hpp @@ -340,6 +340,9 @@ bool setEnvVar(const char *name, const char *value); // Map Level Zero runtime error code to UR error code. ur_result_t ze2urResult(ze_result_t ZeResult); +// Parse Level Zero error code and return the error string. +void zeParseError(ze_result_t ZeError, const char *&ErrorString); + // Trace a call to Level-Zero RT #define ZE2UR_CALL(ZeName, ZeArgs) \ { \ diff --git a/source/adapters/level_zero/context.cpp b/source/adapters/level_zero/context.cpp index faa16d48dd..08f4762b6d 100644 --- a/source/adapters/level_zero/context.cpp +++ b/source/adapters/level_zero/context.cpp @@ -422,6 +422,7 @@ ur_result_t ur_context_handle_t_::finalize() { for (auto &EventCache : EventCaches) { for (auto &Event : EventCache) { auto ZeResult = ZE_CALL_NOCHECK(zeEventDestroy, (Event->ZeEvent)); + Event->ZeEvent = nullptr; // Gracefully handle the case that L0 was already unloaded. if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED) return ze2urResult(ZeResult); @@ -532,6 +533,13 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool( if (*ZePool == nullptr) { ze_event_pool_counter_based_exp_desc_t counterBasedExt = { ZE_STRUCTURE_TYPE_COUNTER_BASED_EVENT_POOL_EXP_DESC, nullptr, 0}; + + ze_intel_event_sync_mode_exp_desc_t eventSyncMode = { + ZE_INTEL_STRUCTURE_TYPE_EVENT_SYNC_MODE_EXP_DESC, nullptr, 0}; + eventSyncMode.syncModeFlags = + ZE_INTEL_EVENT_SYNC_MODE_EXP_FLAG_LOW_POWER_WAIT | + ZE_INTEL_EVENT_SYNC_MODE_EXP_FLAG_SIGNAL_INTERRUPT; + ZeStruct ZeEventPoolDesc; ZeEventPoolDesc.count = MaxNumEventsPerPool; ZeEventPoolDesc.flags = 0; @@ -551,14 +559,11 @@ ur_result_t ur_context_handle_t_::getFreeSlotInExistingOrNewPool( } logger::debug("ze_event_pool_desc_t counter based flags set to: {}", counterBasedExt.flags); + if (InterruptBasedEventEnabled) { + counterBasedExt.pNext = &eventSyncMode; + } ZeEventPoolDesc.pNext = &counterBasedExt; - } - if (InterruptBasedEventEnabled) { - ze_intel_event_sync_mode_exp_desc_t eventSyncMode = { - ZE_INTEL_STRUCTURE_TYPE_EVENT_SYNC_MODE_EXP_DESC, nullptr, 0}; - eventSyncMode.syncModeFlags = - ZE_INTEL_EVENT_SYNC_MODE_EXP_FLAG_LOW_POWER_WAIT | - ZE_INTEL_EVENT_SYNC_MODE_EXP_FLAG_SIGNAL_INTERRUPT; + } else if (InterruptBasedEventEnabled) { ZeEventPoolDesc.pNext = &eventSyncMode; } diff --git a/source/adapters/level_zero/device.cpp b/source/adapters/level_zero/device.cpp index b7422fe2cc..f5f767ea81 100644 --- a/source/adapters/level_zero/device.cpp +++ b/source/adapters/level_zero/device.cpp @@ -654,9 +654,15 @@ ur_result_t urDeviceGetInfo( return ReturnValue(Device->ZeDeviceProperties->physicalEUSimdWidth / 4); case UR_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE: case UR_DEVICE_INFO_PREFERRED_VECTOR_WIDTH_DOUBLE: + // Must return 0 for *vector_width_double* if the device does not have fp64. + if (!(Device->ZeDeviceModuleProperties->flags & ZE_DEVICE_MODULE_FLAG_FP64)) + return ReturnValue(uint32_t{0}); return ReturnValue(Device->ZeDeviceProperties->physicalEUSimdWidth / 8); case UR_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF: case UR_DEVICE_INFO_PREFERRED_VECTOR_WIDTH_HALF: + // Must return 0 for *vector_width_half* if the device does not have fp16. + if (!(Device->ZeDeviceModuleProperties->flags & ZE_DEVICE_MODULE_FLAG_FP16)) + return ReturnValue(uint32_t{0}); return ReturnValue(Device->ZeDeviceProperties->physicalEUSimdWidth / 2); case UR_DEVICE_INFO_MAX_NUM_SUB_GROUPS: { // Max_num_sub_Groups = maxTotalGroupSize/min(set of subGroupSizes); @@ -1484,12 +1490,23 @@ ur_device_handle_t_::useImmediateCommandLists() { bool isDG2OrNewer = this->isIntelDG2OrNewer(); bool isDG2SupportedDriver = this->Platform->isDriverVersionNewerOrSimilar(1, 5, 30820); - if ((isDG2SupportedDriver && isDG2OrNewer) || isPVC()) { + // Disable immediate command lists for DG2 devices on Windows due to driver + // limitations. + bool isLinux = true; +#ifdef _WIN32 + isLinux = false; +#endif + if ((isDG2SupportedDriver && isDG2OrNewer && isLinux) || isPVC() || + isNewerThanIntelDG2()) { return PerQueue; } else { return NotUsed; } } + + logger::info("NOTE: L0 Immediate CommandList Setting: {}", + ImmediateCommandlistsSetting); + switch (ImmediateCommandlistsSetting) { case 0: return NotUsed; diff --git a/source/adapters/level_zero/device.hpp b/source/adapters/level_zero/device.hpp index fb4c519c34..d8f9082af0 100644 --- a/source/adapters/level_zero/device.hpp +++ b/source/adapters/level_zero/device.hpp @@ -196,6 +196,11 @@ struct ur_device_handle_t_ : _ur_object { ZeDeviceIpVersionExt->ipVersion >= 0x030dc000); } + bool isNewerThanIntelDG2() { + return (ZeDeviceProperties->vendorId == 0x8086 && + ZeDeviceIpVersionExt->ipVersion >= 0x030f0000); + } + bool isIntegrated() { return (ZeDeviceProperties->flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED); } diff --git a/source/adapters/level_zero/event.cpp b/source/adapters/level_zero/event.cpp index c1e93483b8..32153689bd 100644 --- a/source/adapters/level_zero/event.cpp +++ b/source/adapters/level_zero/event.cpp @@ -791,7 +791,7 @@ urEventWait(uint32_t NumEvents, ///< [in] number of events in the event list // ur_event_handle_t_ *Event = ur_cast(e); if (!Event->hasExternalRefs()) - die("urEventsWait must not be called for an internal event"); + die("urEventWait must not be called for an internal event"); ze_event_handle_t ZeHostVisibleEvent; if (auto Res = Event->getOrCreateHostVisibleEvent(ZeHostVisibleEvent)) @@ -881,7 +881,14 @@ ur_result_t urEventRelease(ur_event_handle_t Event ///< [in] handle of the event object ) { Event->RefCountExternal--; + bool isEventsWaitCompleted = + Event->CommandType == UR_COMMAND_EVENTS_WAIT && Event->Completed; UR_CALL(urEventReleaseInternal(Event)); + // If this is a Completed Event Wait Out Event, then we need to cleanup the + // event at user release and not at the time of completion. + if (isEventsWaitCompleted) { + UR_CALL(CleanupCompletedEvent((Event), false, false)); + } return UR_RESULT_SUCCESS; } @@ -955,7 +962,6 @@ ur_result_t urEventCreateWithNativeHandle( UREvent = new ur_event_handle_t_(ZeEvent, nullptr /* ZeEventPool */, Context, UR_EXT_COMMAND_TYPE_USER, Properties->isNativeHandleOwned); - UREvent->RefCountExternal++; } catch (const std::bad_alloc &) { @@ -1048,6 +1054,26 @@ ur_result_t ur_event_handle_t_::getOrCreateHostVisibleEvent( return UR_RESULT_SUCCESS; } +/** + * @brief Destructor for the ur_event_handle_t_ class. + * + * This destructor is responsible for cleaning up the event handle when the + * object is destroyed. It checks if the event (`ZeEvent`) is valid and if the + * event has been completed (`Completed`). If both conditions are met, it + * further checks if the associated queue (`UrQueue`) is valid and if it is not + * set to discard events. If all conditions are satisfied, it calls + * `zeEventDestroy` to destroy the event. + * + * This ensures that resources are properly released and avoids potential memory + * leaks or resource mismanagement. + */ +ur_event_handle_t_::~ur_event_handle_t_() { + if (this->ZeEvent && this->Completed) { + if (this->UrQueue && !this->UrQueue->isDiscardEvents()) + ZE_CALL_NOCHECK(zeEventDestroy, (this->ZeEvent)); + } +} + ur_result_t urEventReleaseInternal(ur_event_handle_t Event) { if (!Event->RefCount.decrementAndTest()) return UR_RESULT_SUCCESS; @@ -1070,6 +1096,7 @@ ur_result_t urEventReleaseInternal(ur_event_handle_t Event) { if (Event->OwnNativeHandle) { if (DisableEventsCaching) { auto ZeResult = ZE_CALL_NOCHECK(zeEventDestroy, (Event->ZeEvent)); + Event->ZeEvent = nullptr; // Gracefully handle the case that L0 was already unloaded. if (ZeResult && ZeResult != ZE_RESULT_ERROR_UNINITIALIZED) return ze2urResult(ZeResult); diff --git a/source/adapters/level_zero/event.hpp b/source/adapters/level_zero/event.hpp index de018e7060..efae32f361 100644 --- a/source/adapters/level_zero/event.hpp +++ b/source/adapters/level_zero/event.hpp @@ -156,6 +156,8 @@ struct ur_event_handle_t_ : _ur_object { reinterpret_cast(HostVisibleEvent)); } + ~ur_event_handle_t_(); + // Provide direct access to Context, instead of going via queue. // Not every PI event has a queue, and we need a handle to Context // to get to event pool related information. diff --git a/source/adapters/level_zero/image.cpp b/source/adapters/level_zero/image.cpp index 8437fcff95..8c205f54c5 100644 --- a/source/adapters/level_zero/image.cpp +++ b/source/adapters/level_zero/image.cpp @@ -265,6 +265,16 @@ ur_result_t ze2urImageFormat(const ze_image_desc_t *ZeImageDesc, return UR_RESULT_SUCCESS; } +static bool Is3ChannelOrder(ur_image_channel_order_t ChannelOrder) { + switch (ChannelOrder) { + case UR_IMAGE_CHANNEL_ORDER_RGB: + case UR_IMAGE_CHANNEL_ORDER_RGX: + return true; + default: + return false; + } +} + /// Construct ZE image desc from UR image format and desc. ur_result_t ur2zeImageDesc(const ur_image_format_t *ImageFormat, const ur_image_desc_t *ImageDesc, @@ -843,6 +853,14 @@ ur_result_t urBindlessImagesImageCopyExp( UR_CALL(ur2zeImageDesc(pSrcImageFormat, pSrcImageDesc, ZeImageDesc)); bool UseCopyEngine = hQueue->useCopyEngine(/*PreferCopyEngine*/ true); + // Due to the limitation of the copy engine, disable usage of Copy Engine + // Given 3 channel image + if (Is3ChannelOrder( + ur_cast(pSrcImageFormat->channelOrder)) || + Is3ChannelOrder( + ur_cast(pDstImageFormat->channelOrder))) { + UseCopyEngine = false; + } _ur_ze_event_list_t TmpWaitList; UR_CALL(TmpWaitList.createAndRetainUrZeEventList( @@ -1237,7 +1255,7 @@ ur_result_t urBindlessImagesImportExternalSemaphoreExp( } ZE2UR_CALL(UrPlatform->ZeExternalSemaphoreExt.zexImportExternalSemaphoreExp, - (hDevice->ZeDevice, &ExtSemaphoreHandle, &SemDesc)); + (hDevice->ZeDevice, &SemDesc, &ExtSemaphoreHandle)); *phExternalSemaphoreHandle = (ur_exp_external_semaphore_handle_t)ExtSemaphoreHandle; @@ -1310,7 +1328,7 @@ ur_result_t urBindlessImagesWaitExternalSemaphoreExp( reinterpret_cast(hSemaphore); ZE2UR_CALL(UrPlatform->ZeExternalSemaphoreExt .zexCommandListAppendWaitExternalSemaphoresExp, - (ZeCommandList, &hExtSemaphore, &WaitParams, 1, ZeEvent, + (ZeCommandList, 1, &hExtSemaphore, &WaitParams, ZeEvent, WaitList.Length, WaitList.ZeEventList)); return UR_RESULT_SUCCESS; @@ -1373,7 +1391,7 @@ ur_result_t urBindlessImagesSignalExternalSemaphoreExp( ZE2UR_CALL(UrPlatform->ZeExternalSemaphoreExt .zexCommandListAppendSignalExternalSemaphoresExp, - (ZeCommandList, &hExtSemaphore, &SignalParams, 1, ZeEvent, + (ZeCommandList, 1, &hExtSemaphore, &SignalParams, ZeEvent, WaitList.Length, WaitList.ZeEventList)); return UR_RESULT_SUCCESS; diff --git a/source/adapters/level_zero/memory.cpp b/source/adapters/level_zero/memory.cpp index 5283ea4da3..cf3a3197c4 100644 --- a/source/adapters/level_zero/memory.cpp +++ b/source/adapters/level_zero/memory.cpp @@ -2368,6 +2368,11 @@ _ur_buffer::_ur_buffer(ur_context_handle_t Context, size_t Size, LastDeviceWithValidAllocation = Device; } +_ur_buffer::~_ur_buffer() { + if (isSubBuffer()) + ur::level_zero::urMemRelease(SubBuffer->Parent); +} + ur_result_t _ur_buffer::getZeHandlePtr(char **&ZeHandlePtr, access_mode_t AccessMode, ur_device_handle_t Device, diff --git a/source/adapters/level_zero/memory.hpp b/source/adapters/level_zero/memory.hpp index c2e653b297..652f3e3363 100644 --- a/source/adapters/level_zero/memory.hpp +++ b/source/adapters/level_zero/memory.hpp @@ -116,6 +116,8 @@ struct _ur_buffer final : ur_mem_handle_t_ { Parent->RefCount.increment(); } + ~_ur_buffer(); + // Interop-buffer constructor _ur_buffer(ur_context_handle_t Context, size_t Size, ur_device_handle_t Device, char *ZeMemHandle, bool OwnZeMemHandle); diff --git a/source/adapters/level_zero/platform.cpp b/source/adapters/level_zero/platform.cpp index 2bfc9302db..5e093aa646 100644 --- a/source/adapters/level_zero/platform.cpp +++ b/source/adapters/level_zero/platform.cpp @@ -222,6 +222,7 @@ ur_result_t ur_platform_handle_t_::initialize() { bool MutableCommandListSpecExtensionSupported = false; bool ZeIntelExternalSemaphoreExtensionSupported = false; + bool ZeImmediateCommandListAppendExtensionFound = false; for (auto &extension : ZeExtensions) { // Check if global offset extension is available if (strncmp(extension.name, ZE_GLOBAL_OFFSET_EXP_NAME, @@ -246,6 +247,14 @@ ur_result_t ur_platform_handle_t_::initialize() { ZeDriverEventPoolCountingEventsExtensionFound = true; } } + // Check if the ImmediateAppendCommandLists extension is available. + if (strncmp(extension.name, ZE_IMMEDIATE_COMMAND_LIST_APPEND_EXP_NAME, + strlen(ZE_IMMEDIATE_COMMAND_LIST_APPEND_EXP_NAME) + 1) == 0) { + if (extension.version == + ZE_IMMEDIATE_COMMAND_LIST_APPEND_EXP_VERSION_CURRENT) { + ZeImmediateCommandListAppendExtensionFound = true; + } + } // Check if extension is available for Mutable Command List v1.1. if (strncmp(extension.name, ZE_MUTABLE_COMMAND_LIST_EXP_NAME, strlen(ZE_MUTABLE_COMMAND_LIST_EXP_NAME) + 1) == 0) { @@ -375,6 +384,7 @@ ur_result_t ur_platform_handle_t_::initialize() { ZeMutableCmdListExt.Supported |= ZeMutableCmdListExt.zexCommandListGetNextCommandIdWithKernelsExp != nullptr; + ZeMutableCmdListExt.LoaderExtension = true; } else { ZeMutableCmdListExt.Supported |= (ZE_CALL_NOCHECK( @@ -425,6 +435,21 @@ ur_result_t ur_platform_handle_t_::initialize() { &ZeMutableCmdListExt .zexCommandListGetNextCommandIdWithKernelsExp))) == 0); } + + // Check if ImmediateAppendCommandList is supported and initialize the + // function pointer. + if (ZeImmediateCommandListAppendExtensionFound) { + ZeCommandListImmediateAppendExt + .zeCommandListImmediateAppendCommandListsExp = + (ze_pfnCommandListImmediateAppendCommandListsExp_t) + ur_loader::LibLoader::getFunctionPtr( + GlobalAdapter->processHandle, + "zeCommandListImmediateAppendCommandListsExp"); + ZeCommandListImmediateAppendExt.Supported = + ZeCommandListImmediateAppendExt + .zeCommandListImmediateAppendCommandListsExp != nullptr; + } + return UR_RESULT_SUCCESS; } diff --git a/source/adapters/level_zero/platform.hpp b/source/adapters/level_zero/platform.hpp index 4b613fb1e5..1381f51bca 100644 --- a/source/adapters/level_zero/platform.hpp +++ b/source/adapters/level_zero/platform.hpp @@ -96,6 +96,12 @@ struct ur_platform_handle_t_ : public _ur_platform { // associated with particular Level Zero driver, store this extension here. struct ZeMutableCmdListExtension { bool Supported = false; + // If LoaderExtension is true, the L0 loader is aware of the MCL extension. + // If it is false, the extension has to be loaded directly from the driver + // using zeDriverGetExtensionFunctionAddress. If it is loaded directly from + // the driver, any handles passed to it must be translated using + // zelLoaderTranslateHandle. + bool LoaderExtension = false; ze_result_t (*zexCommandListGetNextCommandIdExp)( ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *, uint64_t *) = nullptr; @@ -119,19 +125,26 @@ struct ur_platform_handle_t_ : public _ur_platform { struct ZeExternalSemaphoreExtension { bool Supported = false; ze_result_t (*zexImportExternalSemaphoreExp)( - ze_device_handle_t, ze_intel_external_semaphore_exp_handle_t *, - const ze_intel_external_semaphore_exp_desc_t *); + ze_device_handle_t, const ze_intel_external_semaphore_exp_desc_t *, + ze_intel_external_semaphore_exp_handle_t *); ze_result_t (*zexCommandListAppendWaitExternalSemaphoresExp)( - ze_command_list_handle_t, + ze_command_list_handle_t, unsigned int, const ze_intel_external_semaphore_exp_handle_t *, - const ze_intel_external_semaphore_wait_exp_params_t *, unsigned int, + const ze_intel_external_semaphore_wait_exp_params_t *, ze_event_handle_t, uint32_t, ze_event_handle_t *); ze_result_t (*zexCommandListAppendSignalExternalSemaphoresExp)( - ze_command_list_handle_t, + ze_command_list_handle_t, size_t, const ze_intel_external_semaphore_exp_handle_t *, - const ze_intel_external_semaphore_signal_exp_params_t *, size_t, + const ze_intel_external_semaphore_signal_exp_params_t *, ze_event_handle_t, uint32_t, ze_event_handle_t *); ze_result_t (*zexDeviceReleaseExternalSemaphoreExp)( ze_intel_external_semaphore_exp_handle_t); } ZeExternalSemaphoreExt; -}; \ No newline at end of file + + struct ZeCommandListImmediateAppendExtension { + bool Supported = false; + ze_result_t (*zeCommandListImmediateAppendCommandListsExp)( + ze_command_list_handle_t, uint32_t, ze_command_list_handle_t *, + ze_event_handle_t, uint32_t, ze_event_handle_t *); + } ZeCommandListImmediateAppendExt; +}; diff --git a/source/adapters/level_zero/program.cpp b/source/adapters/level_zero/program.cpp index be8c366d6b..b5a64c3eda 100644 --- a/source/adapters/level_zero/program.cpp +++ b/source/adapters/level_zero/program.cpp @@ -452,11 +452,9 @@ ur_result_t urProgramLinkExp( // Build flags may be different for different devices, so handle them // here. Clear values of the previous device first. BuildFlagPtrs.clear(); - std::vector TemporaryOptionsStrings; for (uint32_t I = 0; I < count; I++) { - TemporaryOptionsStrings.push_back( - phPrograms[I]->getBuildOptions(ZeDevice)); - BuildFlagPtrs.push_back(TemporaryOptionsStrings.back().c_str()); + BuildFlagPtrs.push_back( + phPrograms[I]->getBuildOptions(ZeDevice).c_str()); } ZeExtModuleDesc.pBuildFlags = BuildFlagPtrs.data(); if (count == 1) diff --git a/source/adapters/level_zero/program.hpp b/source/adapters/level_zero/program.hpp index 4fe8c24acd..90b297fa40 100644 --- a/source/adapters/level_zero/program.hpp +++ b/source/adapters/level_zero/program.hpp @@ -169,7 +169,7 @@ struct ur_program_handle_t_ : _ur_object { DeviceDataMap[ZeDevice].BuildFlags += Options; } - std::string getBuildOptions(ze_device_handle_t ZeDevice) { + std::string &getBuildOptions(ze_device_handle_t ZeDevice) { return DeviceDataMap[ZeDevice].BuildFlags; } diff --git a/source/common/logger/ur_logger.hpp b/source/common/logger/ur_logger.hpp index c4dc655444..786bd32a00 100644 --- a/source/common/logger/ur_logger.hpp +++ b/source/common/logger/ur_logger.hpp @@ -118,16 +118,15 @@ inline Logger create_logger(std::string logger_name, bool skip_prefix, logger::Level default_log_level) { std::transform(logger_name.begin(), logger_name.end(), logger_name.begin(), ::toupper); - std::stringstream env_var_name; const auto default_flush_level = logger::Level::ERR; const std::string default_output = "stderr"; auto level = default_log_level; auto flush_level = default_flush_level; std::unique_ptr sink; - env_var_name << "UR_LOG_" << logger_name; + auto env_var_name = "UR_LOG_" + logger_name; try { - auto map = getenv_to_map(env_var_name.str().c_str()); + auto map = getenv_to_map(env_var_name.c_str()); if (!map.has_value()) { return Logger( default_log_level, @@ -173,7 +172,7 @@ inline Logger create_logger(std::string logger_name, bool skip_prefix, skip_linebreak); } catch (const std::invalid_argument &e) { std::cerr << "Error when creating a logger instance from the '" - << env_var_name.str() << "' environment variable:\n" + << env_var_name << "' environment variable:\n" << e.what() << std::endl; return Logger(default_log_level, std::make_unique( diff --git a/source/loader/CMakeLists.txt b/source/loader/CMakeLists.txt index d8f6056ae9..931c9dd3ed 100644 --- a/source/loader/CMakeLists.txt +++ b/source/loader/CMakeLists.txt @@ -136,6 +136,7 @@ if(UR_ENABLE_SANITIZER) ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.hpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_ddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_ddi.hpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.hpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_libdevice.hpp @@ -207,7 +208,8 @@ if(UR_ENABLE_SANITIZER) if(NOT EXISTS ${LIBCXX_PATH} OR NOT EXISTS ${LIBCXX_ABI_PATH}) message(FATAL_ERROR "libc++ is required but can't find the libraries") endif() - target_link_libraries(ur_loader PRIVATE ${LIBCXX_PATH} ${LIBCXX_ABI_PATH}) + # Link with gcc_s fisrt to avoid some symbols resolve to libc++/libc++abi/libunwind's one + target_link_libraries(ur_loader PRIVATE gcc_s ${LIBCXX_PATH} ${LIBCXX_ABI_PATH}) endif() endif() diff --git a/source/loader/layers/sanitizer/asan/asan_ddi.cpp b/source/loader/layers/sanitizer/asan/asan_ddi.cpp index f8ded3ec7a..9bddd82782 100644 --- a/source/loader/layers/sanitizer/asan/asan_ddi.cpp +++ b/source/loader/layers/sanitizer/asan/asan_ddi.cpp @@ -28,39 +28,40 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices, const ur_device_handle_t *phDevices) { std::shared_ptr CI; UR_CALL(getAsanInterceptor()->insertContext(Context, CI)); - for (uint32_t i = 0; i < numDevices; ++i) { - auto hDevice = phDevices[i]; - std::shared_ptr DI; - UR_CALL(getAsanInterceptor()->insertDevice(hDevice, DI)); - DI->Type = GetDeviceType(Context, hDevice); - if (DI->Type == DeviceType::UNKNOWN) { - getContext()->logger.error("Unsupport device"); - return UR_RESULT_ERROR_INVALID_DEVICE; - } - getContext()->logger.info( - "DeviceInfo {} (Type={}, IsSupportSharedSystemUSM={})", - (void *)DI->Handle, ToString(DI->Type), - DI->IsSupportSharedSystemUSM); - getContext()->logger.info("Add {} into context {}", (void *)DI->Handle, - (void *)Context); - if (!DI->Shadow) { - UR_CALL(DI->allocShadowMemory(Context)); + + if (numDevices > 0) { + auto DeviceType = GetDeviceType(Context, phDevices[0]); + auto ShadowMemory = getAsanInterceptor()->getOrCreateShadowMemory( + phDevices[0], DeviceType); + + for (uint32_t i = 0; i < numDevices; ++i) { + auto hDevice = phDevices[i]; + std::shared_ptr DI; + UR_CALL(getAsanInterceptor()->insertDevice(hDevice, DI)); + DI->Type = GetDeviceType(Context, hDevice); + if (DI->Type == DeviceType::UNKNOWN) { + getContext()->logger.error("Unsupport device"); + return UR_RESULT_ERROR_INVALID_DEVICE; + } + if (DI->Type != DeviceType) { + getContext()->logger.error( + "Different device type in the same context"); + return UR_RESULT_ERROR_INVALID_DEVICE; + } + getContext()->logger.info( + "DeviceInfo {} (Type={}, IsSupportSharedSystemUSM={})", + (void *)DI->Handle, ToString(DI->Type), + DI->IsSupportSharedSystemUSM); + getContext()->logger.info("Add {} into context {}", + (void *)DI->Handle, (void *)Context); + DI->Shadow = ShadowMemory; + CI->DeviceList.emplace_back(hDevice); + CI->AllocInfosMap[hDevice]; } - CI->DeviceList.emplace_back(hDevice); - CI->AllocInfosMap[hDevice]; } return UR_RESULT_SUCCESS; } -bool isInstrumentedKernel(ur_kernel_handle_t hKernel) { - auto hProgram = GetProgram(hKernel); - auto PI = getAsanInterceptor()->getProgramInfo(hProgram); - if (PI == nullptr) { - return false; - } - return PI->isKernelInstrumented(hKernel); -} - } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -462,6 +463,12 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( phEvent ///< [out][optional] return an event object that identifies this particular ///< kernel execution instance. ) { + // This mutex is to prevent concurrent kernel launches across different queues + // as the DeviceASAN local/private shadow memory does not support concurrent + // kernel launches now. + std::scoped_lock Guard( + getAsanInterceptor()->KernelLaunchMutex); + auto pfnKernelLaunch = getContext()->urDdiTable.Enqueue.pfnKernelLaunch; if (nullptr == pfnKernelLaunch) { @@ -470,15 +477,10 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( getContext()->logger.debug("==== urEnqueueKernelLaunch"); - if (!isInstrumentedKernel(hKernel)) { - return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset, - pGlobalWorkSize, pLocalWorkSize, - numEventsInWaitList, phEventWaitList, phEvent); - } - LaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue), pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset, workDim); + UR_CALL(LaunchInfo.Data.syncToDevice(hQueue)); UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo)); @@ -1349,30 +1351,6 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( return UR_RESULT_SUCCESS; } -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelCreate -__urdlllocal ur_result_t UR_APICALL urKernelCreate( - ur_program_handle_t hProgram, ///< [in] handle of the program instance - const char *pKernelName, ///< [in] pointer to null-terminated string. - ur_kernel_handle_t - *phKernel ///< [out] pointer to handle of kernel object created. -) { - auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate; - - if (nullptr == pfnCreate) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urKernelCreate"); - - UR_CALL(pfnCreate(hProgram, pKernelName, phKernel)); - if (isInstrumentedKernel(*phKernel)) { - UR_CALL(getAsanInterceptor()->insertKernel(*phKernel)); - } - - return UR_RESULT_SUCCESS; -} - /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain __urdlllocal ur_result_t UR_APICALL urKernelRetain( @@ -1388,10 +1366,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( UR_CALL(pfnRetain(hKernel)); - auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel); - if (KernelInfo) { - KernelInfo->RefCount++; - } + auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); + KernelInfo.RefCount++; return UR_RESULT_SUCCESS; } @@ -1408,14 +1384,12 @@ __urdlllocal ur_result_t urKernelRelease( } getContext()->logger.debug("==== urKernelRelease"); - UR_CALL(pfnRelease(hKernel)); - auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel); - if (KernelInfo) { - if (--KernelInfo->RefCount == 0) { - UR_CALL(getAsanInterceptor()->eraseKernel(hKernel)); - } + auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); + if (--KernelInfo.RefCount == 0) { + UR_CALL(getAsanInterceptor()->eraseKernelInfo(hKernel)); } + UR_CALL(pfnRelease(hKernel)); return UR_RESULT_SUCCESS; } @@ -1440,13 +1414,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( getContext()->logger.debug("==== urKernelSetArgValue"); std::shared_ptr MemBuffer; - std::shared_ptr KernelInfo; if (argSize == sizeof(ur_mem_handle_t) && (MemBuffer = getAsanInterceptor()->getMemBuffer( - *ur_cast(pArgValue))) && - (KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) { - std::scoped_lock Guard(KernelInfo->Mutex); - KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); + *ur_cast(pArgValue)))) { + auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); + std::scoped_lock Guard(KernelInfo.Mutex); + KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer); } else { UR_CALL( pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue)); @@ -1473,11 +1446,10 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( getContext()->logger.debug("==== urKernelSetArgMemObj"); std::shared_ptr MemBuffer; - std::shared_ptr KernelInfo; - if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue)) && - (KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel))) { - std::scoped_lock Guard(KernelInfo->Mutex); - KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); + if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) { + auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); + std::scoped_lock Guard(KernelInfo.Mutex); + KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer); } else { UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue)); } @@ -1505,12 +1477,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( "==== urKernelSetArgLocal (argIndex={}, argSize={})", argIndex, argSize); - if (auto KI = getAsanInterceptor()->getKernelInfo(hKernel)) { - std::scoped_lock Guard(KI->Mutex); + { + auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); + std::scoped_lock Guard(KI.Mutex); // TODO: get local variable alignment auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal( argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY); - KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ}; + KI.LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ}; argSize = argSizeWithRZ; } @@ -1542,10 +1515,10 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( pArgValue); std::shared_ptr KI; - if (getAsanInterceptor()->getOptions().DetectKernelArguments && - (KI = getAsanInterceptor()->getKernelInfo(hKernel))) { - std::scoped_lock Guard(KI->Mutex); - KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()}; + if (getAsanInterceptor()->getOptions().DetectKernelArguments) { + auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel); + std::scoped_lock Guard(KI.Mutex); + KI.PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()}; } ur_result_t result = @@ -1729,7 +1702,6 @@ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable( ur_result_t result = UR_RESULT_SUCCESS; - pDdiTable->pfnCreate = ur_sanitizer_layer::asan::urKernelCreate; pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urKernelRetain; pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urKernelRelease; pDdiTable->pfnSetArgValue = ur_sanitizer_layer::asan::urKernelSetArgValue; diff --git a/source/loader/layers/sanitizer/asan/asan_interceptor.cpp b/source/loader/layers/sanitizer/asan/asan_interceptor.cpp index 19af8546c2..4c1fb2033e 100644 --- a/source/loader/layers/sanitizer/asan/asan_interceptor.cpp +++ b/source/loader/layers/sanitizer/asan/asan_interceptor.cpp @@ -36,8 +36,7 @@ AsanInterceptor::~AsanInterceptor() { // We must release these objects before releasing adapters, since // they may use the adapter in their destructor for (const auto &[_, DeviceInfo] : m_DeviceMap) { - [[maybe_unused]] auto URes = DeviceInfo->Shadow->Destory(); - assert(URes == UR_RESULT_SUCCESS); + DeviceInfo->Shadow = nullptr; } m_Quarantine = nullptr; @@ -48,6 +47,11 @@ AsanInterceptor::~AsanInterceptor() { // detection depends on it. m_AllocationMap.clear(); + for (auto &[_, ShadowMemory] : m_ShadowMap) { + ShadowMemory->Destory(); + getContext()->urDdiTable.Context.pfnRelease(ShadowMemory->Context); + } + for (auto Adapter : m_Adapters) { getContext()->urDdiTable.Global.pfnAdapterRelease(Adapter); } @@ -301,14 +305,24 @@ ur_result_t AsanInterceptor::postLaunchKernel(ur_kernel_handle_t Kernel, return Result; } -ur_result_t DeviceInfo::allocShadowMemory(ur_context_handle_t Context) { - Shadow = GetShadowMemory(Context, Handle, Type); - assert(Shadow && "Failed to get shadow memory"); - UR_CALL(Shadow->Setup()); - getContext()->logger.info("ShadowMemory(Global): {} - {}", - (void *)Shadow->ShadowBegin, - (void *)Shadow->ShadowEnd); - return UR_RESULT_SUCCESS; +std::shared_ptr +AsanInterceptor::getOrCreateShadowMemory(ur_device_handle_t Device, + DeviceType Type) { + std::scoped_lock Guard(m_ShadowMapMutex); + if (m_ShadowMap.find(Type) == m_ShadowMap.end()) { + ur_context_handle_t InternalContext; + auto Res = getContext()->urDdiTable.Context.pfnCreate( + 1, &Device, nullptr, &InternalContext); + if (Res != UR_RESULT_SUCCESS) { + getContext()->logger.error("Failed to create shadow context"); + return nullptr; + } + std::shared_ptr CI; + insertContext(InternalContext, CI); + m_ShadowMap[Type] = GetShadowMemory(InternalContext, Device, Type); + m_ShadowMap[Type]->Setup(); + } + return m_ShadowMap[Type]; } /// Each 8 bytes of application memory are mapped into one byte of shadow memory @@ -431,6 +445,12 @@ ur_result_t AsanInterceptor::unregisterProgram(ur_program_handle_t Program) { auto ProgramInfo = getProgramInfo(Program); assert(ProgramInfo != nullptr && "unregistered program!"); + std::scoped_lock Guard(m_AllocationMapMutex); + for (auto AI : ProgramInfo->AllocInfoForGlobals) { + m_AllocationMap.erase(AI->AllocBegin); + } + ProgramInfo->AllocInfoForGlobals.clear(); + ProgramInfo->InstrumentedKernels.clear(); return UR_RESULT_SUCCESS; @@ -549,6 +569,10 @@ AsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) { {}}); ContextInfo->insertAllocInfo({Device}, AI); + ProgramInfo->AllocInfoForGlobals.emplace(AI); + + std::scoped_lock Guard(m_AllocationMapMutex); + m_AllocationMap.emplace(AI->AllocBegin, std::move(AI)); } } @@ -629,16 +653,26 @@ ur_result_t AsanInterceptor::eraseProgram(ur_program_handle_t Program) { return UR_RESULT_SUCCESS; } -ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) { - std::scoped_lock Guard(m_KernelMapMutex); - if (m_KernelMap.find(Kernel) != m_KernelMap.end()) { - return UR_RESULT_SUCCESS; +KernelInfo &AsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) { + { + std::shared_lock Guard(m_KernelMapMutex); + if (m_KernelMap.find(Kernel) != m_KernelMap.end()) { + return *m_KernelMap[Kernel].get(); + } } - m_KernelMap.emplace(Kernel, std::make_shared(Kernel)); - return UR_RESULT_SUCCESS; + + // Create new KernelInfo + auto Program = GetProgram(Kernel); + auto PI = getProgramInfo(Program); + bool IsInstrumented = PI->isKernelInstrumented(Kernel); + + std::scoped_lock Guard(m_KernelMapMutex); + m_KernelMap.emplace(Kernel, + std::make_unique(Kernel, IsInstrumented)); + return *m_KernelMap[Kernel].get(); } -ur_result_t AsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) { +ur_result_t AsanInterceptor::eraseKernelInfo(ur_kernel_handle_t Kernel) { std::scoped_lock Guard(m_KernelMapMutex); assert(m_KernelMap.find(Kernel) != m_KernelMap.end()); m_KernelMap.erase(Kernel); @@ -675,13 +709,24 @@ ur_result_t AsanInterceptor::prepareLaunch( std::shared_ptr &ContextInfo, std::shared_ptr &DeviceInfo, ur_queue_handle_t Queue, ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo) { + auto &KernelInfo = getOrCreateKernelInfo(Kernel); + std::shared_lock Guard(KernelInfo.Mutex); - auto KernelInfo = getKernelInfo(Kernel); - assert(KernelInfo && "Kernel should be instrumented"); + auto ArgNums = GetKernelNumArgs(Kernel); + auto LocalMemoryUsage = + GetKernelLocalMemorySize(Kernel, DeviceInfo->Handle); + auto PrivateMemoryUsage = + GetKernelPrivateMemorySize(Kernel, DeviceInfo->Handle); + + getContext()->logger.info( + "KernelInfo {} (Name={}, ArgNums={}, IsInstrumented={}, " + "LocalMemory={}, PrivateMemory={})", + (void *)Kernel, GetKernelName(Kernel), ArgNums, + KernelInfo.IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage); // Validate pointer arguments if (getOptions().DetectKernelArguments) { - for (const auto &[ArgIndex, PtrPair] : KernelInfo->PointerArgs) { + for (const auto &[ArgIndex, PtrPair] : KernelInfo.PointerArgs) { auto Ptr = PtrPair.first; if (Ptr == nullptr) { continue; @@ -690,13 +735,16 @@ ur_result_t AsanInterceptor::prepareLaunch( ContextInfo->Handle, DeviceInfo->Handle, (uptr)Ptr)) { ReportInvalidKernelArgument(Kernel, ArgIndex, (uptr)Ptr, ValidateResult, PtrPair.second); - exitWithErrors(); + if (ValidateResult.Type != + ValidateUSMResult::MAYBE_HOST_POINTER) { + exitWithErrors(); + } } } } // Set membuffer arguments - for (const auto &[ArgIndex, MemBuffer] : KernelInfo->BufferArgs) { + for (const auto &[ArgIndex, MemBuffer] : KernelInfo.BufferArgs) { char *ArgPointer = nullptr; UR_CALL(MemBuffer->getHandle(DeviceInfo->Handle, ArgPointer)); ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer( @@ -709,11 +757,17 @@ ur_result_t AsanInterceptor::prepareLaunch( } } - auto ArgNums = GetKernelNumArgs(Kernel); + if (!KernelInfo.IsInstrumented) { + return UR_RESULT_SUCCESS; + } + // We must prepare all kernel args before call // urKernelGetSuggestedLocalWorkSize, otherwise the call will fail on // CPU device. - if (ArgNums) { + { + assert(ArgNums >= 1 && + "Sanitized Kernel should have at least one argument"); + ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer( Kernel, ArgNums - 1, nullptr, LaunchInfo.Data.getDevicePtr()); if (URes != UR_RESULT_SUCCESS) { @@ -753,15 +807,6 @@ ur_result_t AsanInterceptor::prepareLaunch( LaunchInfo.Data.Host.DeviceTy = DeviceInfo->Type; LaunchInfo.Data.Host.Debug = getOptions().Debug ? 1 : 0; - auto LocalMemoryUsage = - GetKernelLocalMemorySize(Kernel, DeviceInfo->Handle); - auto PrivateMemoryUsage = - GetKernelPrivateMemorySize(Kernel, DeviceInfo->Handle); - - getContext()->logger.info( - "KernelInfo {} (LocalMemory={}, PrivateMemory={})", (void *)Kernel, - LocalMemoryUsage, PrivateMemoryUsage); - // Write shadow memory offset for local memory if (getOptions().DetectLocals) { if (DeviceInfo->Shadow->AllocLocalShadow( @@ -807,9 +852,9 @@ ur_result_t AsanInterceptor::prepareLaunch( } // Write local arguments info - if (!KernelInfo->LocalArgs.empty()) { + if (!KernelInfo.LocalArgs.empty()) { std::vector LocalArgsInfo; - for (auto [ArgIndex, ArgInfo] : KernelInfo->LocalArgs) { + for (auto [ArgIndex, ArgInfo] : KernelInfo.LocalArgs) { LocalArgsInfo.push_back(ArgInfo); getContext()->logger.debug( "local_args (argIndex={}, size={}, sizeWithRZ={})", ArgIndex, @@ -821,10 +866,12 @@ ur_result_t AsanInterceptor::prepareLaunch( // sync asan runtime data to device side UR_CALL(LaunchInfo.Data.syncToDevice(Queue)); - getContext()->logger.debug("launch_info {} (numLocalArgs={}, localArgs={})", - (void *)LaunchInfo.Data.getDevicePtr(), - LaunchInfo.Data.Host.NumLocalArgs, - (void *)LaunchInfo.Data.Host.LocalArgs); + getContext()->logger.info( + "LaunchInfo {} (device={}, debug={}, numLocalArgs={}, localArgs={})", + (void *)LaunchInfo.Data.getDevicePtr(), + ToString(LaunchInfo.Data.Host.DeviceTy), LaunchInfo.Data.Host.Debug, + LaunchInfo.Data.Host.NumLocalArgs, + (void *)LaunchInfo.Data.Host.LocalArgs); return UR_RESULT_SUCCESS; } @@ -834,13 +881,15 @@ AsanInterceptor::findAllocInfoByAddress(uptr Address) { std::shared_lock Guard(m_AllocationMapMutex); auto It = m_AllocationMap.upper_bound(Address); if (It == m_AllocationMap.begin()) { - return std::optional{}; + return std::nullopt; } --It; - // Make sure we got the right AllocInfo - assert(Address >= It->second->AllocBegin && - Address < It->second->AllocBegin + It->second->AllocSize && - "Wrong AllocInfo for the address"); + + // Maybe it's a host pointer + if (Address < It->second->AllocBegin || + Address >= It->second->AllocBegin + It->second->AllocSize) { + return std::nullopt; + } return It; } diff --git a/source/loader/layers/sanitizer/asan/asan_interceptor.hpp b/source/loader/layers/sanitizer/asan/asan_interceptor.hpp index f1e80dae56..4324150b6f 100644 --- a/source/loader/layers/sanitizer/asan/asan_interceptor.hpp +++ b/source/loader/layers/sanitizer/asan/asan_interceptor.hpp @@ -56,8 +56,6 @@ struct DeviceInfo { // Device handles are special and alive in the whole process lifetime, // so we needn't retain&release here. explicit DeviceInfo(ur_device_handle_t Device) : Handle(Device) {} - - ur_result_t allocShadowMemory(ur_context_handle_t Context); }; struct QueueInfo { @@ -85,6 +83,9 @@ struct KernelInfo { ur_kernel_handle_t Handle; std::atomic RefCount = 1; + // sanitized kernel + bool IsInstrumented = false; + // lock this mutex if following fields are accessed ur_shared_mutex Mutex; std::unordered_map> BufferArgs; @@ -94,7 +95,8 @@ struct KernelInfo { // Need preserve the order of local arguments std::map LocalArgs; - explicit KernelInfo(ur_kernel_handle_t Kernel) : Handle(Kernel) { + explicit KernelInfo(ur_kernel_handle_t Kernel, bool IsInstrumented) + : Handle(Kernel), IsInstrumented(IsInstrumented) { [[maybe_unused]] auto Result = getContext()->urDdiTable.Kernel.pfnRetain(Kernel); assert(Result == UR_RESULT_SUCCESS); @@ -112,6 +114,7 @@ struct ProgramInfo { std::atomic RefCount = 1; // Program is built only once, so we don't need to lock it + std::unordered_set> AllocInfoForGlobals; std::unordered_set InstrumentedKernels; explicit ProgramInfo(ur_program_handle_t Program) : Handle(Program) { @@ -303,9 +306,6 @@ class AsanInterceptor { ur_result_t insertProgram(ur_program_handle_t Program); ur_result_t eraseProgram(ur_program_handle_t Program); - ur_result_t insertKernel(ur_kernel_handle_t Kernel); - ur_result_t eraseKernel(ur_kernel_handle_t Kernel); - ur_result_t insertMemBuffer(std::shared_ptr MemBuffer); ur_result_t eraseMemBuffer(ur_mem_handle_t MemHandle); std::shared_ptr getMemBuffer(ur_mem_handle_t MemHandle); @@ -345,13 +345,8 @@ class AsanInterceptor { return nullptr; } - std::shared_ptr getKernelInfo(ur_kernel_handle_t Kernel) { - std::shared_lock Guard(m_KernelMapMutex); - if (m_KernelMap.find(Kernel) != m_KernelMap.end()) { - return m_KernelMap[Kernel]; - } - return nullptr; - } + KernelInfo &getOrCreateKernelInfo(ur_kernel_handle_t Kernel); + ur_result_t eraseKernelInfo(ur_kernel_handle_t Kernel); const AsanOptions &getOptions() { return m_Options; } @@ -362,6 +357,11 @@ class AsanInterceptor { bool isNormalExit() { return m_NormalExit; } + std::shared_ptr + getOrCreateShadowMemory(ur_device_handle_t Device, DeviceType Type); + + ur_shared_mutex KernelLaunchMutex; + private: ur_result_t updateShadowMemory(std::shared_ptr &ContextInfo, std::shared_ptr &DeviceInfo, @@ -378,9 +378,6 @@ class AsanInterceptor { ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo); - ur_result_t allocShadowMemory(ur_context_handle_t Context, - std::shared_ptr &DeviceInfo); - ur_result_t registerDeviceGlobals(ur_program_handle_t Program); ur_result_t registerSpirKernels(ur_program_handle_t Program); @@ -398,7 +395,7 @@ class AsanInterceptor { m_ProgramMap; ur_shared_mutex m_ProgramMapMutex; - std::unordered_map> + std::unordered_map> m_KernelMap; ur_shared_mutex m_KernelMapMutex; @@ -416,6 +413,9 @@ class AsanInterceptor { ur_shared_mutex m_AdaptersMutex; bool m_NormalExit = true; + + std::unordered_map> m_ShadowMap; + ur_shared_mutex m_ShadowMapMutex; }; } // namespace asan diff --git a/source/loader/layers/sanitizer/asan/asan_libdevice.hpp b/source/loader/layers/sanitizer/asan/asan_libdevice.hpp index a2d5ecd6be..4c6aaaeac8 100644 --- a/source/loader/layers/sanitizer/asan/asan_libdevice.hpp +++ b/source/loader/layers/sanitizer/asan/asan_libdevice.hpp @@ -66,7 +66,7 @@ struct AsanRuntimeData { uint32_t Debug = 0; int ReportFlag = 0; - AsanErrorReport Report[ASAN_MAX_NUM_REPORTS]; + AsanErrorReport Report[ASAN_MAX_NUM_REPORTS] = {}; }; constexpr unsigned ASAN_SHADOW_SCALE = 4; diff --git a/source/loader/layers/sanitizer/asan/asan_shadow.cpp b/source/loader/layers/sanitizer/asan/asan_shadow.cpp index de0679687b..145fd232c1 100644 --- a/source/loader/layers/sanitizer/asan/asan_shadow.cpp +++ b/source/loader/layers/sanitizer/asan/asan_shadow.cpp @@ -104,10 +104,14 @@ ur_result_t ShadowMemoryGPU::Setup() { // shadow memory for each contexts, this will cause out-of-resource error when user uses // multiple contexts. Therefore, we just create one shadow memory here. static ur_result_t Result = [this]() { - size_t ShadowSize = GetShadowSize(); + const size_t ShadowSize = GetShadowSize(); + // To reserve very large amount of GPU virtual memroy, the pStart param should be beyond + // the SVM range, so that GFX driver will automatically switch to reservation on the GPU + // heap. + const void *StartAddress = (void *)(0x100'0000'0000'0000ULL); // TODO: Protect Bad Zone auto Result = getContext()->urDdiTable.VirtualMem.pfnReserve( - Context, nullptr, ShadowSize, (void **)&ShadowBegin); + Context, StartAddress, ShadowSize, (void **)&ShadowBegin); if (Result != UR_RESULT_SUCCESS) { getContext()->logger.error( "Shadow memory reserved failed with size {}: {}", diff --git a/source/loader/layers/sanitizer/msan/msan_buffer.cpp b/source/loader/layers/sanitizer/msan/msan_buffer.cpp index 66ebb10326..8c2080b3ac 100644 --- a/source/loader/layers/sanitizer/msan/msan_buffer.cpp +++ b/source/loader/layers/sanitizer/msan/msan_buffer.cpp @@ -48,22 +48,67 @@ ur_result_t EnqueueMemCopyRectHelper( char *DstOrigin = pDst + DstOffset.x + DstRowPitch * DstOffset.y + DstSlicePitch * DstOffset.z; + const bool IsDstDeviceUSM = getMsanInterceptor() + ->findAllocInfoByAddress((uptr)DstOrigin) + .has_value(); + const bool IsSrcDeviceUSM = getMsanInterceptor() + ->findAllocInfoByAddress((uptr)SrcOrigin) + .has_value(); + + ur_device_handle_t Device = GetDevice(Queue); + std::shared_ptr DeviceInfo = + getMsanInterceptor()->getDeviceInfo(Device); std::vector Events; - Events.reserve(Region.depth); + // For now, USM doesn't support 3D memory copy operation, so we can only // loop call 2D memory copy function to implement it. for (size_t i = 0; i < Region.depth; i++) { ur_event_handle_t NewEvent{}; UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D( - Queue, Blocking, DstOrigin + (i * DstSlicePitch), DstRowPitch, + Queue, false, DstOrigin + (i * DstSlicePitch), DstRowPitch, SrcOrigin + (i * SrcSlicePitch), SrcRowPitch, Region.width, Region.height, NumEventsInWaitList, EventWaitList, &NewEvent)); - Events.push_back(NewEvent); + + // Update shadow memory + if (IsDstDeviceUSM && IsSrcDeviceUSM) { + NewEvent = nullptr; + uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow( + (uptr)DstOrigin + (i * DstSlicePitch)); + uptr SrcShadowAddr = DeviceInfo->Shadow->MemToShadow( + (uptr)SrcOrigin + (i * SrcSlicePitch)); + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D( + Queue, false, (void *)DstShadowAddr, DstRowPitch, + (void *)SrcShadowAddr, SrcRowPitch, Region.width, Region.height, + NumEventsInWaitList, EventWaitList, &NewEvent)); + Events.push_back(NewEvent); + } else if (IsDstDeviceUSM && !IsSrcDeviceUSM) { + uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow( + (uptr)DstOrigin + (i * DstSlicePitch)); + const char Val = 0; + // opencl & l0 adapter doesn't implement urEnqueueUSMFill2D, so + // emulate the operation with urEnqueueUSMFill. + for (size_t HeightIndex = 0; HeightIndex < Region.height; + HeightIndex++) { + NewEvent = nullptr; + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill( + Queue, (void *)(DstShadowAddr + HeightIndex * DstRowPitch), + 1, &Val, Region.width, NumEventsInWaitList, EventWaitList, + &NewEvent)); + Events.push_back(NewEvent); + } + } } - UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( - Queue, Events.size(), Events.data(), Event)); + if (Blocking) { + UR_CALL( + getContext()->urDdiTable.Event.pfnWait(Events.size(), &Events[0])); + } + + if (Event) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + Queue, Events.size(), &Events[0], Event)); + } return UR_RESULT_SUCCESS; } @@ -93,7 +138,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { USMDesc.align = getAlignment(); ur_usm_pool_handle_t Pool{}; URes = getMsanInterceptor()->allocateMemory( - Context, Device, &USMDesc, Pool, Size, + Context, Device, &USMDesc, Pool, Size, AllocType::DEVICE_USM, ur_cast(&Allocation)); if (URes != UR_RESULT_SUCCESS) { getContext()->logger.error( @@ -112,6 +157,12 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { Size, HostPtr, this); return URes; } + + // Update shadow memory + std::shared_ptr DeviceInfo = + getMsanInterceptor()->getDeviceInfo(Device); + UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow( + Queue, (uptr)Allocation, Size, 0)); } } @@ -130,8 +181,8 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { ur_usm_desc_t USMDesc{}; USMDesc.align = getAlignment(); ur_usm_pool_handle_t Pool{}; - URes = getMsanInterceptor()->allocateMemory( - Context, nullptr, &USMDesc, Pool, Size, + URes = getContext()->urDdiTable.USM.pfnHostAlloc( + Context, &USMDesc, Pool, Size, ur_cast(&HostAllocation)); if (URes != UR_RESULT_SUCCESS) { getContext()->logger.error("Failed to allocate {} bytes host " diff --git a/source/loader/layers/sanitizer/msan/msan_ddi.cpp b/source/loader/layers/sanitizer/msan/msan_ddi.cpp index 87438a1f99..cc7cfb1ee4 100644 --- a/source/loader/layers/sanitizer/msan/msan_ddi.cpp +++ b/source/loader/layers/sanitizer/msan/msan_ddi.cpp @@ -45,17 +45,10 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices, UR_CALL(DI->allocShadowMemory(Context)); } CI->DeviceList.emplace_back(hDevice); - CI->AllocInfosMap[hDevice]; } return UR_RESULT_SUCCESS; } -bool isInstrumentedKernel(ur_kernel_handle_t hKernel) { - auto hProgram = GetProgram(hKernel); - auto PI = getMsanInterceptor()->getProgramInfo(hProgram); - return PI->isKernelInstrumented(hKernel); -} - } // namespace /////////////////////////////////////////////////////////////////////////////// @@ -106,8 +99,56 @@ ur_result_t urUSMDeviceAlloc( ) { getContext()->logger.debug("==== urUSMDeviceAlloc"); - return getMsanInterceptor()->allocateMemory(hContext, hDevice, pUSMDesc, - pool, size, ppMem); + return getMsanInterceptor()->allocateMemory( + hContext, hDevice, pUSMDesc, pool, size, AllocType::DEVICE_USM, ppMem); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMHostAlloc +ur_result_t UR_APICALL urUSMHostAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + const ur_usm_desc_t + *pUSMDesc, ///< [in][optional] USM memory allocation descriptor + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM host memory object +) { + getContext()->logger.debug("==== urUSMHostAlloc"); + + return getMsanInterceptor()->allocateMemory( + hContext, nullptr, pUSMDesc, pool, size, AllocType::HOST_USM, ppMem); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMSharedAlloc +ur_result_t UR_APICALL urUSMSharedAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_usm_desc_t * + pUSMDesc, ///< [in][optional] Pointer to USM memory allocation descriptor. + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM shared memory object +) { + getContext()->logger.debug("==== urUSMSharedAlloc"); + + return getMsanInterceptor()->allocateMemory( + hContext, hDevice, pUSMDesc, pool, size, AllocType::SHARED_USM, ppMem); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMFree +ur_result_t UR_APICALL urUSMFree( + ur_context_handle_t hContext, ///< [in] handle of the context object + void *pMem ///< [in] pointer to USM memory object +) { + getContext()->logger.debug("==== urUSMFree"); + + return getMsanInterceptor()->releaseMemory(hContext, pMem); } /////////////////////////////////////////////////////////////////////////////// @@ -354,12 +395,6 @@ ur_result_t urEnqueueKernelLaunch( getContext()->logger.debug("==== urEnqueueKernelLaunch"); - if (!isInstrumentedKernel(hKernel)) { - return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset, - pGlobalWorkSize, pLocalWorkSize, - numEventsInWaitList, phEventWaitList, phEvent); - } - USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue), pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset, workDim); @@ -517,6 +552,12 @@ ur_result_t urMemBufferCreate( UR_CALL(pMemBuffer->getHandle(hDevice, Handle)); UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( InternalQueue, true, Handle, Host, size, 0, nullptr, nullptr)); + + // Update shadow memory + std::shared_ptr DeviceInfo = + getMsanInterceptor()->getDeviceInfo(hDevice); + UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow( + InternalQueue, (uptr)Handle, size, 0)); } } @@ -732,10 +773,29 @@ ur_result_t urEnqueueMemBufferWrite( if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) { ur_device_handle_t Device = GetDevice(hQueue); char *pDst = nullptr; + std::vector Events; + ur_event_handle_t Event{}; UR_CALL(MemBuffer->getHandle(Device, pDst)); UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( hQueue, blockingWrite, pDst + offset, pSrc, size, - numEventsInWaitList, phEventWaitList, phEvent)); + numEventsInWaitList, phEventWaitList, &Event)); + Events.push_back(Event); + + // Update shadow memory + std::shared_ptr DeviceInfo = + getMsanInterceptor()->getDeviceInfo(Device); + const char Val = 0; + uptr ShadowAddr = DeviceInfo->Shadow->MemToShadow((uptr)pDst + offset); + Event = nullptr; + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill( + hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList, + phEventWaitList, &Event)); + Events.push_back(Event); + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, Events.size(), Events.data(), phEvent)); + } } else { UR_CALL(pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size, pSrc, numEventsInWaitList, phEventWaitList, @@ -895,15 +955,36 @@ ur_result_t urEnqueueMemBufferCopy( if (SrcBuffer && DstBuffer) { ur_device_handle_t Device = GetDevice(hQueue); + std::shared_ptr DeviceInfo = + getMsanInterceptor()->getDeviceInfo(Device); char *SrcHandle = nullptr; UR_CALL(SrcBuffer->getHandle(Device, SrcHandle)); char *DstHandle = nullptr; UR_CALL(DstBuffer->getHandle(Device, DstHandle)); + std::vector Events; + ur_event_handle_t Event{}; UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( hQueue, false, DstHandle + dstOffset, SrcHandle + srcOffset, size, - numEventsInWaitList, phEventWaitList, phEvent)); + numEventsInWaitList, phEventWaitList, &Event)); + Events.push_back(Event); + + // Update shadow memory + uptr DstShadowAddr = + DeviceInfo->Shadow->MemToShadow((uptr)DstHandle + dstOffset); + uptr SrcShadowAddr = + DeviceInfo->Shadow->MemToShadow((uptr)SrcHandle + srcOffset); + Event = nullptr; + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + hQueue, false, (void *)DstShadowAddr, (void *)SrcShadowAddr, size, + numEventsInWaitList, phEventWaitList, &Event)); + Events.push_back(Event); + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, Events.size(), Events.data(), phEvent)); + } } else { UR_CALL(pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset, dstOffset, size, numEventsInWaitList, @@ -1002,11 +1083,31 @@ ur_result_t urEnqueueMemBufferFill( if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) { char *Handle = nullptr; + std::vector Events; + ur_event_handle_t Event{}; ur_device_handle_t Device = GetDevice(hQueue); UR_CALL(MemBuffer->getHandle(Device, Handle)); UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill( hQueue, Handle + offset, patternSize, pPattern, size, - numEventsInWaitList, phEventWaitList, phEvent)); + numEventsInWaitList, phEventWaitList, &Event)); + Events.push_back(Event); + + // Update shadow memory + std::shared_ptr DeviceInfo = + getMsanInterceptor()->getDeviceInfo(Device); + const char Val = 0; + uptr ShadowAddr = + DeviceInfo->Shadow->MemToShadow((uptr)Handle + offset); + Event = nullptr; + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill( + hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList, + phEventWaitList, &Event)); + Events.push_back(Event); + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, Events.size(), Events.data(), phEvent)); + } } else { UR_CALL(pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, offset, size, numEventsInWaitList, phEventWaitList, @@ -1155,26 +1256,6 @@ ur_result_t urEnqueueMemUnmap( return UR_RESULT_SUCCESS; } -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelCreate -ur_result_t urKernelCreate( - ur_program_handle_t hProgram, ///< [in] handle of the program instance - const char *pKernelName, ///< [in] pointer to null-terminated string. - ur_kernel_handle_t - *phKernel ///< [out] pointer to handle of kernel object created. -) { - auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate; - - getContext()->logger.debug("==== urKernelCreate"); - - UR_CALL(pfnCreate(hProgram, pKernelName, phKernel)); - if (isInstrumentedKernel(*phKernel)) { - UR_CALL(getMsanInterceptor()->insertKernel(*phKernel)); - } - - return UR_RESULT_SUCCESS; -} - /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urKernelRetain ur_result_t urKernelRetain( @@ -1186,10 +1267,8 @@ ur_result_t urKernelRetain( UR_CALL(pfnRetain(hKernel)); - auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel); - if (KernelInfo) { - KernelInfo->RefCount++; - } + auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel); + KernelInfo.RefCount++; return UR_RESULT_SUCCESS; } @@ -1202,14 +1281,12 @@ ur_result_t urKernelRelease( auto pfnRelease = getContext()->urDdiTable.Kernel.pfnRelease; getContext()->logger.debug("==== urKernelRelease"); - UR_CALL(pfnRelease(hKernel)); - auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel); - if (KernelInfo) { - if (--KernelInfo->RefCount == 0) { - UR_CALL(getMsanInterceptor()->eraseKernel(hKernel)); - } + auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel); + if (--KernelInfo.RefCount == 0) { + UR_CALL(getMsanInterceptor()->eraseKernelInfo(hKernel)); } + UR_CALL(pfnRelease(hKernel)); return UR_RESULT_SUCCESS; } @@ -1230,13 +1307,12 @@ ur_result_t urKernelSetArgValue( getContext()->logger.debug("==== urKernelSetArgValue"); std::shared_ptr MemBuffer; - std::shared_ptr KernelInfo; if (argSize == sizeof(ur_mem_handle_t) && (MemBuffer = getMsanInterceptor()->getMemBuffer( - *ur_cast(pArgValue))) && - (KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) { - std::scoped_lock Guard(KernelInfo->Mutex); - KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); + *ur_cast(pArgValue)))) { + auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel); + std::scoped_lock Guard(KernelInfo.Mutex); + KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer); } else { UR_CALL( pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue)); @@ -1260,10 +1336,10 @@ ur_result_t urKernelSetArgMemObj( std::shared_ptr MemBuffer; std::shared_ptr KernelInfo; - if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue)) && - (KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) { - std::scoped_lock Guard(KernelInfo->Mutex); - KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); + if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue))) { + auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel); + std::scoped_lock Guard(KernelInfo.Mutex); + KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer); } else { UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue)); } @@ -1271,6 +1347,266 @@ ur_result_t urKernelSetArgMemObj( return UR_RESULT_SUCCESS; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMFill +ur_result_t UR_APICALL urEnqueueUSMFill( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + void *pMem, ///< [in][bounds(0, size)] pointer to USM memory object + size_t + patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less + ///< than or equal to width. + const void + *pPattern, ///< [in] pointer with the bytes of the pattern to set. + size_t + size, ///< [in] size in bytes to be set. Must be a multiple of patternSize. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. If phEventWaitList and phEvent are not NULL, phEvent + ///< must not refer to an element of the phEventWaitList array. +) { + auto pfnUSMFill = getContext()->urDdiTable.Enqueue.pfnUSMFill; + getContext()->logger.debug("==== urEnqueueUSMFill"); + + std::vector Events; + ur_event_handle_t Event{}; + UR_CALL(pfnUSMFill(hQueue, pMem, patternSize, pPattern, size, + numEventsInWaitList, phEventWaitList, &Event)); + Events.push_back(Event); + + const auto Mem = (uptr)pMem; + auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem); + if (MemInfoItOp) { + auto MemInfo = (*MemInfoItOp)->second; + + const auto &DeviceInfo = + getMsanInterceptor()->getDeviceInfo(MemInfo->Device); + const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem); + + Event = nullptr; + UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)MemShadow, 0, size, 0, + nullptr, &Event)); + Events.push_back(Event); + } + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, Events.size(), Events.data(), phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMMemcpy +ur_result_t UR_APICALL urEnqueueUSMMemcpy( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + bool blocking, ///< [in] blocking or non-blocking copy + void * + pDst, ///< [in][bounds(0, size)] pointer to the destination USM memory object + const void * + pSrc, ///< [in][bounds(0, size)] pointer to the source USM memory object + size_t size, ///< [in] size in bytes to be copied + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. If phEventWaitList and phEvent are not NULL, phEvent + ///< must not refer to an element of the phEventWaitList array. +) { + auto pfnUSMMemcpy = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy; + getContext()->logger.debug("==== pfnUSMMemcpy"); + + std::vector Events; + ur_event_handle_t Event{}; + UR_CALL(pfnUSMMemcpy(hQueue, blocking, pDst, pSrc, size, + numEventsInWaitList, phEventWaitList, &Event)); + Events.push_back(Event); + + const auto Src = (uptr)pSrc, Dst = (uptr)pDst; + auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src); + auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst); + + if (SrcInfoItOp && DstInfoItOp) { + auto SrcInfo = (*SrcInfoItOp)->second; + auto DstInfo = (*DstInfoItOp)->second; + + const auto &DeviceInfo = + getMsanInterceptor()->getDeviceInfo(SrcInfo->Device); + const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src); + const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst); + + Event = nullptr; + UR_CALL(pfnUSMMemcpy(hQueue, blocking, (void *)DstShadow, + (void *)SrcShadow, size, 0, nullptr, &Event)); + Events.push_back(Event); + } else if (DstInfoItOp) { + auto DstInfo = (*DstInfoItOp)->second; + + const auto &DeviceInfo = + getMsanInterceptor()->getDeviceInfo(DstInfo->Device); + auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst); + + Event = nullptr; + UR_CALL(EnqueueUSMBlockingSet(hQueue, (void *)DstShadow, 0, size, 0, + nullptr, &Event)); + Events.push_back(Event); + } + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, Events.size(), Events.data(), phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMFill2D +ur_result_t UR_APICALL urEnqueueUSMFill2D( + ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to. + void * + pMem, ///< [in][bounds(0, pitch * height)] pointer to memory to be filled. + size_t + pitch, ///< [in] the total width of the destination memory including padding. + size_t + patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less + ///< than or equal to width. + const void + *pPattern, ///< [in] pointer with the bytes of the pattern to set. + size_t + width, ///< [in] the width in bytes of each row to fill. Must be a multiple of + ///< patternSize. + size_t height, ///< [in] the height of the columns to fill. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. If phEventWaitList and phEvent are not + ///< NULL, phEvent must not refer to an element of the phEventWaitList array. +) { + auto pfnUSMFill2D = getContext()->urDdiTable.Enqueue.pfnUSMFill2D; + getContext()->logger.debug("==== urEnqueueUSMFill2D"); + + std::vector Events; + ur_event_handle_t Event{}; + UR_CALL(pfnUSMFill2D(hQueue, pMem, pitch, patternSize, pPattern, width, + height, numEventsInWaitList, phEventWaitList, &Event)); + Events.push_back(Event); + + const auto Mem = (uptr)pMem; + auto MemInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Mem); + if (MemInfoItOp) { + auto MemInfo = (*MemInfoItOp)->second; + + const auto &DeviceInfo = + getMsanInterceptor()->getDeviceInfo(MemInfo->Device); + const auto MemShadow = DeviceInfo->Shadow->MemToShadow(Mem); + + const char Pattern = 0; + Event = nullptr; + UR_CALL(pfnUSMFill2D(hQueue, (void *)MemShadow, pitch, 1, &Pattern, + width, height, 0, nullptr, &Event)); + Events.push_back(Event); + } + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, Events.size(), Events.data(), phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueUSMMemcpy2D +ur_result_t UR_APICALL urEnqueueUSMMemcpy2D( + ur_queue_handle_t hQueue, ///< [in] handle of the queue to submit to. + bool blocking, ///< [in] indicates if this operation should block the host. + void * + pDst, ///< [in][bounds(0, dstPitch * height)] pointer to memory where data will + ///< be copied. + size_t + dstPitch, ///< [in] the total width of the source memory including padding. + const void * + pSrc, ///< [in][bounds(0, srcPitch * height)] pointer to memory to be copied. + size_t + srcPitch, ///< [in] the total width of the source memory including padding. + size_t width, ///< [in] the width in bytes of each row to be copied. + size_t height, ///< [in] the height of columns to be copied. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. If phEventWaitList and phEvent are not + ///< NULL, phEvent must not refer to an element of the phEventWaitList array. +) { + auto pfnUSMMemcpy2D = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D; + getContext()->logger.debug("==== pfnUSMMemcpy2D"); + + std::vector Events; + ur_event_handle_t Event{}; + UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, pDst, dstPitch, pSrc, srcPitch, + width, height, numEventsInWaitList, phEventWaitList, + &Event)); + Events.push_back(Event); + + const auto Src = (uptr)pSrc, Dst = (uptr)pDst; + auto SrcInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Src); + auto DstInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Dst); + + if (SrcInfoItOp && DstInfoItOp) { + auto SrcInfo = (*SrcInfoItOp)->second; + auto DstInfo = (*DstInfoItOp)->second; + + const auto &DeviceInfo = + getMsanInterceptor()->getDeviceInfo(SrcInfo->Device); + const auto SrcShadow = DeviceInfo->Shadow->MemToShadow(Src); + const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst); + + Event = nullptr; + UR_CALL(pfnUSMMemcpy2D(hQueue, blocking, (void *)DstShadow, dstPitch, + (void *)SrcShadow, srcPitch, width, height, 0, + nullptr, &Event)); + Events.push_back(Event); + } else if (DstInfoItOp) { + auto DstInfo = (*DstInfoItOp)->second; + + const auto &DeviceInfo = + getMsanInterceptor()->getDeviceInfo(DstInfo->Device); + const auto DstShadow = DeviceInfo->Shadow->MemToShadow(Dst); + + const char Pattern = 0; + Event = nullptr; + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D( + hQueue, (void *)DstShadow, dstPitch, 1, &Pattern, width, height, 0, + nullptr, &Event)); + Events.push_back(Event); + } + + if (phEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + hQueue, Events.size(), Events.data(), phEvent)); + } + + return UR_RESULT_SUCCESS; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Global table /// with current process' addresses @@ -1348,7 +1684,6 @@ ur_result_t urGetKernelProcAddrTable( ) { ur_result_t result = UR_RESULT_SUCCESS; - pDdiTable->pfnCreate = ur_sanitizer_layer::msan::urKernelCreate; pDdiTable->pfnRetain = ur_sanitizer_layer::msan::urKernelRetain; pDdiTable->pfnRelease = ur_sanitizer_layer::msan::urKernelRelease; pDdiTable->pfnSetArgValue = ur_sanitizer_layer::msan::urKernelSetArgValue; @@ -1429,6 +1764,10 @@ ur_result_t urGetEnqueueProcAddrTable( pDdiTable->pfnMemUnmap = ur_sanitizer_layer::msan::urEnqueueMemUnmap; pDdiTable->pfnKernelLaunch = ur_sanitizer_layer::msan::urEnqueueKernelLaunch; + pDdiTable->pfnUSMFill = ur_sanitizer_layer::msan::urEnqueueUSMFill; + pDdiTable->pfnUSMMemcpy = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy; + pDdiTable->pfnUSMFill2D = ur_sanitizer_layer::msan::urEnqueueUSMFill2D; + pDdiTable->pfnUSMMemcpy2D = ur_sanitizer_layer::msan::urEnqueueUSMMemcpy2D; return result; } @@ -1446,6 +1785,9 @@ ur_result_t urGetUSMProcAddrTable( ur_result_t result = UR_RESULT_SUCCESS; pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::msan::urUSMDeviceAlloc; + pDdiTable->pfnHostAlloc = ur_sanitizer_layer::msan::urUSMHostAlloc; + pDdiTable->pfnSharedAlloc = ur_sanitizer_layer::msan::urUSMSharedAlloc; + pDdiTable->pfnFree = ur_sanitizer_layer::msan::urUSMFree; return result; } diff --git a/source/loader/layers/sanitizer/msan/msan_interceptor.cpp b/source/loader/layers/sanitizer/msan/msan_interceptor.cpp index 30a2e07359..21a19e11f3 100644 --- a/source/loader/layers/sanitizer/msan/msan_interceptor.cpp +++ b/source/loader/layers/sanitizer/msan/msan_interceptor.cpp @@ -46,7 +46,8 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context, ur_device_handle_t Device, const ur_usm_desc_t *Properties, ur_usm_pool_handle_t Pool, - size_t Size, void **ResultPtr) { + size_t Size, AllocType Type, + void **ResultPtr) { auto ContextInfo = getContextInfo(Context); std::shared_ptr DeviceInfo = @@ -54,11 +55,27 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context, void *Allocated = nullptr; - UR_CALL(getContext()->urDdiTable.USM.pfnDeviceAlloc( - Context, Device, Properties, Pool, Size, &Allocated)); + if (Type == AllocType::DEVICE_USM) { + UR_CALL(getContext()->urDdiTable.USM.pfnDeviceAlloc( + Context, Device, Properties, Pool, Size, &Allocated)); + } else if (Type == AllocType::HOST_USM) { + UR_CALL(getContext()->urDdiTable.USM.pfnHostAlloc( + Context, Properties, Pool, Size, &Allocated)); + } else if (Type == AllocType::SHARED_USM) { + UR_CALL(getContext()->urDdiTable.USM.pfnSharedAlloc( + Context, Device, Properties, Pool, Size, &Allocated)); + } *ResultPtr = Allocated; + ContextInfo->MaxAllocatedSize = + std::max(ContextInfo->MaxAllocatedSize, Size); + + // For host/shared usm, we only record the alloc size. + if (Type != AllocType::DEVICE_USM) { + return UR_RESULT_SUCCESS; + } + auto AI = std::make_shared(MsanAllocInfo{(uptr)Allocated, Size, @@ -70,18 +87,33 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context, AI->print(); - // For updating shadow memory - ContextInfo->insertAllocInfo({Device}, AI); - // For memory release { std::scoped_lock Guard(m_AllocationMapMutex); - m_AllocationMap.emplace(AI->AllocBegin, std::move(AI)); + m_AllocationMap.emplace(AI->AllocBegin, AI); } + // Update shadow memory + ManagedQueue Queue(Context, Device); + DeviceInfo->Shadow->EnqueuePoisonShadow(Queue, AI->AllocBegin, + AI->AllocSize, 0xff); + return UR_RESULT_SUCCESS; } +ur_result_t MsanInterceptor::releaseMemory(ur_context_handle_t Context, + void *Ptr) { + auto Addr = reinterpret_cast(Ptr); + auto AddrInfoItOp = findAllocInfoByAddress(Addr); + + if (AddrInfoItOp) { + std::scoped_lock Guard(m_AllocationMapMutex); + m_AllocationMap.erase(*AddrInfoItOp); + } + + return getContext()->urDdiTable.USM.pfnFree(Context, Ptr); +} + ur_result_t MsanInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel, ur_queue_handle_t Queue, USMLaunchInfo &LaunchInfo) { @@ -98,8 +130,6 @@ ur_result_t MsanInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel, UR_CALL(prepareLaunch(DeviceInfo, InternalQueue, Kernel, LaunchInfo)); - UR_CALL(updateShadowMemory(ContextInfo, DeviceInfo, InternalQueue)); - return UR_RESULT_SUCCESS; } @@ -124,29 +154,6 @@ ur_result_t MsanInterceptor::postLaunchKernel(ur_kernel_handle_t Kernel, return Result; } -ur_result_t -MsanInterceptor::enqueueAllocInfo(std::shared_ptr &DeviceInfo, - ur_queue_handle_t Queue, - std::shared_ptr &AI) { - return DeviceInfo->Shadow->EnqueuePoisonShadow(Queue, AI->AllocBegin, - AI->AllocSize, 0xff); -} - -ur_result_t -MsanInterceptor::updateShadowMemory(std::shared_ptr &ContextInfo, - std::shared_ptr &DeviceInfo, - ur_queue_handle_t Queue) { - auto &AllocInfos = ContextInfo->AllocInfosMap[DeviceInfo->Handle]; - std::scoped_lock Guard(AllocInfos.Mutex); - - for (auto &AI : AllocInfos.List) { - UR_CALL(enqueueAllocInfo(DeviceInfo, Queue, AI)); - } - AllocInfos.List.clear(); - - return UR_RESULT_SUCCESS; -} - ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) { ur_result_t Result = UR_RESULT_SUCCESS; @@ -156,6 +163,12 @@ ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) { return Result; } + getContext()->logger.info("registerDeviceGlobals"); + Result = registerDeviceGlobals(Program); + if (Result != UR_RESULT_SUCCESS) { + return Result; + } + return Result; } @@ -175,10 +188,7 @@ ur_result_t MsanInterceptor::registerSpirKernels(ur_program_handle_t Program) { Device, Program, kSPIR_MsanSpirKernelMetadata, &MetadataSize, &MetadataPtr); if (Result != UR_RESULT_SUCCESS) { - getContext()->logger.error( - "Can't get the pointer of <{}> under device {}: {}", - kSPIR_MsanSpirKernelMetadata, (void *)Device, Result); - return Result; + continue; } const uint64_t NumOfSpirKernel = MetadataSize / sizeof(SpirKernelInfo); @@ -227,6 +237,56 @@ ur_result_t MsanInterceptor::registerSpirKernels(ur_program_handle_t Program) { return UR_RESULT_SUCCESS; } +ur_result_t +MsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) { + std::vector Devices = GetDevices(Program); + assert(Devices.size() != 0 && "No devices in registerDeviceGlobals"); + auto Context = GetContext(Program); + auto ContextInfo = getContextInfo(Context); + auto ProgramInfo = getProgramInfo(Program); + assert(ProgramInfo != nullptr && "unregistered program!"); + + for (auto Device : Devices) { + ManagedQueue Queue(Context, Device); + + size_t MetadataSize; + void *MetadataPtr; + auto Result = + getContext()->urDdiTable.Program.pfnGetGlobalVariablePointer( + Device, Program, kSPIR_MsanDeviceGlobalMetadata, &MetadataSize, + &MetadataPtr); + if (Result != UR_RESULT_SUCCESS) { + getContext()->logger.info("No device globals"); + continue; + } + + const uint64_t NumOfDeviceGlobal = + MetadataSize / sizeof(DeviceGlobalInfo); + assert((MetadataSize % sizeof(DeviceGlobalInfo) == 0) && + "DeviceGlobal metadata size is not correct"); + std::vector GVInfos(NumOfDeviceGlobal); + Result = getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + Queue, true, &GVInfos[0], MetadataPtr, + sizeof(DeviceGlobalInfo) * NumOfDeviceGlobal, 0, nullptr, nullptr); + if (Result != UR_RESULT_SUCCESS) { + getContext()->logger.error("Device Global[{}] Read Failed: {}", + kSPIR_MsanDeviceGlobalMetadata, Result); + return Result; + } + + auto DeviceInfo = getMsanInterceptor()->getDeviceInfo(Device); + for (size_t i = 0; i < NumOfDeviceGlobal; i++) { + const auto &GVInfo = GVInfos[i]; + UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(Queue, GVInfo.Addr, + GVInfo.Size, 0)); + ContextInfo->MaxAllocatedSize = + std::max(ContextInfo->MaxAllocatedSize, GVInfo.Size); + } + } + + return UR_RESULT_SUCCESS; +} + ur_result_t MsanInterceptor::insertContext(ur_context_handle_t Context, std::shared_ptr &CI) { std::scoped_lock Guard(m_ContextMapMutex); @@ -301,16 +361,26 @@ ur_result_t MsanInterceptor::eraseProgram(ur_program_handle_t Program) { return UR_RESULT_SUCCESS; } -ur_result_t MsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) { - std::scoped_lock Guard(m_KernelMapMutex); - if (m_KernelMap.find(Kernel) != m_KernelMap.end()) { - return UR_RESULT_SUCCESS; +KernelInfo &MsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) { + { + std::shared_lock Guard(m_KernelMapMutex); + if (m_KernelMap.find(Kernel) != m_KernelMap.end()) { + return *m_KernelMap[Kernel].get(); + } } - m_KernelMap.emplace(Kernel, std::make_shared(Kernel)); - return UR_RESULT_SUCCESS; + + // Create new KernelInfo + auto Program = GetProgram(Kernel); + auto PI = getProgramInfo(Program); + bool IsInstrumented = PI->isKernelInstrumented(Kernel); + + std::scoped_lock Guard(m_KernelMapMutex); + m_KernelMap.emplace(Kernel, + std::make_unique(Kernel, IsInstrumented)); + return *m_KernelMap[Kernel].get(); } -ur_result_t MsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) { +ur_result_t MsanInterceptor::eraseKernelInfo(ur_kernel_handle_t Kernel) { std::scoped_lock Guard(m_KernelMapMutex); assert(m_KernelMap.find(Kernel) != m_KernelMap.end()); m_KernelMap.erase(Kernel); @@ -363,10 +433,10 @@ ur_result_t MsanInterceptor::prepareLaunch( }; // Set membuffer arguments - auto KernelInfo = getKernelInfo(Kernel); - assert(KernelInfo && "Kernel must be instrumented"); + auto &KernelInfo = getOrCreateKernelInfo(Kernel); + std::shared_lock Guard(KernelInfo.Mutex); - for (const auto &[ArgIndex, MemBuffer] : KernelInfo->BufferArgs) { + for (const auto &[ArgIndex, MemBuffer] : KernelInfo.BufferArgs) { char *ArgPointer = nullptr; UR_CALL(MemBuffer->getHandle(DeviceInfo->Handle, ArgPointer)); ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer( @@ -379,19 +449,32 @@ ur_result_t MsanInterceptor::prepareLaunch( } } + if (!KernelInfo.IsInstrumented) { + return UR_RESULT_SUCCESS; + } + // Set LaunchInfo + auto ContextInfo = getContextInfo(LaunchInfo.Context); LaunchInfo.Data->GlobalShadowOffset = DeviceInfo->Shadow->ShadowBegin; LaunchInfo.Data->GlobalShadowOffsetEnd = DeviceInfo->Shadow->ShadowEnd; LaunchInfo.Data->DeviceTy = DeviceInfo->Type; LaunchInfo.Data->Debug = getOptions().Debug ? 1 : 0; + UR_CALL(getContext()->urDdiTable.USM.pfnDeviceAlloc( + ContextInfo->Handle, DeviceInfo->Handle, nullptr, nullptr, + ContextInfo->MaxAllocatedSize, &LaunchInfo.Data->CleanShadow)); getContext()->logger.info( "launch_info {} (GlobalShadow={}, Device={}, Debug={})", (void *)LaunchInfo.Data, LaunchInfo.Data->GlobalShadowOffset, ToString(LaunchInfo.Data->DeviceTy), LaunchInfo.Data->Debug); - UR_CALL( - EnqueueWriteGlobal("__MsanLaunchInfo", &LaunchInfo.Data, sizeof(uptr))); + ur_result_t URes = + EnqueueWriteGlobal("__MsanLaunchInfo", &LaunchInfo.Data, sizeof(uptr)); + if (URes != UR_RESULT_SUCCESS) { + getContext()->logger.info("EnqueueWriteGlobal(__MsanLaunchInfo) " + "failed, maybe empty kernel: {}", + URes); + } return UR_RESULT_SUCCESS; } @@ -401,13 +484,16 @@ MsanInterceptor::findAllocInfoByAddress(uptr Address) { std::shared_lock Guard(m_AllocationMapMutex); auto It = m_AllocationMap.upper_bound(Address); if (It == m_AllocationMap.begin()) { - return std::optional{}; + return std::nullopt; } --It; - // Make sure we got the right MsanAllocInfo - assert(Address >= It->second->AllocBegin && - Address < It->second->AllocBegin + It->second->AllocSize && - "Wrong MsanAllocInfo for the address"); + + // Since we haven't intercepted all USM APIs, we can't make sure the found AllocInfo is correct. + if (Address < It->second->AllocBegin || + Address >= It->second->AllocBegin + It->second->AllocSize) { + return std::nullopt; + } + return It; } @@ -458,6 +544,11 @@ ur_result_t USMLaunchInfo::initialize() { USMLaunchInfo::~USMLaunchInfo() { [[maybe_unused]] ur_result_t Result; if (Data) { + if (Data->CleanShadow) { + Result = getContext()->urDdiTable.USM.pfnFree(Context, + Data->CleanShadow); + assert(Result == UR_RESULT_SUCCESS); + } Result = getContext()->urDdiTable.USM.pfnFree(Context, (void *)Data); assert(Result == UR_RESULT_SUCCESS); } diff --git a/source/loader/layers/sanitizer/msan/msan_interceptor.hpp b/source/loader/layers/sanitizer/msan/msan_interceptor.hpp index 80dbf389a4..fea52741f3 100644 --- a/source/loader/layers/sanitizer/msan/msan_interceptor.hpp +++ b/source/loader/layers/sanitizer/msan/msan_interceptor.hpp @@ -76,11 +76,15 @@ struct KernelInfo { ur_kernel_handle_t Handle; std::atomic RefCount = 1; + // sanitized kernel + bool IsInstrumented = false; + // lock this mutex if following fields are accessed ur_shared_mutex Mutex; std::unordered_map> BufferArgs; - explicit KernelInfo(ur_kernel_handle_t Kernel) : Handle(Kernel) { + explicit KernelInfo(ur_kernel_handle_t Kernel, bool IsInstrumented) + : Handle(Kernel), IsInstrumented(IsInstrumented) { [[maybe_unused]] auto Result = getContext()->urDdiTable.Kernel.pfnRetain(Kernel); assert(Result == UR_RESULT_SUCCESS); @@ -117,10 +121,10 @@ struct ProgramInfo { struct ContextInfo { ur_context_handle_t Handle; + size_t MaxAllocatedSize = 1024; std::atomic RefCount = 1; std::vector DeviceList; - std::unordered_map AllocInfosMap; explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) { [[maybe_unused]] auto Result = @@ -129,15 +133,6 @@ struct ContextInfo { } ~ContextInfo(); - - void insertAllocInfo(const std::vector &Devices, - std::shared_ptr &AI) { - for (auto Device : Devices) { - auto &AllocInfos = AllocInfosMap[Device]; - std::scoped_lock Guard(AllocInfos.Mutex); - AllocInfos.List.emplace_back(AI); - } - } }; struct USMLaunchInfo { @@ -165,6 +160,11 @@ struct USMLaunchInfo { ur_result_t initialize(); }; +struct DeviceGlobalInfo { + uptr Size; + uptr Addr; +}; + struct SpirKernelInfo { uptr KernelName; uptr Size; @@ -180,7 +180,8 @@ class MsanInterceptor { ur_device_handle_t Device, const ur_usm_desc_t *Properties, ur_usm_pool_handle_t Pool, size_t Size, - void **ResultPtr); + AllocType Type, void **ResultPtr); + ur_result_t releaseMemory(ur_context_handle_t Context, void *Ptr); ur_result_t registerProgram(ur_program_handle_t Program); ur_result_t unregisterProgram(ur_program_handle_t Program); @@ -203,9 +204,6 @@ class MsanInterceptor { ur_result_t insertProgram(ur_program_handle_t Program); ur_result_t eraseProgram(ur_program_handle_t Program); - ur_result_t insertKernel(ur_kernel_handle_t Kernel); - ur_result_t eraseKernel(ur_kernel_handle_t Kernel); - ur_result_t insertMemBuffer(std::shared_ptr MemBuffer); ur_result_t eraseMemBuffer(ur_mem_handle_t MemHandle); std::shared_ptr getMemBuffer(ur_mem_handle_t MemHandle); @@ -245,13 +243,8 @@ class MsanInterceptor { return m_ProgramMap[Program]; } - std::shared_ptr getKernelInfo(ur_kernel_handle_t Kernel) { - std::shared_lock Guard(m_KernelMapMutex); - if (m_KernelMap.find(Kernel) != m_KernelMap.end()) { - return m_KernelMap[Kernel]; - } - return nullptr; - } + KernelInfo &getOrCreateKernelInfo(ur_kernel_handle_t Kernel); + ur_result_t eraseKernelInfo(ur_kernel_handle_t Kernel); const MsanOptions &getOptions() { return m_Options; } @@ -263,15 +256,6 @@ class MsanInterceptor { bool isNormalExit() { return m_NormalExit; } private: - ur_result_t - updateShadowMemory(std::shared_ptr &ContextInfo, - std::shared_ptr &DeviceInfo, - ur_queue_handle_t Queue); - - ur_result_t enqueueAllocInfo(std::shared_ptr &DeviceInfo, - ur_queue_handle_t Queue, - std::shared_ptr &AI); - /// Initialize Global Variables & Kernel Name at first Launch ur_result_t prepareLaunch(std::shared_ptr &DeviceInfo, ur_queue_handle_t Queue, @@ -283,6 +267,7 @@ class MsanInterceptor { std::shared_ptr &DeviceInfo); ur_result_t registerSpirKernels(ur_program_handle_t Program); + ur_result_t registerDeviceGlobals(ur_program_handle_t Program); private: std::unordered_map> diff --git a/source/loader/layers/sanitizer/msan/msan_libdevice.hpp b/source/loader/layers/sanitizer/msan/msan_libdevice.hpp index cd05cfa38c..0888c9dc75 100644 --- a/source/loader/layers/sanitizer/msan/msan_libdevice.hpp +++ b/source/loader/layers/sanitizer/msan/msan_libdevice.hpp @@ -52,6 +52,8 @@ struct MsanLaunchInfo { uint32_t IsRecover = 0; MsanErrorReport Report; + + void *CleanShadow = nullptr; }; // Based on the observation, only the last 24 bits of the address of the private diff --git a/source/loader/layers/sanitizer/msan/msan_shadow.cpp b/source/loader/layers/sanitizer/msan/msan_shadow.cpp index add9813db6..2573b4caa5 100644 --- a/source/loader/layers/sanitizer/msan/msan_shadow.cpp +++ b/source/loader/layers/sanitizer/msan/msan_shadow.cpp @@ -111,21 +111,25 @@ uptr MsanShadowMemoryCPU::MemToShadow(uptr Ptr) { return Ptr ^ CPU_SHADOW_MASK; } -ur_result_t MsanShadowMemoryCPU::EnqueuePoisonShadow(ur_queue_handle_t, - uptr Ptr, uptr Size, - u8 Value) { - if (Size == 0) { - return UR_RESULT_SUCCESS; +ur_result_t MsanShadowMemoryCPU::EnqueuePoisonShadow( + ur_queue_handle_t Queue, uptr Ptr, uptr Size, u8 Value, uint32_t NumEvents, + const ur_event_handle_t *EventWaitList, ur_event_handle_t *OutEvent) { + + if (Size) { + const uptr ShadowBegin = MemToShadow(Ptr); + const uptr ShadowEnd = MemToShadow(Ptr + Size - 1); + assert(ShadowBegin <= ShadowEnd); + getContext()->logger.debug( + "EnqueuePoisonShadow(addr={}, count={}, value={})", + (void *)ShadowBegin, ShadowEnd - ShadowBegin + 1, + (void *)(size_t)Value); + memset((void *)ShadowBegin, Value, ShadowEnd - ShadowBegin + 1); } - uptr ShadowBegin = MemToShadow(Ptr); - uptr ShadowEnd = MemToShadow(Ptr + Size - 1); - assert(ShadowBegin <= ShadowEnd); - getContext()->logger.debug( - "EnqueuePoisonShadow(addr={}, count={}, value={})", (void *)ShadowBegin, - ShadowEnd - ShadowBegin + 1, (void *)(size_t)Value); - memset((void *)ShadowBegin, Value, ShadowEnd - ShadowBegin + 1); - + if (OutEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + Queue, NumEvents, EventWaitList, OutEvent)); + } return UR_RESULT_SUCCESS; } @@ -134,18 +138,23 @@ ur_result_t MsanShadowMemoryGPU::Setup() { // shadow memory for each contexts, this will cause out-of-resource error when user uses // multiple contexts. Therefore, we just create one shadow memory here. static ur_result_t Result = [this]() { - size_t ShadowSize = GetShadowSize(); + const size_t ShadowSize = GetShadowSize(); + // To reserve very large amount of GPU virtual memroy, the pStart param should be beyond + // the SVM range, so that GFX driver will automatically switch to reservation on the GPU + // heap. + const void *StartAddress = (void *)(0x100'0000'0000'0000ULL); // TODO: Protect Bad Zone auto Result = getContext()->urDdiTable.VirtualMem.pfnReserve( - Context, nullptr, ShadowSize, (void **)&ShadowBegin); - if (Result == UR_RESULT_SUCCESS) { - ShadowEnd = ShadowBegin + ShadowSize; - // Retain the context which reserves shadow memory - getContext()->urDdiTable.Context.pfnRetain(Context); + Context, StartAddress, ShadowSize, (void **)&ShadowBegin); + if (Result != UR_RESULT_SUCCESS) { + getContext()->logger.error( + "Shadow memory reserved failed with size {}: {}", + (void *)ShadowSize, Result); + return Result; } - - // Set shadow memory for null pointer - ManagedQueue Queue(Context, Device); + ShadowEnd = ShadowBegin + ShadowSize; + // Retain the context which reserves shadow memory + getContext()->urDdiTable.Context.pfnRetain(Context); return UR_RESULT_SUCCESS; }(); return Result; @@ -164,88 +173,98 @@ ur_result_t MsanShadowMemoryGPU::Destory() { return Result; } -ur_result_t MsanShadowMemoryGPU::EnqueuePoisonShadow(ur_queue_handle_t Queue, - uptr Ptr, uptr Size, - u8 Value) { - if (Size == 0) { - return UR_RESULT_SUCCESS; - } +ur_result_t MsanShadowMemoryGPU::EnqueueMapShadow( + ur_queue_handle_t Queue, uptr Ptr, uptr Size, + std::vector &EventWaitList, + ur_event_handle_t *OutEvent) { + + const size_t PageSize = GetVirtualMemGranularity(Context, Device); - uptr ShadowBegin = MemToShadow(Ptr); - uptr ShadowEnd = MemToShadow(Ptr + Size - 1); + const uptr ShadowBegin = MemToShadow(Ptr); + const uptr ShadowEnd = MemToShadow(Ptr + Size - 1); assert(ShadowBegin <= ShadowEnd); - { - static const size_t PageSize = - GetVirtualMemGranularity(Context, Device); - - ur_physical_mem_properties_t Desc{ - UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES, nullptr, 0}; - - // Make sure [Ptr, Ptr + Size] is mapped to physical memory - for (auto MappedPtr = RoundDownTo(ShadowBegin, PageSize); - MappedPtr <= ShadowEnd; MappedPtr += PageSize) { - std::scoped_lock Guard(VirtualMemMapsMutex); - if (VirtualMemMaps.find(MappedPtr) == VirtualMemMaps.end()) { - ur_physical_mem_handle_t PhysicalMem{}; - auto URes = getContext()->urDdiTable.PhysicalMem.pfnCreate( - Context, Device, PageSize, &Desc, &PhysicalMem); - if (URes != UR_RESULT_SUCCESS) { - getContext()->logger.error("urPhysicalMemCreate(): {}", - URes); - return URes; - } - - URes = getContext()->urDdiTable.VirtualMem.pfnMap( - Context, (void *)MappedPtr, PageSize, PhysicalMem, 0, - UR_VIRTUAL_MEM_ACCESS_FLAG_READ_WRITE); - if (URes != UR_RESULT_SUCCESS) { - getContext()->logger.error("urVirtualMemMap({}, {}): {}", - (void *)MappedPtr, PageSize, - URes); - return URes; - } - - getContext()->logger.debug("urVirtualMemMap: {} ~ {}", - (void *)MappedPtr, - (void *)(MappedPtr + PageSize - 1)); - - // Initialize to zero - URes = EnqueueUSMBlockingSet(Queue, (void *)MappedPtr, 0, - PageSize); - if (URes != UR_RESULT_SUCCESS) { - getContext()->logger.error("EnqueueUSMBlockingSet(): {}", - URes); - return URes; - } - - VirtualMemMaps[MappedPtr].first = PhysicalMem; + + // Make sure [Ptr, Ptr + Size] is mapped to physical memory + for (auto MappedPtr = RoundDownTo(ShadowBegin, PageSize); + MappedPtr <= ShadowEnd; MappedPtr += PageSize) { + std::scoped_lock Guard(VirtualMemMapsMutex); + if (VirtualMemMaps.find(MappedPtr) == VirtualMemMaps.end()) { + ur_physical_mem_handle_t PhysicalMem{}; + auto URes = getContext()->urDdiTable.PhysicalMem.pfnCreate( + Context, Device, PageSize, nullptr, &PhysicalMem); + if (URes != UR_RESULT_SUCCESS) { + getContext()->logger.error("urPhysicalMemCreate(): {}", URes); + return URes; + } + + URes = getContext()->urDdiTable.VirtualMem.pfnMap( + Context, (void *)MappedPtr, PageSize, PhysicalMem, 0, + UR_VIRTUAL_MEM_ACCESS_FLAG_READ_WRITE); + if (URes != UR_RESULT_SUCCESS) { + getContext()->logger.error("urVirtualMemMap({}, {}): {}", + (void *)MappedPtr, PageSize, URes); + return URes; + } + + getContext()->logger.debug("urVirtualMemMap: {} ~ {}", + (void *)MappedPtr, + (void *)(MappedPtr + PageSize - 1)); + + // Initialize to zero + URes = EnqueueUSMBlockingSet(Queue, (void *)MappedPtr, 0, PageSize, + EventWaitList.size(), + EventWaitList.data(), OutEvent); + if (URes != UR_RESULT_SUCCESS) { + getContext()->logger.error("EnqueueUSMSet(): {}", URes); + return URes; } - // We don't need to record virtual memory map for null pointer, - // since it doesn't have an alloc info. - if (Ptr == 0) { - continue; + EventWaitList.clear(); + if (OutEvent) { + EventWaitList.push_back(*OutEvent); } - auto AllocInfoIt = - getMsanInterceptor()->findAllocInfoByAddress(Ptr); - assert(AllocInfoIt); - VirtualMemMaps[MappedPtr].second.insert((*AllocInfoIt)->second); + VirtualMemMaps[MappedPtr].first = PhysicalMem; + } + + auto AllocInfoItOp = getMsanInterceptor()->findAllocInfoByAddress(Ptr); + if (AllocInfoItOp) { + VirtualMemMaps[MappedPtr].second.insert((*AllocInfoItOp)->second); } } - auto URes = EnqueueUSMBlockingSet(Queue, (void *)ShadowBegin, Value, - ShadowEnd - ShadowBegin + 1); + return UR_RESULT_SUCCESS; +} + +ur_result_t MsanShadowMemoryGPU::EnqueuePoisonShadow( + ur_queue_handle_t Queue, uptr Ptr, uptr Size, u8 Value, uint32_t NumEvents, + const ur_event_handle_t *EventWaitList, ur_event_handle_t *OutEvent) { + if (Size == 0) { + if (OutEvent) { + UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait( + Queue, NumEvents, EventWaitList, OutEvent)); + } + return UR_RESULT_SUCCESS; + } + + std::vector Events(EventWaitList, + EventWaitList + NumEvents); + UR_CALL(EnqueueMapShadow(Queue, Ptr, Size, Events, OutEvent)); + + const uptr ShadowBegin = MemToShadow(Ptr); + const uptr ShadowEnd = MemToShadow(Ptr + Size - 1); + assert(ShadowBegin <= ShadowEnd); + + auto Result = EnqueueUSMBlockingSet(Queue, (void *)ShadowBegin, Value, + ShadowEnd - ShadowBegin + 1, + Events.size(), Events.data(), OutEvent); + getContext()->logger.debug( - "EnqueuePoisonShadow (addr={}, count={}, value={}): {}", + "EnqueuePoisonShadow(addr={}, count={}, value={}): {}", (void *)ShadowBegin, ShadowEnd - ShadowBegin + 1, (void *)(size_t)Value, - URes); - if (URes != UR_RESULT_SUCCESS) { - getContext()->logger.error("EnqueueUSMBlockingSet(): {}", URes); - return URes; - } + Result); - return UR_RESULT_SUCCESS; + return Result; } ur_result_t @@ -278,13 +297,21 @@ MsanShadowMemoryGPU::ReleaseShadow(std::shared_ptr AI) { } uptr MsanShadowMemoryPVC::MemToShadow(uptr Ptr) { - assert(Ptr & 0xFF00000000000000ULL && "Ptr must be device USM"); - return ShadowBegin + (Ptr & 0x3FFF'FFFF'FFFFULL); + assert(Ptr & 0xff00'0000'0000'0000ULL && "Ptr must be device USM"); + if (Ptr < ShadowBegin) { + return Ptr + (ShadowBegin - 0xff00'0000'0000'0000ULL); + } else { + return Ptr - (0xff00'ffff'ffff'ffffULL - ShadowEnd); + } } uptr MsanShadowMemoryDG2::MemToShadow(uptr Ptr) { - assert(Ptr & 0xFFFF000000000000ULL && "Ptr must be device USM"); - return ShadowBegin + (Ptr & 0x3FFF'FFFF'FFFFULL); + assert(Ptr & 0xffff'0000'0000'0000ULL && "Ptr must be device USM"); + if (Ptr < ShadowBegin) { + return Ptr + (ShadowBegin - 0xffff'8000'0000'0000ULL); + } else { + return Ptr - (0xffff'ffff'ffff'ffffULL - ShadowEnd); + } } } // namespace msan diff --git a/source/loader/layers/sanitizer/msan/msan_shadow.hpp b/source/loader/layers/sanitizer/msan/msan_shadow.hpp index de13683cbc..ca5791385c 100644 --- a/source/loader/layers/sanitizer/msan/msan_shadow.hpp +++ b/source/loader/layers/sanitizer/msan/msan_shadow.hpp @@ -32,8 +32,11 @@ struct MsanShadowMemory { virtual uptr MemToShadow(uptr Ptr) = 0; - virtual ur_result_t EnqueuePoisonShadow(ur_queue_handle_t Queue, uptr Ptr, - uptr Size, u8 Value) = 0; + virtual ur_result_t + EnqueuePoisonShadow(ur_queue_handle_t Queue, uptr Ptr, uptr Size, u8 Value, + uint32_t NumEvents = 0, + const ur_event_handle_t *EventWaitList = nullptr, + ur_event_handle_t *OutEvent = nullptr) = 0; virtual ur_result_t ReleaseShadow(std::shared_ptr) { return UR_RESULT_SUCCESS; @@ -74,8 +77,11 @@ struct MsanShadowMemoryCPU final : public MsanShadowMemory { uptr MemToShadow(uptr Ptr) override; - ur_result_t EnqueuePoisonShadow(ur_queue_handle_t Queue, uptr Ptr, - uptr Size, u8 Value) override; + ur_result_t + EnqueuePoisonShadow(ur_queue_handle_t Queue, uptr Ptr, uptr Size, u8 Value, + uint32_t NumEvents = 0, + const ur_event_handle_t *EventWaitList = nullptr, + ur_event_handle_t *OutEvent = nullptr) override; }; struct MsanShadowMemoryGPU : public MsanShadowMemory { @@ -85,19 +91,27 @@ struct MsanShadowMemoryGPU : public MsanShadowMemory { ur_result_t Setup() override; ur_result_t Destory() override; - ur_result_t EnqueuePoisonShadow(ur_queue_handle_t Queue, uptr Ptr, - uptr Size, u8 Value) override final; + + ur_result_t + EnqueuePoisonShadow(ur_queue_handle_t Queue, uptr Ptr, uptr Size, u8 Value, + uint32_t NumEvents = 0, + const ur_event_handle_t *EventWaitList = nullptr, + ur_event_handle_t *OutEvent = nullptr) override final; ur_result_t ReleaseShadow(std::shared_ptr AI) override final; virtual size_t GetShadowSize() = 0; - ur_mutex VirtualMemMapsMutex; + private: + ur_result_t EnqueueMapShadow(ur_queue_handle_t Queue, uptr Ptr, uptr Size, + std::vector &EventWaitList, + ur_event_handle_t *OutEvent); std::unordered_map< uptr, std::pair>>> VirtualMemMaps; + ur_mutex VirtualMemMapsMutex; }; /// Shadow Memory layout of GPU PVC device diff --git a/source/loader/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp b/source/loader/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp index df64a72ed7..27fe223a5d 100644 --- a/source/loader/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp +++ b/source/loader/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp @@ -40,6 +40,9 @@ uptr MmapNoReserve(uptr Addr, uptr Size) { Addr = RoundDownTo(Addr, EXEC_PAGESIZE); void *P = mmap((void *)Addr, Size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_NORESERVE | MAP_ANONYMOUS, -1, 0); + if (P == MAP_FAILED) { + return 0; + } return (uptr)P; } diff --git a/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp index 900eae405b..a267f24433 100644 --- a/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp +++ b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp @@ -154,7 +154,10 @@ DeviceType GetDeviceType(ur_context_handle_t Context, // FIXME: There's no API querying the address bits of device, so we guess it by the // value of device USM pointer (see "USM Allocation Range" in asan_shadow.cpp) auto Type = DeviceType::UNKNOWN; - if (Ptr >> 48 == 0xff00U) { + + // L0 changes their VA layout. + // TODO: update our shadow memory layout/algorithms to accordingly. + if (Ptr >> 52 == 0xff0U) { Type = DeviceType::GPU_PVC; } else { Type = DeviceType::GPU_DG2; @@ -247,9 +250,6 @@ ur_result_t EnqueueUSMBlockingSet(ur_queue_handle_t Queue, void *Ptr, char Value, size_t Size, uint32_t NumEvents, const ur_event_handle_t *EventWaitList, ur_event_handle_t *OutEvent) { - if (Size == 0) { - return UR_RESULT_SUCCESS; - } return getContext()->urDdiTable.Enqueue.pfnUSMFill( Queue, Ptr, 1, &Value, Size, NumEvents, EventWaitList, OutEvent); } diff --git a/test/conformance/program/program_adapter_level_zero_v2.match b/test/conformance/program/program_adapter_level_zero_v2.match index 97d6869b81..fd359b3653 100644 --- a/test/conformance/program/program_adapter_level_zero_v2.match +++ b/test/conformance/program/program_adapter_level_zero_v2.match @@ -1,3 +1,4 @@ urProgramSetSpecializationConstantsTest.InvalidValueSize/* urProgramSetSpecializationConstantsTest.InvalidValueId/* urProgramSetSpecializationConstantsTest.InvalidValuePtr/* +{{OPT}}urMultiDeviceCommandBufferExpTest.* diff --git a/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp b/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp index 9ff11d9016..5f99747462 100644 --- a/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp +++ b/test/conformance/program/urMultiDeviceProgramCreateWithBinary.cpp @@ -240,3 +240,141 @@ TEST_F(urMultiDeviceProgramCreateWithBinaryTest, CheckProgramGetInfo) { reinterpret_cast(property_value.data()); ASSERT_STRNE(returned_kernel_names, ""); } + +struct urMultiDeviceCommandBufferExpTest + : urMultiDeviceProgramCreateWithBinaryTest { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE( + urMultiDeviceProgramCreateWithBinaryTest::SetUp()); + + auto kernelName = + uur::KernelsEnvironment::instance->GetEntryPointNames("foo")[0]; + + ASSERT_SUCCESS(urProgramBuild(context, binary_program, nullptr)); + ASSERT_SUCCESS( + urKernelCreate(binary_program, kernelName.data(), &kernel)); + } + + void TearDown() override { + if (kernel) { + EXPECT_SUCCESS(urKernelRelease(kernel)); + } + UUR_RETURN_ON_FATAL_FAILURE( + urMultiDeviceProgramCreateWithBinaryTest::TearDown()); + } + + static bool hasCommandBufferSupport(ur_device_handle_t device) { + ur_bool_t cmd_buffer_support = false; + auto res = urDeviceGetInfo( + device, UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP, + sizeof(cmd_buffer_support), &cmd_buffer_support, nullptr); + + if (res) { + return false; + } + + return cmd_buffer_support; + } + + static bool hasCommandBufferUpdateSupport(ur_device_handle_t device) { + ur_device_command_buffer_update_capability_flags_t + update_capability_flags; + auto res = urDeviceGetInfo( + device, UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_CAPABILITIES_EXP, + sizeof(update_capability_flags), &update_capability_flags, nullptr); + + if (res) { + return false; + } + + return (0 != update_capability_flags); + } + + ur_kernel_handle_t kernel = nullptr; + + static constexpr size_t global_offset = 0; + static constexpr size_t n_dimensions = 1; + static constexpr size_t global_size = 64; + static constexpr size_t local_size = 4; +}; + +TEST_F(urMultiDeviceCommandBufferExpTest, Enqueue) { + for (size_t i = 0; i < devices.size(); i++) { + auto device = devices[i]; + if (!hasCommandBufferSupport(device)) { + continue; + } + + // Create command-buffer + uur::raii::CommandBuffer cmd_buf_handle; + ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, nullptr, + cmd_buf_handle.ptr())); + + // Append kernel command to command-buffer and close command-buffer + ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( + cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size, + &local_size, 0, nullptr, 0, nullptr, 0, nullptr, nullptr, nullptr, + nullptr)); + ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle)); + + // Verify execution succeeds + ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queues[i])); + } +} + +TEST_F(urMultiDeviceCommandBufferExpTest, Update) { + for (size_t i = 0; i < devices.size(); i++) { + auto device = devices[i]; + if (!(hasCommandBufferSupport(device) && + hasCommandBufferUpdateSupport(device))) { + continue; + } + + // Create a command-buffer with update enabled. + ur_exp_command_buffer_desc_t desc{ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC, nullptr, true, false, + false}; + + // Create command-buffer + uur::raii::CommandBuffer cmd_buf_handle; + ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, &desc, + cmd_buf_handle.ptr())); + + // Append kernel command to command-buffer and close command-buffer + uur::raii::CommandBufferCommand command; + ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( + cmd_buf_handle, kernel, n_dimensions, &global_offset, &global_size, + &local_size, 0, nullptr, 0, nullptr, 0, nullptr, nullptr, nullptr, + command.ptr())); + ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle)); + + // Verify execution succeeds + ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queues[i])); + + // Update kernel and enqueue command-buffer again + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + kernel, // hNewKernel + 0, // numNewMemObjArgs + 0, // numNewPointerArgs + 0, // numNewValueArgs + n_dimensions, // newWorkDim + nullptr, // pNewMemObjArgList + nullptr, // pNewPointerArgList + nullptr, // pNewValueArgList + nullptr, // pNewGlobalWorkOffset + nullptr, // pNewGlobalWorkSize + nullptr, // pNewLocalWorkSize + }; + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(cmd_buf_handle, queues[i], 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queues[i])); + } +}