8000 [XLA:GPU] Move GetCycleTypeAndIndices to its only use site · IBMZ-Linux-OSS-Python/tensorflow@bae08f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit bae08f5

Browse files
frgossentensorflower-gardener
authored andcommitted
[XLA:GPU] Move GetCycleTypeAndIndices to its only use site
PiperOrigin-RevId: 766369503
1 parent 61688d4 commit bae08f5

File tree

4 files changed

+61
-84
lines changed

4 files changed

+61
-84
lines changed

third_party/xla/xla/service/collective_ops_utils.cc

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -815,58 +815,6 @@ HloInstruction* IsOrHasCollectiveWithChannelId(HloInstruction* instruction) {
815815
return nullptr;
816816
}
817817

818-
using SourceTargetPairType = std::pair<int64_t, int64_t>;
819-
using SourceTargetPairsType = std::vector<SourceTargetPairType>;
820-
821-
std::pair<CycleType, std::set<int>> GetCycleTypeAndIndices(
822-
const SourceTargetPairsType& pairs) {
823-
std::set<int> seen_replica_ids;
824-
std::set<std::pair<int64_t, int64_t>> tentative_results;
825-
// first figure out if we're dealing with a potential forward or backward
826-
// cycle.
827-
int forward_edge_counter = 0;
828-
int backward_edge_counter = 0;
829-
for (auto pair : pairs) {
830-
pair.first < pair.second ? forward_edge_counter++ : backward_edge_counter++;
831-
}
832-
bool is_forward_cycle = forward_edge_counter > backward_edge_counter;
833-
for (int64_t i = 0; i < pairs.size(); ++i) {
834-
const SourceTargetPairType& pair = pairs[i];
835-
if (is_forward_cycle) {
836-
// check if the source of the current pair is smaller than the target
837-
if (pair.first < pair.second) {
838-
seen_replica_ids.insert(pair.first);
839-
} else {
840-
// the source of the current pair is larger than the target, so the
841-
// current pair may be part of a cycle. We keep track of the target ID
842-
// and the index of the pair in the original pairs array.
843-
tentative_results.insert(std::make_pair(pair.second, i));
844-
}
845-
} else {
846-
// The backward cycle check uses similar logic but in reverse.
847-
if (pair.first > pair.second) {
848-
seen_replica_ids.insert(pair.second);
849-
} else {
850-
tentative_results.insert(std::make_pair(pair.first, i));
851-
}
852-
}
853-
}
854-
std::set<int> final_results;
855-
// Iterate over the tentative results and only keep the indices that form an
856-
// actual cycle. This is done by checking if the target replica ID of the
857-
// pair is in the set of seen replica IDs. Note that the tentative results
858-
// array will be fairly small in practice, so this is not adding too much to
859-
// the runtime.
860-
for (auto& [replica_id, index] : tentative_results) {
861-
if (seen_replica_ids.find(replica_id) != seen_replica_ids.end()) {
862-
final_results.insert(index);
863-
}
864-
}
865-
CycleType cycle_type = final_results.empty() ? CycleType::kNone
866-
: is_forward_cycle ? CycleType::kForward
867-
: CycleType::kBackward;
868-
return std::make_pair(cycle_type, final_results);
869-
}
870818

871819
bool IsExclusivelyCrossModule(absl::Span<const ReplicaGroup> replica_groups,
872820
bool use_global_ids, bool has_channel_id,

third_party/xla/xla/service/collective_ops_utils.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,6 @@ absl::StatusOr<bool> IsAsyncCollective(const HloInstruction* instruction);
262262
// collective fusion) with channel_id.
263263
HloInstruction* IsOrHasCollectiveWithChannelId(HloInstruction* instruction);
264264

265-
// Returns the cycle type and indices of the vertices that form cycles. For
266-
// example, GetCycleTypeAndIndices({{0,3},{1,0},{2,1},{3,2}}) returns
267-
// {kBackward, {0}}, since the communication pattern contains a backward cycle
268-
// with the cycle-inducing vertex at index 0 in the input source-target pairs
269-
// array. This function uses the assumption that, in practice, in forward
270-
// cycles, most edges will have the target replica ID greater than the source
271-
// replica ID except for the back edges that form cycles (similar logic applies
272-
// to backward cycles).
273-
std::pair<collective_permute_cycle::CycleType, std::set<int>>
274-
GetCycleTypeAndIndices(const std::vector<std::pair<int64_t, int64_t>>& pairs);
275265

276266
// Key that identifies a particular Rendezvous object in our global hashtable.
277267
// This determines which calls to ExecuteOnStream communicate with each other.

third_party/xla/xla/service/collective_ops_utils_test.cc

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -188,27 +188,6 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) {
188188
EXPECT_EQ(IsOrHasCollectiveWithChannelId(fusion2.get()), nullptr);
189189
}
190190

