8000 Reverts f83061d1a8166a1ee80944a1d963855b093153d5 · tensorflow/tensorflow@d3de32a · GitHub
[go: up one dir, main page]

Skip to content

Commit d3de32a

Browse files
Reverts f83061d
PiperOrigin-RevId: 759907649
1 parent 7b3c31b commit d3de32a

File tree

7 files changed

+71
-61
lines changed

7 files changed

+71
-61
lines changed

tensorflow/python/compat/compat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# This value changes every day with an automatic CL. It can be modified in code
3030
# via `forward_compatibility_horizon()` or with the environment variable
3131
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
32-
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 5, 16)
32+
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 5, 17)
3333
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
3434
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
3535

third_party/xla/xla/hlo/ir/collective_device_list.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ limitations under the License.
2020
#include <memory>
2121
#include <optional>
2222
#include <string>
23-
#include <utility>
2423
#include <vector>
2524

2625
#include "absl/types/span.h"
@@ -98,10 +97,6 @@ class CollectiveDeviceList {
9897
explicit CollectiveDeviceList()
9998
: replica_groups_(std::make_shared<std::vector<ReplicaGroup>>()) {};
10099

101-
explicit CollectiveDeviceList(std::vector<ReplicaGroup> replica_groups)
102-
: replica_groups_(std::make_shared<std::vector<ReplicaGroup>>(
103-
std::move(replica_groups))) {};
104-
105100
explicit CollectiveDeviceList(absl::Span<const ReplicaGroup> replica_groups)
106101
: replica_groups_(std::make_shared<std::vector<ReplicaGroup>>(
107102
replica_groups.begin(), replica_groups.end())) {};

third_party/xla/xla/service/collective_ops_utils.cc

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -379,26 +379,43 @@ GetParticipatingDevicesGroups(const HloInstruction* collective) {
379379
device_assignment, GetCollectiveReplicaGroups(collective), mode);
380380
}
381381

382-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
382+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
383383
const DeviceAssignment& device_assignment,
384-
const CollectiveDeviceList& collective_device_list,
384+
absl::Span<const ReplicaGroup> replica_groups,
385385
CollectiveOpGroupMode group_mode) {
386-
return GetParticipatingFlattenedIdGroups(
387-
collective_device_list, group_mode, device_assignment.replica_count(),
388-
device_assignment.computation_count());
386+
// Compute the device_id to flattened_id mapping once to avoid brute force
387+
// searching through device assignment repeatedly.
388+
absl::flat_hash_map<GlobalDeviceId, int64_t> device_id_to_flattened_id;
389+
for (int r = 0; r < device_assignment.replica_count(); ++r) {
390+
for (int c = 0; c < device_assignment.computation_count(); ++c) {
391+
GlobalDeviceId device_id = GlobalDeviceId(device_assignment(r, c));
392+
int64_t flattened_id = r * device_assignment.computation_count() + c;
393+
device_id_to_flattened_id[device_id] = flattened_id;
394+
}
395+
}
396+
397+
std::vector<ReplicaGroup> flattened_id_groups;
398+
TF_ASSIGN_OR_RETURN(std::vector<std::vector<GlobalDeviceId>> device_groups,
399+
GetParticipatingDevicesGroups(
400+
device_assignment, replica_groups, group_mode));
401+
for (const auto& device_group : device_groups) {
402+
ReplicaGroup flattened_id_group;
403+
flattened_id_group.mutable_replica_ids()->Reserve(device_group.size());
404+
for (const GlobalDeviceId& device_id : device_group) {
405+
flattened_id_group.add_replica_ids(device_id_to_flattened_id[device_id]);
406+
}
407+
flattened_id_groups.push_back(flattened_id_group);
408+
}
409+
return flattened_id_groups;
389410
}
390411

