8000 Introduce `FunctionBody::Finalize()` to populate `AllocatorAttribute`… · linux-on-ibm-z/tensorflow@446fac2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 446fac2

Browse files
eunjaekim-0tensorflower-gardener
authored andcommitted
Introduce FunctionBody::Finalize() to populate AllocatorAttributes for arg/ret nodes and release unnecessary resources
PiperOrigin-RevId: 731143677
1 parent 58269e0 commit 446fac2

File tree

4 files changed

+283
-0
lines changed

4 files changed

+283
-0
lines changed

tensorflow/core/common_runtime/BUILD

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,38 @@ cc_library(
911911
hdrs = ["function_body.h"],
912912
copts = tf_copts(),
913913
deps = [
914+
":arg_ret_placement",
914915
"//tensorflow/core:framework",
915916
"//tensorflow/core:graph",
916917
"//tensorflow/core:lib",
918+
"//tensorflow/core/platform:hash",
917919
"//tensorflow/core/platform:refcount",
920+
"@com_google_absl//absl/log:check",
921+
"@com_google_absl//absl/status",
922+
"@com_google_absl//absl/strings:string_view",
923+
"@local_xla//xla/tsl/platform:status",
924+
],
925+
)
926+
927+
tf_cc_test(
928+
name = "function_body_test",
929+
srcs = ["function_body_test.cc"],
930+
deps = [
931+
"//tensorflow/core:core_cpu_base",
932+
"//tensorflow/core:framework",
933+
"//tensorflow/core:ops",
934+
"//tensorflow/core/framework:full_type_proto_cc",
935+
"//tensorflow/core/framework:function_proto_cc",
936+
"//tensorflow/core/framework:function_testlib",
937+
"//tensorflow/core/framework:node_def_proto_cc",
938+
"//tensorflow/core/framework:types_proto_cc",
939+
"//tensorflow/core/platform:refcount",
940+
"@com_google_absl//absl/status",
941+
"@com_google_absl//absl/strings:string_view",
942+
"@com_google_absl//absl/types:span",
943+
"@co 8000 m_google_googletest//:gtest_main",
944+
"@local_xla//xla/tsl/lib/core:status_test_util",
945+
"@local_xla//xla/tsl/platform:status_matchers",
918946
],
919947
)
920948

tensorflow/core/common_runtime/function_body.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,25 @@ limitations under the License.
1515

1616
#include "tensorflow/core/common_runtime/function_body.h"
1717

18+
#include <algorithm>
19+
#include <iterator>
20+
#include <unordered_set>
1821
#include <utility>
22+
#include <vector>
1923

24+
#include "absl/log/check.h"
25+
#include "absl/status/status.h"
26+
#include "ab 67ED sl/strings/string_view.h"
27+
#include "xla/tsl/platform/errors.h"
28+
#include "xla/tsl/platform/status.h"
29+
#include "tensorflow/core/common_runtime/arg_ret_placement.h"
30+
#include "tensorflow/core/framework/allocator.h"
2031
#include "tensorflow/core/framework/function.h"
2132
#include "tensorflow/core/framework/node_def_util.h"
2233
#include "tensorflow/core/framework/types.h"
2334
#include "tensorflow/core/graph/graph.h"
35+
#include "tensorflow/core/lib/gtl/inlined_vector.h"
36+
#include "tensorflow/core/platform/hash.h"
2437
#include "tensorflow/core/platform/refcount.h"
2538

