8000 fix: error after merging LSTM support. · SciSharp/TensorFlow.NET@6b30902 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6b30902

Browse files
committed
fix: error after merging LSTM support.
1 parent df7d700 commit 6b30902

File tree

11 files changed

+79
-89
lines changed

11 files changed

+79
-89
lines changed

src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,6 @@ namespace Tensorflow.Common.Types
77
{
88
public class GeneralizedTensorShape: Nest<Shape>
99
{
10-
////public TensorShapeConfig[] Shapes { get; set; }
11-
///// <summary>
12-
///// create a single-dim generalized Tensor shape.
13-
///// </summary>
14-
///// <param name="dim"></param>
15-
//public GeneralizedTensorShape(int dim, int size = 1)
16-
//{
17-
// var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
18-
// Shapes = Enumerable.Repeat(elem, size).ToArray();
19-
// //Shapes = new TensorShapeConfig[size];
20-
// //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
21-
// //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
22-
// ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
23-
//}
24-
2510
public GeneralizedTensorShape(Shape value, string? name = null)
2611
{
2712
NodeValue = value;

src/TensorFlowNET.Core/Common/Types/NestList.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ public sealed class NestList<T> : INestStructure<T>, IEnumerable<T>
1515
public int ShallowNestedCount => Values.Count;
1616

1717
public int TotalNestedCount => Values.Count;
18-
18+
19+
public NestList(params T[] values)
20+
{
21+
Values = new List<T>(values);
22+
}
23+
1924
public NestList(IEnumerable<T> values)
2025
{
2126
Values = new List<T>(values);

src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ public interface IRnnCell: ILayer
1010
/// <summary>
1111
/// If the derived class tends to not implement it, please return null.
1212
/// </summary>
13-
GeneralizedTensorShape? StateSize { get; }
13+
INestStructure<long>? StateSize { get; }
1414
/// <summary>
1515
/// If the derived class tends to not implement it, please return null.
1616
/// </summary>
17-
GeneralizedTensorShape? OutputSize { get; }
17+
INestStructure<long>? OutputSize { get; }
1818
/// <summary>
1919
/// Whether the optional RNN args are supported when appying the layer.
2020
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.

src/TensorFlowNET.Core/Numpy/Shape.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ limitations under the License.
1919
using System.Collections.Generic;
2020
using System.Linq;
2121
using System.Text;
22+
using Tensorflow.Common.Types;
2223
using Tensorflow.Keras.Saving.Common;
2324
using Tensorflow.NumPy;
2425

2526
namespace Tensorflow
2627
{
2728
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
28-
public class Shape
29+
public class Shape : INestStructure<long>
2930
{
3031
public int ndim => _dims == null ? -1 : _dims.Length;
3132
long[] _dims;
@@ -41,6 +42,27 @@ public long[] strides
4142
}
4243
}
4344

45+
public NestType NestType => NestType.List;
46+
47+
public int ShallowNestedCount => ndim;
48+
/// <summary>
49+
/// The total item count of depth 1 of the nested structure.
50+
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
51+
/// </summary>
52+
public int TotalNestedCount => ndim;
53+
54+
public IEnumerable<long> Flatten() => dims.Select(x => x);
55+
56+
public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func)
57+
{
58+
return new NestList<TOut>(dims.Select(x => func(x)));
59+
}
60+
61+
public Nest<long> AsNest()
62+
{
63+
return new NestList<long>(Flatten()).AsNest();
64+
}
65+
4466
#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
4567
public int Length => ndim;
4668
public long[] Slice(int start, int length)

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null,
185185
{
186186
throw new NotImplementedException();
187187
}
188-
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
189-
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
188+
public INestStructure<long> StateSize => throw new NotImplementedException();
189+
public INestStructure<long> OutputSize => throw new NotImplementedException();
190190
public bool IsTFRnnCell => throw new NotImplementedException();
191191
public bool SupportOptionalArgs => throw new NotImplementedException();
192192
}

src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ public DropoutRNNCellMixin(LayerArgs args): base(args)
1818

1919
}
2020

21-
public abstract GeneralizedTensorShape StateSize { get; }
22-
public abstract GeneralizedTensorShape OutputSize { get; }
21+
public abstract INestStructure<long> StateSize { get; }
22+
public abstract INestStructure<long> OutputSize { get; }
2323
public abstract bool SupportOptionalArgs { get; }
2424
public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype)
2525
{

src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,11 @@ public class LSTMCell : DropoutRNNCellMixin
2222
IVariableV1 _recurrent_kernel;
2323
IInitializer _bias_initializer;
2424
IVariableV1 _bias;
25-
GeneralizedTensorShape _state_size;
26-
GeneralizedTensorShape _output_size;
27-
public override GeneralizedTensorShape StateSize => _state_size;
25+
INestStructure<long> _state_size;
26+
INestStructure<long> _output_size;
27+
public override INestStructure<long> StateSize => _state_size;
2828

29-
public override GeneralizedTensorShape OutputSize => _output_size;
30-
31-
public override bool IsTFRnnCell => true;
29+
public override INestStructure<long> OutputSize => _output_size;
3230

3331
public override bool SupportOptionalArgs => false;
3432
public LSTMCell(LSTMCellArgs args)
@@ -49,10 +47,8 @@ public LSTMCell(LSTMCellArgs args)
4947
_args.Implementation = 1;
5048
}
5149

