8000 Refactor: Change Model evaluate · SciSharp/TensorFlow.NET@02cb239 · GitHub
[go: up one dir, main page]

Skip to content

Commit 02cb239

Browse files
DevNullx64Oceania2018
authored andcommitted
Refactor: Change Model evaluate
IModel.Dictionary<string, float> evaluate(NDArray, NDArray, ...) is now IModel.Dictionary<string, float> evaluate(Tensor, Tensor, ...) Merge Model.Evaluate.test_step_multi_inputs_function(...) and Model.Evaluate.test_function(...) Note: An internal function need to add an explicit cast in Tensor
1 parent f45b35b commit 02cb239

File tree

3 files changed

+7
-13
lines changed
  • src
    • TensorFlowNET.Core/Keras/Engine
    • TensorFlowNET.Keras/Engine

3 files changed

+7
-13
lines changed

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void load_weights(string filepath,
6060
bool skip_mismatch = false,
6161
object options = null);
6262

63-
Dictionary<string, float> evaluate(NDArray x, NDArray y,
63+
Dictionary<string, float> evaluate(Tensor x, Tensor y,
6464
int batch_size = -1,
6565
int verbose = 1,
6666
int steps = -1,

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public partial class Model
2727
/// <param name="use_multiprocessing"></param>
2828
/// <param name="return_dict"></param>
2929
/// <param name="is_val"></param>
30-
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
30+
public Dictionary<string, float> evaluate(Tensor x, Tensor y,
3131
int batch_size = -1,
3232
int verbose = 1,
3333
int steps = -1,
@@ -91,7 +91,7 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
9191
return results;
9292
}
9393

94-
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false)
94+
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
9595
{
9696
var data_handler = new DataHandler(new DataHandlerArgs
9797
{
@@ -119,7 +119,7 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int
119119
foreach (var step in data_handler.steps())
120120
{
121121
callbacks.on_test_batch_begin(step);
122-
logs = test_step_multi_inputs_function(data_handler, iterator);
122+
logs = test_function(data_handler, iterator);
123123
var end_step = step + data_handler.StepIncrement;
124124
if (is_val == false)
125125
callbacks.on_test_batch_end(end_step, logs);
@@ -178,20 +178,14 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
178178
}
179179

180180
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
181-
{
182-
var data = iterator.next();
183-
var outputs = test_step(data_handler, data[0], data[1]);
184-
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
185-
return outputs;
186-
}
187-
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
188181
{
189182
var data = iterator.next();
190183
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
191184
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
192-
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
185+
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
193186
return outputs;
194187
}
188+
195189
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
196190
{
197191
(x, y) = data_handler.DataAdapter.Expand1d(x, y);

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
266266
{
267267
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
268268
// so we need to pass a is_val parameter to stop on_test_batch_end
269-
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
269+
var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
270270
foreach (var log in val_logs)
271271
{
272272
logs["val_" + log.Key] = log.Value;

0 commit comments

Comments
 (0)
0