From c2f463f57ac6858a4b0085006ed6f2c3abc17ccb Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sat, 4 Mar 2023 21:41:25 +0800 Subject: [PATCH 1/4] Support the multiple inputs of keras model.fit. --- src/TensorFlowNET.Core/Data/DatasetV2.cs | 8 ++- src/TensorFlowNET.Core/Data/IDatasetV2.cs | 4 +- src/TensorFlowNET.Core/Data/OwnedIterator.cs | 5 +- .../Keras/ArgsDefinition/DataAdapterArgs.cs | 4 +- .../Keras/ArgsDefinition/DataHandlerArgs.cs | 4 +- src/TensorFlowNET.Core/Keras/Engine/IModel.cs | 11 +++ .../NumPy/NDArray.Implicit.cs | 71 ++++++++++++++++++- .../NumPy/Persistence/NpzDictionaryArray.cs | 2 +- src/TensorFlowNET.Core/Numpy/NDArray.cs | 3 + .../Operations/dataset_ops.cs | 34 +++++++++ .../Engine/DataAdapters/DataAdapter.cs | 18 +++-- .../Engine/DataAdapters/DataHandler.cs | 6 +- .../Engine/DataAdapters/IDataAdapter.cs | 4 +- .../DataAdapters/TensorLikeDataAdapter.cs | 13 ++-- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 65 +++++++++++++++-- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 11 ++- .../Helpers/RandomDataset.cs | 30 ++++++++ .../MultiInputModelTest.cs | 69 ++++++++++++++++++ .../SaveModel/SequentialModelLoad.cs | 1 + .../SaveModel/SequentialModelSave.cs | 22 +----- .../Dataset/DatasetTest.cs | 18 ++--- 21 files changed, 343 insertions(+), 60 deletions(-) create mode 100644 test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 103d7cfff..324d7e834 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -19,6 +19,8 @@ public class DatasetV2 : IDatasetV2 public TensorSpec[] structure { get; set; } + public int FirstInputTensorCount { get; set; } = 1; + public Shape[] output_shapes => structure.Select(x => x.shape).ToArray(); public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); @@ -131,6 +133,7 @@ public IDatasetV2 apply_options() // (4) Apply stats aggregator options + dataset.FirstInputTensorCount = this.FirstInputTensorCount; return dataset; } @@ -142,7 +145,7 @@ public override string ToString() $"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " + $"len: {length}"; - public IEnumerator<(Tensor, Tensor)> GetEnumerator() + public IEnumerator<(Tensors, Tensors)> GetEnumerator() { using var ownedIterator = new OwnedIterator(this); @@ -158,7 +161,8 @@ public override string ToString() break; } - yield return (results[0], results.Length == 1 ? null : results[1]); + yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? + null : new Tensors(results.Skip(FirstInputTensorCount))); } } diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 5cfeb27cc..320cbe348 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -4,7 +4,7 @@ namespace Tensorflow { - public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> + public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)> { string[] class_names { get; set; } @@ -18,6 +18,8 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> TensorSpec[] structure { get; set; } + int FirstInputTensorCount { get; set; } + /// /// Caches the elements in this dataset. /// diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs index eb91272c7..1dafc87ea 100644 --- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -27,7 +27,8 @@ void _create_iterator(IDatasetV2 dataset) _dataset = dataset; _element_spec = dataset.element_spec; // _flat_output_types = - (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes); + _iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes); + // TODO(Rinne): deal with graph mode. ops.make_iterator(dataset.variant_tensor, _iterator_resource); } @@ -48,7 +49,7 @@ public Tensor[] next() public void Dispose() { - tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null); + //tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null); } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index 8ce1ec655..78882e82d 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition { public class DataAdapterArgs: IKerasConfig { - public Tensor X { get; set; } - public Tensor Y { get; set; } + public Tensors X { get; set; } + public Tensors Y { get; set; } public IDatasetV2 Dataset { get; set; } public int BatchSize { get; set; } = 32; public int Steps { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index fd603a85e..82530e950 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition { public class DataHandlerArgs: IKerasConfig { - public Tensor X { get; set; } - public Tensor Y { get; set; } + public Tensors X { get; set; } + public Tensors Y { get; set; } public IDatasetV2 Dataset { get; set; } public int BatchSize { get; set; } = 32; public int StepsPerEpoch { get; set; } = -1; diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index 8bcfcbbbd..e02642dcf 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -24,6 +24,17 @@ ICallback fit(NDArray x, NDArray y, int workers = 1, bool use_multiprocessing = false); + ICallback fit(IEnumerable x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + float validation_split = 0f, + bool shuffle = true, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + void save(string filepath, bool overwrite = true, bool include_optimizer = true, diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs index 53401a444..fd4f93fc1 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -14,7 +14,76 @@ public void Deconstruct(out byte blue, out byte green, out byte red) red = data[2]; } - public static implicit operator NDArray(Array array) + public static implicit operator NDArray(int[] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[] array) + => new NDArray(array); + + public static implicit operator NDArray(float[] array) + => new NDArray(array); + + public static implicit operator NDArray(double[] array) + => new NDArray(array); + + public static implicit operator NDArray(long[] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[] array) + => new NDArray(array); + + public static implicit operator NDArray(int[,] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[,] array) + => new NDArray(array); + + public static implicit operator NDArray(float[,] array) + => new NDArray(array); + + public static implicit operator NDArray(double[,] array) + => new NDArray(array); + + public static implicit operator NDArray(long[,] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[,] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[,] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[,] array) + => new NDArray(array); + + public static implicit operator NDArray(int[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(float[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(double[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(long[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[,,] array) => new NDArray(array); public unsafe static implicit operator bool(NDArray nd) diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs index 6e81216ea..ba7868faa 100644 --- a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs @@ -25,7 +25,7 @@ private NDArray OpenEntry(ZipArchiveEntry entry) return array; using var s = entry.Open(); - return LoadMatrix(s); + return (NDArray)LoadMatrix(s); } public Array LoadMatrix(Stream stream) diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 3a2cb3ee2..6e4c6b32c 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -49,5 +49,8 @@ public IEnumerator GetEnumerator() IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public static explicit operator NDArray(Array array) + => new NDArray(array); } } diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 9407fd5aa..c7e627772 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -1,6 +1,9 @@ using System; +using Tensorflow.Contexts; +using Tensorflow.Eager; using Tensorflow.Framework.Models; using Tensorflow.Functions; +using Tensorflow.Operations; using static Tensorflow.Binding; namespace Tensorflow @@ -220,6 +223,37 @@ public Tensor model_dataset(Tensor input_dataset, return (results[0], results[1]); } + public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null) + { + var ctx = tf.Context; + Dictionary attrs = new(); + attrs["output_types"] = output_types; + attrs["output_shapes"] = output_shapes; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name) + { + attrs = attrs + }); + return result[0]; + } + catch (Exception) + { + return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx); + } + } + return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0]; + } + + public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx) + { + object[] attrs = new object[] { output_types, output_shapes }; + var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name); + return result[0]; + } + /// /// Makes a new iterator from the given `dataset` and stores it in `iterator`. /// diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs index 3314f5c40..6c7d53b2f 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs @@ -10,7 +10,7 @@ public abstract class DataAdapter protected DataAdapterArgs args; protected IDatasetV2 dataset; - public virtual bool CanHandle(Tensor x, Tensor y = null) + public virtual bool CanHandle(Tensors x, Tensors y = null) => throw new NotImplementedException(); public virtual IDatasetV2 GetDataset() @@ -19,12 +19,18 @@ public virtual IDatasetV2 GetDataset() public virtual int GetSize() => throw new NotImplementedException(""); - public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y) + public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y) { - if (x.shape.ndim == 1) - x = array_ops.expand_dims(x, axis: -1); - if (y.shape.ndim == 1) - y = array_ops.expand_dims(y, axis: -1); + for(int i = 0; i < x.Length; i++) + { + if (x[i].shape.ndim == 1) + x[i] = array_ops.expand_dims(x[i], axis: -1); + } + for (int i = 0; i < y.Length; i++) + { + if (y[i].shape.ndim == 1) + y[i] = array_ops.expand_dims(y[i], axis: -1); + } return (x, y); } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index 1ddddd111..4723222f2 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -93,11 +93,15 @@ long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) public IEnumerable<(int, OwnedIterator)> enumerate_epochs() { + var data_iterator = new OwnedIterator(_dataset); foreach (var epoch in range(_initial_epoch, _epochs)) { if (_insufficient_data) break; - using var data_iterator = new OwnedIterator(_dataset); + if (_adapter.ShouldRecreateIterator()) + { + data_iterator = new OwnedIterator(_dataset); + } yield return (epoch, data_iterator); } // _adapter.on_epoch_end() diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs index df414b9fd..4bdc49795 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs @@ -13,10 +13,10 @@ public interface IDataAdapter /// input features /// target labels /// - bool CanHandle(Tensor x, Tensor y = null); + bool CanHandle(Tensors x, Tensors y = null); IDatasetV2 GetDataset(); int GetSize(); - (Tensor, Tensor) Expand1d(Tensor x, Tensor y); + (Tensors, Tensors) Expand1d(Tensors x, Tensors y); bool ShouldRecreateIterator(); } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index fc61aa715..f53c67c4b 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using static Tensorflow.Binding; @@ -20,7 +21,7 @@ public TensorLikeDataAdapter(DataAdapterArgs args) { this.args = args; _process_tensorlike(); - num_samples = (int)args.X.shape[0]; + num_samples = (int)args.X[0].shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); @@ -33,10 +34,11 @@ public TensorLikeDataAdapter(DataAdapterArgs args) indices_dataset = indices_dataset.flat_map(slice_batch_indices); var inputs = new Tensors(); if (args.X != null) - inputs.Add(args.X); + inputs.AddRange(args.X); if (args.Y != null) - inputs.Add(args.Y); + inputs.AddRange(args.Y); dataset = slice_inputs(indices_dataset, inputs); + dataset.FirstInputTensorCount = args.X.Length; } Tensors permutation(Tensors tensor) @@ -87,8 +89,9 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements) return dataset.with_options(new DatasetOptions { }); } - public override int GetSize() - => _size; + public override int GetSize() => _size; + + public override bool ShouldRecreateIterator() => false; void _process_tensorlike() { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 1ebd56d33..39004183b 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -59,7 +59,62 @@ public ICallback fit(NDArray x, NDArray y, StepsPerExecution = _steps_per_execution }); - return FitInternal(data_handler, epochs, verbose); + return FitInternal(data_handler, epochs, verbose, validation_data: null, + train_step_func: train_step_function); + } + + public ICallback fit(IEnumerable x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + float validation_split = 0f, + bool shuffle = true, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + foreach(var tx in x) + { + if (tx.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); + } + } + int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); + + var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor); + var train_y = y[new Slice(0, train_count)]; + var val_x = x.Select(x => x[new Slice(train_count)] as Tensor); + var val_y = y[new Slice(train_count)]; + + var data_handler = new DataHandler(new DataHandlerArgs + { + X = new Tensors(train_x), + Y = train_y, + BatchSize = batch_size, + InitialEpoch = initial_epoch, + Epochs = epochs, + Shuffle = shuffle, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || + data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) + { + return FitInternal(data_handler, epochs, verbose, validation_data: null, + train_step_func: train_step_multi_inputs_function); + } + else + { + return FitInternal(data_handler, epochs, verbose, validation_data: null, + train_step_func: train_step_function); + } } public History fit(IDatasetV2 dataset, @@ -88,10 +143,12 @@ public History fit(IDatasetV2 dataset, StepsPerExecution = _steps_per_execution }); - return FitInternal(data_handler, epochs, verbose, validation_data: validation_data); + return FitInternal(data_handler, epochs, verbose, validation_data: validation_data, + train_step_func: train_step_function); } - History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null) + History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data, + Func> train_step_func) { stop_training = false; _train_counter.assign(0); @@ -113,7 +170,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV foreach (var step in data_handler.steps()) { callbacks.on_train_batch_begin(step); - logs = train_step_function(data_handler, iterator); + logs = train_step_func(data_handler, iterator); var end_step = step + data_handler.StepIncrement; callbacks.on_train_batch_end(end_step, logs); } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index 8d85d70de..d8171e2a9 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -17,12 +17,21 @@ Dictionary train_step_function(DataHandler data_handler, OwnedIte return outputs; } + Dictionary train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) + { + var data = iterator.next(); + var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; + var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); + tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); + return outputs; + } + /// /// The logic for one training step. /// /// /// - Dictionary train_step(DataHandler data_handler, Tensor x, Tensor y) + Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); diff --git a/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs new file mode 100644 index 000000000..e145ce585 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.UnitTest.Helpers +{ + public class RandomDataSet : DataSetBase + { + private Shape _shape; + + public RandomDataSet(Shape shape, int count) + { + _shape = shape; + Debug.Assert(_shape.ndim == 3); + long[] dims = new long[4]; + dims[0] = count; + for (int i = 1; i < 4; i++) + { + dims[i] = _shape[i - 1]; + } + Shape s = new Shape(dims); + Data = np.random.normal(0, 2, s); + Labels = np.random.uniform(0, 1, (count, 1)); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs new file mode 100644 index 000000000..490178bc9 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs @@ -0,0 +1,69 @@ +using Microsoft.VisualStudio.TestPlatform.Utilities; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Xml.Linq; +using Tensorflow.Operations; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.NumPy; +using Microsoft.VisualBasic; +using static HDF.PInvoke.H5T; +using Tensorflow.Keras.UnitTest.Helpers; +using Tensorflow.Keras.Optimizers; + +namespace Tensorflow.Keras.UnitTest +{ + [TestClass] + public class MultiInputModelTest + { + [TestMethod] + public void SimpleModel() + { + var inputs = keras.Input((28, 28, 1)); + var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs); + var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1); + var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1); + var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2); + var flat1 = keras.layers.Flatten().Apply(pool2); + + var inputs_2 = keras.Input((28, 28, 1)); + var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2); + var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2); + var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2); + var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2); + var flat1_2 = keras.layers.Flatten().Apply(pool2_2); + + var concat = keras.layers.Concatenate().Apply((flat1, flat1_2)); + var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat); + var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1); + var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2); + var output = keras.layers.Softmax(-1).Apply(dense3); + + var model = keras.Model((inputs, inputs_2), output); + model.summary(); + + var data_loader = new MnistModelLoader(); + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 59000, + }).Result; + + var loss = keras.losses.SparseCategoricalCrossentropy(); + var optimizer = new Adam(0.001f); + model.compile(optimizer, loss, new string[] { "accuracy" }); + + NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); + NDArray x2 = x1; + + var x = new NDArray[] { x1, x2 }; + model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index e778a5a4a..385ec0f7c 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -13,6 +13,7 @@ using Tensorflow.Keras.Optimizers; using static Tensorflow.KerasApi; using Tensorflow.NumPy; +using Tensorflow.Keras.UnitTest.Helpers; using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; namespace TensorFlowNET.Keras.UnitTest.SaveModel; diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs index 5b7c2b62e..251afde3d 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs @@ -6,7 +6,7 @@ using Tensorflow.Keras.Engine; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; -using Tensorflow.NumPy; +using Tensorflow.Keras.UnitTest.Helpers; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -175,24 +175,4 @@ public void AlexnetFromSequential() // ) #endregion } - - public class RandomDataSet : DataSetBase - { - private Shape _shape; - - public RandomDataSet(Shape shape, int count) - { - _shape = shape; - Debug.Assert(_shape.ndim == 3); - long[] dims = new long[4]; - dims[0] = count; - for (int i = 1; i < 4; i++) - { - dims[i] = _shape[i - 1]; - } - Shape s = new Shape(dims); - Data = np.random.normal(0, 2, s); - Labels = np.random.uniform(0, 1, (count, 1)); - } - } } \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 8317346ea..01f35a417 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -20,7 +20,7 @@ public void Range() Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value++; } } @@ -39,7 +39,7 @@ public void Prefetch() Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value += 2; } } @@ -54,7 +54,7 @@ public void FromTensorSlices() int n = 0; foreach (var (item_x, item_y) in dataset) { - print($"x:{item_x.numpy()},y:{item_y.numpy()}"); + print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}"); n += 1; } Assert.AreEqual(5, n); @@ -69,7 +69,7 @@ public void FromTensor() int n = 0; foreach (var x in dataset) { - Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray())); + Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray())); n += 1; } Assert.AreEqual(1, n); @@ -85,7 +85,7 @@ public void Shard() foreach (var item in dataset2) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value += 3; } @@ -93,7 +93,7 @@ public void Shard() var dataset3 = dataset1.shard(num_shards: 3, index: 1); foreach (var item in dataset3) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value += 3; } } @@ -108,7 +108,7 @@ public void Skip() foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value++; } } @@ -123,7 +123,7 @@ public void Map() foreach (var item in dataset) { - Assert.AreEqual(value + 10, (long)item.Item1); + Assert.AreEqual(value + 10, (long)item.Item1[0]); value++; } } @@ -138,7 +138,7 @@ public void Cache() foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value++; } } From 3cd01f43a0de249e39c6c39422ea218b13b47448 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Sat, 4 Mar 2023 09:55:01 -0600 Subject: [PATCH 2/4] GradientTest inherts from EagerModeTestBase. --- test/TensorFlowNET.Keras.UnitTest/Gradient.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/TensorFlowNET.Keras.UnitTest/Gradient.cs b/test/TensorFlowNET.Keras.UnitTest/Gradient.cs index fad8e1187..f20eae0e0 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Gradient.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Gradient.cs @@ -9,7 +9,7 @@ namespace TensorFlowNET.Keras.UnitTest; [TestClass] -public class GradientTest +public class GradientTest : EagerModeTestBase { public IModel get_actor(int num_states) { From 6a295b68fc0a56423aa05f55c5cdb3d3668d6193 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 5 Mar 2023 01:12:41 +0800 Subject: [PATCH 3/4] Add more explicit conversion for Tensors. --- src/TensorFlowNET.Core/Tensors/Tensors.cs | 101 ++++++++++++++++++ .../Dataset/DatasetTest.cs | 18 ++-- 2 files changed, 110 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index ecd844d1f..7fa4dd443 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -65,6 +65,93 @@ public void Insert(int index, Tensor tensor) IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public NDArray numpy() + { + EnsureSingleTensor(this, "nnumpy"); + return this[0].numpy(); + } + + public T[] ToArray() where T: unmanaged + { + EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); + return this[0].ToArray(); + } + + #region Explicit Conversions + public unsafe static explicit operator bool(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to bool"); + return (bool)tensor[0]; + } + + public unsafe static explicit operator sbyte(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to sbyte"); + return (sbyte)tensor[0]; + } + + public unsafe static explicit operator byte(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to byte"); + return (byte)tensor[0]; + } + + public unsafe static explicit operator ushort(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to ushort"); + return (ushort)tensor[0]; + } + + public unsafe static explicit operator short(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to short"); + return (short)tensor[0]; + } + + public unsafe static explicit operator int(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to int"); + return (int)tensor[0]; + } + + public unsafe static explicit operator uint(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to uint"); + return (uint)tensor[0]; + } + + public unsafe static explicit operator long(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to long"); + return (long)tensor[0]; + } + + public unsafe static explicit operator ulong(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to ulong"); + return (ulong)tensor[0]; + } + + public unsafe static explicit operator float(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to byte"); + return (byte)tensor[0]; + } + + public unsafe static explicit operator double(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to double"); + return (double)tensor[0]; + } + + public unsafe static explicit operator string(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to string"); + return (string)tensor[0]; + } + #endregion + + #region Implicit Conversions public static implicit operator Tensors(Tensor tensor) => new Tensors(tensor); @@ -87,12 +174,26 @@ public static implicit operator Tensor(Tensors tensors) public static implicit operator Tensor[](Tensors tensors) => tensors.items.ToArray(); + #endregion + public void Deconstruct(out Tensor a, out Tensor b) { a = items[0]; b = items[1]; } + private static void EnsureSingleTensor(Tensors tensors, string methodnName) + { + if(tensors.Length == 0) + { + throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); + } + else if(tensors.Length > 1) + { + throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); + } + } + public override string ToString() => items.Count() == 1 ? items.First().ToString() diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 01f35a417..8317346ea 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -20,7 +20,7 @@ public void Range() Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value++; } } @@ -39,7 +39,7 @@ public void Prefetch() Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value += 2; } } @@ -54,7 +54,7 @@ public void FromTensorSlices() int n = 0; foreach (var (item_x, item_y) in dataset) { - print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}"); + print($"x:{item_x.numpy()},y:{item_y.numpy()}"); n += 1; } Assert.AreEqual(5, n); @@ -69,7 +69,7 @@ public void FromTensor() int n = 0; foreach (var x in dataset) { - Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray())); + Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray())); n += 1; } Assert.AreEqual(1, n); @@ -85,7 +85,7 @@ public void Shard() foreach (var item in dataset2) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value += 3; } @@ -93,7 +93,7 @@ public void Shard() var dataset3 = dataset1.shard(num_shards: 3, index: 1); foreach (var item in dataset3) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value += 3; } } @@ -108,7 +108,7 @@ public void Skip() foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value++; } } @@ -123,7 +123,7 @@ public void Map() foreach (var item in dataset) { - Assert.AreEqual(value + 10, (long)item.Item1[0]); + Assert.AreEqual(value + 10, (long)item.Item1); value++; } } @@ -138,7 +138,7 @@ public void Cache() foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value++; } } From bebf9b4fd20b5a10c1f910855cee2d09e57ba5c5 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 5 Mar 2023 02:05:19 +0800 Subject: [PATCH 4/4] Resolve the comment. --- .../Engine/DataAdapters/TensorLikeDataAdapter.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index f53c67c4b..a7e1d7e34 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -21,7 +21,7 @@ public TensorLikeDataAdapter(DataAdapterArgs args) { this.args = args; _process_tensorlike(); - num_samples = (int)args.X[0].shape[0]; + num_samples = (int)args.X.shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));