10000 refactor: Standardize TensorFlowNET.Keras/Losses/* by DevNullx64 · Pull Request #1089 · SciSharp/TensorFlow.NET · GitHub
[go: up one dir, main page]

Skip to content

refactor: Standardize TensorFlowNET.Keras/Losses/* #1089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
namespace Tensorflow.Keras.Losses;

public class BinaryCrossentropy : LossFunctionWrapper, ILossFunc
public class BinaryCrossentropy : LossFunctionWrapper
{
float label_smoothing;

public BinaryCrossentropy(
bool from_logits = false,
float label_smoothing = 0,
Expand All @@ -15,7 +16,6 @@ public BinaryCrossentropy(
this.label_smoothing = label_smoothing;
}


public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
var sum = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);
Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
namespace Tensorflow.Keras.Losses;

public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
public class CategoricalCrossentropy : LossFunctionWrapper
{
float label_smoothing;

public CategoricalCrossentropy(
bool from_logits = false,
float label_smoothing = 0,
Expand All @@ -15,7 +16,6 @@ public CategoricalCrossentropy(
this.label_smoothing = label_smoothing;
}


public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
// Try to adjust the shape so that rank of labels = rank of logits - 1.
Expand Down
40 changes: 17 additions & 23 deletions src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class CosineSimilarity : LossFunctionWrapper
{
public class CosineSimilarity : LossFunctionWrapper, ILossFunc
protected int axis = -1;

public CosineSimilarity(
string reduction = null,
int axis = -1,
string name = null) :
base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
{
protected int axis=-1;
public CosineSimilarity(
string reduction = null,
int axis=-1,
string name = null) :
base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
{
this.axis = axis;
}
this.axis = axis;
}

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis);
Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : constant_op.constant(this.axis));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis: this.axis);
Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis: constant_op.constant(this.axis));
}
}
}
53 changes: 23 additions & 30 deletions src/TensorFlowNET.Keras/Losses/Huber.cs
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class Huber : LossFunctionWrapper
{
public class Huber : LossFunctionWrapper, ILossFunc
protected Tensor delta = tf.Variable(1.0);

public Huber(
string reduction = null,
Tensor delta = null,
string name = null) :
base(reduction: reduction, name: name == null ? "huber" : name)
{
protected Tensor delta = tf.Variable(1.0) ;
public Huber (
string reduction = null,
Tensor delta = null,
string name = null) :
base(reduction: reduction, name: name == null ? "huber" : name)
{
this.delta = delta==null? this.delta: delta;

}
this.delta = delta == null ? this.delta : delta;
}

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
Tensor abs_error = math_ops.abs(error);
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
half * math_ops.pow(error, 2),
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
ops.convert_to_tensor(-1));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
Tensor abs_error = math_ops.abs(error);
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
half * math_ops.pow(e 8000 rror, 2),
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
ops.convert_to_tensor(-1));
}
}
37 changes: 15 additions & 22 deletions src/TensorFlowNET.Keras/Losses/LogCosh.cs
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class LogCosh : LossFunctionWrapper
{
public class LogCosh : LossFunctionWrapper, ILossFunc
{
public LogCosh(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "log_cosh" : name){ }
public LogCosh(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "log_cosh" : name)
{ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor x = y_pred_dispatch - y_true_cast;
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor x = y_pred_dispatch - y_true_cast;

return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
ops.convert_to_tensor(-1));
}
return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
ops.convert_to_tensor(-1));
}
}
}
90 changes: 43 additions & 47 deletions src/TensorFlowNET.Keras/Losses/Loss.cs
A36C
Original file line number Diff line number Diff line change
@@ -1,55 +1,51 @@
using System;
using Tensorflow.Keras.Utils;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Losses
namespace Tensorflow.Keras.Losses;

/// <summary>
/// Loss base class.
/// </summary>
public abstract class Loss : ILossFunc
{
/// <summary>
/// Loss base class.
/// </summary>
public abstract class Loss
protected string reduction;
protected string name;
bool _allow_sum_over_batch_size;
protected bool from_logits = false;
string _name_scope;

public string Reduction => reduction;
public string Name => name;

public Loss(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
{
protected string reduction;
protected string name;
bool _allow_sum_over_batch_size;
protected bool from_logits = false;
string _name_scope;

public string Reduction => reduction;
public string Name => name;
public Loss(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
{
this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
this.name = name;
this.from_logits = from_logits;
_allow_sum_over_batch_size = false;
}
this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
this.name = name;
this.from_logits = from_logits;
_allow_sum_over_batch_size = false;
}

public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
throw new NotImplementedException("");
}
public abstract Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1);

public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
var reduction = GetReduction();
return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
}
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
var reduction = GetReduction();
return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
}

string GetReduction()
{
return reduction switch
{
ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
_ => reduction
};
}

void _set_name_scope()
string GetReduction()
{
return reduction switch
{
_name_scope = name;
}
ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
_ => reduction
};
}

void _set_name_scope()
{
_name_scope = name;
}
}
}
22 changes: 10 additions & 12 deletions src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Losses
namespace Tensorflow.Keras.Losses;

public abstract class LossFunctionWrapper : Loss
{
public class LossFunctionWrapper : Loss
{
public LossFunctionWrapper(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
: base(reduction: reduction,
name: name,
from_logits: from_logits)
{
}
}
public LossFunctionWrapper(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
: base(reduction: reduction,
name: name,
from_logits: from_logits)
{ }
}
29 changes: 11 additions & 18 deletions src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class MeanAbsoluteError : LossFunctionWrapper
{
public class MeanAbsoluteError : LossFunctionWrapper, ILossFunc
{
public MeanAbsoluteError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ }
public MeanAbsoluteError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1));
}
public override Tensor Apply D500 (Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1));
}
}
31 changes: 12 additions & 19 deletions src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class MeanAbsolutePercentageError : LossFunctionWrapper
{
public class MeanAbsolutePercentageError : LossFunctionWrapper, ILossFunc
{
public MeanAbsolutePercentageError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ }
public MeanAbsolutePercentageError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1));
}
}
Loading
0