8000 Merge pull request #1098 from AsakusaRinne/rnn-dev · SciSharp/TensorFlow.NET@81a9d23 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 81a9d23

Browse files
authored
Merge pull request #1098 from AsakusaRinne/rnn-dev
fix: some possible errors of RNN.
2 parents 9da157f + dcaa0f4 commit 81a9d23

File tree

2 files changed

+46
-35
lines changed

2 files changed

+46
-35
lines changed

src/TensorFlowNET.Core/Tensors/Tensors.cs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,12 @@ public Tensor? SingleOrNull
5858
public Tensor this[params string[] slices]
5959
=> this.First()[slices];
6060

61-
public Tensors(Tensor tensor) : base(tensor)
62-
{
63-
64-
}
65-
6661
private Tensors(Nest<Tensor> nested) : base(nested)
6762
{
6863

6964
}
7065

71-
public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
66+
public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors))
7267
{
7368

7469
}
@@ -83,6 +78,22 @@ public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))
8378

8479
}
8580

81+
private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
82+
{
83+
if (tensors.Length == 0)
84+
{
85+
return Nest<Tensor>.Empty;
86+
}
87+
else if(tensors.Length == 1)
88+
{
89+
return new Nest<Tensor>(tensors[0]);
90+
}
91+
else
92+
{
93+
return new Nest<Tensor>(tensors.Select(x => new Nest<Tensor>(x)));
94+
}
95+
}
96+
8697
public bool IsSingle()
8798
{
8899
return Length == 1;
@@ -107,9 +118,14 @@ public void Add(Tensor tensor)
107118
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
108119
Value = null;
109120
}
110-
else
121+
else if(NestType == NestType.List)
122+
{
123+
ListValue!.Add(new Nest<Tensor>(tensor));
124+
}
125+
else //Empty
111126
{
112-
ListValue.Add(new Nest<Tensor>(tensor));
127+
NestType = NestType.Node;
128+
Value = tensor;
113129
}
114130
}
115131

@@ -128,9 +144,14 @@ public void AddRange(IEnumerable<Tensor> tensors)
128144
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
129145
Value = null;
130146
}
131-
else
147+
else if(NestType == NestType.List)
132148
{
133-
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
149+
ListValue!.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
150+
}
151+
else // empty
152+
{
153+
NestType = NestType.List;
154+
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList();
134155
}
135156
}
136157

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -651,13 +651,13 @@ object _get_input_tensor(int time)
651651
states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
652652
if (return_all_outputs)
653653
{
654-
successive_outputs.Add(output);
655-
successive_states.Add(states);
654+
successive_outputs = successive_outputs.MergeWith(output);
655+
successive_outputs = successive_states.MergeWith(states);
656656
}
657657
else
658658
{
659-
successive_outputs = new Tensors { output };
660-
successive_states = new Tensors { states };
659+
successive_outputs = new Tensors(output);
660+
successive_states = new Tensors(states);
661661
}
662662

663663
}
@@ -722,16 +722,11 @@ object _get_input_tensor(int time)
722722
// Get the time(0) input and compute the output for that, the output will
723723
// be used to determine the dtype of output tensor array. Don't read from
724724
// input_ta due to TensorArray clear_after_read default to True.
725-
var inps = new Tensors();
726-
foreach (var inp in flatted_inptus)
727-
{
728-
inps.Add(inp[0]);
729-
}
730-
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();
725+
var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors();
731726

732727
// output_time_zero is used to determine the cell output shape and its
733728
// dtype. the value is discarded.
734-
(output_time_zero, _) = step_function((Tensor)input_time_zero,
729+
(output_time_zero, _) = step_function(input_time_zero,
735730
constants is null ? initial_states : initial_states.MergeWith(constants));
736731

737732
int output_ta_size = return_all_outputs ? time_steps_t : 1;
@@ -816,6 +811,7 @@ object _get_input_tensor(int time)
816811

817812
Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
818813
int parallel_iterations = 32;
814+
new_states = states;
819815
if (masking_fn != null)
820816
{
821817
// Mask for the T output will be base on the output of T - 1. In the
@@ -846,7 +842,7 @@ RNN step function.
846842
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
847843
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
848844
var mask_t = masking_fn(time);
849-
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
845+
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
850846
// mask output
851847
var flat_output = Nest.Flatten(output).ToList();
852848

@@ -871,11 +867,12 @@ RNN step function.
871867
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();
872868

873869
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
874-
// TODO(Wanglongzhi2001),deal with zip output_ta_t
875-
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
870+
output_ta_t = zip(output_ta_t, flat_new_output).Select(item =>
876871
{
877-
output_ta_t.Add(ta.write(ta_index_to_write, Out));
878-
}
872+
var (ta, out_) = item;
873+
return ta.write(ta_index_to_write, out_);
874+
}).ToList();
875+
879876

880877
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
881878

@@ -921,15 +918,8 @@ Tensor _step(Tensor time)
921918
}
922919
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
923920
}
924-
//Tensors outputs = new Tensors();
925-
foreach (var o in output_ta)
926-
{
927-
outputs.Add(o.stack());
928-
}
929-
foreach (var o in outputs)
930-
{
931-
last_output.Add(o[-1]);
932-
}
921+
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors());
922+
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors());
933923
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
934924
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();
935925

0 commit comments

Comments
 (0)
0