8000 copy value when get numpy. · SciSharp/TensorFlow.NET@432ae20 · GitHub
[go: up one dir, main page]

Skip to content

Commit 432ae20

Browse files
committed
copy value when get numpy.
1 parent f566505 commit 432ae20

File tree

7 files changed

+12
-7
lines changed

7 files changed

+12
-7
lines changed

src/TensorFlowNET.Core/NumPy/AutoNumPyAttribute.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
using MethodBoundaryAspect.Fody.Attributes;
22
using System;
33
using System.Collections.Generic;
4+
using System.Diagnostics;
45
using System.Linq;
56
using Tensorflow.Eager;
67
using Tensorflow.Functions;
78
using static Tensorflow.Binding;
89

910
namespace Tensorflow.NumPy
1011
{
12+
[DebuggerStepThrough]
1113
public sealed class AutoNumPyAttribute : OnMethodBoundaryAspect
1214
{
1315
bool _changedMode = false;

src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public partial class np
1010
{
1111
[AutoNumPy]
1212
public static NDArray argmax(NDArray a, Axis axis = null)
13-
=> new NDArray(math_ops.argmax(a, axis));
13+
=> new NDArray(math_ops.argmax(a, axis ?? 0));
1414

1515
[AutoNumPy]
1616
public static NDArray argsort(NDArray a, Axis axis = null)

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public NDArray(long[] value, Shape? shape = null)
3131
public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
3232
: base(address, shape, dtype) { NewEagerTensorHandle(); }
3333

34-
public NDArray(Tensor tensor, bool eval = true) : base(tensor.Handle)
34+
public NDArray(Tensor tensor, bool clone = false) : base(tensor.Handle, clone: clone)
3535
{
3636
if (_handle is null)
3737
{

src/TensorFlowNET.Core/Operations/Operation.Input.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public int InputListLength(string name)
3535
tf.Status.Check(true);
3636
return num;
3737
}
38-
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
38+
public int NumInputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumInputs(_handle);
3939
private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray();
4040

4141
protected InputList _inputs_val;

src/TensorFlowNET.Core/Operations/Operation.Output.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace Tensorflow
2323
{
2424
public partial class Operation
2525
{
26-
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
26+
public int NumOutputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumOutputs(_handle);
2727
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index));
2828

2929
public int OutputListLength(string name)
@@ -38,7 +38,7 @@ public int OutputListLength(string name)
3838
public virtual Tensor[] outputs => _outputs;
3939
public Tensor output => _outputs.FirstOrDefault();
4040

41-
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
41+
public int NumControlOutputs => _handle == IntPtr.Zero ? -1 : c_api.TF_OperationNumControlOutputs(_handle);
4242

4343
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
4444

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ public Tensor()
3939
/// Create a Tensor object from an existing TF handle
4040
/// </summary>
4141
/// <param name="handle">Handle to a <see cref="Tensor"/> object.</param>
42-
public Tensor(SafeTensorHandle handle)
42+
public unsafe Tensor(SafeTensorHandle handle, bool clone = false)
4343
{
4444
_handle = handle;
45+
if (clone)
46+
_handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer());
47+
4548
isCreatedInGraphMode = !tf.executing_eagerly();
4649
}
4750

src/TensorFlowNET.Core/Tensors/Tensor.Value.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ protected NDArray GetNDArray(TF_DataType dtype)
5555
return new NDArray(str, shape);
5656
}
5757

58-
return new NDArray(this);
58+
return new NDArray(this, clone: true);
5959
}
6060

6161
/// <summary>

0 commit comments

Comments
 (0)
0