8000 add SafeTensorHandle to manage tensor handle reference. · SciSharp/TensorFlow.NET@e73ed66 · GitHub
[go: up one dir, main page]

Skip to content

Commit e73ed66

Browse files
committed
add SafeTensorHandle to manage tensor handle reference.
1 parent f3cbd85 commit e73ed66

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+271
-322
lines changed

src/TensorFlowNET.Console/Tensorflow.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
</PropertyGroup>
2020

2121
<ItemGroup>
22-
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.5.0" />
22+
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.5.0" />
2323
</ItemGroup>
2424

2525
<ItemGroup>

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public partial class c_api
9999
public static extern void TF_SetAttrStringList(IntPtr desc, string attr_name, IntPtr[] values, uint[] lengths, int num_values);
100100

101101
[DllImport(TensorFlowLibName)]
102-
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, SafeStatusHandle status);
102+
public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, SafeTensorHandle value, SafeStatusHandle status);
103103

104104
[DllImport(TensorFlowLibName)]
105105
public static extern void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value);

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ public static int len(object a)
164164
return arr.Count;
165165
case ICollection arr:
166166
return arr.Count;
167-
case NDArray ndArray:
168-
return ndArray.ndim == 0 ? 1 : (int)ndArray.dims[0];
169167
case IEnumerable enumerable:
170168
return enumerable.OfType<object>().Count();
171169
case Shape arr:

src/TensorFlowNET.Core/Data/MnistDataSet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class MnistDataSet : DataSetBase
1010
public int EpochsCompleted { get; private set; }
1111
public int IndexInEpoch { get; private set; }
1212

13-
public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
13+
public MnistDataSet(NDArray images, NDArray labels, TF_DataType dataType, bool reshape)
1414
{
1515
EpochsCompleted = 0;
1616
IndexInEpoch = 0;

src/TensorFlowNET.Core/Data/ModelLoadSetting.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ public class ModelLoadSetting
66
{
77
public string TrainDir { get; set; }
88
public bool OneHot { get; set; }
9-
public Type DataType { get; set; } = typeof(float);
9+
public TF_DataType DataType { get; set; } = TF_DataType.TF_FLOAT;
1010
public bool ReShape { get; set; }
1111
public int ValidationSize { get; set; } = 5000;
1212
public int? TrainSize { get; set; }

src/TensorFlowNET.Core/DisposableObject.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private void Dispose(bool disposing)
4848
}
4949

5050
// free unmanaged memory
51-
if (_handle != IntPtr.Zero)
51+
// if (_handle != IntPtr.Zero)
5252
{
5353
// Call the appropriate methods to clean up
5454
// unmanaged resources here.

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public EagerTensor(Array array, Shape shape) : base(array, shape)
5656
public EagerTensor(byte[] bytes, Shape shape, TF_DataType dtype) : base(bytes, shape, dtype)
5757
=> NewEagerTensorHandle(_handle);
5858

59-
void NewEagerTensorHandle(IntPtr h)
59+
void NewEagerTensorHandle(SafeTensorHandle h)
6060
{
6161
_id = ops.uid();
6262
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals
303303
/// <param name="t">const tensorflow::Tensor&amp;</param>
304304
/// <returns>TFE_TensorHandle*</returns>
305305
[DllImport(TensorFlowLibName)]
306-
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status);
306+
public static extern SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status);
307307

308308
[DllImport(TensorFlowLibName)]
309309
public static extern SafeTensorHandleHandle TFE_EagerTensorHandle(IntPtr t);
@@ -334,7 +334,7 @@ public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals
334334
/// <param name="status">TF_Status*</param>
335335
/// <returns></returns>
336336
[DllImport(TensorFlowLibName)]
337-
public static extern IntPtr TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);
337+
public static extern SafeTensorHandle TFE_TensorHandleResolve(SafeTensorHandleHandle h, SafeStatusHandle status);
338338

339339

340340
/// <summary>

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,5 @@ public static implicit operator NDArray(float value)
4646

