8000 feat: support training of RNN · SciSharp/TensorFlow.NET@cb31cf4 · GitHub
[go: up one dir, main page]

Skip to content

Commit cb31cf4

Browse files
feat: support training of RNN
2 parents e1ece66 + edbf89b commit cb31cf4

File tree

159 files changed

+11058
-865
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

159 files changed

+11058
-865
lines changed

src/TensorFlowNET.Core/APIs/c_api.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System;
1818
using System.Runtime.InteropServices;
19+
using static Tensorflow.CppShapeInferenceResult.Types;
1920

2021
namespace Tensorflow
2122
{
@@ -50,6 +51,19 @@ public static string StringPiece(IntPtr handle)
5051
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
5152
}
5253

54+
public unsafe static byte[] ByteStringPiece(IntPtr handle)
55+
{
56+
byte* str_data = (byte*)handle.ToPointer();
57+
List<byte> bytes = new List<byte>();
58+
byte current = 255;
59+
while (current != ((byte)'\0'))
60+
{
61+
current = *(str_data++);
62+
bytes.Add(current);
63+
}
64+
return bytes.Take(bytes.Count - 1).ToArray();
65+
}
66+
5367
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
5468
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);
5569

src/TensorFlowNET.Core/APIs/tf.control_flow.cs

Lines changed: 5 additions & 5 deletions
Or D7AE iginal file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
4646
Tensor loop_vars,
4747
int parallel_iterations = 10)
4848
{
49-
Func<Tensor[], Tensor> cond1 = x
49+
Func<Tensors, Tensor> cond1 = x
5050
=> cond(x[0]);
5151

52-
Func<Tensor[], Tensor[]> body1 = x
52+
Func<Tensors, Tensors> body1 = x
5353
=> new[] { body(x[0]) };
5454

5555
var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
5858
return results[0];
5959
}
6060