52-
_state_size = new GeneralizedTensorShape(_args.Units, 2);
53-
_output_size = new GeneralizedTensorShape(_args.Units);
54-
55-
50+
_state_size = new NestList<long>(_args.Units, _args.Units);
51+
_output_size = new NestNode<long>(_args.Units);
5652
}
5753

5854
public override void build(KerasShapesWrapper input_shape)
@@ -229,11 +225,6 @@ public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1)
229225
var o = _args.RecurrentActivation.Apply(z3);
230226
return new Tensors(c, o);
231227
}
232-
233-
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
234-
{
235-
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
236-
}
237228
}
238229

239230

src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public Tensors States
8686
set { _states = value; }
8787
}
8888

89-
private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
89+
private INestStructure<Shape> compute_output_shape(Shape input_shape)
9090
{
9191
var batch = input_shape[0];
9292
var time_step = input_shape[1];
@@ -96,13 +96,15 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
9696
}
9797

9898
// state_size is a array of ints or a positive integer
99-
var state_size = Cell.StateSize.ToSingleShape();
99+
var state_size = Cell.StateSize;
100+
if(state_size?.TotalNestedCount == 1)
101+
{
102+
state_size = new NestList<long>(state_size.Flatten().First());
103+
}
100104

101-
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
102-
Func<Shape, Shape> _get_output_shape;
103-
_get_output_shape = (flat_output_size) =>
105+
Func<long, Shape> _get_output_shape = (flat_output_size) =>
104106
{
105-
var output_dim = flat_output_size.as_int_list();
107+
var output_dim = new Shape(flat_output_size).as_int_list();
106108
Shape output_shape;
107109
if (_args.ReturnSequences)
108110
{
@@ -125,31 +127,28 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
125127

126128
Type type = Cell.GetType();
127129
PropertyInfo output_size_info = type.GetProperty("output_size");
128-
Shape output_shape;
130+
INestStructure<Shape> output_shape;
129131
if (output_size_info != null)
130132
{
131-
output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape());
132-
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
133-
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
133+
output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize);
134134
}
135135
else
136136
{
137-
output_shape = _get_output_shape(state_size);
137+
output_shape = new NestNode<Shape>(_get_output_shape(state_size.Flatten().First()));
138138
}
139139

140140
if (_args.ReturnState)
141141
{
142-
Func<Shape, Shape> _get_state_shape;
143-
_get_state_shape = (flat_state) =>
142+
Func<long, Shape> _get_state_shape = (flat_state) =>
144143
{
145-
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
144+
var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list());
146145
return new Shape(state_shape);
147146
};
148147

149148

150-
var state_shape = _get_state_shape(state_size);
149+
var state_shape = Nest.MapStructure(_get_state_shape, state_size);
151150

