8000 Lower to BatchFunctionOpWithDeviceOp · IBMZ-Linux-OSS-Python/tensorflow@7f32242 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f32242

Browse files
deqiangctensorflower-gardener
authored andcommitted
Lower to BatchFunctionOpWithDeviceOp
PiperOrigin-RevId: 766370877
1 parent bae08f5 commit 7f32242

File tree

4 files changed

+63
-4
lines changed

4 files changed

+63
-4
lines changed

tensorflow/compiler/mlir/tfrt/tests/mlrt/tpu_conversions.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,22 @@ func.func private @NopMapFnBody(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: t
166166
%a = "tf.AddV2"(%arg2, %const) {__op_key = 3: i32}: (tensor<i32>, tensor<i32>) -> tensor<i32>
167167
return
168168
}
169+
170+
171+
// -----
172+
func.func @callee(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>) {
173+
%1 = "tf.TPUCompileMlirAndExecute"(%arg0) {metadata = "metadata", mlir_module = "mlir_module", operandSegmentSizes = array<i32: 1, 0>, producer_name = "producer_name"} : (tensor<i32>) -> tensor<i32>
174+
func.return %1: tensor<i32>
175+
}
176+
177+
// CHECK-LABEL: func @batch_function
178+
func.func @batch_function(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>) {
179+
// CHECK: [[device:%.*]] = tf_mlrt_tpu.get_tpu_host_device
180+
// CHECK: [[batch_result_future:%.*]] = tf_mlrt.batch_function.device([[device]]) (%arg0, %arg1)
181+
// CHECK: [[batch_result:%.*]] = tf_mlrt.await [[batch_result_future]]
182+
// CHECK: return [[batch_result]]
183+
%0 = "tf.BatchFunction"(%arg0, %arg1) {device = "/device:CPU:0", allowed_batch_sizes = [64], batch_timeout_micros = 1 : i64, batching_queue = "", container = "", f = @callee, max_batch_size = 256 : i64, num_batch_threads = 2 : i64, operandSegmentSizes = array<i32: 1, 1>, shared_name = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
184+
func.return %0 : tensor<i32>
185+
}
186+
187+

tensorflow/compiler/mlir/tfrt/transforms/mlrt/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ cc_library(
144144
"//tensorflow/compiler/mlir/tfrt/ir/mlrt:tf_mlrt_tpu_ops",
145145
"@com_google_absl//absl/log:check",
146146
"@llvm-project//llvm:Support",
147+
"@llvm-project//mlir:FuncDialect",
147148
"@llvm-project//mlir:IR",
148149
"@llvm-project//mlir:Support",
149150
"@llvm-project//mlir:TransformUtils",

tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,20 @@ class BatchFunctionOpConversion
842842
llvm::SmallVector<mlir::Type, 4> result_types(
843843
op->getNumResults(), rewriter.getType<mlrt::compiler::FutureType>());
844844

845-
rewriter.replaceOpWithNewOp<tf_mlrt::BatchFunctionOp>(
846-
op, result_types, adaptor.getOperands(), node_def.device(),
847-
op.getFAttr(), node_def_text);
845+
if (auto custom_device =
846+
op->getAttrOfType<mlir::StringAttr>(kTfMlrtCustomDevice)) {
847+
mlir::Value device =
848+
CreateCustomDevice(op->getLoc(), custom_device.getValue(), rewriter);
849+
if (!device) return op->emitWarning("Failed to create custom device.");
850+
851+
rewriter.replaceOpWithNewOp<tf_mlrt::BatchFunctionWithDeviceOp>(
852+
op, result_types, device, adaptor.getOperands(), node_def.device(),
853+
op.getFAttr(), node_def_text);
854+
} else {
855+
rewriter.replaceOpWithNewOp<tf_mlrt::BatchFunctionOp>(
856+
op, result_types, adaptor.getOperands(), node_def.device(),
857+
op.getFAttr(), node_def_text);
858+
}
848859

849860
return mlir::success();
850861
}

tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License.
1414
==============================================================================*/
1515
#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h"
1616

17+
#include "llvm/Support/Casting.h"
18+
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
1719
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
1820
#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h"
1921
#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h"
@@ -90,8 +92,18 @@ class TPUCompileMlirAndExecuteOpPreParallelizationConversion
9092
}
9193
}
9294
}
95+
if (replaced_ops.empty()) {
96+
auto caller_batch_ops = FindCallerBatchFunctionOps(op);
97+
for (auto* batch_op : caller_batch_ops) {
98+
mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
99+
rewriter.setInsertionPoint(batch_op);
100+
mlir::Operation* batch_op_with_device = rewriter.clone(*batch_op);
101+
batch_op_with_device->setAttr(kTfMlrtCustomDevice,
102+
rewriter.getStringAttr(kTpuHostDevice));
103+
rewriter.replaceOp(batch_op, batch_op_with_device->getResults());
104+
}
105+
}
93106
}
94-
95107
auto compile_and_execute_op =
96108
rewriter.create<tf_mlrt::TFTPUCompileAndExecuteOp>(
97109
op.getLoc(), op.getResultTypes(), operands,
@@ -108,6 +120,22 @@ class TPUCompileMlirAndExecuteOpPreParallelizationConversion
108120

109121
private:
110122
bool use_tpu_host_allocator_for_inputs_ = false;
123+
124+
llvm::SmallVector<mlir::Operation*, 4> FindCallerBatchFunctionOps(
125+
mlir::Operation* op) const {
126+
llvm::SmallVector<mlir::Operation*, 4> result;
127+
if (auto func = llvm::dyn_cast<mlir::func::FuncOp>(op->getParentOp())) {
128+
if (auto uses = func.getSymbolUses(func->getParentOp())) {
129+
for (auto& use : uses.value()) {
130+
auto* user = use.getUser();
131+
if (auto batch_op = llvm::dyn_cast<mlir::TF::BatchFunctionOp>(user)) {
132+
result.push_back(batch_op);
133+
}
134+
}
135+
}
136+
}
137+
return result;
138+
}
111139
};
112140

113141
class TPUCompileMlirAndExecuteOpConversion

0 commit comments

Comments
 (0)
0