8000 Add additional categorical cross entropy logic · SciSharp/TensorFlow.NET@47d0f82 · GitHub
[go: up one dir, main page]

Skip to content

Commit 47d0f82

Browse files
HallupaOceania2018
authored andcommitted
Add additional categorical cross entropy logic
1 parent 1478a2c commit 47d0f82

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,19 @@ public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_l
260260
if (from_logits)
261261
return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis);
262262

263-
throw new NotImplementedException("");
263+
if (output.op != null && output.op.type == "Softmax")
264+
{
265+
if (output.op.inputs.Length != 1) throw new ApplicationException();
266+
var o = output = output.op.inputs[0];
267+
return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: o, axis: axis);
268+
}
269+
270+
// scale preds so that the class probas of each sample sum to 1
271+
output = output / math_ops.reduce_sum(output, new Axis(axis), true);
272+
// Compute cross entropy from probabilities.
273+
var epsilon_ = constant_op.constant(epsilon(), output.dtype.as_base_dtype());
274+
output = clip_ops.clip_by_value(output, epsilon_, 1.0 - epsilon_);
275+
return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis));
264276
}
265277

266278
/// <summary>

0 commit comments

Comments
 (0)
0