8000 fix: temporarily fix the sequential nest error. · SciSharp/TensorFlow.NET@4bca319 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4bca319

Browse files
AsakusaRinneOceania2018
authored andcommitted
fix: temporarily fix the sequential nest error.
1 parent 6fb930a commit 4bca319

File tree

5 files changed

+55
-6
lines changed

5 files changed

+55
-6
lines changed

src/TensorFlowNET.Core/Training/Saving/SavedModel/save.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ private static (MetaGraphDef, Graph, TrackableSaver, AssetInfo, IList<Trackable>
8888
{
8989
if (ops.inside_function())
9090
{
91-
throw new AssertionError("`tf.saved_model.save` is not supported inside a traced @tf.function. " +
91+
throw new AssertionError("`tf.saved_model.save` is not supported inside a traced [AutoGraph]. " +
9292
"Move the call to the outer eagerly-executed context.");
9393
}
9494

src/TensorFlowNET.Keras/Engine/Layer.Apply.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,19 @@ public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
4141

4242
return outputs;
4343
}
44+
45+
// TODO(Rinne): remove it and completely fix issue 1084
46+
[Obsolete]
47+
private bool _enforce_layer_construction = false;
48+
[Obsolete]
49+
internal void enforce_layer_construction()
50+
{
51+
_enforce_layer_construction = true;
52+
}
53+
[Obsolete]
54+
internal void unset_layer_construction()
55+
{
56+
_enforce_layer_construction = false;
57+
}
4458
}
4559
}

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ internal virtual void Initialize(LayerArgs args)
291291
bool _in_functional_construction_mode(Tensors inputs)
292292
{
293293
return tf.Context.executing_eagerly()
294-
&& inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count();
294+
&& inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count() || _enforce_layer_construction;
295295
}
296296

297297
public void SetConnectivityMetadata(Tensors inputs, Tensors outputs)

src/TensorFlowNET.Keras/Engine/Sequential.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,17 @@ public void InitLayers(IEnumerable<ILayer> layers)
6262
{
6363
foreach(var layer in layers)
6464
{
65+
// TODO(Rinne): remove it and completely fix issue 1084
66+
if(layer is Sequential s)
67+
{
68+
s.Layers.ForEach(x => ((Layer)x).enforce_layer_construction());
69+
}
6570
add(layer);
71+
// TODO(Rinne): remove it and completely fix issue 1084
72+
if (layer is Sequential s2)
73+
{
74+
s2.Layers.ForEach(x => ((Layer)x).unset_layer_construction());
75+
}
6676
}
6777
}
6878

@@ -163,7 +173,7 @@ void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType inpu
163173
Tensors layer_output = null;
164174
Tensors outputs = null;
165175
List<INode> created_nodes = new List<INode>();
166-
foreach (var layer in args.Layers)
176+
foreach (var layer in Layers)
167177
{
168178
clear_previously_created_nodes(layer, _created_nodes);
169179
layer_output = layer.Apply(layer_input);

test/TensorFlowNET.Keras.UnitTest/Model/ModelBuildTest.cs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
23
using static Tensorflow.Binding;
4+
using static Tensorflow.KerasApi;
35

46
namespace Tensorflow.Keras.UnitTest.Model
57
{
@@ -14,24 +16,47 @@ public void DenseBuild()
1416
var dense = tf.keras.layers.Dense(64);
1517
var output = dense.Apply(input);
1618
var model = tf.keras.Model(input, output);
19+
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
1720

1821
// one dimensions input with unknown batchsize
1922
var input_2 = tf.keras.layers.Input((60));
2023
var dense_2 = tf.keras.layers.Dense(64);
21-
var output_2 = dense.Apply(input_2);
24+
var output_2 = dense_2.Apply(input_2);
2225
var model_2 = tf.keras.Model(input_2, output_2);
26+
model_2.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
2327

2428
// two dimensions input with specified batchsize
2529
var input_3 = tf.keras.layers.Input((17, 60), 8);
2630
var dense_3 = tf.keras.layers.Dense(64);
27-
var output_3 = dense.Apply(input_3);
31+
var output_3 = dense_3.Apply(input_3);
2832
var model_3 = tf.keras.Model(input_3, output_3);
33+
model_3.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
2934

3035
// one dimensions input with specified batchsize
3136
var input_4 = tf.keras.layers.Input((60), 8);
3237
var dense_4 = tf.keras.layers.Dense(64);
33-
var output_4 = dense.Apply(input_4);
38+
var output_4 = dense_4.Apply(input_4);
3439
var model_4 = tf.keras.Model(input_4, output_4);
40+
model_4.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
41+
}
42+
43+
[TestMethod]
44+
public void NestedSequential()
45+
{
46+
var block1 = keras.Sequential(new[] {
47+
keras.layers.InputLayer((3, 3)),
48+
keras.Sequential(new []
49+
{
50+
keras.layers.Flatten(),
51+
keras.layers.Dense(5)
52+
}
53+
)
54+
});
55+
block1.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
56+
57+
var x = tf.ones((1, 3, 3));
58+
var y = block1.predict(x);
59+
Console.WriteLine(y);
3560
}
3661
}
3762
}

0 commit comments

Comments
 (0)
0