2639
namespace tensorflow {
@@ -67,4 +80,39 @@ FunctionBody::FunctionBody(core::RefCountPtr<FunctionRecord>&& record,
6780

6881
FunctionBody::~FunctionBody() { delete this->graph; }
6982

83+
absl::Status FunctionBody::Finalize() {
84+
// Get the allocator attributes for the function body args and rets first to
85+
// avoid mutating the struct in case of an error.
86+
std::vector<AllocatorAttributes> args_alloc_attrs;
87+
std::vector<AllocatorAttributes> rets_alloc_attrs;
88+
TF_RETURN_IF_ERROR(full_type::SetAllocAttrsForArgs(
89+
this->arg_nodes, this->arg_types, args_alloc_attrs));
90+
TF_RETURN_IF_ERROR(full_type::SetAllocAttrsForRets(
91+
this->ret_nodes, this->ret_types, rets_alloc_attrs));
92+
// Move them to the struct.
93+
this->args_alloc_attrs.clear();
94+
this->rets_alloc_attrs.clear();
95+
std::move(args_alloc_attrs.begin(), args_alloc_attrs.end(),
96+
std::back_inserter(this->args_alloc_attrs));
97+
std::move(rets_alloc_attrs.begin(), rets_alloc_attrs.end(),
98+
std::back_inserter(this->rets_alloc_attrs));
99+
100+
// Unreference the function record.
101+
this->record.reset();
102+
103+
// Destruct the owned graph.
104+
if (this->graph != nullptr) {
105+
delete this->graph;
106+
this->graph = nullptr;
107+
}
108+
109+
// Clear the vectors holding the pointers to the nodes in the destructed
110+
// graph.
111+
this->arg_nodes.clear();
112+
this->ret_nodes.clear();
113+
this->control_ret_nodes.clear();
114+
115+
return absl::OkStatus();
116+
}
117+
70118
} // end namespace tensorflow

tensorflow/core/common_runtime/function_body.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License.
1616
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_
1717
#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_
1818

19+
#include "absl/status/status.h"
20+
#include "tensorflow/core/framework/allocator.h"
1921
#include "tensorflow/core/framework/function.h"
2022
#include "tensorflow/core/framework/types.h"
2123
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -43,10 +45,22 @@ struct FunctionBody {
4345
absl::InlinedVector<Node*, 4UL> ret_nodes;
4446
absl::InlinedVector<Node*, 4UL> control_ret_nodes;
4547

48+
// Allocator attributes arg/ret nodes of the function body.
49+
absl::InlinedVector<AllocatorAttributes, 4UL> args_alloc_attrs;
50+
absl::InlinedVector<AllocatorAttributes, 4UL> rets_alloc_attrs;
51+
4652
FunctionBody() {}
4753
FunctionBody(core::RefCountPtr<FunctionRecord>&& record,
4854
DataTypeSlice arg_types, DataTypeSlice ret_types, Graph* g);
4955
~FunctionBody();
56+
57+
// Finalizes the function body by unreferencing the function record,
58+
// destructing the graph it own, and resetting the node pointers. It populates
59+
// the alloc attrs for the function body, so that
60+
// FunctionLibraryRuntime::RunRemote can use it to allocate tensors.
61+
//
62+
// Returns an error if the allocator attributes cannot be populated.
63+
absl::Status Finalize();
5064
};
5165

5266
} // end namespace tensorflow
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/common_runtime/function_body.h"
17+
18+
#include <string>
19+
#include <utility>
20+
21+
#include <gmock/gmock.h>
22+
#include <gtest/gtest.h>
23+
#include "absl/status/status.h"
24+
#include "absl/strings/string_view.h"
25+
#include "absl/types/span.h"
26+
#include "xla/tsl/lib/core/status_test_util.h"
27+
#include "xla/tsl/platform/status_matchers.h"
28+
#include "tensorflow/core/framework/allocator.h"
29+
#include "tensorflow/core/framework/full_type.pb.h"
30+
#include "tensorflow/core/framework/function.h"
31+
#include "tensorflow/core/framework/function.pb.h"
32+
#include "tensorflow/core/framework/function_testlib.h"
33+
#include "tensorflow/core/framework/node_def.pb.h"
34+
#include "tensorflow/core/framework/op.h"
35+
#include "tensorflow/core/framework/types.pb.h"
36+
#include "tensorflow/core/graph/graph.h"
37+
#include "tensorflow/core/platform/refcount.h"
38+
39+
namespace tensorflow {
40+
namespace {
41+
42+
using ::testing::IsEmpty;
43+
using ::testing::IsNull;
44+
using ::testing::Not;
45+
using ::testing::Pointee;
46+
using ::testing::Property;
47+
using ::testing::UnorderedElementsAre;
48+
using ::tsl::testing::StatusIs;
49+
50+
NodeDef GetNodeDef(
51+
absl::string_view name, absl::string_view op,
52+
absl::Span<const std::string> inputs,
53+
absl::Span<
54+
const std::pair<std::string, FunctionDefHelper::AttrValueWrapper>>
55+
attrs,
56+
bool set_full_type_def = false) {
57+
NodeDef node_def = test::function::NDef(name, op, inputs, attrs);
58+
if (!set_full_type_def) return node_def;
59+
60+
FullTypeDef& experiment_type = *node_def.mutable_experimental_type();
61+
experiment_type.set_type_id(TFT_PRODUCT);
62+
experiment_type.add_args()->set_type_id(TFT_SHAPE_TENSOR);
63+
return node_def;
64+
}
65+
66+
TEST(FunctionBodyTest, EmptyFunctionBody) {
67+
core::RefCountPtr<FunctionRecord> record(new FunctionRecord(
68+
FunctionDef(), /*stack_traces=*/{}, /*finalized=*/false));
69+
Graph* graph = new Graph(OpRegistry::Global());
70+
FunctionBody fbody(std::move(record), {}, {}, graph);
71+
72+
EXPECT_THAT(fbody.record, Not(IsNull()));
73+
EXPECT_THAT(fbody.graph, Not(IsNull()));
74+
EXPECT_THAT(fbody.arg_types, IsEmpty());
75+
EXPECT_THAT(fbody.ret_types, IsEmpty());
76+
EXPECT_THAT(fbody.arg_nodes, IsEmpty());
77+
EXPECT_THAT(fbody.ret_nodes, IsEmpty());
78+
EXPECT_THAT(fbody.control_ret_nodes, IsEmpty());
79+
EXPECT_THAT(fbody.args_alloc_attrs, IsEmpty());
80+
EXPECT_THAT(fbody.rets_alloc_attrs, IsEmpty());
81+
}
82+
83+
TEST(FunctionBodyTest, SimpleFunctionBody) {
84+
core::RefCountPtr<FunctionRecord> record(
85+
new FunctionRecord(test::function::XTimesTwoWithControlOutput(),
86+
/*stack_traces=*/{}, /*finalized=*/false));
87+
Graph* graph = new Graph(OpRegistry::Global());
88+
TF_ASSERT_OK(graph->AddNode(
89+
GetNodeDef("x", FunctionLibraryDefinition::kArgOp, /*inputs=*/{},
90+
/*attrs=*/{{"T", DT_INT32}, {"index", 0}})));
91+
TF_ASSERT_OK(graph->AddNode(
92+
GetNodeDef("y", FunctionLibraryDefinition::kRetOp, /*inputs=*/{},
93+
/*attrs=*/{{"T", DT_INT32}, {"index", 0}})));
94+
TF_ASSERT_OK(graph->AddNode(GetNodeDef("dummy", "Const", /*inputs=*/{},
95+
/*attrs=*/{{"dtype", DT_INT32}})));
96+
FunctionBody fbody(std::move(record), {DT_INT32}, {DT_INT32}, graph);
97+
98+
EXPECT_THAT(fbody.record, Not(IsNull()));
99+
EXPECT_THAT(fbody.graph, Not(IsNull()));
100+
EXPECT_THAT(fbody.arg_types, UnorderedElementsAre(DT_INT32));
101+
EXPECT_THAT(fbody.ret_types, UnorderedElementsAre(DT_INT32));
102+
EXPECT_THAT(fbody.arg_nodes,
103+
UnorderedElementsAre(Pointee(Property(
104+
&Node::type_string, FunctionLibraryDefinition::kArgOp))));
105+
EXPECT_THAT(fbody.ret_nodes,
106+
UnorderedElementsAre(Pointee(Property(
107+
&Node::type_string, FunctionLibraryDefinition::kRetOp))));
108+
EXPECT_THAT(fbody.control_ret_nodes,
109+
UnorderedElementsAre(Pointee(Property(&Node::name, "dummy"))));
110+
EXPECT_THAT(fbody.args_alloc_attrs, IsEmpty());
111+
EXPECT_THAT(fbody.rets_alloc_attrs, IsEmpty());
112+
}
113+
114+
TEST(FunctionBodyTest, FunctionBodyFinalized) {
115+
core::RefCountPtr<FunctionRecord> record(
116+
new FunctionRecord(test::function::XTimesTwoWithControlOutput(),
117+
/*stack_traces=*/{}, /*finalized=*/false));
118+
Graph* graph = new Graph(OpRegistry::Global());
119+
TF_ASSERT_OK(graph->AddNode(GetNodeDef(
120+
"x", FunctionLibraryDefinition::kArgOp, /*inputs=*/{},
121+
/*attrs=*/{{"T", DT_INT32}, {"index", 0}}, /*set_full_type_def=*/true)));
122+
TF_ASSERT_OK_AND_ASSIGN(
123+
Node * output_const,
124+
graph->AddNode(GetNodeDef("output_const", "Const", /*inputs=*/{},
125+
/*attrs=*/{{"dtype", DT_INT32}},
126+
/*set_full_type_def=*/true)));
127+
TF_ASSERT_OK_AND_ASSIGN(
128+
Node * ret_node,
129+
graph->AddNode(GetNodeDef("y", FunctionLibraryDefinition::kRetOp,
130+
/*inputs=*/{"output_const"},
131+
/*attrs=*/{{"T", DT_INT32}, {"index", 0}})));
132+
graph->AddEdge(output_const, 0, ret_node, 0);
133+
TF_ASSERT_OK(graph->AddNode(GetNodeDef("dummy", "Const", /*inputs=*/{},
134+
/*attrs=*/{{"dtype", DT_INT32}})));
135+
FunctionBody fbody(std::move(record), {DT_INT32}, {DT_INT32}, graph);
136+
137+
// Finalize the function body.
138+
TF_EXPECT_OK(fbody.Finalize());
139+
140+
// Check the function body properties after finalization.
141+
EXPECT_THAT(fbody.record, IsNull());
142+
EXPECT_THAT(fbody.graph, IsNull());
143+
EXPECT_THAT(fbody.arg_types, UnorderedElementsAre(DT_INT32));
144+
EXPECT_THAT(fbody.ret_types, UnorderedElementsAre(DT_INT32));
145+
EXPECT_THAT(fbody.arg_nodes, IsEmpty());
146+
EXPECT_THAT(fbody.ret_nodes, IsEmpty());
147+
EXPECT_THAT(fbody.control_ret_nodes, IsEmpty());
148+
EXPECT_THAT(
149+
fbody.args_alloc_attrs,
150+
UnorderedElementsAre(Property(&AllocatorAttributes::on_host, true)));
151+
EXPECT_THAT(
152+
fbody.rets_alloc_attrs,
153+
UnorderedElementsAre(Property(&AllocatorAttributes::on_host, true)));
154+
}
155+
156+
TEST(FunctionBodyTest, FunctionBodyNotUpdatedWithFinalizationFailure) {
157+
core::RefCountPtr<FunctionRecord> record(
158+
new FunctionRecord(test::function::XTimesTwoWithControlOutput(),
159+
/*stack_traces=*/{}, /*finalized=*/false));
160+
Graph* graph = new Graph(OpRegistry::Global());
161+
TF_ASSERT_OK(graph->AddNode(GetNodeDef(
162+
"x", FunctionLibraryDefinition::kArgOp, /*inputs=*/{},
163+
/*attrs=*/{{"T", DT_INT32}, {"index", 0}}, /*set_full_type_def=*/true)));
164+
TF_ASSERT_OK(
165+
graph->AddNode(GetNodeDef("y", FunctionLibraryDefinition::kRetOp,
166+
/*inputs=*/{"output_const"},
167+
/*attrs=*/{{"T", DT_INT32}, {"index", 0}})));
168+
TF_ASSERT_OK(graph->AddNode(GetNodeDef("dummy", "Const", /*inputs=*/{},
169+
/*attrs=*/{{"dtype", DT_INT32}})));
170+
FunctionBody fbody(std::move(record), {DT_INT32}, {DT_INT32}, graph);
171+
172+
// Finalization fails due to missing input to the ret node.
173+
EXPECT_THAT(fbody.Finalize(), Not(StatusIs(absl::StatusCode::kOk)));
174+
175+
// Check the function body properties after finalization.
176+
EXPECT_THAT(fbody.record, Not(IsNull()));
177+
EXPECT_THAT(fbody.graph, Not(IsNull()));
178+
EXPECT_THAT(fbody.arg_types, UnorderedElementsAre(DT_INT32));
179+
EXPECT_THAT(fbody.ret_types, UnorderedElementsAre(DT_INT32));
180+
EXPECT_THAT(fbody.arg_nodes,
181+
UnorderedElementsAre(Pointee(Property(
182+
&Node::type_string, FunctionLibraryDefinition::kArgOp))));
183+
EXPECT_THAT(fbody.ret_nodes,
184+
UnorderedElementsAre(Pointee(Property(
185+
&Node::type_string, FunctionLibraryDefinition::kRetOp))));
186+
EXPECT_THAT(fbody.control_ret_nodes,
187+
UnorderedElementsAre(Pointee(Property(&Node::name, "dummy"))));
188+
EXPECT_THAT(fbody.args_alloc_attrs, IsEmpty());
189+
EXPECT_THAT(fbody.rets_alloc_attrs, IsEmpty());
190+
}
191+
192+
} // namespace
193+
} // namespace tensorflow

0 commit comments

Comments
 (0)
0