8000 refacto: Standardize TensorFlowNET.Keras/Losses/ · DevNullx64/TensorFlow.NET@ec8bd2e · GitHub
[go: up one dir, main page]

Skip to content

Commit ec8bd2e

Browse files
committed
refacto: Standardize TensorFlowNET.Keras/Losses/
Smooth implementation
1 parent e9f2cac commit ec8bd2e

13 files changed

+200
-253
lines changed

src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs

< B41A span class="sr-only">Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
namespace Tensorflow.Keras.Losses;
22

3-
public class BinaryCrossentropy : LossFunctionWrapper, ILossFunc
3+
public class BinaryCrossentropy : LossFunctionWrapper
44
{
55
float label_smoothing;
6+
67
public BinaryCrossentropy(
78
bool from_logits = false,
89
float label_smoothing = 0,
@@ -15,7 +16,6 @@ public BinaryCrossentropy(
1516
this.label_smoothing = label_smoothing;
1617
}
1718

18-
1919
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
2020
{
2121
var sum = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);

src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
namespace Tensorflow.Keras.Losses;
22

3-
public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
3+
public class CategoricalCrossentropy : LossFunctionWrapper
44
{
55
float label_smoothing;
6+
67
public CategoricalCrossentropy(
78
bool from_logits = false,
89
float label_smoothing = 0,
@@ -15,7 +16,6 @@ public CategoricalCrossentropy(
1516
this.label_smoothing = label_smoothing;
1617
}
1718

18-
1919
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
2020
{
2121
// Try to adjust the shape so that rank of labels = rank of logits - 1.
Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,22 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
4-
using static Tensorflow.Binding;
5-
using static Tensorflow.KerasApi;
1+
namespace Tensorflow.Keras.Losses;
62

7-
namespace Tensorflow.Keras.Losses
3+
public class CosineSimilarity : LossFunctionWrapper
84
{
9-
public class CosineSimilarity : LossFunctionWrapper, ILossFunc
5+
protected int axis = -1;
6+
7+
public CosineSimilarity(
8+
string reduction = null,
9+
int axis = -1,
10+
string name = null) :
11+
base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
1012
{
11-
protected int axis=-1;
12-
public CosineSimilarity(
13-
string reduction = null,
14-
int axis=-1,
15-
string name = null) :
16-
base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
17-
{
18-
this.axis = axis;
19-
}
13+
this.axis = axis;
14+
}
2015

21-
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
22-
{
23-
Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis);
24-
Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
25-
return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : constant_op.constant(this.axis));
26-
}
16+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
17+
{
18+
Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis: this.axis);
19+
Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
20+
return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis: constant_op.constant(this.axis));
2721
}
28-
}
22+
}
Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,29 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
4-
using static Tensorflow.Binding;
5-
using static Tensorflow.KerasApi;
1+
namespace Tensorflow.Keras.Losses;
62

7-
namespace Tensorflow.Keras.Losses
3+
public class Huber : LossFunctionWrapper
84
{
9-
public class Huber : LossFunctionWrapper, ILossFunc
5+
protected Tensor delta = tf.Variable(1.0);
6+
7+
public Huber(
8+
string reduction = null,
9+
Tensor delta = null,
10+
string name = null) :
11+
base(reduction: reduction, name: name == null ? "huber" : name)
1012
{
11-
protected Tensor delta = tf.Variable(1.0) ;
12-
public Huber (
13-
string reduction = null,
14-
Tensor delta = null,
15-
string name = null) :
16-
base(reduction: reduction, name: name == null ? "huber" : name)
17-
{
18-
this.delta = delta==null? this.delta: delta;
19-
20-
}
13+
this.delta = delta == null ? this.delta : delta;
14+
}
2115

22-
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
23-
{
24-
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
25-
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
26-
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
27-
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
28-
Tensor abs_error = math_ops.abs(error);
29-
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
30-
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
31-
half * math_ops.pow(error, 2),
32-
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
33-
ops.convert_to_tensor(-1));
34-
}
16+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
17+
{
18+
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
19+
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
20+
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
21+
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
22+
Tensor abs_error = math_ops.abs(error);
23+
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
24+
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
25+
half * math_ops.pow(error, 2),
26+
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
27+
ops.convert_to_tensor(-1));
3528
}
3629
}
Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
4-
using Tensorflow.Operations;
5-
using static Tensorflow.Binding;
6-
using static Tensorflow.KerasApi;
1+
namespace Tensorflow.Keras.Losses;
72

8-
namespace Tensorflow.Keras.Losses
3+
public class LogCosh : LossFunctionWrapper
94
{
10-
public class LogCosh : LossFunctionWrapper, ILossFunc
11-
{
12-
public LogCosh(
13-
string reduction = null,
14-
string name = null) :
15-
base(reduction: reduction, name: name == null ? "log_cosh" : name){ }
5+
public LogCosh(
6+
string reduction = null,
7+
string name = null) :
8+
base(reduction: reduction, name: name == null ? "log_cosh" : name)
9+
{ }
1610

17-
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
18-
{
19-
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
20-
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
21-
Tensor x = y_pred_dispatch - y_true_cast;
11+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
12+
{
13+
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
14+
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
15+
Tensor x = y_pred_dispatch - y_true_cast;
2216

23-
return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
24-
ops.convert_to_tensor(-1));
25-
}
17+
return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
18+
ops.convert_to_tensor(-1));
2619
}
27-
}
20+
}
Lines changed: 43 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,51 @@
1-
using System;
2-
using Tensorflow.Keras.Utils;
1+
using Tensorflow.Keras.Utils;
32

