File tree 11 files changed +428
-18
lines changed
TensorFlowNET.Keras/Layers
test/TensorFlowNET.Keras.UnitTest/Layers 11 files changed +428
-18
lines changed Original file line number Diff line number Diff line change
1
+ using Newtonsoft . Json ;
2
+ using System ;
3
+ using System . Collections . Generic ;
4
+ using System . Text ;
5
+ using Tensorflow . NumPy ;
6
+
7
+ namespace Tensorflow . Keras . ArgsDefinition
8
+ {
9
+ public class BidirectionalArgs : AutoSerializeLayerArgs
10
+ {
11
+ [ JsonProperty ( "layer" ) ]
12
+ public ILayer Layer { get ; set ; }
13
+ [ JsonProperty ( "merge_mode" ) ]
14
+ public string ? MergeMode { get ; set ; }
15
+ [ JsonProperty ( "backward_layer" ) ]
16
+ public ILayer BackwardLayer { get ; set ; }
17
+ public NDArray Weights { get ; set ; }
18
+ }
19
+
20
+ }
Original file line number Diff line number Diff line change @@ -5,5 +5,10 @@ public class LSTMArgs : RNNArgs
5
5
// TODO: maybe change the `RNNArgs` and implement this class.
6
6
public bool UnitForgetBias { get ; set ; }
7
7
public int Implementation { get ; set ; }
8
+
9
+ public LSTMArgs Clone ( )
10
+ {
11
+ return ( LSTMArgs ) MemberwiseClone ( ) ;
12
+ }
8
13
}
9
14
}
Original file line number Diff line number Diff line change @@ -40,5 +40,10 @@ public class RNNArgs : AutoSerializeLayerArgs
40
40
public bool ZeroOutputForMask { get ; set ; } = false ;
41
41
[ JsonProperty ( "recurrent_dropout" ) ]
42
42
public float RecurrentDropout { get ; set ; } = .0f ;
43
+
44
+ public RNNArgs Clone ( )
45
+ {
46
+ return ( RNNArgs ) MemberwiseClone ( ) ;
47
+ }
43
48
}
44
49
}
Original file line number Diff line number Diff line change
1
+ using Newtonsoft . Json ;
2
+ using System ;
3
+ using System . Collections . Generic ;
4
+ using System . Runtime . CompilerServices ;
5
+ using System . Text ;
6
+
7
+
8
+ namespace Tensorflow . Keras . ArgsDefinition
9
+ {
10
+ public class WrapperArgs : AutoSerializeLayerArgs
11
+ {
12
+ [ JsonProperty ( "layer" ) ]
13
+ public ILayer Layer { get ; set ; }
14
+
15
+ public WrapperArgs ( ILayer layer )
16
+ {
17
+ Layer = layer ;
18
+ }
19
+
20
+ public static implicit operator WrapperArgs ( BidirectionalArgs args )
21
+ => new WrapperArgs ( args . Layer ) ;
22
+ }
23
+
24
+ }
Original file line number Diff line number Diff line change @@ -258,7 +258,19 @@ public IRnnCell GRUCell(
258
258
float dropout = 0f ,
259
259
float recurrent_dropout = 0f ,
260
260
bool reset_after = true ) ;
261
-
261
+
262
+ /// <summary>
263
+ /// Bidirectional wrapper for RNNs.
264
+ /// </summary>
265
+ /// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param>
266
+ /// automatically.</param>
267
+ /// <returns></returns>
268
+ public ILayer Bidirectional (
269
+ ILayer layer ,
270
+ string merge_mode = "concat" ,
271
+ NDArray weights = null ,
272
+ ILayer backward_layer = null ) ;
273
+
262
274
public ILayer Subtract ( ) ;
263
275
}
264
276
}
Original file line number Diff line number Diff line change @@ -908,6 +908,20 @@ public IRnnCell GRUCell(
908
908
ResetAfter = reset_after
909
909
} ) ;
910
910
911
+ public ILayer Bidirectional (
912
+ ILayer layer ,
913
+ string merge_mode = "concat" ,
914
+ NDArray weights = null ,
915
+ ILayer backward_layer = null )
916
+ => new Bidirectional ( new BidirectionalArgs
917
+ {
918
+ Layer = layer ,
919
+ MergeMode = merge_mode ,
920
+ Weights = weights ,
921
+ BackwardLayer = backward_layer
922
+ } ) ;
923
+
924
+
911
925
/// <summary>
912
926
///
913
927
/// </summary>
Original file line number Diff line number Diff line change
1
+ using System ;
2
+ using System . Collections . Generic ;
3
+ using System . Diagnostics ;
4
+ using System . Text ;
5
+ using Tensorflow . Keras . ArgsDefinition ;
6
+ using Tensorflow . Keras . Saving ;
7
+
8
+ namespace Tensorflow . Keras . Layers
9
+ {
10
+ /// <summary>
11
+ /// Abstract wrapper base class. Wrappers take another layer and augment it in various ways.
12
+ /// Do not use this class as a layer, it is only an abstract base class.
13
+ /// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers.
14
+ /// </summary>
15
+ public abstract class Wrapper : Layer
16
+ {
17
+ public ILayer _layer ;
18
+ public Wrapper ( WrapperArgs args ) : base ( args )
19
+ {
20
+ _layer = args . Layer ;
21
+ }
22
+
23
+ public virtual void Build ( KerasShapesWrapper input_shape )
24
+ {
25
+ if ( ! _layer . Built )
26
+ {
27
+ _layer . build ( input_shape ) ;
28
+ }
29
+ built = true ;
30
+ }
31
+
32
+ }
33
+ }
You can’t perform that action at this time.
0 commit comments