8000 fix: some possible errors of RNN. by AsakusaRinne · Pull Request #1098 · SciSharp/TensorFlow.NET · GitHub
[go: up one dir, main page]

Skip to content

fix: some possible errors of RNN. #1098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions src/TensorFlowNET.Core/Tensors/Tensors.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,12 @@ public Tensor? SingleOrNull
public Tensor this[params string[] slices]
=> this.First()[slices];

public Tensors(Tensor tensor) : base(tensor)
{

}

private Tensors(Nest<Tensor> nested) : base(nested)
{

}

public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors))
{

}
Expand All @@ -83,6 +78,22 @@ public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))

}

private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
{
if (tensors.Length == 0)
{
return Nest<Tensor>.Empty;
}
else if(tensors.Length == 1)
{
return new Nest<Tensor>(tensors[0]);
}
else
{
return new Nest<Tensor>(tensors.Select(x => new Nest<Tensor>(x)));
}
}

public bool IsSingle()
{
return Length == 1;
Expand All @@ -107,9 +118,14 @@ public void Add(Tensor tensor)
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
Value = null;
}
else
else if(NestType == NestType.List)
{
ListValue!.Add(new Nest<Tensor>(tensor));
}
else //Empty
{
ListValue.Add(new Nest<Tensor>(tensor));
NestType = NestType.Node;
Value = tensor;
}
}

Expand All @@ -128,9 +144,14 @@ public void AddRange(IEnumerable<Tensor> tensors)
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
Value = null;
}
else
else if(NestType == NestType.List)
{
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
ListValue!.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
}
else // empty
{
NestType = NestType.List;
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList();
}
}

Expand Down
40 changes: 15 additions & 25 deletions src/TensorFlowNET.Keras/BackendImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -651,13 +651,13 @@ object _get_input_tensor(int time)
states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(states);
successive_outputs = successive_outputs.MergeWith(output);
successive_outputs = successive_states.MergeWith(states);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { states };
successive_outputs = new Tensors(output);
successive_states = new Tensors(states);
}

}
Expand Down Expand Up @@ -722,16 +722,11 @@ object _get_input_tensor(int time)
// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var inps = new Tensors();
foreach (var inp in flatted_inptus)
{
inps.Add(inp[0]);
}
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();
var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors();

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function((Tensor)input_time_zero,
(output_time_zero, _) = step_function(input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

int output_ta_size = return_all_outputs ? time_steps_t : 1;
Expand Down Expand Up @@ -816,6 +811,7 @@ object _get_input_tensor(int time)

Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
int parallel_iterations = 32;
new_states = states;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
Expand Down Expand Up @@ -846,7 +842,7 @@ RNN step function.
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

Expand All @@ -871,11 +867,12 @@ RNN step function.
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// TODO(Wanglongzhi2001),deal with zip output_ta_t
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
output_ta_t = zip(output_ta_t, flat_new_output).Select(item =>
{
output_ta_t.Add(ta.write(ta_index_to_write, Out));
}
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();


new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();

Expand Down Expand Up @@ -921,15 +918,8 @@ Tensor _step(Tensor time)
}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
}
//Tensors outputs = new Tensors();
foreach (var o in output_ta)
{
outputs.Add(o.stack());
}
foreach (var o in outputs)
{
last_output.Add(o[-1]);
}
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();

Expand Down
0