11
11
using System . Linq . Expressions ;
12
12
using Tensorflow . Keras . Utils ;
13
13
using Tensorflow . Common . Types ;
14
+ using System . Runtime . CompilerServices ;
14
15
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
15
16
16
17
namespace Tensorflow . Keras . Layers . Rnn
@@ -30,7 +31,19 @@ public class RNN : RnnBase
30
31
private int _num_constants ;
31
32
protected IVariableV1 _kernel ;
32
33
protected IVariableV1 _bias ;
33
- protected IRnnCell _cell ;
34
+ private IRnnCell _cell ;
35
+ protected IRnnCell Cell
36
+ {
37
+ get
38
+ {
39
+ return _cell ;
40
+ }
41
+ init
42
+ {
43
+ _cell = value ;
44
+ _self_tracked_trackables . Add ( _cell ) ;
45
+ }
46
+ }
34
47
35
48
public RNN ( RNNArgs args ) : base ( PreConstruct ( args ) )
36
49
{
@@ -40,14 +53,14 @@ public RNN(RNNArgs args) : base(PreConstruct(args))
40
53
// if is StackedRnncell
41
54
if ( args . Cells != null )
42
55
{
43
- _cell = new StackedRNNCells ( new StackedRNNCellsArgs
56
+ Cell = new StackedRNNCells ( new StackedRNNCellsArgs
44
57
{
45
58
Cells = args . Cells
46
59
} ) ;
47
60
}
48
61
else
49
62
{
50
- _cell = args . Cell ;
63
+ Cell = args . Cell ;
51
64
}
52
65
53
66
// get input_shape
@@ -65,7 +78,7 @@ public Tensors States
65
78
if ( _states == null )
66
79
{
67
80
// CHECK(Rinne): check if this is correct.
68
- var nested = _cell . StateSize . MapStructure < Tensor ? > ( x => null ) ;
81
+ var nested = Cell . StateSize . MapStructure < Tensor ? > ( x => null ) ;
69
82
_states = nested . AsNest ( ) . ToTensors ( ) ;
70
83
}
71
84
return _states ;
@@ -83,7 +96,7 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
83
96
}
84
97
85
98
// state_size is a array of ints or a positive integer
86
- var state_size = _cell . StateSize . ToSingleShape ( ) ;
99
+ var state_size = Cell . StateSize . ToSingleShape ( ) ;
87
100
88
101
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
89
102
Func < Shape , Shape > _get_output_shape ;
@@ -110,12 +123,12 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
110
123
return output_shape ;
111
124
} ;
112
125
113
- Type type = _cell . GetType ( ) ;
126
+ Type type = Cell . GetType ( ) ;
114
127
PropertyInfo output_size_info = type . GetProperty ( "output_size" ) ;
115
128
Shape output_shape ;
116
129
if ( output_size_info != null )
117
130
{
118
- output_shape = nest . map_structure ( _get_output_shape , _cell . OutputSize . ToSingleShape ( ) ) ;
131
+ output_shape = nest . map_structure ( _get_output_shape , Cell . OutputSize . ToSingleShape ( ) ) ;
119
132
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
120
133
output_shape = ( output_shape . Length == 1 ? ( int ) output_shape [ 0 ] : output_shape ) ;
121
134
}
@@ -171,7 +184,9 @@ private Tensors compute_mask(Tensors inputs, Tensors mask)
171
184
172
185
public override void build ( KerasShapesWrapper input_shape )
173
186
{
174
- object get_input_spec ( Shape shape )
187
+ input_shape = new KerasShapesWrapper ( input_shape . Shapes [ 0 ] ) ;
188
+
189
+ InputSpec get_input_spec ( Shape shape )
175
190
{
176
191
var input_spec_shape = shape . as_int_list ( ) ;
177
192
@@ -213,10 +228,13 @@ object get_state_spec(Shape shape)
213
228
// numpy inputs.
214
229
215
230
216
- if ( ! _cell . Built )
231
+ if ( Cell is Layer layer && ! layer . Built )
217
232
{
218
- _cell . build ( input_shape ) ;
233
+ layer . build ( input_shape ) ;
234
+ layer . Built = true ;
219
235
}
236
+
237
+ this . built = true ;
220
238
}
221
239
222
240
/// <summary>
@@ -247,10 +265,10 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
247
265
248
266
( inputs , initial_state , constants ) = _process_inputs ( inputs , initial_state , constants ) ;
249
267
250
- _maybe_reset_cell_dropout_mask ( _cell ) ;
251
- if ( _cell is StackedRNNCells )
268
+ _maybe_reset_cell_dropout_mask ( Cell ) ;
269
+ if ( Cell is StackedRNNCells )
252
270
{
253
- var stack_cell = _cell as StackedRNNCells ;
271
+ var stack_cell = Cell as StackedRNNCells ;
254
272
foreach ( IRnnCell cell in stack_cell . Cells )
255
273
{
256
274
_maybe_reset_cell_dropout_mask ( cell ) ;
@@ -300,10 +318,10 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
300
318
bool is_tf_rnn_cell = false ;
301
319
if ( constants is not null )
302
320
{
303
- if ( ! _cell . SupportOptionalArgs )
321
+ if ( ! Cell . SupportOptionalArgs )
304
322
{
305
323
throw new ValueError (
306
- $ "RNN cell { _cell } does not support constants." +
324
+ $ "RNN cell { Cell } does not support constants." +
307
325
$ "Received: constants={ constants } ") ;
308
326
}
309
327
@@ -312,7 +330,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
312
330
constants = new Tensors ( states . TakeLast ( _num_constants ) . ToArray ( ) ) ;
313
331
states = new Tensors ( states . SkipLast ( _num_constants ) . ToArray ( ) ) ;
314
332
states = len ( states ) == 1 && is_tf_rnn_cell ? new Tensors ( states [ 0 ] ) : states ;
315
- var ( output , new_states ) = _cell . Apply ( inputs , states , optional_args : new RnnOptionalArgs ( ) { Constants = constants } ) ;
333
+ var ( output , new_states ) = Cell . Apply ( inputs , states , optional_args : new RnnOptionalArgs ( ) { Constants = constants } ) ;
316
334
return ( output , new_states . Single ) ;
317
335
} ;
318
336
}
@@ -321,7 +339,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
321
339
step = ( inputs , states ) =>
322
340
{
323
341
states = len ( states ) == 1 && is_tf_rnn_cell ? new Tensors ( states . First ( ) ) : states ;
324
- var ( output , new_states ) = _cell . Apply ( inputs , states ) ;
342
+ var ( output , new_states ) = Cell . Apply ( inputs , states ) ;
325
343
return ( output , new_states ) ;
326
344
} ;
327
345
}
@@ -562,7 +580,7 @@ protected Tensors get_initial_state(Tensors inputs)
562
580
var batch_size = _args . TimeMajor ? input_shape [ 1 ] : input_shape [ 0 ] ;
563
581
var dtype = input . dtype ;
564
582
565
- Tensors init_state = _cell . GetInitialState ( null , batch_size , dtype ) ;
583
+ Tensors init_state = Cell . GetInitialState ( null , batch_size , dtype ) ;
566
584
567
585
return init_state ;
568
586
}
0 commit comments