8000 Merge pull request #1149 from Wanglongzhi2001/master · SciSharp/TensorFlow.NET@fa2d2dc · GitHub
[go: up one dir, main page]

Skip to content

Commit fa2d2dc

Browse files
authored
Merge pull request #1149 from Wanglongzhi2001/master
feat: add Bidirectional layer
2 parents fa5d19d + 0c9437a commit fa2d2dc

File tree

11 files changed

+428
-18
lines changed

11 files changed

+428
-18
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Newtonsoft.Json;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow.NumPy;
6+
7+
namespace Tensorflow.Keras.ArgsDefinition
8+
{
9+
public class BidirectionalArgs : AutoSerializeLayerArgs
10+
{
11+
[JsonProperty("layer")]
12+
public ILayer Layer { get; set; }
13+
[JsonProperty("merge_mode")]
14+
public string? MergeMode { get; set; }
15+
[JsonProperty("backward_layer")]
16+
public ILayer BackwardLayer { get; set; }
17+
public NDArray Weights { get; set; }
18+
}
19+
20+
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,10 @@ public class LSTMArgs : RNNArgs
55
// TODO: maybe change the `RNNArgs` and implement this class.
66
public bool UnitForgetBias { get; set; }
77
public int Implementation { get; set; }
8+
9+
public LSTMArgs Clone()
10+
{
11+
return (LSTMArgs)MemberwiseClone();
12+
}
813
}
914
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,10 @@ public class RNNArgs : AutoSerializeLayerArgs
4040
public bool ZeroOutputForMask { get; set; } = false;
4141
[JsonProperty("recurrent_dropout")]
4242
public float RecurrentDropout { get; set; } = .0f;
43+
44+
public RNNArgs Clone()
45+
{
46+
return (RNNArgs)MemberwiseClone();
47+
}
4348
}
4449
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using Newtonsoft.Json;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Runtime.CompilerServices;
5+
using System.Text;
6+
7+
8+
namespace Tensorflow.Keras.ArgsDefinition
9+
{
10+
public class WrapperArgs : AutoSerializeLayerArgs
11+
{
12+
[JsonProperty("layer")]
13+
public ILayer Layer { get; set; }
14+
15+
public WrapperArgs(ILayer layer)
16+
{
17+
Layer = layer;
18+
}
19+
20+
public static implicit operator WrapperArgs(BidirectionalArgs args)
21+
=> new WrapperArgs(args.Layer);
22+
}
23+
24+
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,19 @@ public IRnnCell GRUCell(
258258
float dropout = 0f,
259259
float recurrent_dropout = 0f,
260260
bool reset_after = true);
261-
261+
262+
/// <summary>
263+
/// Bidirectional wrapper for RNNs.
264+
/// </summary>
265+
/// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param>
266+
/// automatically.</param>
267+
/// <returns></returns>
268+
public ILayer Bidirectional(
269+
ILayer layer,
270+
string merge_mode = "concat",
271+
NDArray weights = null,
272+
ILayer backward_layer = null);
273+
262274
public ILayer Subtract();
263275
}
264276
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,20 @@ public IRnnCell GRUCell(
908908
ResetAfter = reset_after
909909
});
910910

911+
public ILayer Bidirectional(
912+
ILayer layer,
913+
string merge_mode = "concat",
914+
NDArray weights = null,
915+
ILayer backward_layer = null)
916+
=> new Bidirectional(new BidirectionalArgs
917+
{
918+
Layer = layer,
919+
MergeMode = merge_mode,
920+
Weights = weights,
921+
BackwardLayer = backward_layer
922+
});
923+
924+
911925
/// <summary>
912926
///
913927
/// </summary>
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Text;
5+
using Tensorflow.Keras.ArgsDefinition;
6+
using Tensorflow.Keras.Saving;
7+
8+
namespace Tensorflow.Keras.Layers
9+
{
10+
/// <summary>
11+
/// Abstract wrapper base class. Wrappers take another layer and augment it in various ways.
12+
/// Do not use this class as a layer, it is only an abstract base class.
13+
/// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
14+
/// </summary>
15+
public abstract class Wrapper: Layer
16+
{
17+
public ILayer _layer;
18+
public Wrapper(WrapperArgs args):base(args)
19+
{
20+
_layer = args.Layer;
21+
}
22+
23+
public virtual void Build(KerasShapesWrapper input_shape)
24+
{
25+
if (!_layer.Built)
26+
{
27+
_layer.build(input_shape);
28+
}
29+
built = true;
30+
}
31+
32+
}
33+
}

0 commit comments

Comments
 (0)
0