8000 to_numpy_string · SciSharp/TensorFlow.NET@76abb2c · GitHub
[go: up one dir, main page]

Skip to content

Commit 76abb2c

Browse files
committed
to_numpy_string
1 parent d88fec4 commit 76abb2c

File tree

2 files changed

+88
-7
lines changed

2 files changed

+88
-7
lines changed

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,20 @@ public static Tensor shape_tensor(int[] shape)
470470
return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape");
471471
}
472472

473-
public static string to_numpy_string(Tensor tensor)
473+
public static string to_numpy_string(Tensor array)
474+
{
475+
Shape shape = array.shape;
476+
if (shape.ndim == 0)
477+
return array[0].ToString();
478+
479+
var s = new StringBuilder();
480+
s.Append("array(");
481+
PrettyPrint(s, array);
482+
s.Append(")");
483+
return s.ToString();
484+
}
485+
486+
static string Render(Tensor tensor)
474487
{
475488
if (tensor.buffer == IntPtr.Zero)
476489
return "Empty";
@@ -487,7 +500,7 @@ public static string to_numpy_string(Tensor tensor)
487500
else
488501
return $"['{string.Join("', '", tensor.StringData().Take(25))}']";
489502
}
490-
else if(dtype == TF_DataType.TF_VARIANT)
503+
else if (dtype == TF_DataType.TF_VARIANT)
491504
{
492505
return "<unprintable>";
493506
}
@@ -515,7 +528,7 @@ public static string to_numpy_string(Tensor tensor)
515528
var array = tensor.ToArray<float>();
516529
return DisplayArrayAsString(array, tensor.shape);
517530
}
518-
else if(dtype == TF_DataType.TF_DOUBLE)
531+
else if (dtype == TF_DataType.TF_DOUBLE)
519532
{
520533
var array = tensor.ToArray<double>();
521534
return DisplayArrayAsString(array, tensor.shape);
@@ -532,14 +545,72 @@ static string DisplayArrayAsString<T>(T[] array, Shape shape)
532545
if (shape.ndim == 0)
533546
return array[0].ToString();
534547

535-
var display = "array([";
548+
var display = "";
536549
if (array.Length < 10)
537550
display += string.Join(", ", array);
538551
else
539-
display += string.Join(", ", array.Take(3)) + " ... " + string.Join(", ", array.Skip(array.Length - 3));
540-
return display + "])";
552+
display += string.Join(", ", array.Take(3)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 3));
553+
return display;
554+
}
555+
556+
static void PrettyPrint(StringBuilder s, Tensor array, bool flat = false)
557+
{
558+
var shape = array.shape;
559+
560+
if (shape.Length == 1)
561+
{
562+
s.Append("[");
563+
s.Append(Render(array));
564+
s.Append("]");
565+
return;
566+
}
567+
568+
var len = shape[0];
569+
s.Append("[");
570+
571+
if (len <= 10)
572+
{
573+
for (int i = 0; i < len; i++)
574+
{
575+
PrettyPrint(s, array[i], flat);
576+
if (i < len - 1)
577+
{
578+
s.Append(", ");
579+
if (!flat)
580+
s.AppendLine();
581+
}
582+
}
583+
}
584+
else
585+
{
586+
for (int i = 0; i < 5; i++)
587+
{
588+
PrettyPrint(s, array[i], flat);
589+
if (i < len - 1)
590+
{
591+
s.Append(", ");
592+
if (!flat)
593+
s.AppendLine();
594+
}
595+
}
596+
597+
s.Append(" ... ");
598+
s.AppendLine();
599+
600+
for (int i = (int)array.size - 5; i < len; i++)
601+
{
602+
PrettyPrint(s, array[i], flat);
603+
if (i < len - 1)
604+
{
605+
s.Append(", ");
606+
if (!flat)
607+
s.AppendLine();
608+
}
609+
}
610+
}
611+
612+
s.Append("]");
541613
}
542-
543614

544615
public static ParsedSliceArgs ParseSlices(Slice[] slices)
545616
{

test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55
using System.Text;
6+
using Tensorflow;
67
using Tensorflow.NumPy;
78

89
namespace TensorFlowNET.UnitTest.NumPy
@@ -88,5 +89,14 @@ public void meshgrid_same_ndim()
8889
AssetSequenceEqual(a.ToArray<int>(), new int[] { 0, 1, 2, 0, 1, 2, 0, 1, 2 });
8990
AssetSequenceEqual(b.ToArray<int>(), new int[] { 0, 0, 0, 1, 1, 1, 2, 2, 2 });
9091
}
92+
93+
[TestMethod]
94+
public void to_numpy_string()
95+
{
96+
var nd = np.arange(10 * 10 * 10 * 10).reshape((10, 10, 10, 10));
97+
var str = tensor_util.to_numpy_string(nd);
98+
Assert.AreEqual("array([[[[0, 1, 2, ..., 7, 8, 9],", str.Substring(0, 33));
99+
Assert.AreEqual("[9990, 9991, 9992, ..., 9997, 9998, 9999]]]])", str.Substring(str.Length - 45));
100+
}
91101
}
92102
}

0 commit comments

Comments
 (0)
0