8000 optimize slice. · SciSharp/TensorFlow.NET@d88fec4 · GitHub
[go: up one dir, main page]

Skip to content

Commit d88fec4

Browse files
committed
optimize slice.
1 parent 3052e1f commit d88fec4

14 files changed

+209
-37
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,10 +513,13 @@ public static Shape GetShape(this object data)
513513
if (data is NDArray nd)
514514
return nd.shape;
515515

516-
if (data is Tensor tensor)
516+
else if (data is Tensor tensor)
517517
return tensor.shape;
518518

519-
if (!data.GetType().IsArray)
519+
else if (data is Axis axis)
520+
return axis.IsScalar ? Shape.Scalar : new Shape(axis.axis);
521+
522+
else if (!data.GetType().IsArray)
520523
return Shape.Scalar;
521524

522525
switch (data)

src/TensorFlowNET.Core/Data/MnistDataSet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public MnistDataSet(NDArray images, NDArray labels, TF_DataType dataType, bool r
1717

1818
NumOfExamples = (int)images.dims[0];
1919

20-
images = images.reshape((images.dims[0], images.dims[1] * images.dims[2]));
20+
// images = images.reshape((images.dims[0], images.dims[1] * images.dims[2]));
2121
images = images.astype(dataType);
2222
// for debug np.multiply performance
2323
var sw = new Stopwatch();

src/TensorFlowNET.Core/Data/MnistModelLoader.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ private NDArray ExtractImages(string file, int? limit = null)
123123

124124
bytestream.Read(buf, 0, buf.Length);
125125

126-
var data = np.frombuffer(buf, new Shape(buf.Length), np.@byte);
127-
data = data.reshape((num_images, rows, cols, 1));
128-
126+
var data = np.frombuffer(buf, (num_images, rows * cols), np.@byte);
129127
return data;
130128
}
131129
}

src/TensorFlowNET.Core/NumPy/Axis.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace Tensorflow
2424
public record Axis(params int[] axis)
2525
{
2626
public int size => axis == null ? -1 : axis.Length;
27+
public bool IsScalar { get; init; }
2728

2829
public int this[int index] => axis[index];
2930

@@ -34,7 +35,7 @@ public static implicit operator int(Axis axis)
3435
=> axis.axis[0];
3536

3637
public static implicit operator Axis(int axis)
37-
=> new Axis(axis);
38+
=> new Axis(axis) { IsScalar = true };
3839

3940
public static implicit operator Axis((int, int) axis)
4041
=> new Axis(axis.Item1, axis.Item2);

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public NDArray this[params int[] indices]
1515
Start = x,
1616
Stop = x + 1,
1717
IsIndex = true
18-
}));
18+
}).ToArray());
1919

2020
set => SetData(indices.Select(x =>
2121
{
@@ -55,21 +55,58 @@ public NDArray this[NDArray mask]
5555
}
5656
}
5757

58-
NDArray GetData(IEnumerable<Slice> slices)
58+
59+
unsafe NDArray GetData(Slice[] slices)
5960
{
6061
if (shape.IsScalar)
6162
return GetScalar();
6263

64+
if (SliceHelper.AreAllIndex(slices, out var indices1))
65+
{
66+
var newshape = ShapeHelper.GetShape(shape, slices);
67+
if (newshape.IsScalar)
68+
{
69+
var offset = ShapeHelper.GetOffset(shape, indices1);
70+
return GetScalar((ulong)offset);
71+
}
72+
else
73+
{
74+
return GetArrayData(newshape, indices1);
75+
}
76+
}
77+
else if (slices.Count() == 1)
78+
{
79+
var slice = slices[0];
80+
if (slice.Step == 1)
81+
{
82+
var newshape = ShapeHelper.GetShape(shape, slice);
83+
var array = new NDArray(newshape, dtype: dtype);
84+
85+
var new_dims = new int[shape.ndim];
86+
new_dims[0] = slice.Start ?? 0;
87+
//for (int i = 1; i < shape.ndim; i++)
88+
//new_dims[i] = (int)shape.dims[i];
89+
90+
var offset = ShapeHelper.GetOffset(shape, new_dims);
91+
var src = (byte*)data + (ulong)offset * dtypesize;
92+
var dst = (byte*)array.data;
93+
var len = (ulong)newshape.size * dtypesize;
94+
95+
System.Buffer.MemoryCopy(src, dst, len, len);
96+
97+
return array;
98+
}
99+
}
100+
101+
// default, performance is bad
63102
var tensor = base[slices.ToArray()];
64103
if (tensor.Handle == null)
65104
{
66105
if (tf.executing_eagerly())
67106
tensor = tf.defaultSession.eval(tensor);
68-
else
69-
return new NDArray(tensor);
70107
}
71-
72-
return new NDArray(tensor);
108+
109+
return new NDArray(tensor, tf.executing_eagerly());
73110
}
74111

75112
unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged
@@ -78,17 +115,26 @@ unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged
78115
return *((T*)data + offset);
79116
}
80117

