10000 Add KernelThunk deserialization · IBMZ-Linux-OSS-Python/tensorflow@0473289 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0473289

Browse files
beckerhetensorflower-gardener
authored andcommitted
Add KernelThunk deserialization
This is adding `KernelThunk::FromProto` and also makes some adjustments to the `KernelThunkProto` message. By moving the `Dim3DProto` de-/serialization code to where Dim3D is defined it will be easier to share that code with our thunks (like CustomKernelThunk for example). PiperOrigin-RevId: 766552605
1 parent b62cb1c commit 0473289

File tree

6 files changed

+109
-29
lines changed

6 files changed

+109
-29
lines changed

third_party/xla/xla/backends/gpu/runtime/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ cc_library(
788788
hdrs = ["kernel_thunk.h"],
789789
deps = [
790790
":thunk",
791+
"//xla:shape_util",
791792
"//xla:types",
792793
"//xla/codegen/emitters:kernel_arguments",
793794
"//xla/hlo/ir:hlo",
@@ -1860,9 +1861,12 @@ tf_proto_library(
18601861
"thunk.proto",
18611862
],
18621863
protodeps = [
1864+
# keep sorted
1865+
"//xla:xla_data_proto",
18631866
"//xla/service:buffer_assignment_proto",
1867+
"//xla/service/gpu:launch_dimensions_proto",
1868+
"//xla/stream_executor:launch_dim_proto",
18641869
"//xla/stream_executor/gpu:gpu_blas_lt_proto",
1865-
"//xla:xla_data_proto",
18661870
],
18671871
)
18681872

@@ -1874,6 +1878,7 @@ cc_library(
18741878
":conditional_thunk",
18751879
":copy_thunk",
18761880
":gemm_thunk",
1881+
":kernel_thunk",
18771882
":sequential_thunk",
18781883
":thunk",
18791884
":thunk_proto_cc",

third_party/xla/xla/backends/gpu/runtime/kernel_thunk.cc

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include <string>
2323
#include <utility>
2424
#include <variant>
25+
#include <vector>
2526

2627
#include "absl/container/inlined_vector.h"
2728
#include "absl/log/check.h"
@@ -38,6 +39,7 @@ limitations under the License.
3839
#include "xla/service/gpu/kernels/custom_kernel.h"
3940
#include "xla/service/gpu/launch_dimensions.h"
4041
#include "xla/service/gpu/stream_executor_util.h"
42+
#include "xla/shape.h"
4143
#include "xla/stream_executor/device_memory.h"
4244
#include "xla/stream_executor/gpu/tma_metadata.h"
4345
#include "xla/stream_executor/kernel.h"
@@ -53,16 +55,6 @@ namespace gpu {
5355
// KernelThunk
5456
//===----------------------------------------------------------------------===//
5557

56-
namespace {
57-
Dim3DProto Dim3DToProto(const se::Dim3D& dim) {
58-
Dim3DProto proto;
59-
proto.set_x(dim.x);
60-
proto.set_y(dim.y);
61-
proto.set_z(dim.z);
62-
return proto;
63-
}
64-
} // namespace
65-
6658
KernelThunk::KernelThunk(
6759
Thunk::ThunkInfo thunk_info, std::string kernel_name,
6860
absl::Span<const emitters::KernelArgument> kernel_arguments,
@@ -102,17 +94,46 @@ absl::StatusOr<ThunkProto> KernelThunk::ToProto() const {
10294
kernel_proto->add_written(written);
10395
}
10496
kernel_proto->set_kernel_name(kernel_name_);
105-
*kernel_proto->mutable_launch_block_counts() =
106-
Dim3DToProto(launch_dimensions_.block_counts());
107-
*kernel_proto->mutable_launch_thread_counts_per_block() =
108-
Dim3DToProto(launch_dimensions_.thread_counts_per_block());
97+
*kernel_proto->mutable_launch_dimensions() = launch_dimensions_.ToProto();
10998
if (cluster_dim_) {
110-
*kernel_proto->mutable_cluster_dim() = Dim3DToProto(*cluster_dim_);
99+
*kernel_proto->mutable_cluster_dim() = cluster_dim_->ToProto();
111100
}
112101
kernel_proto->set_shmem_bytes(shmem_bytes_);
113102
return proto;
114103
}
115104

105+
absl::StatusOr<std::unique_ptr<KernelThunk>> KernelThunk::FromProto(
106+
ThunkInfo thunk_info, const KernelThunkProto& proto,
107+
absl::Span<const BufferAllocation> buffer_allocations) {
108+
TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
109+
LaunchDimensions::FromProto(proto.launch_dimensions()));
110+
std::optional<stream_executor::ClusterDim> cluster_dim;
111+
if (proto.has_cluster_dim()) {
112+
TF_ASSIGN_OR_RETURN(
113+
cluster_dim.emplace(),
114+
stream_executor::ClusterDim::FromProto(proto.cluster_dim()));
115+
}
116+
117+
if (proto.written().size() != proto.args().size()) {
118+
return absl::InvalidArgumentError(
119+
"Proto fields `written` and `args` need to have the same cardinality.");
120+
}
121+
122+
std::vector<emitters::KernelArgument> arguments;
123+
arguments.reserve(proto.args().size());
124+
for (int i = 0; i < proto.args().size(); ++i) {
125+
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
126+
BufferAllocation::Slice::FromProto(proto.args().at(i),
127+
buffer_allocations));
128+
bool written = proto.written().at(i);
129+
arguments.push_back(emitters::KernelArgument{Shape{}, slice, written});
130+
}
131+
132+
return std::make_unique<KernelThunk>(thunk_info, proto.kernel_name(),
133+
arguments, launch_dimensions,
134+
cluster_dim, proto.shmem_bytes());
135+
}
136+
116137
absl::Status KernelThunk::Initialize(const InitializeParams& params) {
117138
absl::MutexLock lock(&mutex_);
118139

third_party/xla/xla/backends/gpu/runtime/kernel_thunk.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ class KernelThunk : public Thunk {
8484
std::string ToString(int indent) const override;
8585

8686
absl::StatusOr<ThunkProto> ToProto() const override;
87+
static absl::StatusOr<std::unique_ptr<KernelThunk>> FromProto(
88+
ThunkInfo thunk_info, const KernelThunkProto& proto,
89+
absl::Span<const BufferAllocation> buffer_allocations);
8790

8891
absl::Status Initialize(const InitializeParams& params) override;
8992
absl::Status ExecuteOnStream(const ExecuteParams& params) override;

third_party/xla/xla/backends/gpu/runtime/kernel_thunk_test.cc

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ limitations under the License.
1515

1616
#include "xla/backends/gpu/runtime/kernel_thunk.h"
1717

18+
#include <array>
19+
#include <memory>
1820
#include <optional>
1921
#include <string>
2022
#include <vector>
2123

24+
#include <gmock/gmock.h>
2225
#include <gtest/gtest.h>
2326
#include "absl/strings/string_view.h"
2427
#include "xla/backends/gpu/runtime/thunk.h"
@@ -138,13 +141,60 @@ TEST(KernelThunkTest, ToProto) {
138141
written: false
139142
written: true
140143
kernel_name: "kernel123"
141-
launch_block_counts { x: 32 y: 31 z: 30 }
142-
launch_thread_counts_per_block { x: 256 y: 255 z: 254 }
143-
cluster_dim { x: 8 y: 7 z: 6 }
144+
launch_dimensions {
145+
block_counts { coordinates { x: 32 y: 31 z: 30 } }
146+
thread_counts_per_block { coordinates { x: 256 y: 255 z: 254 } }
147+
}
148+
cluster_dim { coordinates { x: 8 y: 7 z: 6 } }
144149
shmem_bytes: 1024
145150
}
146151
)pb"));
147152
}
148153

154+
TEST(KernelThunkTest, ToAndFromProto) {
155+
Thunk::ThunkInfo thunk_info;
156+
thunk_info.profile_annotation = "DotGeneral";
157+
thunk_info.execution_stream_id = 123;
158+
159+
std::array allocations{
160+
BufferAllocation{/*index=*/0, /*size=*/1024, /*color=*/0},
161+
BufferAllocation{/*index=*/0, /*size=*/256, /*color=*/0}};
162+
163+
// Note that slices keep a pointer to the allocation. Therefore `allocations`
164+
// shouldn't be mutated afterwards.
165+
BufferAllocation::Slice slice0(&allocations.at(0), /*offset=*/0,
166+
/*size=*/1024);
167+
BufferAllocation::Slice slice1(&allocations.at(1), /*offset=*/0,
168+
/*size=*/256);
169+
170+
std::vector<emitters::KernelArgument> kernel_arguments = {
171+
emitters::KernelArgument(ShapeUtil::MakeShape(F32, {1024}), slice0,
172+
/*written=*/false),
173+
emitters::KernelArgument(ShapeUtil::MakeShape(F32, {256}), slice1,
174+
/*written=*/true)};
175+
176+
LaunchDimensions launch_dimensions(se::BlockDim(32, 31, 30),
177+
se::ThreadDim(256, 255, 254));
178+
se::ClusterDim cluster_dim(8, 7, 6);
179+
constexpr absl::string_view kKernelName = "kernel123";
180+
constexpr int kSharedMemoryBytes = 1024;
181+
KernelThunk thunk(thunk_info, std::string{kKernelName}, kernel_arguments,
182+
launch_dimensions, cluster_dim, kSharedMemoryBytes,
183+
/*tma_metadata=*/std::nullopt);
184+
TF_ASSERT_OK_AND_ASSIGN(ThunkProto proto, thunk.ToProto());
185+
ASSERT_TRUE(proto.has_kernel_thunk());
186+
TF_ASSERT_OK_AND_ASSIGN(
187+
std::unique_ptr<KernelThunk> reconstructed_thunk,
188+
KernelThunk::FromProto(thunk_info, proto.kernel_thunk(), allocations));
189+
190+
EXPECT_THAT(reconstructed_thunk->cluster_dim(), cluster_dim);
191+
EXPECT_THAT(reconstructed_thunk->kernel_name(), kKernelName);
192+
EXPECT_THAT(reconstructed_thunk->launch_dimensions(), launch_dimensions);
193+
EXPECT_THAT(reconstructed_thunk->shmem_bytes(), kSharedMemoryBytes);
194+
EXPECT_THAT(reconstructed_thunk->written(),
195+
::testing::ElementsAre(false, true));
196+
EXPECT_THAT(reconstructed_thunk->arguments(),
197+
::testing::ElementsAre(slice0, slice1));
198+
}
149199
} // namespace
150200
} // namespace xla::gpu

