8000 Add new feature: add LSTMCell and test · SciSharp/TensorFlow.NET@df7d700 · GitHub
[go: up one dir, main page]

Skip to content

Commit df7d700

Browse files
Wanglongzhi2001AsakusaRinne
authored andcommitted
Add new feature: add LSTMCell and test
1 parent 5bfe098 commit df7d700

File tree

10 files changed

+376
-36
lines changed

10 files changed

+376
-36
lines changed
Lines changed: 30 additions & 2 deletions
< 6293 td data-grid-cell-id="diff-5622c76e1a81d0441abd07afc73f2ea1971a3407e1019116f46301df0f3a71c6-6-34-2" data-line-anchor="diff-5622c76e1a81d0441abd07afc73f2ea1971a3407e1019116f46301df0f3a71c6R34" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);padding-right:24px" tabindex="-1" valign="top" class="focusable-grid-cell diff-text-cell right-side-diff-cell left-side">
}
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,35 @@
1-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
1+
using Newtonsoft.Json;
2+
using static Tensorflow.Binding;
3+
4+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
25
{
36
// TODO: complete the implementation
4-
public class LSTMCellArgs : LayerArgs
7+
public class LSTMCellArgs : AutoSerializeLayerArgs
58
{
9+
[JsonProperty("units")]
10+
public int Units { get; set; }
11+
// TODO(Rinne): lack of initialized value of Activation. Merging keras
12+
// into tf.net could resolve it.
13+
[JsonProperty("activation")]
14+
public Activation Activation { get; set; }
15+
[JsonProperty("recurrent_activation")]
16+
public Activation RecurrentActivation { get; set; }
17+
[JsonProperty("use_bias")]
18+
public bool UseBias { get; set; } = true;
19+
[JsonProperty("dropout")]
20+
public float Dropout { get; set; } = .0f;
21+
[JsonProperty("recurrent_dropout")]
22+
public float RecurrentDropout { get; set; } = .0f;
23+
[JsonProperty("kernel_initializer")]
24+
public IInitializer KernelInitializer { get; set; }
25+
[JsonProperty("recurrent_initializer")]
26+
public IInitializer RecurrentInitializer { get; set; }
27+
[JsonProperty("bias_initializer")]
28+
public IInitializer BiasInitializer { get; set; }
29+
[JsonProperty("unit_forget_bias")]
30+
public bool UnitForgetBias { get; set; } = true;
31+
[JsonProperty("implementation")]
32+
public int Implementation { get; set; } = 2;
33+
634
735
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
using Newtonsoft.Json;
2-
using System;
3-
using System.Collections.Generic;
4-
using System.Text;
52

63
namespace Tensorflow.Keras.ArgsDefinition.Rnn
74
{
@@ -25,5 +22,6 @@ public class SimpleRNNCellArgs: AutoSerializeLayerArgs
2522
public IInitializer RecurrentInitializer { get; set; }
2623
[JsonProperty("bias_initializer")]
2724
public IInitializer BiasInitializer { get; set; }
25+
2826
}
2927
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,18 @@ public ILayer LayerNormalization(Axis? axis,
160160
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
161161
public ILayer LeakyReLU(float alpha = 0.3f);
162162

163+
public IRnnCell LSTMCell(int uints,
164+
string activation = "tanh",
165+
string recurrent_activation = "sigmoid",
166+
bool use_bias = true,
167+
string kernel_initializer = "glorot_uniform",
168+
string recurrent_initializer = "orthogonal",
169+
string bias_initializer = "zeros",
170+
bool unit_forget_bias = true,
171+
float dropout = 0f,
172+
float recurrent_dropout = 0f,
173+
int implementation = 2);
174+
163175
public ILayer LSTM(int units,
164176
Activation activation = null,
165177
Activation recurrent_activation = null,

src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
5858

5959
if (num_rows < num_cols)
6060
{
61-
// q = tf.linalg.matrix_transpose(q);
62-
throw new NotImplementedException("");
61+
q = array_ops.matrix_transpose(q);
6362
}
6463

6564
return _gain * tf.reshape(q, shape);

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,49 @@ public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose",
971971
});
972972
}
973973

974+
/// <summary>
975+
/// Transposes last two dimensions of tensor `a`.
976+
/// For example:
977+
/// <code> python
978+
/// x = tf.constant([[1, 2, 3], [4, 5, 6]])
979+
/// tf.matrix_transpose(x) # [[1, 4],
980+
/// # [2, 5],
981+
/// # [3, 6]]
982+
/// </code>
983+
/// Matrix with two batch dimensions.
984+
/// x.shape is [1, 2, 3, 4]
985+
/// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
986+
/// </summary>
987+
/// <param name="a"></param>
988+
/// <param name="name"></param>
989+
/// <param name="conjugate"></param>
990+
/// <returns></returns>
991+
/// <exception cref="ValueError"></exception>
992+
public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false)
993+
{
994+
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
995+
{
996+
var a_shape = a.shape;
997+
var ndims = a.shape.ndim;
998+
Axis perm;
999+
if(ndims != 0)
1000+
{
1001+
if (ndims < 2)
1002+
{
1003+
throw new ValueError("Argument `a` should be a (batch) matrix with rank " +
1004+
$">= 2. Received `a` = {a} with shape: {a_shape}");
1005+
}
1006+
perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray());
1007+
}
1008+
else
1009+
{
1010+
var a_rank = a.rank;
1011+
perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray());
1012+
}
1013+
return transpose(a, perm:perm, conjugate:conjugate);
1014+
});
1015+
}
1016+
9741017
public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1,
9751018
string name = "split")
9761019
{

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ public IRnnCell SimpleRNNCell(
702702
UseBias = use_bias,
703703
KernelInitializer = GetInitializerByName(kernel_initializer),
704704
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
705+
BiasInitializer = GetInitializerByName(bias_initializer),
705706
Dropout = dropout,
706707
RecurrentDropout = recurrent_dropout
707708
});
@@ -786,6 +787,33 @@ public ILayer RNN(
786787
TimeMajor = time_major
787788
});
788789

790+
791+
public IRnnCell LSTMCell(int uints,
792+
string activation = "tanh",
793+
string recurrent_activation = "sigmoid",
794+
bool use_bias = true,
795+
string kernel_initializer = "glorot_uniform",
796+
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed.
797+
string bias_initializer = "zeros",
798+
bool unit_forget_bias = true,
799+
float dropout = 0f,
800+
float recurrent_dropout = 0f,
801+
int implementation = 2)
802+
=> new LSTMCell(new LSTMCellArgs
803+
{
804+
Units = uints,
805+
Activation = keras.activations.GetActivationFromName(activation),
806+
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
807+
UseBias = use_bias,
808+
KernelInitializer = GetInitializerByName(kernel_initializer),
809+
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
810+
BiasInitializer = GetInitializerByName(bias_initializer),
811+
UnitForgetBias = unit_forget_bias,
812+
Dropout = dropout,
813+
RecurrentDropout = recurrent_dropout,
814+
Implementation = implementation
815+
});
816+
789817
/// <summary>
790818
/// Long Short-Term Memory layer - Hochreiter 1997.
791819
/// </summary>

src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public void reset_recurrent_dropout_mask()
4141

4242
}
4343

44-
public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
44+
public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
4545
{
4646
if (dropout == 0f)
4747
return null;
@@ -53,7 +53,7 @@ public void reset_recurrent_dropout_mask()
5353
}
5454

5555
// Get the recurrent dropout mask for RNN cell.
56-
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
56+
public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
5757
{
5858
if (dropout == 0f)
5959
return null;

0 commit comments

Comments
 (0)
0