4747
public static implicit operator NDArray(double value)
4848
=> new NDArray(value);
49-
50-
public static implicit operator Tensor(NDArray nd)
51-
=> nd?._tensor;
52-
53-
public static implicit operator NDArray(Tensor tensor)
54-
=> new NDArray(tensor);
5549
}
5650
}

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public NDArray this[params int[] index]
11+
public NDArray this[params int[] indices]
1212
{
13-
get => GetData(index.Select(x => new Slice
13+
get => GetData(indices.Select(x => new Slice
1414
{
1515
Start = x,
1616
Stop = x + 1,
1717
IsIndex = true
1818
}));
1919

20-
set => SetData(index.Select(x =>
20+
set => SetData(indices.Select(x =>
2121
{
2222
if(x < 0)
2323
x = (int)dims[0] + x;
@@ -57,12 +57,37 @@ public NDArray this[NDArray mask]
5757

5858
NDArray GetData(IEnumerable<Slice> slices)
5959
{
60-
var tensor = _tensor[slices.ToArray()];
61-
return new NDArray(tensor);
60+
if (shape.IsScalar)
61+
return GetScalar();
62+
63+
var tensor = base[slices.ToArray()];
64+
if (tensor.Handle == null)
65+
tensor = tf.defaultSession.eval(tensor);
66+
return new NDArray(tensor.Handle);
67+
}
68+
69+
unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged
70+
{
71+
var offset = (ulong)ShapeHelper.GetOffset(shape, indices);
72+
return *((T*)data + offset);
73+
}
74+
75+
NDArray GetScalar()
76+
{
77+
var array = new NDArray(Shape.Scalar, dtype: dtype);
78+
unsafe
79+
{
80+
var src = (byte*)data + dtypesize;
81+
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), bytesize, bytesize);
82+
}
83+
return array;
6284
}
6385

6486
NDArray GetData(int[] indices, int axis = 0)
6587
{
88+
if (shape.IsScalar)
89+
return GetScalar();
90+
6691
if(axis == 0)
6792
{
6893
var dims = shape.as_int_list();

src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public static NDArray operator +(NDArray lhs, NDArray rhs) => lhs.Tensor + rhs.Tensor;
12-
public static NDArray operator -(NDArray lhs, NDArray rhs) => lhs.Tensor - rhs.Tensor;
13-
public static NDArray operator *(NDArray lhs, NDArray rhs) => lhs.Tensor * rhs.Tensor;
14-
public static NDArray operator /(NDArray lhs, NDArray rhs) => lhs.Tensor / rhs.Tensor;
15-
public static NDArray operator >(NDArray lhs, NDArray rhs) => lhs.Tensor > rhs.Tensor;
16-
public static NDArray operator <(NDArray lhs, NDArray rhs) => lhs.Tensor < rhs.Tensor;
11+
public static NDArray operator +(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("add", lhs, rhs));
12+
public static NDArray operator -(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("sub", lhs, rhs));
13+
public static NDArray operator *(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("mul", lhs, rhs));
14+
public static NDArray operator /(NDArray lhs, NDArray rhs) => new NDArray(BinaryOpWrapper("div", lhs, rhs));
15+
public static NDArray operator >(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.greater(lhs, rhs));
16+
public static NDArray operator <(NDArray lhs, NDArray rhs) => new NDArray(gen_math_ops.less(lhs, rhs));
17+
public static NDArray operator -(NDArray lhs) => new NDArray(gen_math_ops.neg(lhs));
1718
}
1819
}

src/TensorFlowNET.Core/NumPy/NumPy.Logical.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ namespace Tensorflow.NumPy
1010
public partial class np
1111
{
1212
public static NDArray logical_or(NDArray x1, NDArray x2)
13-
=> tf.logical_or(x1, x2);
13+
=> new NDArray(tf.logical_or(x1, x2));
1414

1515
public static NDArray logical_and(NDArray x1, NDArray x2)
16-
=> tf.logical_and(x1, x2);
16+
=> new NDArray(tf.logical_and(x1, x2));
1717
}
1818
}

src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ namespace Tensorflow.NumPy
1010
public partial class np
1111
{
1212
public static NDArray amin(NDArray x, int axis = 0)
13-
=> tf.arg_min(x, axis);
13+
=> new NDArray(tf.arg_min(x, axis));
1414

1515
public static NDArray amax(NDArray x, int axis = 0)
16-
=> tf.arg_max(x, axis);
16+
=> new NDArray(tf.arg_max(x, axis));
1717
}
1818
}

src/TensorFlowNET.Core/NumPy/Numpy.Math.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,30 @@ namespace Tensorflow.NumPy
1010
public partial class np
1111
{
1212
public static NDArray exp(NDArray x)
13-
=> tf.exp(x);
13+
=> new NDArray(tf.exp(x));
1414

1515
public static NDArray log(NDArray x)
16-
=> tf.log(x);
16+
=> new NDArray(tf.log(x));
1717

1818
public static NDArray multiply(NDArray x1, NDArray x2)
19-
=> tf.multiply(x1, x2);
19+
=> new NDArray(tf.multiply(x1, x2));
2020

2121
public static NDArray maximum(NDArray x1, NDArray x2)
22-
=> tf.maximum(x1, x2);
22+
=> new NDArray(tf.maximum(x1, x2));
2323

2424
public static NDArray minimum(NDArray x1, NDArray x2)
25-
=> tf.minimum(x1, x2);
25+
=> new NDArray(tf.minimum(x1, x2));
2626

2727
public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null, bool keepdims = false)
28-
=> tf.reduce_prod(array, axis: axis);
28+
=> new NDArray(tf.reduce_prod(array, axis: axis));
2929

3030
public static NDArray prod<T>(params T[] array) where T : unmanaged
31-
=> tf.reduce_prod(ops.convert_to_tensor(array));
31+
=> new NDArray(tf.reduce_prod(new NDArray(array)));
3232

3333
public static NDArray sqrt(NDArray x)
34-
=> tf.sqrt(x);
34+
=> new NDArray(tf.sqrt(x));
3535

3636
public static NDArray sum(NDArray x1, Axis? axis = null)
37-
=> tf.math.sum(x1, axis);
37+
=> new NDArray(tf.math.sum(x1, axis));
3838
}
3939
}

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 34 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,36 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public NDArray(bool value) => Init(value);
12-
public NDArray(byte value) => Init(value);
13-
public NDArray(short value) => Init(value);
14-
public NDArray(int value) => Init(value);
15-
public NDArray(long value) => Init(value);
16-
public NDArray(float value) => Init(value);
17-
public NDArray(double value) => Init(value);
18-
public NDArray(Array value, Shape? shape = null) => Init(value, shape);
19-
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype);
20-
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape);
21-
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype) => Init(bytes, shape, dtype);
22-
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) => Init(address, shape, dtype);
11+
public NDArray(bool value) : base(value) { NewEagerTensorHandle(); }
12+
public NDArray(byte value) : base(value) { NewEagerTensorHandle(); }
13+
public NDArray(short value) : base(value) { NewEagerTensorHandle(); }
14+
public NDArray(int value) : base(value) { NewEagerTensorHandle(); }
15+
public NDArray(long value) : base(value) { NewEagerTensorHandle(); }
16+
public NDArray(float value) : base(value) { NewEagerTensorHandle(); }
17+
public NDArray(double value) : base(value) { NewEagerTensorHandle(); }
18+
19+
public NDArray(Array value, Shape? shape = null)
20+
: base(value, shape) { NewEagerTensorHandle(); }
21+
22+
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
23+
: base(shape, dtype: dtype) { NewEagerTensorHandle(); }
24+
25+
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype)
26+
: base(bytes, shape, dtype) { NewEagerTensorHandle(); }
27+
28+
public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
29+
: base(address, shape, dtype) { NewEagerTensorHandle(); }
30+
31+
public NDArray(Tensor tensor) : base(tensor.Handle)
32+
{
33+
if (_handle is null)
34+
{
35+
tensor = tf.defaultSession.eval(tensor);
36+
_handle = tensor.Handle;
37+
}
38+
39+
NewEagerTensorHandle();
40+
}
2341