61-
public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
62-
Func<Tensor[], Tensor[]> body,
63-
Tensor[] loop_vars,
61+
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
62+
Func<Tensors, Tensors> body,
63+
Tensors loop_vars,
6464
int parallel_iterations = 10,
6565
string name = null)
6666
=> control_flow_ops.while_loop(cond, body, loop_vars,

src/TensorFlowNET.Core/APIs/tf.tensor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,15 @@ public Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides = n
7171
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
7272
=> array_ops.split(
7373
value: value,
74-
num_split: num_split,
74+
num_or_size_splits: num_split,
7575
axis: axis,
7676
name: name);
7777

7878
public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
7979
=> array_ops.split(
8080
value: value,
81-
num_split: num_split,
82-
axis: axis,
81+
num_or_size_splits: num_split,
82+
axis: ops.convert_to_tensor(axis),
8383
name: name);
8484

8585
public Tensor ensure_shape(Tensor x, Shape shape, string name = null)

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ public static TF_DataType GetDataType(this object data)
503503
case Tensors tensors:
504504
return tensors.dtype;
505505
case IEnumerable<Tensor> tensors:
506-
return tensors.First().dtype;
506+
return tensors.Where(x => x is not null).First().dtype;
507507
case RefVariable variable:
508508
return variable.dtype;
509509
case ResourceVariable variable:

src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs renamed to src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs

Lines changed: 3 additions & 3 deletions
15
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
using System.Collections.Generic;
44
using System.Text;
55

6-
namespace Tensorflow.Extensions
6+
namespace Tensorflow.Common.Extensions
77
{
88
public static class JObjectExtensions
99
{
1010
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
1111
{
1212
var res = obj[key];
13-
if(res is null)
13+
if (res is null)
1414
{
-
return default(T);
15+
return default;
1616
}
1717
else
1818
{
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class LinqExtensions
9+
{
10+
#if NETSTANDARD2_0
11+
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
12+
{
13+
return sequence.Skip(sequence.Count() - count);
14+
}
15+
16+
public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
17+
{
18+
return sequence.Take(sequence.Count() - count);
19+
}
20+
#endif
21+
public static Tensors ToTensors(this Tensor[] tensors)
22+
{
23+
return new Tensors(tensors);
24+
}
25+
26+
public static Tensors ToTensors(this IList<Tensor> tensors)
27+
{
28+
return new Tensors(tensors);
29+
}
30+
31+
public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
32+
{
33+
first = values.Item1;
34+
second = values.Item2;
35+
third = values.Item3;
36+
}
37+
}
38+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Common.Types;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class NestExtensions
9+
{
10+
public static Tensors ToTensors(this INestable<Tensor> tensors)
11+
{
12+
return new Tensors(tensors.AsNest());
13+
}
14+
15+
public static Tensors? ToTensors(this Nest<Tensor> tensors)
16+
{
17+
return Tensors.FromNest(tensors);
18+
}
19+
20+
/// <summary>
21+
/// If the nested object is already a nested type, this function could reduce it.
22+
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
23+
/// </summary>
24+
/// <typeparam name="TIn"></typeparam>
25+
/// <typeparam name="TOut"></typeparam>
26+
/// <param name="input"></param>
27+
/// <returns></returns>
28+
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
29+
{
30+
return Nest<TOut>.ReduceFrom(input);
31+
}
32+
}
33+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This is a temp solution, which should be removed after refactoring `Tensors`
9+
/// </summary>
10+
[Obsolete]
11+
public class FakeTensorByTensorArray: Tensor
12+
{
13+
public TensorArray TensorArray { get; set; }
14+
15+
public FakeTensorByTensorArray(TensorArray array)
16+
{
17+
TensorArray = array;
18+
}
19+
}
20+
}
Lines changed: 69 additions & 0 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Types
7+
{
8+
public class GeneralizedTensorShape: Nest<Shape>
9+
{
10+
public GeneralizedTensorShape(Shape value, string? name = null)
11+
{
12+
NodeValue = value;
13+
NestType = NestType.Node;
14+
}
15+
16+
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
17+
{
18+
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
19+
Name = name;
20+
NestType = NestType.List;
21+
}
22+
23+
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
24+
{
25+
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
26+
Name = name;
27+
NestType = NestType.Dictionary;
28+
}
29+
30+
public GeneralizedTensorShape(Nest<Shape> other)
31+
{
32+
NestType = other.NestType;
33+
NodeValue = other.NodeValue;
34+
DictValue = other.DictValue;
35+
ListValue = other.ListValue;
36+
Name = other.Name;
37+
}
38+
39+
public Shape ToSingleShape()
40+
{
41+
var shapes = Flatten().ToList();
42+
if (shapes.Count != 1)
43+
{
44+
throw new ValueError("The generalized shape contains more than 1 dim.");
45+
}
46+
return shapes[0];
47+
}
48+
49+
public long ToNumber()
50+
{
51+
var shapes = Flatten().ToList();
52+
if (shapes.Count != 1 || shapes[0].ndim != 1)
53+
{
54+
throw new ValueError("The generalized shape contains more than 1 dim.");
55+
}
56+
return shapes[0].dims[0];
57+
}
58+
59+
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
60+
{
61+
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
62+
}
63+
64+
public static implicit operator GeneralizedTensorShape(Shape shape)
65+
{
66+
return new GeneralizedTensorShape(shape);
67+
}
68+
}
69+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This interface indicates that a class may have a nested structure and provide
9+
/// methods to manipulate with the structure.
10+
/// </summary>
11+
public interface INestStructure<T>: INestable<T>
12+
{
13+
NestType NestType { get; }
14+
15+
/// <summary>
16+
/// The item count of depth 1 of the nested structure.
17+
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
18+
/// </summary>
19+
int ShallowNestedCount { get; }
20+
/// <summary>
21+
/// The total item count of depth 1 of the nested structure.
22+
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
23+
/// </summary>
24+
int TotalNestedCount { get; }
25+
26+
/// <summary>
27+
/// Flatten the Nestable object. Node that if the object contains only one value,
28+
/// it will be flattened to an enumerable with one element.
29+
/// </summary>
30+
/// <returns></returns>
31+
IEnumerable<T> Flatten();
32+
/// <summary>
33+
/// Construct a new object with the same nested structure.
34+
/// </summary>
35+
/// <typeparam name="TOut"></typeparam>
36+
/// <param name="func"></param>
37+
/// <returns></returns>
38+
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func);
39+
}
40+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
public interface INestable<T>
8+
{
9+
Nest<T> 741A AsNest();
10+
}
11+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This interface is used when some corresponding python methods have optional args.
9+
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
10+
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
11+
/// as the parameter of the method.
12+
/// </summary>
13+
public interface IOptionalArgs
14+
{
15+
/// <summary>
16+
/// The identifier of the class. It is not an argument but only something to
17+
/// separate different OptionalArgs.
18+
/// </summary>
19+
string Identifier { get; }
20+
}
21+
}

0 commit comments

Comments
 (0)
0