8000 Update PredictInternational on Model.Predict.cs · SciSharp/TensorFlow.NET@e9f2cac · GitHub
[go: up one dir, main page]

Skip to content

Commit e9f2cac

Browse files
DevNullx64Oceania2018
authored andcommitted
Update PredictInternational on Model.Predict.cs
Fix issue if data_handler.steps() > 1
1 parent 4bca319 commit e9f2cac

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ Tensors PredictInternal(DataHandler data_handler, int verbose)
9999
}
100100
else
101101
{
102-
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0);
102+
for (int i = 0; i < batch_outputs.Length; i++)
103+
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
103104
}
104105

105106
var end_step = step + data_handler.StepIncrement;
@@ -116,7 +117,7 @@ Tensors run_predict_step(OwnedIterator iterator)
116117
{
117118
var data = iterator.next();
118119
var outputs = predict_step(data);
119-
tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1));
120+
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _predict_counter.assign_add(1));
120121
return outputs;
121122
}
122123

0 commit comments

Comments
 (0)
0