81-
NDArray GetScalar()
118+
unsafe NDArray GetScalar(ulong offset = 0)
82119
{
83120
var array = new NDArray(Shape.Scalar, dtype: dtype);
84-
unsafe
85-
{
86-
var src = (byte*)data + dtypesize;
87-
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), bytesize, bytesize);
88-
}
121+
var src = (byte*)data + offset * dtypesize;
122+
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), dtypesize, dtypesize);
89123
return array;
90124
}
91125

126+
unsafe NDArray GetArrayData(Shape newshape, int[] indices)
127+
{
128+
var offset = ShapeHelper.GetOffset(shape, indices);
129+
var len = (ulong)newshape.size * dtypesize;
130+
var array = new NDArray(newshape, dtype: dtype);
131+
132+
var src = (byte*)data + (ulong)offset * dtypesize;
133+
System.Buffer.MemoryCopy(src, array.data.ToPointer(), len, len);
134+
135+
return array;
136+
}
137+
92138
NDArray GetData(int[] indices, int axis = 0)
93139
{
94140
if (shape.IsScalar)

src/TensorFlowNET.Core/NumPy/ShapeHelper.cs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow.NumPy
77
{
8-
internal class ShapeHelper
8+
public class ShapeHelper
99
{
1010
public static long GetSize(Shape shape)
1111
{
@@ -41,6 +41,34 @@ public static long[] GetStrides(Shape shape)
4141
return strides;
4242
}
4343

44+
public static Shape GetShape(Shape shape1, params Slice[] slices)
45+
{
46+
var new_dims = shape1.dims.ToArray();
47+
slices = SliceHelper.AlignWithShape(shape1, slices);
48+
49+
for (int i = 0; i < shape1.dims.Length; i++)
50+
{
51+
Slice slice = slices[i];
52+
if (slice.Equals(Slice.All))
53+
new_dims[i] = shape1.dims[i];
54+
else if (slice.IsIndex)
55+
new_dims[i] = 1;
56+
else // range
57+
new_dims[i] = (slice.Stop ?? shape1.dims[i]) - (slice.Start ?? 0);
58+
}
59+
60+
// strip first dim if is index
61+
var return_dims = new List<long>();
62+
for (int i = 0; i< new_dims.Length; i++)
63+
{
64+
if (slices[i].IsIndex)
65+
continue;
66+
return_dims.add(new_dims[i]);
67+
}
68+
69+
return new Shape(return_dims.ToArray());
70+
}
71+
4472
public static bool Equals(Shape shape, object target)
4573
{
4674
switch (target)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.NumPy
7+
{
8+
public class SliceHelper
9+
{
10+
public static Slice[] AlignWithShape(Shape shape, Slice[] slices)
11+
{
12+
// align slices
13+
var ndim = shape.ndim;
14+
var new_slices = new List<Slice>();
15+
var slice_index = 0;
16+
17+
for (int i = 0; i < ndim; i++)
18+
{
19+
if (slice_index > slices.Length - 1)
20+
{
21+
new_slices.Add(Slice.All);
22+
continue;
23+
}
24+
25+
if (slices[slice_index] == Slice.All)
26+
{
27+
new_slices.Add(Slice.All);
28+
for (int j = 0; j < ndim - slices.Length; j++)
29+
{
30+
new_slices.Add(Slice.All);
31+
i++;
32+
}
33+
}
34+
else
35+
{
36+
new_slices.Add(slices[slice_index]);
37+
}
38+
slice_index++;
39+
}
40+
41+
return new_slices.ToArray();
42+
}
43+
44+
public static bool AreAllIndex(Slice[] slices, out int[] indices)
45+
{
46+
indices = new int[slices.Length];
47+
for (int i = 0; i< slices.Length; i++)
48+
{
49+
indices[i] = slices[i].Start ?? 0;
50+
if (!slices[i].IsIndex)
51+
return false;
52+
}
53+
return true;
54+
}
55+
}
56+
}

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public NDArray(byte[] bytes, Shape shape, TF_DataType dtype)
2828
public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
2929
: base(address, shape, dtype) { NewEagerTensorHandle(); }
3030

31-
public NDArray(Tensor tensor) : base(tensor.Handle)
31+
public NDArray(Tensor tensor, bool eval = true) : base(tensor.Handle)
3232
{
3333
if (_handle is null)
3434
{
@@ -53,9 +53,12 @@ public static NDArray Scalar<T>(T value) where T : unmanaged
5353

5454
void NewEagerTensorHandle()
5555
{
56-
_id = ops.uid();
57-
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
58-
tf.Status.Check(true);
56+
if(_handle is not null)
57+
{
58+
_id = ops.uid();
59+
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
60+
tf.Status.Check(true);
61+
}
5962
}
6063
}
6164
}

src/TensorFlowNET.Core/Numpy/Slice.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,12 @@ public class Slice
115115
/// <param name="start">Start index of the slice, null means from the start of the array</param>
116116
/// <param name="stop">Stop index (first index after end of slice), null means to the end of the array</param>
117117
/// <param name="step">Optional step to select every n-th element, defaults to 1</param>
118-
public Slice(int? start = null, int? stop = null, int step = 1)
118+
public Slice(int? start = null, int? stop = null, int step = 1, bool isIndex = false)
119119
{
120120
Start = start;
121121
Stop = stop;
122122
Step = step;
123+
IsIndex = isIndex;
123124
}
124125

125126
public Slice(string slice_notation)

src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow
88
public sealed class SafeStringTensorHandle : SafeTensorHandle
99
{
1010
Shape _shape;
11-
SafeTensorHandle _handle;
11+
IntPtr _handle;
1212
const int TF_TSRING_SIZE = 24;
1313

1414
protected SafeStringTensorHandle()
@@ -18,7 +18,7 @@ protected SafeStringTensorHandle()
1818
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape)
1919
: base(handle.DangerousGetHandle())
2020
{
21-
_handle = handle;
21+
_handle = c_api.TF_TensorData(handle);
2222
_shape = shape;
2323
}
2424

@@ -28,15 +28,10 @@ protected override bool ReleaseHandle()
2828
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}");
2929
#endif
3030

31-
long size = 1;
32-
foreach (var s in _shape.dims)
33-
size *= s;
34-
var tstr = c_api.TF_TensorData(_handle);
35-
36-
for (int i = 0; i < size; i++)
31+
for (int i = 0; i < _shape.size; i++)
3732
{
38-
c_api.TF_StringDealloc(tstr);
39-
tstr += TF_TSRING_SIZE;
33+
c_api.TF_StringDealloc(_handle);
34+
_handle += TF_TSRING_SIZE;
4035
}
4136

4237
SetHandle(IntPtr.Zero);

src/TensorFlowNET.Core/Tensors/Tensor.String.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public SafeStringTensorHandle StringTensor(string[] strings, Shape shape)
2323
public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape)
2424
{
2525
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
26-
shape.ndim == 0 ? null : shape.dims,
26+
shape.dims,
2727
10000 shape.ndim,
2828
(ulong)shape.size * TF_TSRING_SIZE);
2929

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ public static Tensor shape_tensor(int[] shape)
472472

473473
public static string to_numpy_string(Tensor tensor)
474474
{
475+
if (tensor.buffer == IntPtr.Zero)
476+
return "Empty";
477+
475478
var dtype = tensor.dtype;
476479
var shape = tensor.shape;
477480

src/TensorFlowNET.Core/ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public static Tensor convert_to_tensor(object value,
161161
IEnumerable<Tensor> tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name),
162162
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
163163
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref),
164-
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name),
164+
Axis ts => constant_op.constant(ts, dtype: dtype, name: name),
165165
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name),
166166
string str => constant_op.constant(str, dtype: tf.@string, name: name),
167167
string[] str => constant_op.constant(str, dtype: tf.@string, name: name),

0 commit comments

Comments
 (0)
0