File tree Expand file tree Collapse file tree 4 files changed +14
-10
lines changed
third_party/xla/xla/service/spmd/shardy Expand file tree Collapse file tree 4 files changed +14
-10
lines changed Original file line number Diff line number Diff line change @@ -22,6 +22,7 @@ cc_library(
22
22
"//xla/service/spmd/shardy:constants" ,
23
23
"//xla/service/spmd/shardy:utils" ,
24
24
"@com_google_absl//absl/log:check" ,
25
+ "@com_google_absl//absl/strings:string_view" ,
25
26
"@llvm-project//llvm:Support" ,
26
27
"@llvm-project//mlir:IR" ,
27
28
"@llvm-project//mlir:Pass" ,
Original file line number Diff line number Diff line change @@ -22,6 +22,7 @@ limitations under the License.
22
22
#include < vector>
23
23
24
24
#include " absl/log/check.h"
25
+ #include " absl/strings/string_view.h"
25
26
#include " llvm/ADT/StringRef.h"
26
27
#include " mlir/IR/BuiltinAttributes.h"
27
28
#include " mlir/IR/BuiltinOps.h"
@@ -60,9 +61,10 @@ mlir::LogicalResult rewriteShardingCustomCall(
60
61
61
62
std::vector<int64_t > unspecDims;
62
63
if (std::optional<mlir::Attribute> backendConfig = op.getBackendConfig ()) {
64
+ StringRef configStr =
65
+ mlir::dyn_cast<mlir::StringAttr>(*backendConfig).getValue ();
63
66
CHECK_OK (xla::sharding_op_util::ParseAttributes (
64
- mlir::dyn_cast<mlir::StringAttr>(*backendConfig).getValue ().str (),
65
- &unspecDims));
67
+ absl::string_view (configStr.data (), configStr.size ()), &unspecDims));
66
68
}
67
69
68
70
if (op->getNumResults () != 1 ) {
Original file line number Diff line number Diff line change @@ -77,8 +77,7 @@ class ManualComputationPattern : public OpConversionPattern<CallOp> {
77
77
mlir::LogicalResult matchAndRewrite (
78
78
CallOp callOp, OpAdaptor adaptor,
79
79
mlir::ConversionPatternRewriter& rewriter) const override {
80
- if (!absl::StrContains (callOp.getCallee ().str (),
81
- kManualComputationBodyFuncName .str ())) {
80
+ if (!callOp.getCallee ().contains (kManualComputationBodyFuncName )) {
82
81
return mlir::failure ();
83
82
}
84
83
@@ -159,8 +158,7 @@ class SdyRoundTripShardMapImportPass
159
158
MLIRContext& context = getContext ();
160
159
mlir::ConversionTarget target (context);
161
160
target.addDynamicallyLegalOp <CallOp>([](CallOp op) {
162
- return !absl::StrContains (op.getCallee ().str (),
163
- kManualComputationBodyFuncName .str ());
161
+ return !op.getCallee ().contains (kManualComputationBodyFuncName );
164
162
});
165
163
target.addLegalOp <sdy::ManualComputationOp, sdy::ReturnOp, CustomCallOp>();
166
164
mlir::RewritePatternSet patterns (&context);
Original file line number Diff line number Diff line change @@ -22,6 +22,7 @@ limitations under the License.
22
22
23
23
#include " absl/log/check.h"
24
24
#include " absl/strings/escaping.h"
25
+ #include " absl/strings/string_view.h"
25
26
#include " mlir/AsmParser/AsmParser.h"
26
27
#include " mlir/Dialect/Func/IR/FuncOps.h"
27
28
#include " mlir/IR/Attributes.h"
@@ -80,14 +81,16 @@ template <typename AttrTy>
80
81
AttrTy parseStringAttr (mlir::DictionaryAttr dictAttr,
81
82
llvm::StringRef attrName) {
82
83
if (mlir::Attribute stringAttr = dictAttr.get (attrName)) {
83
- std::string value ;
84
+ std::string unescapedValue ;
84
85
std::string error;
86
+ llvm::StringRef escapedValue =
87
+ mlir::cast<mlir::StringAttr>(stringAttr).getValue ();
85
88
CHECK (absl::CUnescape (
86
- mlir::cast<mlir::StringAttr>(stringAttr). getValue (). str (), &value ,
87
- &error))
89
+ absl::string_view (escapedValue. data (), escapedValue. size ()) ,
90
+ &unescapedValue, & error))
88
91
<< error;
89
92
return mlir::cast<AttrTy>(
90
- mlir::parseAttribute (value , stringAttr.getContext ()));
93
+ mlir::parseAttribute (unescapedValue , stringAttr.getContext ()));
91
94
}
92
95
return nullptr ;
93
96
}
You can’t perform that action at this time.
0 commit comments