8000 Support the multiple inputs of keras model.fit. by AsakusaRinne · Pull Request #996 · SciSharp/TensorFlow.NET · GitHub
[go: up one dir, main page]

Skip to content

Support the multiple inputs of keras model.fit. #996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension 10000

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/TensorFlowNET.Core/Data/DatasetV2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class DatasetV2 : IDatasetV2

public TensorSpec[] structure { get; set; }

public int FirstInputTensorCount { get; set; } = 1;

public Shape[] output_shapes => structure.Select(x => x.shape).ToArray();

public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
Expand Down Expand Up @@ -131,6 +133,7 @@ public IDatasetV2 apply_options()

// (4) Apply stats aggregator options

dataset.FirstInputTensorCount = this.FirstInputTensorCount;
return dataset;
}

Expand All @@ -142,7 +145,7 @@ public override string ToString()
$"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " +
$"len: {length}";

public IEnumerator<(Tensor, Tensor)> GetEnumerator()
public IEnumerator<(Tensors, Tensors)> GetEnumerator()
{
using var ownedIterator = new OwnedIterator(this);

Expand All @@ -158,7 +161,8 @@ public override string ToString()
break;
}

yield return (results[0], results.Length == 1 ? null : results[1]);
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount)));
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/TensorFlowNET.Core/Data/IDatasetV2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Tensorflow
{
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)>
{
string[] class_names { get; set; }

Expand All @@ -18,6 +18,8 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>

TensorSpec[] structure { get; set; }

int FirstInputTensorCount { get; set; }

/// <summary>
/// Caches the elements in this dataset.
/// </summary>
Expand Down
5 changes: 3 additions & 2 deletions src/TensorFlowNET.Core/Data/OwnedIterator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ void _create_iterator(IDatasetV2 dataset)
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
// TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
}

Expand All @@ -48,7 +49,7 @@ public Tensor[] next()

public void Dispose()
{
tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
//tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataAdapterArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public Tensors X { get; set; }
public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataHandlerArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public Tensors X { get; set; }
public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;
Expand Down
11 changes: 11 additions & 0 deletions src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ ICallback fit(NDArray x, NDArray y,
int workers = 1,
bool use_multiprocessing = false);

ICallback fit(IEnumerable<NDArray> x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false);

void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,
Expand Down
71 changes: 70 additions & 1 deletion src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,76 @@ public void Deconstruct(out byte blue, out byte green, out byte red)
red = data[2];
}

public static implicit operator NDArray(Array array)
public static implicit operator NDArray(int[] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[] array)
=> new NDArray(array);

public static implicit operator NDArray(float[] array)
=> new NDArray(array);

public static implicit operator NDArray(double[] array)
=> new NDArray(array);

public static implicit operator NDArray(long[] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[] array)
=> new NDArray(array);

public static implicit operator NDArray(int[,] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[,] array)
=> new NDArray(array);

public static implicit operator NDArray(float[,] array)
=> new NDArray(array);

public static implicit operator NDArray(double[,] array)
=> new NDArray(array);

public static implicit operator NDArray(long[,] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[,] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[,] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[,] array)
=> new NDArray(array);

public static implicit operator NDArray(int[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(float[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(double[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(long[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[,,] array)
=> new NDArray(array);

public unsafe static implicit operator bool(NDArray nd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private NDArray OpenEntry(ZipArchiveEntry entry)
return array;

using var s = entry.Open();
return LoadMatrix(s);
return (NDArray)LoadMatrix(s);
}

public Array LoadMatrix(Stream stream)
Expand Down
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Core/Numpy/NDArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,8 @@ public IEnumerator<NDArray> GetEnumerator()

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();

public static explicit operator NDArray(Array array)
=> new NDArray(array);
}
}
34 changes: 34 additions & 0 deletions src/TensorFlowNET.Core/Operations/dataset_ops.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
Expand Down Expand Up @@ -220,6 +223,37 @@ public Tensor model_dataset(Tensor input_dataset,
return (results[0], results[1]);
}

public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null)
{
var ctx = tf.Context;
Dictionary<string, object> attrs = new();
attrs["output_types"] = output_types;
attrs["output_shapes"] = output_shapes;
if (ctx.executing_eagerly())
{
try
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name)
{
attrs = attrs
});
return result[0];
}
catch (Exception)
{
return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx);
}
}
return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0];
}

public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx)
{
object[] attrs = new object[] { output_types, output_shapes };
var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name);
return result[0];
}

/// <summary>
/// Makes a new iterator from the given `dataset` and stores it in `iterator`.
/// </summary>
Expand Down
18 changes: 12 additions & 6 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public abstract class DataAdapter
protected DataAdapterArgs args;
protected IDatasetV2 dataset;

public virtual bool CanHandle(Tensor x, Tensor y = null)
public virtual bool CanHandle(Tensors x, Tensors y = null)
=> throw new NotImplementedException();

public virtual IDatasetV2 GetDataset()
Expand All @@ -19,12 +19,18 @@ public virtual IDatasetV2 GetDataset()
public virtual int GetSize()
=> throw new NotImplementedException("");

public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y)
{
if (x.shape.ndim == 1)
x = array_ops.expand_dims(x, axis: -1);
if (y.shape.ndim == 1)
y = array_ops.expand_dims(y, axis: -1);
for(int i = 0; i < x.Length; i++)
{
if (x[i].shape.ndim == 1)
x[i] = array_ops.expand_dims(x[i], axis: -1);
}
for (int i = 0; i < y.Length; i++)
{
if (y[i].shape.ndim == 1)
y[i] = array_ops.expand_dims(y[i], axis: -1);
}
return (x, y);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,15 @@ long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)

public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
{
var data_iterator = new OwnedIterator(_dataset);
foreach (var epoch in range(_initial_epoch, _epochs))
{
if (_insufficient_data)
break;
using var data_iterator = new OwnedIterator(_dataset);
if (_adapter.ShouldRecreateIterator())
{
data_iterator = new OwnedIterator(_dataset);
}
yield return (epoch, data_iterator);
}
// _adapter.on_epoch_end()
Expand Down
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ public interface IDataAdapter
/// <param name="x">input features</param>
/// <param name="y">target labels</param>
/// <returns></returns>
bool CanHandle(Tensor x, Tensor y = null);
bool CanHandle(Tensors x, Tensors y = null);
IDatasetV2 GetDataset();
int GetSize();
(Tensor, Tensor) Expand1d(Tensor x, Tensor y);
(Tensors, Tensors) Expand1d(Tensors x, Tensors y);
bool ShouldRecreateIterator();
}
}
13 changes: 8 additions & 5 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
Expand All @@ -20,7 +21,7 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
{
this.args = args;
_process_tensorlike();
num_samples = (int)args.X.shape[0];
num_samples = (int)args.X[0].shape[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use (int)args.X.shape[0] instead of (int)args.X[0].shape[0]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I have changed it.

var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));
Expand All @@ -33,10 +34,11 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
var inputs = new Tensors();
if (args.X != null)
inputs.Add(args.X);
inputs.AddRange(args.X);
if (args.Y != null)
inputs.Add(args.Y);
inputs.AddRange(args.Y);
dataset = slice_inputs(indices_dataset, inputs);
dataset.FirstInputTensorCount = args.X.Length;
}

Tensors permutation(Tensors tensor)
Expand Down Expand Up @@ -87,8 +89,9 @@ IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)
return dataset.with_options(new DatasetOptions { });
}

public override int GetSize()
=> _size;
public override int GetSize() => _size;

public override bool ShouldRecreateIterator() => false;

void _process_tensorlike()
{
Expand Down
Loading
0