From ec8bd2eb330642d39b62ce1d743ce805932ce08e Mon Sep 17 00:00:00 2001 From: Luc BOLOGNA Date: Thu, 1 Jun 2023 23:50:55 +0200 Subject: [PATCH] refacto: Standardize TensorFlowNET.Keras/Losses/ Smooth implementation --- .../Losses/BinaryCrossentropy.cs | 4 +- .../Losses/CategoricalCrossentropy.cs | 4 +- .../Losses/CosineSimilarity.cs | 40 ++++----- src/TensorFlowNET.Keras/Losses/Huber.cs | 53 +++++------ src/TensorFlowNET.Keras/Losses/LogCosh.cs | 37 ++++---- src/TensorFlowNET.Keras/Losses/Loss.cs | 90 +++++++++---------- .../Losses/LossFunctionWrapper.cs | 22 +++-- .../Losses/MeanAbsoluteError.cs | 29 +++--- .../Losses/MeanAbsolutePercentageError.cs | 31 +++---- .../Losses/MeanSquaredError.cs | 29 +++--- .../Losses/MeanSquaredLogarithmicError.cs | 49 +++++----- .../Losses/SigmoidFocalCrossEntropy.cs | 3 +- .../Losses/SparseCategoricalCrossentropy.cs | 62 ++++++------- 13 files changed, 200 insertions(+), 253 deletions(-) diff --git a/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs index ff7bb6b70..0de50a7ec 100644 --- a/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs @@ -1,8 +1,9 @@ namespace Tensorflow.Keras.Losses; -public class BinaryCrossentropy : LossFunctionWrapper, ILossFunc +public class BinaryCrossentropy : LossFunctionWrapper { float label_smoothing; + public BinaryCrossentropy( bool from_logits = false, float label_smoothing = 0, @@ -15,7 +16,6 @@ public BinaryCrossentropy( this.label_smoothing = label_smoothing; } - public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) { var sum = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits); diff --git a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs index feb052244..1af57b552 100644 --- a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs @@ -1,8 +1,9 @@ namespace Tensorflow.Keras.Losses; -public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc +public class CategoricalCrossentropy : LossFunctionWrapper { float label_smoothing; + public CategoricalCrossentropy( bool from_logits = false, float label_smoothing = 0, @@ -15,7 +16,6 @@ public CategoricalCrossentropy( this.label_smoothing = label_smoothing; } - public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) { // Try to adjust the shape so that rank of labels = rank of logits - 1. diff --git a/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs index 16ab4b799..cf9df8d0d 100644 --- a/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs +++ b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs @@ -1,28 +1,22 @@ -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Losses; -namespace Tensorflow.Keras.Losses +public class CosineSimilarity : LossFunctionWrapper { - public class CosineSimilarity : LossFunctionWrapper, ILossFunc + protected int axis = -1; + + public CosineSimilarity( + string reduction = null, + int axis = -1, + string name = null) : + base(reduction: reduction, name: name == null ? "cosine_similarity" : name) { - protected int axis=-1; - public CosineSimilarity( - string reduction = null, - int axis=-1, - string name = null) : - base(reduction: reduction, name: name == null ? "cosine_similarity" : name) - { - this.axis = axis; - } + this.axis = axis; + } - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) - { - Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis); - Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis); - return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : constant_op.constant(this.axis)); - } + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis: this.axis); + Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis); + return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis: constant_op.constant(this.axis)); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/Huber.cs b/src/TensorFlowNET.Keras/Losses/Huber.cs index 7169ba461..61f006d2b 100644 --- a/src/TensorFlowNET.Keras/Losses/Huber.cs +++ b/src/TensorFlowNET.Keras/Losses/Huber.cs @@ -1,36 +1,29 @@ -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Losses; -namespace Tensorflow.Keras.Losses +public class Huber : LossFunctionWrapper { - public class Huber : LossFunctionWrapper, ILossFunc + protected Tensor delta = tf.Variable(1.0); + + public Huber( + string reduction = null, + Tensor delta = null, + string name = null) : + base(reduction: reduction, name: name == null ? "huber" : name) { - protected Tensor delta = tf.Variable(1.0) ; - public Huber ( - string reduction = null, - Tensor delta = null, - string name = null) : - base(reduction: reduction, name: name == null ? "huber" : name) - { - this.delta = delta==null? this.delta: delta; - - } + this.delta = delta == null ? this.delta : delta; + } - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) - { - Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT); - Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT); - Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT); - Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); - Tensor abs_error = math_ops.abs(error); - Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); - return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, - half * math_ops.pow(error, 2), - half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), - ops.convert_to_tensor(-1)); - } + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT); + Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT); + Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT); + Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); + Tensor abs_error = math_ops.abs(error); + Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); + return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, + half * math_ops.pow(error, 2), + half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), + ops.convert_to_tensor(-1)); } } diff --git a/src/TensorFlowNET.Keras/Losses/LogCosh.cs b/src/TensorFlowNET.Keras/Losses/LogCosh.cs index 7cfd4f67b..0c7a9b6e2 100644 --- a/src/TensorFlowNET.Keras/Losses/LogCosh.cs +++ b/src/TensorFlowNET.Keras/Losses/LogCosh.cs @@ -1,27 +1,20 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow.Operations; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Losses; -namespace Tensorflow.Keras.Losses +public class LogCosh : LossFunctionWrapper { - public class LogCosh : LossFunctionWrapper, ILossFunc - { - public LogCosh( - string reduction = null, - string name = null) : - base(reduction: reduction, name: name == null ? "log_cosh" : name){ } + public LogCosh( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "log_cosh" : name) + { } - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) - { - Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); - Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); - Tensor x = y_pred_dispatch - y_true_cast; + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor x = y_pred_dispatch - y_true_cast; - return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), - ops.convert_to_tensor(-1)); - } + return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), + ops.convert_to_tensor(-1)); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs index 77bf7e1dc..ce77f6d63 100644 --- a/src/TensorFlowNET.Keras/Losses/Loss.cs +++ b/src/TensorFlowNET.Keras/Losses/Loss.cs @@ -1,55 +1,51 @@ -using System; -using Tensorflow.Keras.Utils; +using Tensorflow.Keras.Utils; -namespace Tensorflow.Keras.Losses +namespace Tensorflow.Keras.Losses; + +/// +/// Loss base class. +/// +public abstract class Loss : ILossFunc { - /// - /// Loss base class. - /// - public abstract class Loss + protected string reduction; + protected string name; + bool _allow_sum_over_batch_size; + protected bool from_logits = false; + string _name_scope; + + public string Reduction => reduction; + public string Name => name; + + public Loss(string reduction = ReductionV2.AUTO, + string name = null, + bool from_logits = false) { - protected string reduction; - protected string name; - bool _allow_sum_over_batch_size; - protected bool from_logits = false; - string _name_scope; - - public string Reduction => reduction; - public string Name => name; - public Loss(string reduction = ReductionV2.AUTO, - string name = null, - bool from_logits = false) - { - this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction; - this.name = name; - this.from_logits = from_logits; - _allow_sum_over_batch_size = false; - } + this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction; + this.name = name; + this.from_logits = from_logits; + _allow_sum_over_batch_size = false; + } - public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) - { - throw new NotImplementedException(""); - } + public abstract Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1); - public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) - { - var losses = Apply(y_true, y_pred, from_logits: from_logits); - var reduction = GetReduction(); - return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight); - } + public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + var losses = Apply(y_true, y_pred, from_logits: from_logits); + var reduction = GetReduction(); + return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight); + } - string GetReduction() - { - return reduction switch - { - ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE, - _ => reduction - }; - } - - void _set_name_scope() + string GetReduction() + { + return reduction switch { - _name_scope = name; - } + ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE, + _ => reduction + }; + } + + void _set_name_scope() + { + _name_scope = name; } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs b/src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs index 758b46f4b..f4ee2b346 100644 --- a/src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs +++ b/src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs @@ -1,16 +1,14 @@ using Tensorflow.Keras.Utils; -namespace Tensorflow.Keras.Losses +namespace Tensorflow.Keras.Losses; + +public abstract class LossFunctionWrapper : Loss { - public class LossFunctionWrapper : Loss - { - public LossFunctionWrapper(string reduction = ReductionV2.AUTO, - string name = null, - bool from_logits = false) - : base(reduction: reduction, - name: name, - from_logits: from_logits) - { - } - } + public LossFunctionWrapper(string reduction = ReductionV2.AUTO, + string name = null, + bool from_logits = false) + : base(reduction: reduction, + name: name, + from_logits: from_logits) + { } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs index c203bc5ad..19476a68a 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs @@ -1,23 +1,16 @@ -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Losses; -namespace Tensorflow.Keras.Losses +public class MeanAbsoluteError : LossFunctionWrapper { - public class MeanAbsoluteError : LossFunctionWrapper, ILossFunc - { - public MeanAbsoluteError( - string reduction = null, - string name = null) : - base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ } + public MeanAbsoluteError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ } - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) - { - Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); - Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); - return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1)); - } + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1)); } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs index 8dcaa1bcc..226c4237a 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs @@ -1,24 +1,17 @@ -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Losses; -namespace Tensorflow.Keras.Losses +public class MeanAbsolutePercentageError : LossFunctionWrapper { - public class MeanAbsolutePercentageError : LossFunctionWrapper, ILossFunc - { - public MeanAbsolutePercentageError( - string reduction = null, - string name = null) : - base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ } + public MeanAbsolutePercentageError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ } - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) - { - Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); - Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); - Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); - return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1)); - } + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); + return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1)); } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs index 73cddef14..a937c1963 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs @@ -1,23 +1,16 @@ -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Losses; -namespace Tensorflow.Keras.Losses +public class MeanSquaredError : LossFunctionWrapper { - public class MeanSquaredError : LossFunctionWrapper, ILossFunc - { - public MeanSquaredError( - string reduction = null, - string name = null) : - base(reduction: reduction, name: name==null? "mean_squared_error" : name){ } + public MeanSquaredError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name==null? "mean_squared_error" : name){ } - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) - { - Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); - Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); - return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), ops.convert_to_tensor(-1)); - } + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), ops.convert_to_tensor(-1)); } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs index e29659218..0a4e7d3c5 100644 --- a/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs @@ -1,33 +1,28 @@ -using System; -using System.Collections.Generic; -using System.Text; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; +namespace Tensorflow.Keras.Losses; -namespace Tensorflow.Keras.Losses +public class MeanSquaredLogarithmicError : LossFunctionWrapper { - public class MeanSquaredLogarithmicError : LossFunctionWrapper, ILossFunc - { - public MeanSquaredLogarithmicError( - string reduction = null, - string name = null) : - base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name){ } - + public MeanSquaredLogarithmicError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name) + { } - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor first_log = null, second_log = null; + if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) + { + first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0); + second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7) + 1.0); + } + else { - Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); - Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); - Tensor first_log=null, second_log=null; - if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) { - first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0); - second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7) + 1.0); - } - else { - first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f); - second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7f) + 1.0f); - } - return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), ops.convert_to_tensor(-1)); + first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f); + second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7f) + 1.0f); } + return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), ops.convert_to_tensor(-1)); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs index 7ac3fa0bb..ec6dcedf8 100644 --- a/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs +++ b/src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs @@ -2,7 +2,7 @@ namespace Tensorflow.Keras.Losses; -public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc +public class SigmoidFocalCrossEntropy : LossFunctionWrapper { float _alpha; float _gamma; @@ -20,7 +20,6 @@ public SigmoidFocalCrossEntropy(bool from_logits = false, _gamma = gamma; } - public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) { y_true = tf.cast(y_true, dtype: y_pred.dtype); diff --git a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs index 4e2790ab1..17ce2d30b 100644 --- a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs @@ -1,41 +1,41 @@ using static Tensorflow.Binding; -namespace Tensorflow.Keras.Losses +namespace Tensorflow.Keras.Losses; + +public class SparseCategoricalCrossentropy : LossFunctionWrapper { - public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc + private bool _from_logits = false; + + public SparseCategoricalCrossentropy( + bool from_logits = false, + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) + { + _from_logits = from_logits; + } + + public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) { - private bool _from_logits = false; - public SparseCategoricalCrossentropy( - bool from_logits = false, - string reduction = null, - string name = null) : - base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) + target = tf.cast(target, dtype: TF_DataType.TF_INT64); + + if (!_from_logits) { - _from_logits = from_logits; + var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); + output = tf.clip_by_value(output, epsilon, 1 - epsilon); + output = tf.log(output); } - public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) + // Try to adjust the shape so that rank of labels = rank of logits - 1. + var output_shape = array_ops.shape_v2(output); + var output_rank = output.shape.ndim; + var target_rank = target.shape.ndim; + var update_shape = target_rank != output_rank - 1; + if (update_shape) { - target = tf.cast(target, dtype: TF_DataType.TF_INT64); - - if (!_from_logits) - { - var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); - output = tf.clip_by_value(output, epsilon, 1 - epsilon); - output = tf.log(output); - } - - // Try to adjust the shape so that rank of labels = rank of logits - 1. - var output_shape = array_ops.shape_v2(output); - var output_rank = output.shape.ndim; - var target_rank = target.shape.ndim; - var update_shape = target_rank != output_rank - 1; - if (update_shape) - { - target = array_ops.reshape(target, new int[] { -1 }); - output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() }); - } - return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output); + target = array_ops.reshape(target, new int[] { -1 }); + output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() }); } + return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output); } -} +} \ No newline at end of file