391-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
392-
const CollectiveDeviceList& collective_device_list,
412+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
413+
absl::Span<const ReplicaGroup> replica_groups,
393414
CollectiveOpGroupMode group_mode, int replica_count, int partition_count) {
394-
if (group_mode == CollectiveOpGroupMode::kFlattenedID) {
395-
return collective_device_list;
396-
}
397415
std::vector<ReplicaGroup> filled_empty_replica_group;
398-
absl::Span<const ReplicaGroup> original_replica_groups =
399-
collective_device_list.replica_groups();
416+
absl::Span<const ReplicaGroup> original_replica_groups = replica_groups;
400417
std::vector<ReplicaGroup> flattened_replica_groups;
401-
if (collective_device_list.replica_groups().empty()) {
418+
if (replica_groups.empty()) {
402419
filled_empty_replica_group.emplace_back();
403420
const int64_t id_count =
404421
group_mode == CollectiveOpGroupMode::kCrossPartition ? partition_count
@@ -408,7 +425,11 @@ absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
408425
}
409426
original_replica_groups = filled_empty_replica_group;
410427
}
411-
if (group_mode == CollectiveOpGroupMode::kCrossReplica) {
428+
if (group_mode == CollectiveOpGroupMode::kFlattenedID) {
429+
flattened_replica_groups.insert(flattened_replica_groups.end(),
430+
original_replica_groups.begin(),
431+
original_replica_groups.end());
432+
} else if (group_mode == CollectiveOpGroupMode::kCrossReplica) {
412433
flattened_replica_groups.resize(original_replica_groups.size() *
413434
partition_count);
414435
for (int64_t i = 0, current_group_offset = 0;
@@ -453,30 +474,30 @@ absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
453474
}
454475
}
455476
}
456-
return CollectiveDeviceList(flattened_replica_groups);
477+
return flattened_replica_groups;
457478
}
458479

459-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
480+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
460481
const HloInstruction* hlo, const DeviceAssignment& device_assignment) {
461482
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode mode,
462483
GetCollectiveOpGroupMode(hlo));
463484
TF_ASSIGN_OR_RETURN(
464-
CollectiveDeviceList collective_device_list,
485+
std::vector<ReplicaGroup> replica_groups,
465486
GetParticipatingFlattenedIdGroups(device_assignment,
466-
GetCollectiveDeviceList(hlo), mode));
467-
return collective_device_list;
487+
GetCollectiveReplicaGroups(hlo), mode));
488+
return replica_groups;
468489
}
469490

470491
// Same as above, used for cases where static_device_assignment is not present.
471-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
492+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
472493
const HloInstruction* hlo, int replica_count, int partition_count) {
473494
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode mode,
474495
GetCollectiveOpGroupMode(hlo));
475496
TF_ASSIGN_OR_RETURN(
476-
CollectiveDeviceList collective_device_list,
477-
GetParticipatingFlattenedIdGroups(GetCollectiveDeviceList(hlo), mode,
497+
std::vector<ReplicaGroup> replica_groups,
498+
GetParticipatingFlattenedIdGroups(GetCollectiveReplicaGroups(hlo), mode,
478499
replica_count, partition_count));
479-
return collective_device_list;
500+
return replica_groups;
480501
}
481502

482503
absl::StatusOr<std::vector<GlobalDeviceId>> GetParticipatingDevices(

third_party/xla/xla/service/collective_ops_utils.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,23 +187,23 @@ GetParticipatingDevicesGroups(const HloInstruction* collective);
187187

188188
// Same as above, except that it returns the flattened id in the replica groups
189189
// instead of device id.
190-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
190+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
191191
const DeviceAssignment& device_assignment,
192-
const CollectiveDeviceList& collective_device_list,
192+
absl::Span<const ReplicaGroup> replica_groups,
193193
CollectiveOpGroupMode group_mode);
194194

