8000 Fix crash of transpose · linux-on-ibm-z/tensorflow@e9009ce · GitHub
[go: up one dir, main page]

Skip to content

Commit e9009ce

Browse files
Fix crash of transpose
1 parent 5c30489 commit e9009ce

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

tensorflow/core/kernels/transpose_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ void TransposeOp::Compute(OpKernelContext* ctx) {
150150
bool is_identity = true;
151151
for (int i = 0; i < dims; ++i) {
152152
int32_t d = permutation[i];
153-
if (d < 0) d += dims;
153+
if (d < 0) {
154+
d += dims;
155+
permutation[i] = d;
156+
}
154157
OP_REQUIRES(
155158
ctx, 0 <= d && d < dims,
156159
errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));

tensorflow/python/kernel_tests/math_ops/transpose_op_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,13 @@ def testError(self):
541541
self._testError(
542542
np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 1], "2 is missing")
543543

544+
def testNegativePerm(self):
545+
self.assertEqual([15, 100, 37],
546+
array_ops.transpose(
547+
constant_op.constant(
548+
1, dtype=dtypes.int32, shape=[100, 37, 15]),
549+
[-1, -3, -2]).get_shape().dims)
550+
544551

545552
if __name__ == "__main__":
546553
test.main()

0 commit comments

Comments
 (0)
0