10000 fix: error when training SimpleRNN. · SciSharp/TensorFlow.NET@07ea656 · GitHub
[go: up one dir, main page]

Skip to content

Commit 07ea656

Browse files
committed
fix: error when training SimpleRNN.
1 parent f1fbcf2 commit 07ea656

File tree

8 files changed

+78
-35
lines changed

8 files changed

+78
-35
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Exceptions
6+
{
7+
public class NotOkStatusException : TensorflowException
8+
{
9+
public NotOkStatusException() : base()
10+
{
11+
12+
}
13+
14+
public NotOkStatusException(string message) : base(message)
15+
{
16+
17+
}
18+
}
19+
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,16 @@ public void run(FeedItem[] feed_dict = null, Session session = null)
186186
}
187187

188188
public virtual T get_attr<T>(string name)
189-
=> (T)get_attr(name);
189+
{
190+
if (typeof(T).IsValueType)
191+
{
192+
return (T)Convert.ChangeType(get_attr(name), typeof(T));
193+
}
194+
else
195+
{
196+
return (T)get_attr(name);
197+
}
198+
}
190199

191200
internal unsafe TF_DataType _get_attr_type(string name)
192201
{

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4633,8 +4633,9 @@ public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool
46334633
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary<string, object>() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } });
46344634
return _fast_path_result[0];
46354635
}
4636-
catch (Exception)
4636+
catch (Exception ex)
46374637
{
4638+
Console.WriteLine();
46384639
}
46394640
try
46404641
{

src/TensorFlowNET.Core/Status/Status.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
using System;
1818
using System.Diagnostics;
1919
using System.Runtime.CompilerServices;
20+
using Tensorflow.Exceptions;
2021
using Tensorflow.Util;
2122
using static Tensorflow.c_api;
2223

@@ -88,7 +89,7 @@ public void Check(bool throwException = false)
8889
case TF_Code.TF_INVALID_ARGUMENT:
8990
throw new InvalidArgumentError(message);
9091
default:
91-
throw new TensorflowException(message);
92+
throw new NotOkStatusException(message);
9293
}
9394
}
9495
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
namespace System.Runtime.CompilerServices
2+
{
3+
internal static class IsExternalInit { }
4+
}

src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System.Linq.Expressions;
1212
using Tensorflow.Keras.Utils;
1313
using Tensorflow.Common.Types;
14+
using System.Runtime.CompilerServices;
1415
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
1516

1617
namespace Tensorflow.Keras.Layers.Rnn
@@ -30,7 +31,19 @@ public class RNN : RnnBase
3031
private int _num_constants;
3132
protected IVariableV1 _kernel;
3233
protected IVariableV1 _bias;
33-
protected IRnnCell _cell;
34+
private IRnnCell _cell;
35+
protected IRnnCell Cell
36+
{
37+
get
38+
{
39+
return _cell;
40+
}
41+
init
42+
{
43+
_cell = value;
44+
_self_tracked_trackables.Add(_cell);
45+
}
46+
}
3447

3548
public RNN(RNNArgs args) : base(PreConstruct(args))
3649
{
@@ -40,14 +53,14 @@ public RNN(RNNArgs args) : base(PreConstruct(args))
4053
// if is StackedRnncell
4154
if (args.Cells != null)
4255
{
43-
_cell = new StackedRNNCells(new StackedRNNCellsArgs
56+
Cell = new StackedRNNCells(new StackedRNNCellsArgs
4457
{
4558
Cells = args.Cells
4659
});
4760
}
4861
else
4962
{
50-
_cell = args.Cell;
63+
Cell = args.Cell;
5164
}
5265

5366
// get input_shape
@@ -65,7 +78,7 @@ public Tensors States
6578
if (_states == null)
6679
{
6780
// CHECK(Rinne): check if this is correct.
68-
var nested = _cell.StateSize.MapStructure<Tensor?>(x => null);
81+
var nested = Cell.StateSize.MapStructure<Tensor?>(x => null);
6982
_states = nested.AsNest().ToTensors();
7083
}
7184
return _states;
@@ -83,7 +96,7 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
8396
}
8497

8598
// state_size is a array of ints or a positive integer
86-
var state_size = _cell.StateSize.ToSingleShape();
99+
var state_size = Cell.StateSize.ToSingleShape();
87100

88101
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
89102
Func<Shape, Shape> _get_output_shape;
@@ -110,12 +123,12 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
110123
return output_shape;
111124
};
112125