195195
// Same as above, but take replica/partition count instead of device assignment.
196-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
197-
const CollectiveDeviceList& collective_device_list,
196+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
197+
absl::Span<const ReplicaGroup> replica_groups,
198198
CollectiveOpGroupMode group_mode, int replica_count, int partition_count);
199199

200200
// Same as above, with collective group mode determined by the collective
201201
// instruction.
202-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
202+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
203203
const HloInstruction* hlo, const DeviceAssignment& device_assignment);
204204

205205
// Same as above, used for cases where static_device_assignment is not present.
206-
absl::StatusOr<CollectiveDeviceList> GetParticipatingFlattenedIdGroups(
206+
absl::StatusOr<std::vector<ReplicaGroup>> GetParticipatingFlattenedIdGroups(
207207
const HloInstruction* hlo, int replica_count, int partition_count);
208208

209209
// Figures out which devices are participating in the collective subgroup.

third_party/xla/xla/service/collective_ops_utils_test.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) {
160160
HloInstruction *instr =
161161
builder.AddInstruction(HloInstruction::CreateAllGather(
162162
ShapeUtil::MakeShape(BF16, {1, 4096, 4096}), {param_0}, 1,
163-
CollectiveDeviceList(std::vector<ReplicaGroup>({group})), true, 231,
164-
true));
163+
CollectiveDeviceList({group}), true, 231, true));
165164
auto computation = builder.Build(
166165
builder.AddInstruction(HloInstruction::CreateTuple({instr})));
167166
auto fusion =
@@ -1098,20 +1097,18 @@ TEST_P(GetParticipatingTest, Test) {
10981097
testing::UnorderedElementsAreArray(expect_device_groups));
10991098

11001099
// Test GetParticipatingFlattenedIdGroups.
1101-
absl::StatusOr<CollectiveDeviceList> collective_device_list =
1102-
GetParticipatingFlattenedIdGroups(
1103-
device_assignment, CollectiveDeviceList(replica_groups), *group_mode);
1104-
if (!collective_device_list.ok()) {
1100+
absl::StatusOr<std::vector<ReplicaGroup>> actual_flattened_id_groups =
1101+
GetParticipatingFlattenedIdGroups(device_assignment, replica_groups,
1102+
*group_mode);
1103+
if (!actual_flattened_id_groups.ok()) {
11051104
EXPECT_TRUE(tc.expected_failure);
11061105
return;
11071106
}
1108-
const std::vector<ReplicaGroup> &actual_flattened_id_groups =
1109-
collective_device_list.value().replica_groups();
11101107

11111108
std::vector<std::vector<int64_t>> actual_flattened_id_groups_int;
1112-
actual_flattened_id_groups_int.reserve(actual_flattened_id_groups.size());
1109+
actual_flattened_id_groups_int.reserve(actual_flattened_id_groups->size());
11131110

1114-
for (auto subgroup : actual_flattened_id_groups) {
1111+
for (auto subgroup : *actual_flattened_id_groups) {
11151112
std::vector<int64_t> replica_group;
11161113
for (int id : subgroup.replica_ids()) {
11171114
replica_group.push_back(id);

third_party/xla/xla/service/spmd/dot_handler.cc

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
606606
collective->opcode() == HloOpcode::kAllReduce) {
607607
communication_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
608608
ShapeUtil::ByteSizeOf(collective->shape()),
609-
collective->device_list());
609+
collective->replica_groups());
610610
}
611611
} else {
612612
auto new_lhs =
@@ -647,17 +647,18 @@ std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
647647
collective = collective->mutable_operand(0);
648648
}
649649
communication_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
650-
ShapeUtil::ByteSizeOf(dot->shape()), collective->device_list());
650+
ShapeUtil::ByteSizeOf(dot->shape()), collective->replica_groups());
651651
}
652652

