8000 Add flag (only on flatbuffer export tool) to disable buffer sharing i… · linux-on-ibm-z/tensorflow@6e3ec2a · GitHub
[go: up one dir, main page]

Skip to content

Commit 6e3ec2a

Browse files
LukeBoyertensorflower-gardener
authored andcommitted
Add flag (only on flatbuffer export tool) to disable buffer sharing in flatbuffer.
Some downstream tools won't support this. PiperOrigin-RevId: 730648967
1 parent 988ab99 commit 6e3ec2a

File tree

4 files changed

+128
-14
lines changed

4 files changed

+128
-14
lines changed

tensorflow/compiler/mlir/lite/flatbuffer_export.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,8 @@ class Translator {
558558
OpOrArgNameMapper* op_or_arg_name_mapper,
559559
const std::map<std::string, std::string>& metadata,
560560
bool serialize_stablehlo_ops,
561-
std::optional<size_t> custom_option_alignment);
561+
std::optional<size_t> custom_option_alignment,
562+
bool disable_buffer_deduping = false);
562563

563564
private:
564565
enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
@@ -567,7 +568,8 @@ class Translator {
567568
const std::unordered_set<std::string>& saved_model_tags,
568569
OpOrArgNameMapper* op_or_arg_name_mapper,
569570
const std::map<std::string, std::string>& metadata,
570-
std::optional<size_t> custom_option_alignment)
571+
std::optional<size_t> custom_option_alignment,
572+
bool disable_buffer_deduping)
571573
: module_(module),
572574
name_mapper_(*op_or_arg_name_mapper),
573575
builder_(kInitialBufferSize),
@@ -580,7 +582,8 @@ class Translator {
580582
converter_flags.supported_backends().end()),
581583
serialize_debug_metadata_(converter_flags.serialize_debug_metadata()),
582584
use_buffer_offset_(converter_flags.use_buffer_offset()),
583-
custom_option_alignment_(custom_option_alignment) {
585+
custom_option_alignment_(custom_option_alignment),
586+
disable_buffer_deduping_(disable_buffer_deduping) {
584587
// The first buffer must be empty according to the schema definition.
585588
empty_buffer_ = tflite::CreateBuffer(builder_);
586589
buffers_.push_back(empty_buffer_);
@@ -930,6 +933,8 @@ class Translator {
930933

931934
std::optional<size_t> custom_option_alignment_ = std::nullopt;
932935

936+
bool disable_buffer_deduping_ = false;
937+
933938
// Map from mlir constant attribute to the buffer index. This is used to
934939
// deduplicate the buffers in the flatbuffer.
935940
llvm::DenseMap<mlir::ElementsAttr, int> const_attribute_to_buffer_map_;
@@ -967,6 +972,7 @@ std::string Translator::UniqueName(mlir::Value val) {
967972

968973
std::optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
969974
mlir::Value value, bool can_be_deduplicated, int& index) {
975+
can_be_deduplicated = can_be_deduplicated && !disable_buffer_deduping_;
970976
auto inst = value.getDefiningOp();
971977
ElementsAttr attr;
972978
if (auto cst = dyn_cast<mlir::arith::ConstantOp>(inst)) {
@@ -3870,8 +3876,8 @@ std::optional<std::string> Translator::Translate(
38703876
const std::unordered_set<std::string>& tags,
38713877
OpOrArgNameMapper* op_or_arg_name_mapper,
38723878
const std::map<std::string, std::string>& metadata,
3873-
bool serialize_stablehlo_ops,
3874-
std::optional<size_t> custom_option_alignment) {
3879+
bool serialize_stablehlo_ops, std::optional<size_t> custom_option_alignment,
3880+
bool disable_buffer_deduping) {
38753881
OpOrArgLocNameMapper default_op_or_arg_name_mapper;
38763882
if (!op_or_arg_name_mapper)
38773883
op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
@@ -3887,9 +3893,9 @@ std::optional<std::string> Translator::Translate(
38873893
new_converter_flags.set_use_buffer_offset(true);
38883894
}
38893895

3890-
auto translator = std::unique_ptr<Translator>(
3891-
new Translator(module, new_converter_flags, tags, op_or_arg_name_mapper,
3892-
metadata, custom_option_alignment));
38 8000 96+
auto translator = std::unique_ptr<Translator>(new Translator(
3897+
module, new_converter_flags, tags, op_or_arg_name_mapper, metadata,
3898+
custom_option_alignment, disable_buffer_deduping));
38933899
translator->convert_stablehlo_ = serialize_stablehlo_ops;
38943900
auto ret = translator->TranslateInternal();
38953901

@@ -3900,9 +3906,9 @@ std::optional<std::string> Translator::Translate(
39003906
ret = std::nullopt;
39013907
auto new_converter_flags = converter_flags;
39023908
new_converter_flags.set_use_buffer_offset(true);
3903-
translator = std::unique_ptr<Translator>(
3904-
new Translator(module, new_converter_flags, tags, op_or_arg_name_mapper,
3905-
metadata, custom_option_alignment));
3909+
translator = std::unique_ptr<Translator>(new Translator(
3910+
module, new_converter_flags, tags, op_or_arg_name_mapper, metadata,
3911+
custom_option_alignment, disable_buffer_deduping));
39063912
return translator->TranslateInternal();
39073913
}
39083914
return ret;
@@ -4495,7 +4501,7 @@ bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
44954501
auto maybe_translated = Translator::Translate(
44964502
module, options.converter_flags, options.saved_model_tags,
44974503
options.op_or_arg_name_mapper, options.metadata, serialize_stablehlo_ops,
4498-
options.custom_option_alignment);
4504+
options.custom_option_alignment, options.disable_buffer_deduping);
44994505
if (!maybe_translated) return false;
45004506
*serialized_flatbuffer = std::move(*maybe_translated);
45014507
return true;

tensorflow/compiler/mlir/lite/flatbuffer_export.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ struct FlatbufferExportOptions {
4848
// options. If specified, the value should be multiplier of 16 (default
4949
// alignment for TFL flatbuffer).
5050
std::optional<size_t> custom_option_alignment = std::nullopt;
51+
// Whether to disable buffer-deduping which emits tensors with shared
52+
// buffers.
53+
bool disable_buffer_deduping = false;
5154
};
5255

5356
// Translates the given MLIR `module` into a FlatBuffer and stores the

tensorflow/compiler/mlir/lite/flatbuffer_translate.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ static opt<bool, true> serialize_debug_metadata_flag(
160160
llvm::cl::desc("Wether to serialize debug metadata or not"),
161161
llvm::cl::location(serialize_debug_metadata), llvm::cl::init(false));
162162

163+
// NOLINTNEXTLINE
164+
static opt<bool> disable_buffer_deduping_flag(
165+
"disable-buffer-deduping",
166+
llvm::cl::desc("Whether to disable buffer deduping or not"),
167+
llvm::cl::init(false));
168+
163169
namespace mlir {
164170
namespace {
165171
static OwningOpRef<mlir::ModuleOp> FlatBufferFileToMlirTrans(
@@ -212,6 +218,7 @@ static LogicalResult MlirToFlatBufferFileTranslateFunction(
212218
options.op_or_arg_name_mapper = op_or_arg_name_mapper.get();
213219
options.converter_flags.set_serialize_debug_metadata(
214220
serialize_debug_metadata);
221+
options.disable_buffer_deduping = disable_buffer_deduping_flag.getValue();
215222
if (!tflite::MlirToFlatBufferTranslateFunction(
216223
module, options, &serialized_flatbuffer, emit_stablehlo_ops))
217224
return mlir::failure();

tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/deduplicate_const.mlir

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
1+
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --check-prefix=CHECK
2+
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer -disable-buffer-deduping %s -o - | flatbuffer_to_string - | FileCheck %s --check-prefix=NO_DEDUPE
23

34
module {
45
func.func @add(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> attributes {tf.entry_function = {inputs = "serving_default_x", outputs = "outputs"}} {
@@ -90,4 +91,101 @@ func.func @sub(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> attributes {tf.entry_f
9091
// CHECK-NEXT: }, {
9192
// CHECK-NEXT: data: [ 49, 46, 54, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
9293
// CHECK-NEXT: } ],
93-
// CHECK: }
94+
// CHECK: }
95+
96+
// NO_DEDUPE: {
97+
// NO_DEDUPE: version: 3,
98+
// NO_DEDUPE: operator_codes: [ {
99+
// NO_DEDUPE: version: 1
100+
// NO_DEDUPE: }, {
101+
// NO_DEDUPE: deprecated_builtin_code: 41,
102+
// NO_DEDUPE: version: 1,
103+
// NO_DEDUPE: builtin_code: SUB
104+
// NO_DEDUPE: } ],
105+
// NO_DEDUPE: subgraphs: [ {
106+
// NO_DEDUPE: tensors: [ {
107+
// NO_DEDUPE: shape: [ 3, 2 ],
108+
// NO_DEDUPE: buffer: 1,
109+
// NO_DEDUPE: name: "serving_default_x",
110+
// NO_DEDUPE: quantization: {
111+
// NO_DEDUPE: },
112+
// NO_DEDUPE: has_rank: true
113+
// NO_DEDUPE: }, {
114+
// NO_DEDUPE: shape: [ 3, 2 ],
115+
// NO_DEDUPE: buffer: 2,
116+
// NO_DEDUPE: name: "tfl.pseudo_const",
117+
// NO_DEDUPE: quantization: {
118+
// NO_DEDUPE: },
119+
// NO_DEDUPE: has_rank: true
120+
// NO_DEDUPE: }, {
121+
// NO_DEDUPE: shape: [ 3, 2 ],
122+
// NO_DEDUPE: buffer: 3,
123+
// NO_DEDUPE: name: "outputs",
124+
// NO_DEDUPE: quantization: {
125+
// NO_DEDUPE: },
126+
// NO_DEDUPE: has_rank: true
127+
// NO_DEDUPE: } ],
128+
// NO_DEDUPE: inputs: [ 0 ],
129+
// NO_DEDUPE: outputs: [ 2 ],
130+
// NO_DEDUPE: operators: [ {
131+
// NO_DEDUPE: inputs: [ 1, 0 ],
132+
// NO_DEDUPE: outputs: [ 2 ],
133+
// NO_DEDUPE: builtin_options_type: AddOptions,
134+
// NO_DEDUPE: builtin_options: {
135+
// NO_DEDUPE: }
136+
// NO_DEDUPE: } ],
137+
// NO_DEDUPE: name: "add"
138+
// NO_DEDUPE: }, {
139+
// NO_DEDUPE: tensors: [ {
140+
// NO_DEDUPE: shape: [ 3, 2 ],
141+
// NO_DEDUPE: buffer: 4,
142+
// NO_DEDUPE: name: "serving_default_x",
143+
// NO_DEDUPE: quantization: {
144+
// NO_DEDUPE: },
145+
// NO_DEDUPE: has_rank: true
146+
// NO_DEDUPE: }, {
147+
// NO_DEDUPE: shape: [ 3, 2 ],
148+
// NO_DEDUPE: buffer: 5,
149+
// NO_DEDUPE: name: "tfl.pseudo_const1",
150+
// NO_DEDUPE: quantization: {
151+
// NO_DEDUPE: },
152+
// NO_DEDUPE: has_rank: true
153+
// NO_DEDUPE: }, {
154+
// NO_DEDUPE: shape: [ 3, 2 ],
155+
// NO_DEDUPE: buffer: 6,
156+
// NO_DEDUPE: name: "outputs",
157+
// NO_DEDUPE: quantization: {
158+
// NO_DEDUPE: },
159+
// NO_DEDUPE: has_rank: true
160+
// NO_DEDUPE: } ],
161+
// NO_DEDUPE: inputs: [ 0 ],
162+
// NO_DEDUPE: outputs: [ 2 ],
163+
// NO_DEDUPE: operators: [ {
164+
// NO_DEDUPE: opcode_index: 1,
165+
// NO_DEDUPE: inputs: [ 1, 0 ],
166+
// NO_DEDUPE: outputs: [ 2 ],
167+
// NO_DEDUPE: builtin_options_type: SubOptions,
168+
// NO_DEDUPE: builtin_options: {
169+
// NO_DEDUPE: }
170+
// NO_DEDUPE: } ],
171+
// NO_DEDUPE: name: "sub"
172+
// NO_DEDUPE: } ],
173+
// NO_DEDUPE: description: "MLIR Converted.",
174+
// NO_DEDUPE: buffers: [ {
175+
// NO_DEDUPE: }, {
176+
// NO_DEDUPE: }, {
177+
// NO_DEDUPE: data: [ 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64 ]
178+
// NO_DEDUPE: }, {
179+
// NO_DEDUPE: }, {
180+
// NO_DEDUPE: }, {
181+
// NO_DEDUPE: data: [ 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64, 0, 0, 192, 64 ]
182+
// NO_DEDUPE: }, {
183+
// NO_DEDUPE: }, {
184+
// NO_DEDUPE: data: [ 49, 46, 54, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
185+
// NO_DEDUPE: } ],
186+
// NO_DEDUPE: metadata: [ {
187+
// NO_DEDUPE: name: "min_runtime_version",
188+
// NO_DEDUPE: buffer: 7
189+
// NO_DEDUPE: } ],
190+
// NO_DEDUPE: signature_defs: [ ]
191+
// NO_DEDUPE: }

0 commit comments

Comments
 (0)
0