152-
return new List<Shape> { output_shape, state_shape };
151+
return new Nest<Shape>(new[] { output_shape, state_shape } );
153152
}
154153
else
155154
{
@@ -435,7 +434,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo
435434
tmp.add(tf.math.count_nonzero(s.Single()));
436435
}
437436
var non_zero_count = tf.add_n(tmp);
438-
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
437+
initial_state = tf.cond(non_zero_count > 0, States, initial_state);
439438
if ((int)non_zero_count.numpy() > 0)
440439
{
441440
initial_state = States;
@@ -445,16 +444,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo
445444
{
446445
initial_state = States;
447446
}
448-
// TODO(Wanglongzhi2001),
449-
// initial_state = tf.nest.map_structure(
450-
//# When the layer has a inferred dtype, use the dtype from the
451-
//# cell.
452-
// lambda v: tf.cast(
453-
// v, self.compute_dtype or self.cell.compute_dtype
454-
// ),
455-
// initial_state,
456-
// )
457-
447+
//initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state);
458448
}
459449
else if (initial_state is null)
460450
{

src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ public class SimpleRNNCell : DropoutRNNCellMixin
2424
IVariableV1 _kernel;
2525
IVariableV1 _recurrent_kernel;
2626
IVariableV1 _bias;
27-
GeneralizedTensorShape _state_size;
28-
GeneralizedTensorShape _output_size;
27+
INestStructure<long> _state_size;
28+
INestStructure<long> _output_size;
2929

30-
public override GeneralizedTensorShape StateSize => _state_size;
31-
public override GeneralizedTensorShape OutputSize => _output_size;
30+
public override INestStructure<long> StateSize => _state_size;
31+
public override INestStructure<long> OutputSize => _output_size;
3232
public override bool SupportOptionalArgs => false;
3333

3434
public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
@@ -41,8 +41,8 @@ public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
4141
}
4242
this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
4343
this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
44-
_state_size = new GeneralizedTensorShape(args.Units);
45-
_output_size = new GeneralizedTensorShape(args.Units);
44+
_state_size = new NestNode<long>(args.Units);
45+
_output_size = new NestNode<long>(args.Units);
4646
}
4747

4848
public override void build(KerasShapesWrapper input_shape)

src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using System;
2-
using System.Collections.Generic;
32
using System.ComponentModel;
43
using System.Linq;
54
using Tensorflow.Common.Extensions;
65
using Tensorflow.Common.Types;
7-
using Tensorflow.Keras.ArgsDefinition;
86
using Tensorflow.Keras.ArgsDefinition.Rnn;
97
using Tensorflow.Keras.Engine;
108
using Tensorflow.Keras.Saving;
@@ -38,24 +36,24 @@ public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
3836

3937
public bool SupportOptionalArgs => false;
4038

41-
public GeneralizedTensorShape StateSize
39+
public INestStructure<long> StateSize
4240
{
4341
get
4442
{
4543
if (_reverse_state_order)
4644
{
4745
var state_sizes = Cells.Reverse().Select(cell => cell.StateSize);
48-
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s))));
46+
return new Nest<long>(state_sizes);
4947
}
5048
else
5149
{
5250
var state_sizes = Cells.Select(cell => cell.StateSize);
53-
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s))));
51+
return new Nest<long>(state_sizes);
5452
}
5553
}
5654
}
5755

58-
public GeneralizedTensorShape OutputSize
56+
public INestStructure<long> OutputSize
5957
{
6058
get
6159
{
@@ -66,7 +64,7 @@ public GeneralizedTensorShape OutputSize
6664
}
6765
else if (RnnUtils.is_multiple_state(lastCell.StateSize))
6866
{
69-
return lastCell.StateSize.First();
67+
return new NestNode<long>(lastCell.StateSize.Flatten().First());
7068
}
7169
else
7270
{
@@ -89,7 +87,7 @@ public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null,
8987
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
9088
{
9189
// Recover per-cell states.
92-
var state_size = _reverse_state_order ? new GeneralizedTensorShape(StateSize.Reverse()) : StateSize;
90+
var state_size = _reverse_state_order ? new NestList<long>(StateSize.Flatten().Reverse()) : StateSize;
9391
var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray());
9492

9593
var new_nest_states = Nest<Tensor>.Empty;
@@ -118,20 +116,20 @@ public override void build(KerasShapesWrapper input_shape)
118116
layer.build(shape);
119117
layer.Built = true;
120118
}
121-
GeneralizedTensorShape output_dim;
119+
INestStructure<long> output_dim;
122120
if(cell.OutputSize is not null)
123121
{
124122
output_dim = cell.OutputSize;
125123
}
126124
else if (RnnUtils.is_multiple_state(cell.StateSize))
127125
{
128-
output_dim = cell.StateSize.First();
126+
output_dim = new NestNode<long>(cell.StateSize.Flatten().First());
129127
}
130128
else
131129
{
132130
output_dim = cell.StateSize;
133131
}
134-
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.ToSingleShape().dims).ToArray());
132+
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray());
135133
}
136134
this.Built = true;
137135
}

0 commit comments

Comments
 (0)
0