653653
double extra_collective_permute_time = 0.0;
654654
if (communication_time_in_ms != 0.0) {
655655
extra_collective_permute_time =
656656
communication_time_in_ms *
657-
visitor->GetCommunicationMultiplier(collective->device_list()) * 2 /
658-
num_partitions;
657+
visitor->GetCommunicationMultiplier(collective->replica_groups()) *
658+
2 / num_partitions;
659659
VLOG(2) << "GetCommunicationMultiplier: "
660-
<< visitor->GetCommunicationMultiplier(collective->device_list());
660+
<< visitor->GetCommunicationMultiplier(
661+
collective->replica_groups());
661662
}
662663
VLOG(2) << "collective: " << collective->ToString() << "\n"
663664
<< "dot: " << dot->ToString() << "\n"
@@ -672,7 +673,7 @@ std::optional<WindowedEinsumConfig> GetWindowedEinsumConfiguration(
672673
(std::max(
673674
computation_time_in_ms,
674675
communication_time_in_ms * visitor->GetCommunicationMultiplier(
675-
collective->device_list())) +
676+
collective->replica_groups())) +
676677
extra_collective_permute_time) >=
677678
(computation_time_in_ms + communication_time_in_ms)) {
678679
VLOG(2) << "Overhead outweighs benefit. Skipping windowed einsum";
@@ -3406,13 +3407,11 @@ bool PrioritizeContractingDimensionsPartitioning(
34063407
auto reduce_scatter_subgroups = GetPartitionGroupsForReplication(
34073408
outer_output_tmp_sharding, output_slice_dims);
34083409
const double all_gather_time_in_ms = visitor->GetCommunicationTimeInMilliSec(
3409-
all_gather_bytes,
3410-
CollectiveDeviceList(visitor->CreateReplicaGroups(all_gather_subgroups)));
3410+
all_gather_bytes, visitor->CreateReplicaGroups(all_gather_subgroups));
34113411
const double reduce_scatter_time_in_ms =
34123412
visitor->GetCommunicationTimeInMilliSec(
34133413
reduce_scatter_bytes,
3414-
CollectiveDeviceList(
3415-
visitor->CreateReplicaGroups(reduce_scatter_subgroups)));
3414+
visitor->CreateReplicaGroups(reduce_scatter_subgroups));
34163415

34173416
Shape other_original_shape = other_hlo->shape();
34183417
*other_hlo->mutable_shape() =
@@ -3528,13 +3527,11 @@ bool LhsIsBestMatchForNonContractingPartitioning(
35283527
const double lhs_all_gather_time_in_ms =
35293528
visitor->GetCommunicationTimeInMilliSec(
35303529
lhs_all_gather_bytes,
3531-
CollectiveDeviceList(
3532-
visitor->CreateReplicaGroups(lhs_all_gather_subgroups)));
3530+
visitor->CreateReplicaGroups(lhs_all_gather_subgroups));
35333531
const double rhs_all_gather_time_in_ms =
35343532
visitor->GetCommunicationTimeInMilliSec(
35353533
rhs_all_gather_bytes,
3536-
CollectiveDeviceList(
3537-
visitor->CreateReplicaGroups(rhs_all_gather_subgroups)));
3534+
visitor->CreateReplicaGroups(rhs_all_gather_subgroups));
35383535

35393536
HloInstruction* compute_lhs = lhs.hlo();
35403537
Shape lhs_original_shape = compute_lhs->shape();

third_party/xla/xla/service/spmd/spmd_partitioner.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,12 +719,12 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
719719
}
720720

721721
virtual double GetCommunicationTimeInMilliSec(
722-
int64_t bytes, const CollectiveDeviceList& collective_device_list) {
722+
int64_t bytes, absl::Span<const ReplicaGroup> device_groups) {
723723
return 0.0;
724724
}
725725

726726
virtual int GetCommunicationMultiplier(
727-
const CollectiveDeviceList& collective_device_list) {
727+
absl::Span<const ReplicaGroup> device_groups) {
728728
return 1;
729729
}
730730

0 commit comments

Comments
 (0)
0