8000 Add traceme for BlockHostUntilDone to identify the time consumption. · IBMZ-Linux-OSS-Python/tensorflow@5c31bcf · GitHub
[go: up one dir, main page]

Skip to content

Commit 5c31bcf

Browse files
hhbtensorflower-gardener
authored andcommitted
Add traceme for BlockHostUntilDone to identify the time consumption.
PiperOrigin-RevId: 766382982
1 parent 7f32242 commit 5c31bcf

File tree

1 file changed

+48
-12
lines changed

1 file changed

+48
-12
lines changed

third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc

Lines changed: 48 additions & 12 deletions
8000
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,11 @@ class TfrtGpuAsyncHostToDeviceTransferManager final
487487
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
488488
stream, literal, shaped_buffer));
489489

490-
absl::Status status = stream->BlockHostUntilDone();
490+
absl::Status status;
491+
{
492+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
493+
status = stream->BlockHostUntilDone();
494+
}
491495
VLOG(3) << "Finish transfer h2d for literal with shape "
492496
<< literal.shape().ToString() << " on device "
493497
<< device_->DebugString() << " with status " << status;
@@ -596,7 +600,11 @@ class TfrtGpuAsyncHostToDeviceTransferManager final
596600
TF_CHECK_OK(stream->Memcpy(&sub_buffer, host_data_ptr, transfer_size))
597601
<< "Failed to copy data to GPU";
598602

599-
absl::Status status = stream->BlockHostUntilDone();
603+
absl::Status status;
604+
{
605+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
606+
status = stream->BlockHostUntilDone();
607+
}
600608
VLOG(3) << "H2D copy done: " << status;
601609
CHECK_OK(status) << "Failed to block host until done";
602610
}
@@ -907,12 +915,15 @@ SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction(
907915
}
908916

909917
// Wait for the data to be available on the host.
910-
absl::Status st = stream->BlockHostUntilDone();
918+
{
919+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
920+
status = stream->BlockHostUntilDone();
921+
}
911922
VLOG(3) << "D2H copy done. " << status;
912-
if (!st.ok()) {
923+
if (!status.ok()) {
913924
done_event.SetError(absl::InternalError(absl::StrFormat(
914925
"failed to synchronize send operation with a stream: %s",
915-
st.message())));
926+
status.message())));
916927
return;
917928
}
918929

@@ -2063,7 +2074,10 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuClient::BufferFromHostBuffer(
20632074
dst_definition_event.SetError(status);
20642075
return;
20652076
}
2066-
status = stream->BlockHostUntilDone();
2077+
{
2078+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
2079+
status = stream->BlockHostUntilDone();
2080+
}
20672081
VLOG(3) << "H2D copy done. " << status;
20682082
if (status.ok()) {
20692083
copy_event.SetStateConcrete();
@@ -2177,7 +2191,11 @@ TfrtGpuClient::BufferFromHostLiteral(const LiteralSlice& literal,
21772191
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
21782192
stream, literal, shaped_buffer));
21792193

2180-
auto status = stream->BlockHostUntilDone();
2194+
absl::Status status;
2195+
{
2196+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
2197+
status = stream->BlockHostUntilDone();
2198+
}
21812199
CHECK_OK(status) << "Failed to block host until done";
21822200
VLOG(3) << "BufferFromHostLiteral done for device_buffer: "
21832201
<< device_buffer << " AsyncValue: " << av.get();
@@ -2592,7 +2610,10 @@ absl::StatusOr<Shape> TfrtGpuBuffer::logical_on_device_shape() {
25922610
auto stream = device_->stream();
25932611
TF_RETURN_IF_ERROR(
25942612
transfer_manager->ReadDynamicShapes(stream, &shaped_buffer, &ret_shape));
2595-
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
2613+
{
2614+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
2615+
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
2616+
}
25962617
return ret_shape;
25972618
}
25982619

@@ -2870,7 +2891,11 @@ PjRtFuture<> TfrtGpuBuffer::ToLiteral(MutableLiteralBase* literal) {
28702891
byte_size))
28712892
<< "stream->Memcpy failed copying from GPU to host";
28722893

2873-
absl::Status status = stream->BlockHostUntilDone();
2894+
absl::Status status;
2895+
{
2896+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
2897+
status = stream->BlockHostUntilDone();
2898+
}
28742899
VLOG(3) << "D2H copy done. " << status;
28752900
if (!status.ok()) {
28762901
VLOG(3) << "stream->BlockHostUntilDone failed: " << status;
@@ -3010,7 +3035,11 @@ PjRtFuture<> TfrtGpuBuffer::CopyRawToHostFuture(PjRtFuture<void*> dst,
30103035
<< host_ptr << " (" << transfer_size << " bytes)";
30113036
CHECK_OK(stream->Memcpy(host_ptr, *sub_buffer, transfer_size))
30123037
<< "stream->Memcpy failed copying from GPU to host";
3013-
absl::Status status = stream->BlockHostUntilDone();
3038+
absl::Status status;
3039+
{
3040+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
3041+
status = stream->BlockHostUntilDone();
3042+
}
30143043
VLOG(3) << "D2H copy done. " << status;
30153044
if (!status.ok()) {
30163045
LOG(ERROR) << "stream->BlockHostUntilDone failed: " << status;
@@ -3204,7 +3233,10 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> TfrtGpuBuffer::CopyToMemorySpace(
32043233
dst_definition_event.SetError(status);
32053234
return;
32063235
}
3207-
status = stream->BlockHostUntilDone();
3236+
{
3237+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
3238+
status = stream->BlockHostUntilDone();
3239+
}
32083240
if (status.ok()) {
32093241
VLOG(3) << "D2D copy done. dst: " << dst.opaque();
32103242
dst_definition_event.SetStateConcrete();
@@ -3788,7 +3820,11 @@ absl::StatusOr<PjRtLoadedExecutable::Result> TfrtGpuExecutable::ExecuteHelper(
37883820
// has completed, so that the next execute_fn can start.
37893821
scheduled_event.SetStateConcrete();
37903822

3791-
absl::Status status = stream->BlockHostUntilDone();
3823+
absl::Status status;
3824+
{
3825+
tsl::profiler::TraceMe traceme("BlockHostUntilDone");
3826+
status = stream->BlockHostUntilDone();
3827+
}
37923828
if (!status.ok()) {
37933829
LOG(ERROR) << "BlockHostUntilDone failed for executable "
37943830
<< executable_name << " on device "

0 commit comments

Comments
 (0)
0