third_party/xla/xla/backends/gpu/runtime/thunk.proto

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ syntax = "proto3";
1818
package xla.gpu;
1919

2020
import "xla/service/buffer_assignment.proto";
21+
import "xla/service/gpu/launch_dimensions.proto";
2122
import "xla/stream_executor/gpu/gpu_blas_lt.proto";
23+
import "xla/stream_executor/launch_dim.proto";
2224
import "xla/xla_data.proto";
2325

2426
// Contains basic pieces of information that every thunk type has.
@@ -55,20 +57,13 @@ message WhileThunkProto {
5557
optional int64 trip_count = 4;
5658
}
5759

58-
message Dim3DProto {
59-
int64 x = 1;
60-
int64 y = 2;
61-
int64 z = 3;
62-
}
63-
6460
message KernelThunkProto {
6561
repeated xla.buffer_assignment.BufferAllocationSliceProto args = 1;
6662
repeated bool written = 2;
6763
string kernel_name = 3;
68-
Dim3DProto launch_block_counts = 4;
69-
Dim3DProto launch_thread_counts_per_block = 5;
70-
optional Dim3DProto cluster_dim = 6;
71-
int64 shmem_bytes = 7;
64+
LaunchDimensionsProto launch_dimensions = 4;
65+
optional stream_executor.ClusterDimProto cluster_dim = 5;
66+
int64 shmem_bytes = 6;
7267
}
7368

7469
message GemmThunkProto {

third_party/xla/xla/backends/gpu/runtime/thunk_proto_deserialization.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "xla/backends/gpu/runtime/conditional_thunk.h"
2525
#include "xla/backends/gpu/runtime/copy_thunk.h"
2626
#include "xla/backends/gpu/runtime/gemm_thunk.h"
27+
#include "xla/backends/gpu/runtime/kernel_thunk.h"
2728
#include "xla/backends/gpu/runtime/sequential_thunk.h"
2829
#include "xla/backends/gpu/runtime/thunk.h"
2930
#include "xla/backends/gpu/runtime/triangular_solve_thunk.h"
@@ -90,6 +91,11 @@ absl::StatusOr<std::unique_ptr<Thunk>> DeserializeThunkProto(
9091
thunk_proto.triangular_solve_thunk(),
9192
buffer_allocations);
9293
}
94+
95+
if (thunk_proto.has_kernel_thunk()) {
96+
return KernelThunk::FromProto(
97+
std::move(thunk_info), thunk_proto.kernel_thunk(), buffer_allocations);
98+
}
9399
return absl::InvalidArgumentError("Unknown thunk type found in ThunkProto.");
94100
}
95101

0 commit comments

Comments
 (0)
29CC
0