@@ -1924,6 +1924,13 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps {
1924
1924
ragged_all_to_all->shape ().dimensions ().begin (),
1925
1925
ragged_all_to_all->shape ().dimensions ().end ()};
1926
1926
1927
+ // The ragged-all-to-all accepts an output tensor as a parameter to allow
1928
+ // buffer reuse. We initialize the output tensor with -1 to make sure that
1929
+ // we don't accidentally overwrite data that is not part of the
1930
+ // ragged-all-to-all update.
1931
+ Array<float > output_init_data (ragged_tensor_sizes);
1932
+ output_init_data.Fill (-1 );
1933
+
1927
1934
Array<IndexType> output_sizes = input_sizes;
1928
1935
output_sizes.TransposeDimensions ({1 , 0 , 2 });
1929
1936
@@ -1934,8 +1941,7 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps {
1934
1941
int64_t num_replicas = input_sizes.dim (0 );
1935
1942
std::vector<Array<float >> input_data (num_replicas,
1936
1943
Array<float >(ragged_tensor_sizes));
1937
- std::vector<Array<float >> output_data (num_replicas,
1938
- Array<float >(ragged_tensor_sizes));
1944
+ std::vector<Array<float >> output_data (num_replicas, output_init_data);
1939
1945
FillWithRandomData (input_data, output_data, input_offsets, output_offsets,
1940
1946
input_sizes);
1941
1947
@@ -1955,9 +1961,7 @@ class RaggedAllToAllTest : public AsyncMemcpyCollectiveOps {
1955
1961
GetReplicaSlice (replica_id, output_sizes)));
1956
1962
}
1957
1963
1958
- // The ragged-all-to-all accepts an output tensor as a parameter to allow
1959
- // buffer reuse. We initialize the output tensor with zeros.
1960
- output_init_ = LiteralUtil::CreateFull (ragged_tensor_sizes, 0 );
1964
+ output_init_ = LiteralUtil::CreateFromArray (output_init_data);
1961
1965
}
1962
1966
1963
1967
// Returns a vector of pointers to the literals in the format needed for
0 commit comments