4-
namespace Tensorflow.Keras.Losses
3+
namespace Tensorflow.Keras.Losses;
4+
5+
/// <summary>
6+
/// Loss base class.
7+
/// </summary>
8+
public abstract class Loss : ILossFunc
59
{
6-
/// <summary>
7-
/// Loss base class.
8-
/// </summary>
9-
public abstract class Loss
10+
protected string reduction;
11+
protected string name;
12+
bool _allow_sum_over_batch_size;
13+
protected bool from_logits = false;
14+
string _name_scope;
15+
16+
public string Reduction => reduction;
17+
public string Name => name;
18+
19+
public Loss(string reduction = ReductionV2.AUTO,
20+
string name = null,
21+
bool from_logits = false)
1022
{
11-
protected string reduction;
12-
protected string name;
13-
bool _allow_sum_over_batch_size;
14-
protected bool from_logits = false;
15-
string _name_scope;
16-
17-
public string Reduction => reduction;
18-
public string Name => name;
19-
public Loss(string reduction = ReductionV2.AUTO,
20-
string name = null,
21-
bool from_logits = false)
22-
{
23-
this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
24-
this.name = name;
25-
this.from_logits = from_logits;
26-
_allow_sum_over_batch_size = false;
27-
}
23+
this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
24+
this.name = name;
25+
this.from_logits = from_logits;
26+
_allow_sum_over_batch_size = false;
27+
}
2828

29-
public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
30-
{
31-
throw new NotImplementedException("");
32-
}
29+
public abstract Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1);
3330

34-
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
35-
{
36-
var losses = Apply(y_true, y_pred, from_logits: from_logits);
37-
var reduction = GetReduction();
38-
return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
39-
}
31+
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
32 341A +
{
33+
var losses = Apply(y_true, y_pred, from_logits: from_logits);
34+
var reduction = GetReduction();
35+
return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
36+
}
4037

41-
string GetReduction()
42-
{
43-
return reduction switch
44-
{
45-
ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
46-
_ => reduction
47-
};
48-
}
49-
50-
void _set_name_scope()
38+
string GetReduction()
39+
{
40+
return reduction switch
5141
{
52-
_name_scope = name;
53-
}
42+
ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
43+
_ => reduction
44+
};
45+
}
46+
47+
void _set_name_scope()
48+
{
49+
_name_scope = name;
5450
}
55-
}
51+
}
Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
using Tensorflow.Keras.Utils;
22

3-
namespace Tensorflow.Keras.Losses
3+
namespace Tensorflow.Keras.Losses;
4+
5+
public abstract class LossFunctionWrapper : Loss
46
{
5-
public class LossFunctionWrapper : Loss
6-
{
7-
public LossFunctionWrapper(string reduction = ReductionV2.AUTO,
8-
string name = null,
9-
bool from_logits = false)
10-
: base(reduction: reduction,
11-
name: name,
12-
from_logits: from_logits)
13-
{
14-
}
15-
}
7+
public LossFunctionWrapper(string reduction = ReductionV2.AUTO,
8+
string name = null,
9+
bool from_logits = false)
10+
: base(reduction: reduction,
11+
name: name,
12+
from_logits: from_logits)
13+
{ }
1614
}
Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
4-
using static Tensorflow.Binding;
5-
using static Tensorflow.KerasApi;
1+
namespace Tensorflow.Keras.Losses;
62

7-
namespace Tensorflow.Keras.Losses
3+
public class MeanAbsoluteError : LossFunctionWrapper
84
{
9-
public class MeanAbsoluteError : LossFunctionWrapper, ILossFunc
10-
{
11-
public MeanAbsoluteError(
12-
string reduction = null,
13-
string name = null) :
14-
base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ }
5+
public MeanAbsoluteError(
6+
string reduction = null,
7+
string name = null) :
8+
base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ }
159

16-
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
17-
{
18-
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
19-
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
20-
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1));
21-
}
10+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
11+
{
12+
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
13+
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
14+
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1));
2215
}
2316
}
Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Text;
4-
using static Tensorflow.Binding;
5-
using static Tensorflow.KerasApi;
1+
namespace Tensorflow.Keras.Losses;
62

7-
namespace Tensorflow.Keras.Losses
3+
public class MeanAbsolutePercentageError : LossFunctionWrapper
84
{
9-
public class MeanAbsolutePercentageError : LossFunctionWrapper, ILossFunc
10-
{
11-
public MeanAbsolutePercentageError(
12-
string reduction = null,
13-
string name = null) :
14-
base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ }
5+
public MeanAbsolutePercentageError(
6+
string reduction = null,
7+
string name = null) :
8+
base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ }
159

16-
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
17-
{
18-
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
19-
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
20-
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
21-
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1));
22-
}
10+
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
11+
{
12+
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
13+
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
14+
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
15+
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1));
2316
}
2417
}

0 commit comments

Comments
 (0)
0