8000 Merge pull request #59589 from pak-laura/cherry9f9b9edd1b7e0486e9a4ec… · IBMZ-Linux-OSS-Python/tensorflow@a5e0e59 · GitHub
[go: up one dir, main page]

Skip to content

Commit a5e0e59

Browse files
Merge pull request tensorflow#59589 from pak-laura/cherry9f9b9edd1b7e0486e9a4ec3b355f6b9ee2b5e77c
Copy input tensor in a RandomShuffleOp to output directly when its nu…
2 parents 6a28a0d + 20e119c commit a5e0e59

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5175,6 +5175,17 @@ func.func @tensor_scatter_max(%tensor: tensor<?x?x?xf32>, %indices: tensor<?x2xi
51755175

51765176
// -----
51775177

5178+
// CHECK-LABEL: @random_shuffle_num_elems_le_1
5179+
func.func @random_shuffle_num_elems_le_1() -> tensor<f32> {
5180+
// CHECK: [[INPUT:%.*]] = mhlo.constant dense<1.000000e+20> : tensor<f32>
5181+
// CHECK-NEXT: return [[INPUT]]
5182+
%cst = "tf.Const"() {value = dense<1.000000e+20> : tensor<f32>} : () -> tensor<f32>
5183+
%0 = "tf.RandomShuffle"(%cst) {device = "", seed = -4294967297 : i64, seed2 = -2147483649 : i64} : (tensor<f32>) -> tensor<f32>
5184+
return %0 : tensor<f32>
5185+
}
5186+
5187+
// -----
5188+
51785189
// CHECK-LABEL: @random_shuffle_first_dim_1
51795190
// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32>
51805191
func.func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> {

tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5818,19 +5818,24 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
58185818

58195819
LogicalResult matchAndRewrite(TF::RandomShuffleOp op,
58205820
PatternRewriter &rewriter) const override {
5821-
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
5821+
auto no_op = [&]() {
5822+
rewriter.replaceOp(op, op.getValue());
5823+
return success();
5824+
};
5825+
5826+
auto input_type = op.getValue().getType().dyn_cast<RankedTensorType>();
58225827
if (!input_type) return failure();
5828+
if (input_type.hasStaticShape() && input_type.getNumElements() <= 1)
5829+
// No shuffling is required, so copy input directly to output.
5830+
return no_op();
58235831

58245832
int64_t input_rank = input_type.getRank();
58255833
int64_t first_dim_size = input_type.getDimSize(0);
58265834
if (ShapedType::isDynamic(first_dim_size)) return failure();
58275835

5828-
// We are shuffling along the first dimension. If its size is <= 1, then
5829-
// shuffling is a no-op.
5830-
if (first_dim_size <= 1) {
5831-
rewriter.replaceOp(op, op.value());
5832-
return success();
5833-
}
5836+
if (first_dim_size <= 1)
5837+
// No shuffling is required, so copy input directly to output.
5838+
return no_op();
58345839

58355840
// For vectors, shuffle values by sorting instead of the obvious
58365841
// Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,

tensorflow/compiler/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,7 @@ tf_xla_py_test(
11191119
],
11201120
deps = [
11211121
":xla_test",
1122+
"//tensorflow:tensorflow_py",
11221123
"//tensorflow/python:array_ops",
11231124
"//tensorflow/python:framework",
11241125
"//tensorflow/python:math_ops",

tensorflow/compiler/tests/random_ops_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ def testShuffle2d(self):
279279
self.assertAllEqual(len(result.flatten()), len(expected))
280280
self.assertAllEqual(set(result.flatten()), set(expected))
281281

282+
def testRandomShuffleInputRank0(self):
283+
with self.session():
284+
with self.test_scope():
285+
shuffle = random_ops.random_shuffle(value=1e20)
286+
self.evaluate(shuffle)
287+
282288

283289
if __name__ == '__main__':
284290
googletest.main()

0 commit comments

Comments
 (0)
0