From fd33af914d9c265f9eea68c2986aad82331df2ed Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Sun, 17 May 2020 15:13:16 +0300 Subject: [PATCH 1/2] Added the TensorCreation.java example --- .../examples/tensors/TensorCreation.java | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java new file mode 100644 index 0000000..fc7d342 --- /dev/null +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java @@ -0,0 +1,62 @@ +package org.tensorflow.model.examples.tensors; + +import org.tensorflow.Tensor; +import org.tensorflow.tools.Shape; +import org.tensorflow.tools.ndarray.IntNdArray; +import org.tensorflow.tools.ndarray.NdArrays; +import org.tensorflow.types.TInt32; + +/** + * Creates a few tensors of ranks: 0, 1, 2, 3. + */ +public class TensorCreation { + public static void main(String[] args) { + // Rank 0 Tensor + Tensor rank0Tensor = TInt32.scalarOf(42); + System.out.println("---- Scalar tensor ---------"); + System.out.println("DataType: " + rank0Tensor.dataType().name()); + System.out.println("Rank: " + rank0Tensor.data().rank()); + System.out.println("NumElements: " + rank0Tensor.data().size()); + + // Rank 1 Tensor + Tensor rank1Tensor = TInt32.vectorOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + System.out.println("---- Vector tensor ---------"); + System.out.println("DataType: " + rank1Tensor.dataType().name()); + System.out.println("Rank: " + rank1Tensor.data().rank()); + System.out.println("NumElements: " + rank1Tensor.data().size()); + + // Rank 2 Tensor + // 3x2 matrix of ints. + IntNdArray matrix2d = NdArrays.ofInts(Shape.of(3, 2)); + + matrix2d.set(NdArrays.vectorOf(1, 2), 0) + .set(NdArrays.vectorOf(3, 4), 1) + .set(NdArrays.vectorOf(5, 6), 2); + + Tensor rank2Tensor = TInt32.tensorOf(matrix2d); + + System.out.println("---- Matrix tensor ---------"); + System.out.println("DataType: " + rank2Tensor.dataType().name()); + System.out.println("Rank: " + rank2Tensor.data().rank()); + System.out.println("NumElements: " + rank2Tensor.data().size()); + System.out.println("6th element: " + rank2Tensor.data().getInt(2, 1)); + + // Rank 3 Tensor + // 3*2*4 matrix of ints. + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(3, 2, 4)); + + matrix3d.elements(0).forEach(matrix -> { + matrix + .set(NdArrays.vectorOf(1, 2, 3, 4), 0) + .set(NdArrays.vectorOf(5, 6, 7, 8), 1); + }); + + Tensor rank3Tensor = TInt32.tensorOf(matrix3d); + + System.out.println("---- Matrix tensor ---------"); + System.out.println("DataType: " + rank3Tensor.dataType().name()); + System.out.println("Rank: " + rank3Tensor.data().rank()); + System.out.println("NumElements: " + rank3Tensor.data().size()); + System.out.println("n-th element: " + rank3Tensor.data().getInt(2, 1, 3)); + } +} From 2ebfa9638b24645e9822bc8963ea7d859d6e505d Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Tue, 19 May 2020 17:16:48 +0300 Subject: [PATCH 2/2] Fixed review comments --- .../examples/tensors/TensorCreation.java | 54 ++++++++++++++++--- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java index fc7d342..2487123 100644 --- a/tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java +++ b/tensorflow-examples/src/main/java/org/tensorflow/model/examples/tensors/TensorCreation.java @@ -1,3 +1,19 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ package org.tensorflow.model.examples.tensors; import org.tensorflow.Tensor; @@ -6,6 +22,8 @@ import org.tensorflow.tools.ndarray.NdArrays; import org.tensorflow.types.TInt32; +import java.util.Arrays; + /** * Creates a few tensors of ranks: 0, 1, 2, 3. */ @@ -13,17 +31,29 @@ public class TensorCreation { public static void main(String[] args) { // Rank 0 Tensor Tensor rank0Tensor = TInt32.scalarOf(42); + System.out.println("---- Scalar tensor ---------"); + System.out.println("DataType: " + rank0Tensor.dataType().name()); - System.out.println("Rank: " + rank0Tensor.data().rank()); - System.out.println("NumElements: " + rank0Tensor.data().size()); + + System.out.println("Rank: " + rank0Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank0Tensor.shape().asArray())); + + rank0Tensor.data().scalars().forEach(value -> System.out.println("Value: " + value.getObject())); // Rank 1 Tensor Tensor rank1Tensor = TInt32.vectorOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + System.out.println("---- Vector tensor ---------"); + System.out.println("DataType: " + rank1Tensor.dataType().name()); - System.out.println("Rank: " + rank1Tensor.data().rank()); - System.out.println("NumElements: " + rank1Tensor.data().size()); + + System.out.println("Rank: " + rank1Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank1Tensor.shape().asArray())); + + System.out.println("6th element: " + rank1Tensor.data().getInt(5)); // Rank 2 Tensor // 3x2 matrix of ints. @@ -36,9 +66,13 @@ public static void main(String[] args) { Tensor rank2Tensor = TInt32.tensorOf(matrix2d); System.out.println("---- Matrix tensor ---------"); + System.out.println("DataType: " + rank2Tensor.dataType().name()); - System.out.println("Rank: " + rank2Tensor.data().rank()); - System.out.println("NumElements: " + rank2Tensor.data().size()); + + System.out.println("Rank: " + rank2Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank2Tensor.shape().asArray())); + System.out.println("6th element: " + rank2Tensor.data().getInt(2, 1)); // Rank 3 Tensor @@ -54,9 +88,13 @@ public static void main(String[] args) { Tensor rank3Tensor = TInt32.tensorOf(matrix3d); System.out.println("---- Matrix tensor ---------"); + System.out.println("DataType: " + rank3Tensor.dataType().name()); - System.out.println("Rank: " + rank3Tensor.data().rank()); - System.out.println("NumElements: " + rank3Tensor.data().size()); + + System.out.println("Rank: " + rank3Tensor.shape().size()); + + System.out.println("Shape: " + Arrays.toString(rank3Tensor.shape().asArray())); + System.out.println("n-th element: " + rank3Tensor.data().getInt(2, 1, 3)); } }