8000 [tsl:concurrency] Fix warnings in async_value.{h,cc} by copybara-service[bot] · Pull Request #93585 · tensorflow/tensorflow · GitHub
[go: up one dir, main page]

Skip to content

[tsl:concurrency] Fix warnings in async_value.{h,cc} #93585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ filegroup(
"@local_xla//xla/tsl/framework:xla_cpu_runtime_hdrs",
"@local_xla//xla/tsl/framework/fixedpoint:xla_cpu_runtime_hdrs",
"@local_xla//xla/tsl/lib/math:xla_cpu_runtime_hdrs",
"@local_xla//xla/tsl/util:xla_cpu_runtime_hdrs",
],
visibility = [
"//tensorflow/tools/pip_package:__pkg__",
Expand Down
18 changes: 13 additions & 5 deletions third_party/xla/xla/tsl/concurrency/async_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ void IndirectAsyncValue::ForwardTo(RCReference<AsyncValue> value) {
//===----------------------------------------------------------------------===//

void BlockUntilReady(AsyncValue* async_value) {
if (ABSL_PREDICT_TRUE(async_value->IsAvailable())) return;
if (ABSL_PREDICT_TRUE(async_value->IsAvailable())) {
return;
}

absl::Notification notification;
async_value->AndThen([&notification] { notification.Notify(); });
Expand All @@ -169,12 +171,16 @@ void RunWhenReady(absl::Span<AsyncValue* const> values,
// Perform a quick scan of the arguments. If they are all available,
// then we can run the callee synchronously.
absl::InlinedVector<AsyncValue*, 4> unavailable_values;
for (auto i : values) {
if (!i->IsAvailable()) unavailable_values.push_back(i);
for (AsyncValue* value : values) {
if (!value->IsAvailable()) {
unavailable_values.push_back(value);
}
}

// If we can synchronously call 'callee', then do it and we're done.
if (unavailable_values.empty()) return std::move(callee)();
if (unavailable_values.empty()) {
return std::move(callee)();
}

// If there is exactly one unavailable value, then we can just AndThen it.
if (unavailable_values.size() == 1) {
Expand All @@ -196,7 +202,9 @@ void RunWhenReady(absl::Span<AsyncValue* const> values,
for (auto* val : unavailable_values) {
val->AndThen([data]() {
// Decrement the counter unless we're the last to be here.
if (data->counter.fetch_sub(1) != 1) return;
if (data->counter.fetch_sub(1) != 1) {
return;
}

// If we are the last one, then run the callee and free the data.
std::move(data->callee)();
Expand Down
54 changes: 25 additions & 29 deletions third_party/xla/xla/tsl/concurrency/async_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,7 @@ class AsyncValue {
// for the base type, which we do not have sufficient information to perform
// at runtime.
template <typename T>
const T& get() const;

// Same as the const overload of get(), except for returning a non-const ref.
template <typename T>
T& get();
T& get() const;

// Returns the underlying error. IsError() must be true.
const absl::Status& GetError() const;
Expand Down Expand Up @@ -291,8 +287,9 @@ class AsyncValue {
is_refcounted_(is_refcounted),
type_id_(GetTypeId<T>()),
waiters_and_state_(WaitersAndState(nullptr, state)) {
if (AsyncValueAllocationTrackingEnabled() && is_refcounted)
if (AsyncValueAllocationTrackingEnabled() && is_refcounted) {
total_allocated_async_values_.fetch_add(1, std::memory_order_relaxed);
}
}

AsyncValue(Kind kind, State state, bool is_refcounted)
Expand All @@ -302,8 +299,9 @@ class AsyncValue {
is_refcounted_(is_refcounted),
type_id_(kUnknownTypeId),
waiters_and_state_(WaitersAndState(nullptr, state)) {
if (AsyncValueAllocationTrackingEnabled() && is_refcounted)
if (AsyncValueAllocationTrackingEnabled() && is_refcounted) {
total_allocated_async_values_.fetch_add(1, std::memory_order_relaxed);
}
}

AsyncValue(const AsyncValue&) = delete;
Expand Down Expand Up @@ -463,7 +461,7 @@ class AsyncValue {
static uint16_t CreateTypeInfoAndReturnTypeIdImpl(const TypeInfo& type_info);

template <typename T>
const T& GetConcreteValue() const;
T& GetConcreteValue() const;

// Returns the TypeInfoTable instance (there is one per process).
using TypeInfoTable = internal::ConcurrentVector<TypeInfo>;
Expand Down Expand Up @@ -615,12 +613,7 @@ class ConcreteAsyncValue : public AsyncValue {
NotifyAvailable(State::kError);
}

const T& get() const {
DCHECK(HasData());
return data_store_.data();
}

T& get() {
T& get() const {
DCHECK(HasData());
return data_store_.data();
}
Expand Down Expand Up @@ -708,7 +701,9 @@ class ConcreteAsyncValue : public AsyncValue {
}

void Destroy(State s) {
if (HasData()) data().~T();
if (HasData()) {
data().~T();
}
error_.reset();
has_data_ = false;
}
Expand All @@ -725,8 +720,9 @@ class ConcreteAsyncValue : public AsyncValue {
has_data_ = true;
}

T& data() { return *reinterpret_cast<T*>(&data_); }
const T& data() const { return *reinterpret_cast<const T*>(&data_); }
T& data() { return *reinterpret_cast<T*>(data_); }

const T& data() const { return *reinterpret_cast<const T*>(data_); }

bool HasData(State s) const { return has_data_; }
bool HasData() const { return has_data_; }
Expand Down Expand Up @@ -845,8 +841,9 @@ class TypedIndirectAsyncValue : public IndirectAsyncValue {
inline AsyncValue::~AsyncValue() {
DCHECK_EQ(waiters_and_state_.load().waiter(), nullptr)
<< "An async value with waiters should never have refcount of zero";
if (AsyncValueAllocationTrackingEnabled() && is_refcounted_)
if (AsyncValueAllocationTrackingEnabled() && is_refcounted_) {
total_allocated_async_values_.fetch_sub(1, std::memory_order_relaxed);
}

// Catch use-after-free errors more eagerly, by triggering the size assertion
// in the 'get' accessor.
Expand Down Expand Up @@ -927,18 +924,18 @@ inline void AsyncValue::DropRef(uint32_t count) {
}

template <typename T>
const T& AsyncValue::GetConcreteValue() const {
T& AsyncValue::GetConcreteValue() const {
// Make sure both T (the stored type) and BaseT have vtable_ptr or
// neither have the vtable_ptr.
DCHECK_EQ(std::is_polymorphic<T>::value, has_vtable_);
DCHECK(IsTypeIdCompatible<T>()) << "Incorrect accessor";

const char* this_ptr = reinterpret_cast<const char*>(this);
return *reinterpret_cast<const T*>(this_ptr + AsyncValue::kDataOffset);
uintptr_t base = reinterpret_cast<uintptr_t>(this);
return *reinterpret_cast<T*>(base + AsyncValue::kDataOffset);
}

template <typename T>
const T& AsyncValue::get() const {
T& AsyncValue::get() const {
auto s = state();
(void)s;

Expand Down Expand Up @@ -968,11 +965,6 @@ const T& AsyncValue::get() const {
}
}

template <typename T>
T& AsyncValue::get() {
return const_cast<T&>(static_cast<const AsyncValue*>(this)->get<T>());
}

inline void AsyncValue::SetStateConcrete() {
DCHECK(IsConstructed() && kind() == Kind::kConcrete);
NotifyAvailable(State::kConcrete);
Expand All @@ -991,13 +983,17 @@ void AsyncValue::emplace(Args&&... args) {
inline const absl::Status* AsyncValue::GetErrorIfPresent() const {
switch (kind()) {
case Kind::kConcrete: {
if (state() != State::kError) return nullptr;
if (state() != State::kError) {
return nullptr;
}
return &GetTypeInfo().get_error(this);
}
case Kind::kIndirect: {
auto* iv_value = static_cast<const IndirectAsyncValue*>(this)->value_;
// Unresolved IndirectAsyncValues are not errors.
if (!iv_value) return nullptr;
if (!iv_value) {
return nullptr;
}

DCHECK(iv_value->kind() != Kind::kIndirect);
return iv_value->GetErrorIfPresent();
Expand Down
19 changes: 7 additions & 12 deletions third_party/xla/xla/tsl/concurrency/async_value_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License.

#include "xla/tsl/concurrency/async_value.h"

#include <cstddef>
#include <cstdint>
#include <memory>
#include <utility>
Expand Down Expand Up @@ -197,20 +196,16 @@ TEST(AsyncValueTest, MoveOnlyCallback) {
//===----------------------------------------------------------------------===//

static void BM_AddAndThenCallback(benchmark::State& state) {
size_t n = 0;

auto ref = MakeConstructedAsyncValueRef<int32_t>(42);
for (auto _ : state) {
// Reset AsyncValue to avoid keeping enqueued callbacks alive.
if (++n % 1024 == 0) {
ref.SetStateConcrete();
ref = MakeConstructedAsyncValueRef<int32_t>(42);
}
internal::AsyncValueStorage<int32_t> storage;

ref.AndThen([] {});
}
AsyncValueOwningRef<int32_t> owner =
MakeConstructedAsyncValueRef<int32_t>(storage, 42);
AsyncValuePtr<int32_t> ptr = owner.AsPtr();

ref.SetStateConcrete();
ptr.AndThen([] {});
ptr.SetStateConcrete();
}
}

BENCHMARK(BM_AddAndThenCallback);
Expand Down
0