8000 Added Cropping and Permute (+test) · SciSharp/TensorFlow.NET@a7099db · GitHub
[go: up one dir, main page]

Skip to content

Commit a7099db

Browse files
KevinOfCathayOceania2018
authored andcommitted
Added Cropping and Permute (+test)
1 parent 47d0f82 commit a7099db

File tree

15 files changed

+604
-75
lines changed

15 files changed

+604
-75
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition {
6+
public class ELUArgs : LayerArgs {
7+
public float Alpha { get; set; } = 0.1f;
8+
}
9+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Tensorflow.NumPy;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition {
4+
public class Cropping2DArgs : LayerArgs {
5+
/// <summary>
6+
/// channel last: (b, h, w, c)
7+
/// channels_first: (b, c, h, w)
8+
/// </summary>
9+
public enum DataFormat { channels_first = 0, channels_last = 1 }
10+
/// <summary>
11+
/// Accept: int[1][2], int[1][1], int[2][2]
12+
/// </summary>
13+
public NDArray cropping { get; set; }
14+
public DataFormat data_format { get; set; } = DataFormat.channels_last;
15+
}
16+
}
Lines changed: 16 addition F438 s & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Tensorflow.NumPy;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition {
4+
public class Cropping3DArgs : LayerArgs {
5+
/// <summary>
6+
/// channel last: (b, h, w, c)
7+
/// channels_first: (b, c, h, w)
8+
/// </summary>
9+
public enum DataFormat { channels_first = 0, channels_last = 1 }
10+
/// <summary>
11+
/// Accept: int[1][3], int[1][1], int[3][2]
12+
/// </summary>
13+
public NDArray cropping { get; set; }
14+
public DataFormat data_format { get; set; } = DataFormat.channels_last;
15+
}
16+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Tensorflow.NumPy;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition {
4+
public class CroppingArgs : LayerArgs {
5+
/// <summary>
6+
/// Accept length 1 or 2
7+
/// </summary>
8+
public NDArray cropping { get; set; }
9+
}
10+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
namespace Tensorflow.Keras.ArgsDefinition {
2+
public class PermuteArgs : LayerArgs {
3+
public int[] dims { get; set; }
4+
}
5+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Keras.Layers {
9+
/// <summary>
10+
/// ELU Layer:
11+
/// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere
12+
/// </summary>
13+
public class ELU : Layer {
14+
ELUArgs args;
15+
float alpha => args.Alpha;
16+
public ELU ( ELUArgs args ) : base(args) {
17+
this.args = args;
18+
}
19+
protected override void build ( Tensors inputs ) {
20+
if ( alpha < 0f ) {
21+
throw new ValueError("Alpha must be a number greater than 0.");
22+
}
23+
built = true;
24+
}
25+
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
26+
Tensor output = inputs;
27+
if ( alpha != 1f ) {
28+
output = tf.where(output > 0f, output, alpha * (tf.exp(output) - 1f));
29+
}
30+
return output;
31+
}
32+
33+
public override Shape ComputeOutputShape ( Shape input_shape ) {
34+
return input_shape;
35+
}
36+
}
37+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.Keras.Layers {
9+
/// <summary>
10+
/// SELU Layer:
11+
/// similar to ELU, but has pre-defined alpha and scale
12+
/// </summary>
13+
public class SELU : Layer {
14+
protected const float alpha = 1.67326324f, scale = 1.05070098f;
15+
public SELU ( LayerArgs args ) : base(args) {
16+
// SELU has no arguments
17+
}
18+
protected override void build ( Tensors inputs ) {
19+
if ( alpha < 0f ) {
20+
throw new ValueError("Alpha must be a number greater than 0.");
21+
}
22+
built = true;
23+
}
24+
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
25+
Tensor output = inputs;
26+
return tf.where(output > 0f, scale * output, scale * alpha * (tf.exp(output) - 1f));
27+
}
28+
public override Shape ComputeOutputShape ( Shape input_shape ) {
29+
return input_shape;
30+
}
31+
}
32+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using Tensorflow.Keras.ArgsDefinition;
2+
using Tensorflow.Keras.Engine;
3+
4+
namespace Tensorflow.Keras.Layers {
5+
public class Cropping1D : Layer {
6+
CroppingArgs args;
7+
public Cropping1D ( CroppingArgs args ) : base(args) {
8+
this.args = args;
9+
}
10+
11+
protected override void build ( Tensors inputs ) {
12+
if ( args.cropping.rank != 1 ) {
13+
// throw an ValueError exception
14+
throw new ValueError("");
15+
}
16+
else if ( args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1 ) {
17+
throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
18+
}
19+
built = true;
20+
}
21+
22+
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
23+
Tensor output = inputs;
24+
if ( output.rank != 3 ) {
25+
// throw an ValueError exception
26+
throw new ValueError("Expected dim=3, found dim=" + output.rank);
27+
}
28+
if ( args.cropping.shape[0] == 1 ) {
29+
int crop_start = args.cropping[0];
30+
output = output[new Slice(), new Slice(crop_start, ( int ) output.shape[1] - crop_start), new Slice()];
31+
}
32+
else {
33+
int crop_start = args.cropping[0], crop_end = args.cropping[1];
34+
output = output[new Slice(), new Slice(crop_start, ( int ) (output.shape[1]) - crop_end), new Slice()];
35+
}
36+
return output;
37+
}
38+
39+
public override Shape ComputeOutputShape ( Shape input_shape ) {
40+
if ( args.cropping.shape[0] == 1 ) {
41+
int crop = args.cropping[0];
42+
return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop * 2), ( int ) (input_shape[2]));
43+
}
44+
else {
45+
int crop_start = args.cropping[0], crop_end = args.cropping[1];
46+
return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop_start - crop_end), ( int ) (input_shape[2]));
47+
}
48+
}
49+
}
50+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
using Tensorflow.Keras.ArgsDefinition;
2+
using Tensorflow.Keras.Engine;
3+
4+
namespace Tensorflow.Keras.Layers {
5+
/// <summary>
6+
/// Crop the input along axis 1 and 2.
7+
/// <para> For example: </para>
8+
/// <para> shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5) </para>
9+
/// </summary>
10+
public class Cropping2D : Layer {
11+
Cropping2DArgs args;
12+
public Cropping2D ( Cropping2DArgs args ) : base(args) {
13+
this.args = args;
14+
}
15+
protected override void build ( Tensors inputs ) {
16+
built = true;
17+
}
18+
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
19+
Tensor output = inputs;
20+
if ( output.rank != 4 ) {
21+
// throw an ValueError exception
22+
throw new ValueError("Expected dim=4, found dim=" + output.rank);
23+
}
24+
if ( args.cropping.shape == new Shape(1) ) {
25+
int crop = args.cropping[0];
26+
if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
27+
output = output[new Slice(),
28+
new Slice(crop, ( int ) output.shape[1] - crop),
29+
new Slice(crop, ( int ) output.shape[2] - crop),
30+
new Slice()];
31+
}
32+
else {
33+
output = output[new Slice(),
34+
new Slice(),
35+
new Slice(crop, ( int ) output.shape[2] - crop),
36+
new Slice(crop, ( int ) output.shape[3] - crop)];
37+
}
38+
}
39+
// a tuple of 2 integers
40+
else if ( args.cropping.shape == new Shape(2) ) {
41+
int crop_1 = args.cropping[0];
42+
int crop_2 = args.cropping[1];
43+
if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
44+
output = output[new Slice(),
45+
new Slice(crop_1, ( int ) output.shape[1] - crop_1),
46+
new Slice(crop_2, ( int ) output.shape[2] - crop_2),
47+
new Slice()];
48+
}
49+
else {
50+
output = output[new Slice(),
51+
new Slice(),
52+
new Slice(crop_1, ( int ) output.shape[2] - crop_1),
53+
new Slice(crop_2, ( int ) output.shape[3] - crop_2)];
54+
}
55+
}
56+
else if ( args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2 ) {
57+
int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1];
58+
int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1];
59+
if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
60+
output = output[new Slice(),
61+
new Slice(x_start, ( int ) output.shape[1] - x_end),
62+
new Slice(y_start, ( int ) output.shape[2] - y_end),
63+
new Slice()];
64+
}
65+
else {
66+
output = output[new Slice(),
67+
new Slice(),
68+
new Slice(x_start, ( int ) output.shape[2] - x_end),
69+
new Slice(y_start, ( int ) output.shape[3] - y_end)
70+
];
71+
}
72+
}
73+
return output;
74+
}
75+
76+
public override Shape ComputeOutputShape ( Shape input_shape ) {
77+
if ( args.cropping.shape == new Shape(1) ) {
78+
int crop = args.cropping[0];
79+
if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
80+
return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop * 2, ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3]);
81+
}
82+
else {
83+
return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2);
84+
}
85+
}
86+
// a tuple of 2 integers
87+
else if ( args.cropping.shape == new Shape(2) ) {
88+
int crop_1 = args.cropping[0], crop_2 = args.cropping[1];
89+
if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
90+
return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1 * 2, ( int ) input_shape[2] - crop_2 * 2, ( int ) input_shape[3]);
91+
}
92+
else {
93+
return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop_1 * 2, ( int ) input_shape[3] - crop_2 * 2);
94+
}
95+
}
96+
else if ( args.cropping.shape == new Shape(2, 2) ) {
97+
int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1];
98+
int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1];
99+
if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
100+
return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1_start - crop_1_end,
101+
( int ) input_shape[2] - crop_2_start - crop_2_end, ( int ) input_shape[3]);
102+
}
103+
else {
104+
return new Shape(( int ) input_shape[0], ( int ) input_shape[1],
105+
( int ) input_shape[2] - crop_1_start - crop_1_end, ( int ) input_shape[3] - crop_2_start - crop_2_end);
106+
}
107+
}
108+
else {
109+
throw new ValueError();
110+
}
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)
0