113-
Type type = _cell.GetType();
126+
Type type = Cell.GetType();
114127
PropertyInfo output_size_info = type.GetProperty("output_size");
115128
Shape output_shape;
116129
if (output_size_info != null)
117130
{
118-
output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape());
131+
output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape());
119132
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
120133
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
121134
}
@@ -171,7 +184,9 @@ private Tensors compute_mask(Tensors inputs, Tensors mask)
171184

172185
public override void build(KerasShapesWrapper input_shape)
173186
{
174-
object get_input_spec(Shape shape)
187+
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);
188+
189+
InputSpec get_input_spec(Shape shape)
175190
{
176191
var input_spec_shape = shape.as_int_list();
177192

@@ -213,10 +228,13 @@ object get_state_spec(Shape shape)
213228
// numpy inputs.
214229

215230

216-
if (!_cell.Built)
231+
if (Cell is Layer layer && !layer.Built)
217232
{
218-
_cell.build(input_shape);
233+
layer.build(input_shape);
234+
layer.Built = true;
219235
}
236+
237+
this.built = true;
220238
}
221239

222240
/// <summary>
@@ -247,10 +265,10 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
247265

248266
(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);
249267

250-
_maybe_reset_cell_dropout_mask(_cell);
251-
if (_cell is StackedRNNCells)
268+
_maybe_reset_cell_dropout_mask(Cell);
269+
if (Cell is StackedRNNCells)
252270
{
253-
var stack_cell = _cell as StackedRNNCells;
271+
var stack_cell = Cell as StackedRNNCells;
254272
foreach (IRnnCell cell in stack_cell.Cells)
255273
{
256274
_maybe_reset_cell_dropout_mask(cell);
@@ -300,10 +318,10 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
300318
bool is_tf_rnn_cell = false;
301319
if (constants is not null)
302320
{
303-
if (!_cell.SupportOptionalArgs)
321+
if (!Cell.SupportOptionalArgs)
304322
{
305323
throw new ValueError(
306-
$"RNN cell {_cell} does not support constants." +
324+
$"RNN cell {Cell} does not support constants." +
307325
$"Received: constants={constants}");
308326
}
309327

@@ -312,7 +330,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
312330
constants = new Tensors(states.TakeLast(_num_constants).ToArray());
313331
states = new Tensors(states.SkipLast(_num_constants).ToArray());
314332
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
315-
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
333+
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
316334
return (output, new_states.Single);
317335
};
318336
}
@@ -321,7 +339,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
321339
step = (inputs, states) =>
322340
{
323341
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
324-
var (output, new_states) = _cell.Apply(inputs, states);
342+
var (output, new_states) = Cell.Apply(inputs, states);
325343
return (output, new_states);
326344
};
327345
}
@@ -562,7 +580,7 @@ protected Tensors get_initial_state(Tensors inputs)
562580
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
563581
var dtype = input.dtype;
564582

565-
Tensors init_state = _cell.GetInitialState(null, batch_size, dtype);
583+
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);
566584

567585
return init_state;
568586
}

src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,5 @@ private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args)
3232
});
3333
return args;
3434
}
35-
36-
public override void build(KerasShapesWrapper input_shape)
37-
{
38-
var single_shape = input_shape.ToSingleShape();
39-
var input_dim = single_shape[-1];
40-
_buildInputShape = input_shape;
41-
42-
_kernel = add_weight("kernel", (single_shape[-1], args.Units),
43-
initializer: args.KernelInitializer
44-
//regularizer = self.kernel_regularizer,
45-
//constraint = self.kernel_constraint,
46-
//caching_device = default_caching_device,
47-
);
48-
}
4935
}
5036
}

test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ public void SimpleRNN()
7777
var output = keras.layers.Dense(10).Apply(x);
7878
var model = keras.Model(inputs, output);
7979
model.summary();
80+
81+
model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy());
82+
var datax = np.ones((16, 10, 8), dtype: dtypes.float32);
83+
var datay = np.ones((16));
84+
model.fit(datax, datay, epochs: 20);
8085
}
8186
[TestMethod]
8287
381A public void RNNForSimpleRNNCell()

0 commit comments

Comments
 (0)
0