8000 [MLIR][NVVM] Extend TMA Bulk Copy Op · llvm/llvm-project@a7144d6 · GitHub
[go: up one dir, main page]

Skip to content

Commit a7144d6

Browse files
committed
[MLIR][NVVM] Extend TMA Bulk Copy Op
This patch extends the non-tensor TMA Bulk Copy Op (from shared_cta to global) with an optional byte mask operand. This mask helps in selectively copying a particular byte to the destination. * lit tests are added to verify the lowering to the intrinsics. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent 9f77c26 commit a7144d6

File tree

4 files changed

+83
-28
lines changed

4 files changed

+83
-28
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ enum NVVMMemorySpace {
4949
kSharedClusterMemorySpace = 7,
5050
};
5151

52+
/// A pair type of LLVM's Intrinsic ID and args (which are llvm values).
53+
/// This type is returned by the getIntrinsicIDAndArgs() methods.
54+
using IDArgPair =
55+
std::pair<llvm::Intrinsic::ID, llvm::SmallVector<llvm::Value *>>;
56+
5257
/// Return the element type and number of elements associated with a wmma matrix
5358
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
5459
/// WMMA_REGS structure.

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,51 +2599,63 @@ def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
25992599
}
26002600

26012601
def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
2602-
NVVM_Op<"cp.async.bulk.global.shared.cta"> {
2602+
NVVM_Op<"cp.async.bulk.global.shared.cta", [AttrSizedOperandSegments]> {
26032603
let summary = "Async bulk copy from Shared CTA memory to Global memory";
26042604
let description = [{
26052605
Initiates an asynchronous copy operation from Shared CTA memory to
2606-
global memory.
2606+
global memory. The 32-bit operand `size` specifies the amount of
2607+
memory to be copied, in terms of number of bytes. `size` must be a
2608+
multiple of 16. The `l2CacheHint` operand is optional, and it is used
2609+
to specify cache eviction policy that may be used during the memory
2610+
access. The `byteMask` operand is optional. The i-th bit in the 16-bit
2611+
wide `byteMask` specifies whether the i-th byte of each 16-byte wide
2612+
chunk of source data is copied to the destination. If the bit is set,
2613+
the byte is copied.
2614+
2615+
Example:
2616+
```mlir
2617+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size
2618+
: !llvm.ptr<1>, !llvm.ptr<3>
2619+
2620+
// with l2_cache_hint
2621+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch
2622+
: !llvm.ptr<1>, !llvm.ptr<3>
2623+
2624+
// with byte_mask
2625+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask
2626+
: !llvm.ptr<1>, !llvm.ptr<3>
2627+
2628+
// with both l2_cache_hint and byte_mask
2629+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask
2630+
: !llvm.ptr<1>, !llvm.ptr<3>
2631+
```
26072632

2608-
The `l2CacheHint` operand is optional, and it is used to specify cache
2609-
eviction policy that may be used during the memory access.
2610-
26112633
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
26122634
}];
26132635

26142636
let arguments = (ins
26152637
LLVM_PointerGlobal:$dstMem,
26162638
LLVM_PointerShared:$srcMem,
26172639
I32:$size,
2618-
Optional<I64>:$l2CacheHint);
2640+
Optional<I64>:$l2CacheHint,
2641+
Optional<I16>:$byteMask);
26192642

26202643
let assemblyFormat = [{
26212644
$dstMem `,` $srcMem `,` $size
26222645
(`l2_cache_hint` `=` $l2CacheHint^ )?
2623-
attr-dict `:` type($dstMem) `,` type($srcMem)
2646+
(`byte_mask` `=` $byteMask^ )?
2647+
attr-dict `:` type($dstMem) `,` type($srcMem)
26242648
}];
26252649

2650+
let extraClassDeclaration = [{
2651+
static mlir::NVVM::IDArgPair
2652+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2653+
llvm::IRBuilderBase& builder);
2654+
}];
26262655
string llvmBuilder = [{
2627-
// Arguments to the intrinsic:
2628-
// dst, src, size, cache_hint,
2629-
// Flag for cache_hint
2630-
//
2631-
llvm::SmallVector<llvm::Value *> translatedOperands;
2632-
translatedOperands.push_back($dstMem);
2633-
translatedOperands.push_back($srcMem);
2634-
translatedOperands.push_back($size);
2635-
2636-
// Cachehint, if available
2637-
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2638-
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
2639-
bool isCacheHint = op.getL2CacheHint() ? true : false;
2640-
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
2641-
2642-
// Flag argument for cachehint
2643-
translatedOperands.push_back(builder.getInt1(isCacheHint));
2644-
2645-
createIntrinsicCall(builder,
2646-
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
2656+
auto [id, args] = NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
2657+
*op, moduleTranslation, builder);
2658+
createIntrinsicCall(builder, id, args);
26472659
}];
26482660
}
26492661

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,34 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
12531253
return id;
12541254
}
12551255

1256+
mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1257+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1258+
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
1259+
llvm::SmallVector<llvm::Value *> args;
1260+
llvm::Intrinsic::ID id =
1261+
llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
1262+
1263+
// Fill the Intrinsic Args
1264+
args.push_back(mt.lookupValue(thisOp.getDstMem()));
1265+
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
1266+
args.push_back(mt.lookupValue(thisOp.getSize()));
1267+
1268+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1269+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1270+
llvm::Value *i64Unused =
1271+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1272+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
1273+
args.push_back(builder.getInt1(hasCacheHint));
1274+
1275+
// Choose the bytemask variant
1276+
if (mlir::Value byteMask = thisOp.getByteMask()) {
1277+
args.push_back(mt.lookupValue(byteMask));
1278+
id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
1279+
}
1280+
1281+
return {id, std::move(args)};
1282+
}
1283+
12561284
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
12571285
bool isIm2Col) {
12581286
switch (tensorDims) {

mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,19 @@ llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr
2626
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
2727
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
2828
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
29-
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
29+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
3030
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>
3131

3232
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
3333
llvm.return
3434
}
35+
36+
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask
37+
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64, %mask : i16) {
38+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false, i16 %[[MASK:.*]])
39+
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true, i16 %[[MASK]])
40+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
41+
42+
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
43+
llvm.return
44+
}

0 commit comments

Comments
 (0)
0