8000 fix: none gradient error when training LSTM. · SciSharp/TensorFlow.NET@675b93a · GitHub
[go: up one dir, main page]

Skip to content

Commit 675b93a

Browse files
committed
fix: none gradient error when training LSTM.
1 parent 0114885 commit 675b93a

File tree

29 files changed

+1743
-295
lines changed

29 files changed

+1743
-295
lines changed

src/TensorFlowNET.Core/APIs/tf.tensor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ public Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides = n
7171
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
7272
=> array_ops.split(
7373
value: value,
74-
num_split: num_split,
74+
num_or_size_splits: num_split,
7575
axis: axis,
7676
name: name);
7777

7878
public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
7979
=> array_ops.split(
8080
value: value,
81-
num_split: num_split,
82-
axis: axis,
81+
num_or_size_splits: num_split,
82+
axis: ops.convert_to_tensor(axis),
8383
name: name);
8484

8585
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)

src/TensorFlowNET.Core/Common/Types/Nest.cs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,25 +197,11 @@ public bool IsNested()
197197
}
198198
else if(NestType is NestType.List)
199199
{
200-
foreach(var item in ListValue!)
201-
{
202-
if(item.NestType is NestType.List or NestType.Dictionary)
203-
{
204-
return true;
205-
}
206-
}
207-
return false;
200+
return ListValue!.Count > 0;
208201
}
209202
else
210203
{
211-
foreach (var item in DictValue!.Values)
212-
{
213-
if (item.NestType is NestType.List or NestType.Dictionary)
214-
{
215-
return true;
216-
}
217-
}
218-
return false;
204+
return DictValue!.Count > 0;
219205
}
220206
}
221207

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,11 @@ bool SetOpAttrScalar(Context ctx, SafeEagerOpHandle op,
352352
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
353353
break;
354354
case TF_AttrType.TF_ATTR_SHAPE:
355-
var dims = (value as long[]).ToArray();
355+
long[] dims;
356+
if (value is Shape shape) dims = shape.dims.ToArray();
357+
else if (value is long[] longs) dims = longs.ToArray();
358+
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray();
359+
else dims = ((long[])value).ToArray();
356360
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
357361
status.Check(true);
358362
break;

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,14 @@ TapeTensor TapeTensorFromTensor(Tensor tensor)
137137
{
138138
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
139139
}
140-
Shape tensor_shape = new(dims);
141140

142141
if(status.Code != TF_Code.TF_OK)
143142
{
144143
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null);
145144
}
146145
else
147146
{
147+
Shape tensor_shape = new(dims);
148148
return new TapeTensor(id, dtype, tensor_shape);
149149
}
150150
}
@@ -173,8 +173,12 @@ bool DTypeNeedsHandleData(TF_DataType dtype)
173173
return dtype == dtypes.variant || dtype == dtypes.resource;
174174
}
175175

176-
bool ListContainNone(long[] list)
176+
bool ListContainNone(long[]? list)
177177
{
178+
if(list is null)
179+
{
180+
return true;
181+
}
178182
int len = list.Length;
179183
if(len == 0)
180184
{

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_v
9090
? input_values[0].rank + dim_int
9191
: dim_int % input_values[0].rank;
9292
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray();
93-
var sizes_tensor = constant_op.constant(sizes);
94-
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList();
93+
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList();
9594
}
9695
else if (constant_op.is_constant(concat_dim))
9796
{
@@ -127,7 +126,7 @@ there will be a small number of performance regressions.*/
127126
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
128127
new Tensor[] { tf.constant(1), tf.constant(-1) });
129128
var squeeze_sizes = array_ops.squeeze(slice);
130-
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList();
129+
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList();
131130
}
132131
else
133132
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ public class LSTMArgs : RNNArgs
44
{
55
// TODO: maybe change the `RNNArgs` and implement this class.
66
public bool UnitForgetBias { get; set; }
7-
public float Dropout { get; set; }
8-
public float RecurrentDropout { get; set; }
97
public int Implementation { get; set; }
108
}
119
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class LSTMCellArgs : AutoSerializeLayerArgs
2929
[JsonProperty("unit_forget_bias")]
3030
public bool UnitForgetBias { get; set; } = true;
3131
[JsonProperty("implementation")]
32-
public int Implementation { get; set; } = 2;
32+
public int Implementation { get; set; } = 1;
3333

3434
}
3535
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
77
// TODO(Rinne): add regularizers.
88
public class RNNArgs : AutoSerializeLayerArgs
99
{
10-
[JsonProperty("cell")]
11-
// TODO: the cell should be serialized with `serialize_keras_object`.
12-
public IRnnCell Cell { get; set; } = null;
13-
[JsonProperty("cells")]
14-
public IList<IRnnCell> Cells { get; set; } = null;
15-
1610
[JsonProperty("return_sequences")]
1711
public bool ReturnSequences { get; set; } = false;
1812
[JsonProperty("return_state")]
@@ -25,8 +19,10 @@ public class RNNArgs : AutoSerializeLayerArgs
2519
public bool Unroll { get; set; } = false;
2620
[JsonProperty("time_major")]
2721
public bool TimeMajor { get; set; } = false;
22+
23+
public int? InputDim { get; set; }
24+
public int? InputLength { get; set; }
2825
// TODO: Add `num_constants` and `zero_output_for_mask`.
29-
public Dictionary<string, object> Kwargs { get; set; } = null;
3026

3127
public int Units { get; set; }
3228
public Activation Activation { get; set; }
@@ -38,21 +34,5 @@ public class RNNArgs : AutoSerializeLayerArgs
3834
public float Dropout { get; set; } = .0f;
3935
public bool ZeroOutputForMask { get; set; } = false;
4036
public float RecurrentDropout { get; set; } = .0f;
41-
42-
// kernel_regularizer=None,
43-
// recurrent_regularizer=None,
44-
// bias_regularizer=None,
45-
// activity_regularizer=None,
46-
// kernel_constraint=None,
47-
// recurrent_constraint=None,
48-
// bias_constraint=None,
49-
// dropout=0.,
50-
// recurrent_dropout=0.,
51-
// return_sequences=False,
52-
// return_state=False,
53-
// go_backwards=False,
54-
// stateful=False,
55-
// unroll=False,
56-
// **kwargs):
5737
}
5838
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
55
{
66
public class StackedRNNCellsArgs : LayerArgs
77
{
8-
public IList<IRnnCell> Cells { get; set; }
9-
public Dictionary<string, object> Kwargs { get; set; } = null;
8+
public bool ReverseStateOrder = false;
109
}
1110
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ public ILayer LSTM(int units,
182182
bool unit_forget_bias = true,
183183
float dropout = 0f,
184184
float recurrent_dropout = 0f,
185-
int implementation = 2,
185+
int implementation = 1,
186186
bool return_sequences = false,
187187
bool return_state = false,
188188
bool go_backwards = false,

0 commit comments

Comments
 (0)
0