2442
public static NDArray Scalar<T>(T value) where T : unmanaged
2543
=> value switch
@@ -33,59 +51,11 @@ public static NDArray Scalar<T>(T value) where T : unmanaged
3351
_ => throw new NotImplementedException("")
3452
};
3553

36-
void Init<T>(T value) where T : unmanaged
37-
{
38-
_tensor = value switch
39-
{
40-
bool val => new Tensor(val),
41-
byte val => new Tensor(val),
42-
int val => new Tensor(val),
43-
long val => new Tensor(val),
44-
float val => new Tensor(val),
45-
double val => new Tensor(val),
46-
_ => throw new NotImplementedException("")
47-
};
48-
49-
_tensor.SetReferencedByNDArray();
50-
}
51-
52-
void Init(Array value, Shape? shape = null)
53-
{
54-
_tensor = new Tensor(value, shape ?? value.GetShape());
55-
_tensor.SetReferencedByNDArray();
56-
}
57-
58-
void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
59-
{
60-
_tensor = new Tensor(shape, dtype: dtype);
61-
_tensor.SetReferencedByNDArray();
62-
}
63-
64-
void Init(Tensor value, Shape? shape = null)
65-
{
66-
// created tensor in graph mode
67-
if (value.TensorDataPointer == IntPtr.Zero)
68-
{
69-
if (!value.graph.building_function)
70-
{
71-
value = tf.defaultSession.eval(value);
72-
value = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype);
73-
}
74-
}
75-
_tensor = value;
76-
_tensor.SetReferencedByNDArray();
77-
}
78-
79-
void Init(byte[] bytes, Shape shape, TF_DataType dtype)
80-
{
81-
_tensor = new Tensor(bytes, shape, dtype);
82-
_tensor.SetReferencedByNDArray();
83-
}
84-
85-
void Init(IntPtr address, Shape shape, TF_DataType dtype)
54+
void NewEagerTensorHandle()
8655
{
87-
_tensor = new Tensor(address, shape, dtype);
88-
_tensor.SetReferencedByNDArray();
56+
_id = ops.uid();
57+
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
58+
tf.Status.Check(true);
8959
}
9060
}
9161
}

0 commit comments

Comments
 (0)
0