191-
TEST(CollectiveOpsUtilsTest, GetForwardCycleIndices) {
192-
auto res_one_cycle = GetCycleTypeAndIndices({{0, 1}, {1, 2}, {2, 3}, {3, 0}});
193-
EXPECT_EQ(res_one_cycle.first, CycleType::kForward);
194-
EXPECT_THAT(res_one_cycle.second, testing::UnorderedElementsAreArray({3}));
195-
auto res_two_cycles =
196-
GetCycleTypeAndIndices({{0, 1}, {1, 2}, {2, 3}, {3, 0}, {4, 1}});
197-
EXPECT_EQ(res_two_cycles.first, CycleType::kForward);
198-
EXPECT_THAT(res_two_cycles.second,
199-
testing::UnorderedElementsAreArray({3, 4}));
200-
}
201-
202-
TEST(CollectiveOpsUtilsTest, GetBackwardCycleIndices) {
203-
auto res_one_cycle = GetCycleTypeAndIndices({{0, 3}, {1, 0}, {2, 1}, {3, 2}});
204-
EXPECT_EQ(res_one_cycle.first, CycleType::kBackward);
205-
EXPECT_THAT(res_one_cycle.second, testing::UnorderedElementsAreArray({0}));
206-
auto res_two_cycles =
207-
GetCycleTypeAndIndices({{0, 3}, {1, 4}, {2, 1}, {3, 2}, {4, 3}, {3, 0}});
208-
EXPECT_EQ(res_two_cycles.first, CycleType::kBackward);
209-
EXPECT_THAT(res_two_cycles.second,
210-
testing::UnorderedElementsAreArray({0, 1}));
211-
}
212191

213192
TEST(IsExclusivelyCrossModuleTest, CrossReplicaNoCha 9E88 nnelSet) {
214193
int64_t num_replicas = 4;

third_party/xla/xla/service/gpu/transforms/collectives/collective_permute_cycle_decomposer.cc

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ limitations under the License.
3737
#include "xla/hlo/utils/hlo_query.h"
3838
#include "xla/literal_util.h"
3939
#include "xla/service/collective_ops_utils.h"
40-
#include "xla/service/collective_permute_cycle.h"
4140
#include "xla/service/gpu/backend_configs.pb.h"
4241
#include "xla/service/source_target_pairs.h"
4342
#include "xla/shape.h"
@@ -53,6 +52,67 @@ namespace {
5352

5453
using CycleType = collective_permute_cycle::CycleType;
5554

55+
using SourceTargetPairType = std::pair<int64_t, int64_t>;
56+
using SourceTargetPairsType = std::vector<SourceTargetPairType>;
57+
58+
// Returns the cycle type and indices of the vertices that form cycles. For
59+
// example, GetCycleTypeAndIndices({{0,3},{1,0},{2,1},{3,2}}) returns
60+
// {kBackward, {0}}, since the communication pattern contains a backward cycle
61+
// with the cycle-inducing vertex at index 0 in the input source-target pairs
62+
// array. This function uses the assumption that, in practice, in forward
63+
// cycles, most edges will have the target replica ID greater than the source
64+
// replica ID except for the back edges that form cycles (similar logic applies
65+
// to backward cycles).
66+
std::pair<CycleType, std::set<int>> GetCycleTypeAndIndices(
67+
const SourceTargetPairsType& pairs) {
68+
std::set<int> seen_replica_ids;
69+
std::set<std::pair<int64_t, int64_t>> tentative_results;
70+
// first figure out if we're dealing with a potential forward or backward
71+
// cycle.
72+
int forward_edge_counter = 0;
73+
int backward_edge_counter = 0;
74+
for (auto pair : pairs) {
75+
pair.first < pair.second ? forward_edge_counter++ : backward_edge_counter++;
76+
}
77+
bool is_forward_cycle = forward_edge_counter > backward_edge_counter;
78+
for (int64_t i = 0; i < pairs.size(); ++i) {
79+
const SourceTargetPairType& pair = pairs[i];
80+
if (is_forward_cycle) {
81+
// check if the source of the current pair is smaller than the target
82+
if (pair.first < pair.second) {
83+
seen_replica_ids.insert(pair.first);
84+
} else {
85+
// the source of the current pair is larger than the target, so the
86+
// current pair may be part of a cycle. We keep track of the target ID
87+
// and the index of the pair in the original pairs array.
88+
tentative_results.insert(std::make_pair(pair.second, i));
89+
}
90+
} else {
91+
// The backward cycle check uses similar logic but in reverse.
92+
if (pair.first > pair.second) {
93+
seen_replica_ids.insert(pair.second);
94+
} else {
95+
tentative_results.insert(std::make_pair(pair.first, i));
96+
}
97+
}
98+
}
99+
std::set<int> final_results;
100+
// Iterate over the tentative results and only keep the indices that form an
101+
// actual cycle. This is done by checking if the target replica ID of the
102+
// pair is in the set of seen replica IDs. Note that the tentative results
103+
// array will be fairly small in practice, so this is not adding too much to
104+
// the runtime.
105+
for (auto& [replica_id, index] : tentative_results) {
106+
if (seen_replica_ids.find(replica_id) != seen_replica_ids.end()) {
107+
final_results.insert(index);
108+
}
109+
}
110+
CycleType cycle_type = final_results.empty() ? CycleType::kNone
111+
: is_forward_cycle ? CycleType::kForward
112+
: CycleType::kBackward;
113+
return std::make_pair(cycle_type, final_results);
114+
}
115+
56116
// Returns the cycle type and indices of the vertices that form cycles. If the
57117
// cycle type is kUnknown, the set of indices will be empty.
58118
std::pair<CycleType, std::set<int>> GetCycleTypeAndIndicesArray(

0 commit comments

Comments
 (0)
0