8000 [XLA:GPU] Fix `HasCycle` function · linux-on-ibm-z/tensorflow@d313af9 · GitHub
[go: up one dir, main page]

Skip to content

Commit d313af9

Browse files
frgossentensorflower-gardener
authored andcommitted
[XLA:GPU] Fix HasCycle function
This is needed to avoid deadlocks when running maxtext with PP and FSDP. In this case, we see collective-permutes with multiple cycles, that were falsely categorized as acyclic. The result is a decomposed collective-permute issuing a cyclic recv leading into a deadlock. PiperOrigin-RevId: 730578883
1 parent dfae6d7 commit d313af9

File tree

3 files changed

+83
-6
lines changed

3 files changed

+83
-6
lines changed

third_party/xla/xla/service/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ cc_library(
293293
"//xla/hlo/parser:hlo_parser",
294294
"//xla/service/graphcycles",
295295
"//xla/tsl/platform:statusor",
296-
"@com_google_absl//absl/container:flat_hash_map",
296+
"@com_google_absl//absl/container:flat_hash_set",
297297
"@com_google_absl//absl/container:inlined_vector",
298298
"@com_google_absl//absl/log:check",
299299
"@com_google_absl//absl/status:statusor",

third_party/xla/xla/service/collective_permute_cycle.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ limitations under the License.
1818
#include <cstddef>
1919
#include <cstdint>
2020
#include <utility>
21+
#include <vector>
2122

23+
#include "absl/container/flat_hash_set.h"
2224
#include "xla/service/source_target_pairs.h"
2325

2426
namespace xla {
@@ -124,7 +126,33 @@ CycleType GetCycleType(const SourceTargetPairs& pairs) {
124126
}
125127

126128
bool HasCycles(const SourceTargetPairs& pairs) {
127-
return GetCycleType(pairs) != CycleType::kNone;
129+
// Build source-target map for quick lookup.
130+
std::vector<int64_t> source_target_map(pairs.size(), -1);
131+
for (int64_t i = 0; i < pairs.size(); ++i) {
132+
int64_t source = pairs[i].source;
133+
int64_t target = pairs[i].target;
134+
while (source_target_map.size() <= source) source_target_map.push_back(-1);
135+
source_target_map[source] = target;
136+
}
137+
138+
// Cache indices known to be acyclic.
139+
absl::flat_hash_set<int64_t> acyclic;
140+
< 8000 /td>141+
// Search for cycles.
142+
int64_t n = source_target_map.size();
143+
for (int64_t i = 0; i < n; ++i) {
144+
absl::flat_hash_set<int64_t> path;
145+
int64_t current = i;
146+
while (current != -1 && !acyclic.contains(current)) {
147+
if (path.contains(current)) return true;
148+
path.insert(current);
149+
current = current < n ? source_target_map[current] : -1;
150+
}
151+
acyclic.insert(path.begin(), path.end());
152+
}
153+
154+
// No cycles found.
155+
return false;
128156
}
129157

130158
} // namespace collective_permute_cycle

third_party/xla/xla/service/collective_permute_cycle_test.cc

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ TEST_F(CollectivePermuteUtilsTest, HasCycles) {
5757
EXPECT_TRUE(HasCycles(fwd4_.cycle));
5858
EXPECT_TRUE(HasCycles(bwd4_.cycle));
5959

60-
EXPECT_FALSE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 2}})))
61-
<< "Lasso 3->2";
62-
EXPECT_FALSE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 1}})))
63-
<< "Lasso 3->1";
60+
EXPECT_TRUE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 2}})));
61+
EXPECT_TRUE(HasCycles(SourceTargetPairs({{0, 1}, {1, 2}, {2, 3}, {3, 1}})));
6462

6563
EXPECT_FALSE(HasCycles(SourceTargetPairs({{1, 2}, {2, 3}, {3, 0}})))
6664
<< "Forward only";
@@ -159,6 +157,57 @@ TEST_F(CollectivePermuteUtilsTest, GetCycleType) {
159157
<< "Lasso 3->1";
160158
}
161159

160+
TEST_F(CollectivePermuteUtilsTest, HasCyclesTwoCycles) {
161+
// Cycle: 0->1, 1->2, 2->3, 3->0
162+
// Cycle: 4->5, 5->6, 6->7, 7->4
163+
EXPECT_TRUE(HasCycles(SourceTargetPairs(
164+
{{0, 1}, {1, 2}, {2, 3}, {3, 0}, {4, 5}, {5, 6}, {6, 7}, {7, 4}})));
165+
}
166+
167+
TEST_F(CollectivePermuteUtilsTest, HasCyclesOneCycleAndOneAlmostCycle) {
168+
// Not a cycle: 0->1, 1->2, 2->3 (missing: 3->4)
169+
// Cycle: 4->5, 5->6, 6->7, 7->4
170+
EXPECT_TRUE(HasCycles(SourceTargetPairs(
171+
{{0, 1}, {1, 2}, {2, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 4}})));
172+
}
173+
174+
TEST_F(CollectivePermuteUtilsTest, HasCyclesTwoAlmostCycles) {
175+
// Not a cycle: 0->1, 1->2, 3->0 (missing: 2->3)
176+
// Not a cycle: 4->5, 5->6, 7->4 (missing: 6->7)
177+
EXPECT_FALSE(HasCycles(
178+
SourceTargetPairs({{0, 1}, {1, 2}, {3, 0}, {4, 5}, {5, 6}, {7, 4}})));
179+
}
180+
181+
TEST_F(CollectivePermuteUtilsTest, HasCyclesTwoCyclesInterleaved) {
182+
// Cycle: 0->2, 2->4, 4->6, 6->0
183+
// Cycle: 1->3, 3->5, 5->7, 7->1
184+
EXPECT_TRUE(HasCycles(SourceTargetPairs(
185+
{{0, 2}, {2, 4}, {4, 6}, {6, 0}, {1, 3}, {3, 5}, {5, 7}, {7, 1}})));
186+
}
187+
188+
TEST_F(CollectivePermuteUtilsTest, HasCyclesSimpleCycle) {
189+
// Cycle: 0->1, 1->2, 2->3, 3->4, 4->5, 5->6, 6->7, 7->0
190+
EXPECT_TRUE(HasCycles(SourceTargetPairs(
191+
{{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}, {5, 6}, {6, 7}, {7, 0}})));
192+
}
193+
194+
TEST_F(CollectivePermuteUtilsTest, HasCyclesSimpleAlmostCycle) {
195+
// Not a cycle: 0->1, 1->2, 2->3, 4->5, 5->6, 6->7, 7->0 (missing: 3->4)
196+
EXPECT_FALSE(HasCycles(SourceTargetPairs(
197+
{{0, 1}, {1, 2}, {2, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 0}})));
198+
}
199+
200+
TEST_F(CollectivePermuteUtilsTest, HasCyclesSelfCycle) {
201+
// Self cycle: 0->0
202+
EXPECT_TRUE(HasCycles(SourceTargetPairs({{0, 0}})));
203+
}
204+
205+
TEST_F(CollectivePermuteUtilsTest, HasCyclesSkippingFirstDeviceCycle) {
206+
// Cycle: 1->2, 2->3, 3->4, 4->5, 5->6, 6->7, 7->1
207+
EXPECT_TRUE(HasCycles(SourceTargetPairs(
208+
{{1, 2}, {2, 3}, {3, 4}, {4, 5}, {5, 6}, {6, 7}, {7, 1}})));
209+
}
210+
162211
} // namespace
163212
} // namespace collective_permute_cycle
164213
} // namespace xla

0 commit comments

Comments
 (0)
0