10000 [MLIR][NVVM] Extend TMA Bulk Copy Op · llvm/llvm-project@a227f45 · GitHub
[go: up one dir, main page]

Skip to content

Commit a227f45

Browse files
committed
[MLIR][NVVM] Extend TMA Bulk Copy Op
This patch extends the non-tensor TMA Bulk Copy Op (from shared_cta to global) with an optional byte mask operand. This mask helps in selectively copying a particular byte to the destination. * lit tests are added to verify the lowering to the intrinsics. Signed-off-by: Durgadoss R 8000 <durgadossr@nvidia.com>
1 parent 9f77c26 commit a227f45

File tree

3 files changed

+59
-28
lines changed

3 files changed

+59
-28
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,51 +2599,47 @@ def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
25992599
}
26002600

26012601
def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
2602-
NVVM_Op<"cp.async.bulk.global.shared.cta"> {
2602+
NVVM_Op<"cp.async.bulk.global.shared.cta", [AttrSizedOperandSegments]> {
26032603
let summary = "Async bulk copy from Shared CTA memory to Global memory";
26042604
let description = [{
26052605
Initiates an asynchronous copy operation from Shared CTA memory to
2606-
global memory.
2606+
global memory. The 32-bit operand `size` specifies the amount of
2607+
memory to be copied, in terms of number of bytes. `size` must be a
2608+
multiple of 16. The `l2CacheHint` operand is optional, and it is used
2609+
to specify cache eviction policy that may be used during the memory
2610+
access. The i-th bit in the 16-bit wide `byteMask` operand specifies
2611+
whether the i-th byte of each 16-byte wide chunk of source data is
2612+
copied to the destination. If the bit is set, the byte is copied.
26072613

2608-
The `l2CacheHint` operand is optional, and it is used to specify cache
2609-
eviction policy that may be used during the memory access.
2610-
26112614
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
26122615
}];
26132616

26142617
let arguments = (ins
26152618
LLVM_PointerGlobal:$dstMem,
26162619
LLVM_PointerShared:$srcMem,
26172620
I32:$size,
2618-
Optional<I64>:$l2CacheHint);
2621+
Optional<I64>:$l2CacheHint,
2622+
Optional<I16>:$byteMask);
26192623

26202624
let assemblyFormat = [{
26212625
$dstMem `,` $srcMem `,` $size
26222626
(`l2_cache_hint` `=` $l2CacheHint^ )?
2623-
attr-dict `:` type($dstMem) `,` type($srcMem)
2627+
(`byte_mask` `=` $byteMask^ )?
2628+
attr-dict `:` type($dstMem) `,` type($srcMem)
26242629
}];
26252630

2631+
let extraClassDeclaration = [{
2632+
static llvm::Intrinsic::ID
2633+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2634+
llvm::SmallVector<llvm::Value *> &args,
2635+
llvm::IRBuilderBase& builder);
2636+
}];
26262637
string llvmBuilder = [{
2627-
// Arguments to the intrinsic:
2628-
// dst, src, size, cache_hint,
2629-
// Flag for cache_hint
2630-
//
2631-
llvm::SmallVector<llvm::Value *> translatedOperands;
2632-
translatedOperands.push_back($dstMem);
2633-
translatedOperands.push_back($srcMem);
2634-
translatedOperands.push_back($size);
2635-
2636-
// Cachehint, if available
2637-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2638-
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2639-
bool isCacheHint = op.getL2CacheHint() ? true : false;
2640-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2641-
2642-
// Flag argument for cachehint
2643-
translatedOperands.push_back(builder.getInt1(isCacheHint));
2644-
2645-
createIntrinsicCall(builder,
2646-
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
2638+
llvm::SmallVector<llvm::Value *> args;
2639+
llvm::Intrinsic::ID id =
2640+
NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2641+
*op, moduleTranslation, args, builder);
2642+
createIntrinsicCall(builder, id, args);
26472643
}];
26482644
}
26492645

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,31 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
12531253
return id;
12541254
}
12551255

1256+
llvm::Intrinsic::ID CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1257+
Operation &op, LLVM::ModuleTranslation &mt,
1258+
llvm::SmallVector<llvm::Value *> &args, llvm::IRBuilderBase &builder) {
1259+
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1260+
1261+
// Fill the Intrinsic Args
1262+
args.push_back(mt.lookupValue(thisOp.getDstMem()));
1263+
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1264+
args.push_back(mt.lookupValue(thisOp.getSize()));
1265+
1266+
auto cacheHint = thisOp.getL2CacheHint();
1267+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1268+
auto *i64Unused =
1269+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1270+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1271+
args.push_back(builder.getInt1(hasCacheHint));
1272+
1273+
if (auto byteMask = thisOp.getByteMask()) {
1274+
args.push_back(mt.lookupValue(byteMask));
1275+
return llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1276+
}
1277+
1278+
return llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1279+
}
1280+
12561281
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
12571282
bool isIm2Col) {
12581283
switch (tensorDims) {

mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,19 @@ llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr
2626
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
2727
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
2828
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
29-
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
29+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
3030
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>
3131

3232
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
3333
llvm.return
3434
}
35+
36+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask
37+
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64, %mask : i16) {
38+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false, i16 %[[MASK:.*]])
39+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true, i16 %[[MASK]])
40+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
41+
42+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
43+
llvm.return
44+
}

0 commit comments

Comments
 (0)
0