32
32
#include < string>
33
33
#include < unistd.h>
34
34
35
+ #define STRING (s ) #s
36
+ #define TO_STRING (x ) STRING(x)
37
+
35
38
using namespace mlir ;
36
39
using namespace mlir ::linalg;
37
40
using llvm::Error;
@@ -53,26 +56,31 @@ struct LinalgCodegenPass : public PassWrapper<LinalgCodegenPass, FunctionPass> {
53
56
registry.insert <linalg::LinalgDialect, AffineDialect, scf::SCFDialect>();
54
57
}
55
58
LinalgCodegenPass () = default ;
56
- LinalgCodegenPass (int M, int N, int K) : M(M), N(N), K(K) {}
59
+ LinalgCodegenPass (int M, int N, int K, const std::string &target_cpu,
60
+ const std::string &vector_width) : M(M), N(N), K(K),
61
+ target_cpu (target_cpu), vector_width(vector_width) {}
57
62
LinalgCodegenPass (const LinalgCodegenPass &pass) {
58
63
M = pass.M ;
59
64
N = pass.N ;
60
65
K = pass.K ;
66
+ target_cpu = pass.target_cpu ;
67
+ vector_width = pass.vector_width ;
61
68
}
62
69
void runOnFunction () override ;
63
70
64
71
int M, N, K;
72
+ std::string target_cpu, vector_width;
65
73
};
66
74
} // namespace
67
75
68
76
void LinalgCodegenPass::runOnFunction () {
69
77
MLIRContext *ctx = getFunction ().getContext ();
70
78
SmallVector<Attribute, 4 > attrs;
71
79
attrs.push_back (ArrayAttr::get ({StringAttr::get (" prefer-vector-width" , ctx),
72
- StringAttr::get (" 512 " , ctx)},
80
+ StringAttr::get (vector_width , ctx)},
73
81
ctx));
74
82
attrs.push_back (ArrayAttr::get ({StringAttr::get (" target-cpu" , ctx),
75
- StringAttr::get (" skylake-avx512 " , ctx)},
83
+ StringAttr::get (target_cpu , ctx)},
76
84
ctx));
77
85
getFunction ()->setAttr (" passthrough" , ArrayAttr::get (attrs, ctx));
78
86
@@ -181,8 +189,9 @@ void LinalgCodegenPass::runOnFunction() {
181
189
// getFunction().dump();
182
190
}
183
191
184
- std::unique_ptr<OperationPass<FuncOp>> createLinalgCodegenPass (int M, int N, int K) {
185
- return std::make_unique<LinalgCodegenPass>(M, N, K);
192
+ std::unique_ptr<OperationPass<FuncOp>> createLinalgCodegenPass (int M, int N, int K,
193
+ const std::string &target_cpu, const std::string &vector_width) {
194
+ return std::make_unique<LinalgCodegenPass>(M, N, K, target_cpu, vector_width);
186
195
}
187
196
188
197
}
@@ -237,7 +246,9 @@ Error compile(Options &options, mlir::DialectRegistry ®istry) {
237
246
int M, N, K;
238
247
get_dimensions (options.inputFile , M, N, K);
239
248
pm.addPass (createCanonicalizerPass ());
240
- pm.addPass (createLinalgCodegenPass (M, N, K));
249
+ std::string target_cpu = TO_STRING (TARGET_CPU);
250
+ std::string vector_width = TO_STRING (VECTOR_WIDTH);
251
+ pm.addPass (createLinalgCodegenPass (M, N, K, target_cpu, vector_width));
241
252
242
253
// Lower to LLVM
243
254
pm.addPass (createConvertVectorToSCFPass ());
0 commit comments