@@ -37,7 +37,6 @@ limitations under the License.
37
37
#include " xla/hlo/utils/hlo_query.h"
38
38
#include " xla/literal_util.h"
39
39
#include " xla/service/collective_ops_utils.h"
40
- #include " xla/service/collective_permute_cycle.h"
41
40
#include " xla/service/gpu/backend_configs.pb.h"
42
41
#include " xla/service/source_target_pairs.h"
43
42
#include " xla/shape.h"
@@ -53,6 +52,67 @@ namespace {
53
52
54
53
using CycleType = collective_permute_cycle::CycleType;
55
54
55
+ using SourceTargetPairType = std::pair<int64_t , int64_t >;
56
+ using SourceTargetPairsType = std::vector<SourceTargetPairType>;
57
+
58
+ // Returns the cycle type and indices of the vertices that form cycles. For
59
+ // example, GetCycleTypeAndIndices({{0,3},{1,0},{2,1},{3,2}}) returns
60
+ // {kBackward, {0}}, since the communication pattern contains a backward cycle
61
+ // with the cycle-inducing vertex at index 0 in the input source-target pairs
62
+ // array. This function uses the assumption that, in practice, in forward
63
+ // cycles, most edges will have the target replica ID greater than the source
64
+ // replica ID except for the back edges that form cycles (similar logic applies
65
+ // to backward cycles).
66
+ std::pair<CycleType, std::set<int >> GetCycleTypeAndIndices (
67
+ const SourceTargetPairsType& pairs) {
68
+ std::set<int > seen_replica_ids;
69
+ std::set<std::pair<int64_t , int64_t >> tentative_results;
70
+ // first figure out if we're dealing with a potential forward or backward
71
+ // cycle.
72
+ int forward_edge_counter = 0 ;
73
+ int backward_edge_counter = 0 ;
74
+ for (auto pair : pairs) {
75
+ pair.first < pair.second ? forward_edge_counter++ : backward_edge_counter++;
76
+ }
77
+ bool is_forward_cycle = forward_edge_counter > backward_edge_counter;
78
+ for (int64_t i = 0 ; i < pairs.size (); ++i) {
79
+ const SourceTargetPairType& pair = pairs[i];
80
+ if (is_forward_cycle) {
81
+ // check if the source of the current pair is smaller than the target
82
+ if (pair.first < pair.second ) {
83
+ seen_replica_ids.insert (pair.first );
84
+ } else {
85
+ // the source of the current pair is larger than the target, so the
86
+ // current pair may be part of a cycle. We keep track of the target ID
87
+ // and the index of the pair in the original pairs array.
88
+ tentative_results.insert (std::make_pair (pair.second , i));
89
+ }
90
+ } else {
91
+ // The backward cycle check uses similar logic but in reverse.
92
+ if (pair.first > pair.second ) {
93
+ seen_replica_ids.insert (pair.second );
94
+ } else {
95
+ tentative_results.insert (std::make_pair (pair.first , i));
96
+ }
97
+ }
98
+ }
99
+ std::set<int > final_results;
100
+ // Iterate over the tentative results and only keep the indices that form an
101
+ // actual cycle. This is done by checking if the target replica ID of the
102
+ // pair is in the set of seen replica IDs. Note that the tentative results
103
+ // array will be fairly small in practice, so this is not adding too much to
104
+ // the runtime.
105
+ for (auto & [replica_id, index] : tentative_results) {
106
+ if (seen_replica_ids.find (replica_id) != seen_replica_ids.end ()) {
107
+ final_results.insert (index);
108
+ }
109
+ }
110
+ CycleType cycle_type = final_results.empty () ? CycleType::kNone
111
+ : is_forward_cycle ? CycleType::kForward
112
+ : CycleType::kBackward ;
113
+ return std::make_pair (cycle_type, final_results);
114
+ }
115
+
56
116
// Returns the cycle type and indices of the vertices that form cycles. If the
57
117
// cycle type is kUnknown, the set of indices will be empty.
58
118
std::pair<CycleType, std::set<int >> GetCycleTypeAndIndicesArray (
0 commit comments