8000 Add pb model save by Oceania2018 · Pull Request #976 · SciSharp/TensorFlow.NET · GitHub
[go: up one dir, main page]

Skip to content
8000

Add pb model save #976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.compat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/

using System.Text;

namespace Tensorflow
{
public partial class tensorflow
Expand All @@ -23,6 +25,26 @@ public partial class tensorflow
public class CompatApi
{
public CompatV1Api v1 { get; } = new CompatV1Api();

internal string as_text(string bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return bytes_or_text;
}
internal string as_text(byte[] bytes_or_text, Encoding? encoding = null)
{
if(encoding is null) encoding = Encoding.UTF8;
return encoding.GetString(bytes_or_text);
}

internal string as_str(string bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
{
return as_text(bytes_or_text, encoding);
}
}

public bool executing_eagerly()
Expand Down
152 changes: 152 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/CheckPointUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using Tensorflow.Train;
using Tensorflow.Training;
using pbc = global::Google.Protobuf.Collections;

namespace Tensorflow.Checkpoint;

public static class CheckPointUtils
{
private static string _ESCAPE_CHAR = ".";
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>,
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
{
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
Dictionary<Trackable, string> object_names = new();
foreach (var pair in node_paths)
{
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
}

Dictionary<Trackable, int> node_ids = new();
for (int i = 0; i < trackable_objects.Count; i++)
{
node_ids[trackable_objects[i]] = i;
}

var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names);
return (trackable_objects, node_paths, node_ids, slot_variables, object_names);
}

public static
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
serialize_slot_variables(IEnumerable<Trackable> trackable_objects,
IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names)
{
var non_slot_objects = trackable_objects.ToList();
Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
slot_variables = new();
foreach (var trackable in non_slot_objects)
{
if (trackable is not Optimizer)
{
continue;
}

var optim = (Optimizer)trackable;
var slot_names = optim.get_slot_names();
foreach (var slot_name in slot_names)
{
for (int original_variable_node_id = 0;
original_variable_node_id < non_slot_objects.Count;
original_variable_node_id++)
{
var original_variable = non_slot_objects[original_variable_node_id];
IVariableV1 slot_variable;
if (original_variable is not IVariableV1)
{
slot_variable = null;
}
slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name);
if(slot_variable is null) continue;

// There're some problems about the inherits of `Variable` and `Trackable`.
throw new NotImplementedException();
}
}
}

return slot_variables;
}

public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map)
{
if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res))
{
return trackable;
}
else
{
return possible_res;
}
}

public static string get_full_name(Trackable variable)
{
// TODO: This state is not correct, the whole framework need to be updated in the future.
if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable)))
{
return "";
}
// skip the check of attribute `_save_slice_info` .

// TODO: Need to be revised!!!
Debug.Assert(variable is BaseResourceVariable);
return ((BaseResourceVariable)variable).Name;
}

public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto)
{
HashSet<int> checkpointed_trackables = new();
Dictionary<int, HashSet<int>> parents = new();
for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
{
var object_proto = object_graph_proto.Nodes[i];
// skip the process of registered saver.
if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 ||
object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0)
{
checkpointed_trackables.Add(i);
}

foreach (var child_proto in object_proto.Children)
{
var child = child_proto.NodeId;
if (!parents.ContainsKey(child))
{
parents[child] = new HashSet<int>();
}

parents[child].Add(i);
}
}

Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable());
while (to_visit.Count > 0)
{
var trackable = to_visit.Dequeue();
if (!parents.ContainsKey(trackable)) continue;
var current_parents = parents[trackable];
foreach (var parent in current_parents)
{
checkpointed_trackables.Add(parent);
if (parents.ContainsKey(parent))
{
to_visit.Enqueue(parent);
}
}
parents.Remove(trackable);
}

// TODO: Complete it after supporting checkpoint.
// for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
// {
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
// }
}
}
5 changes: 5 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/CheckpointOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace Tensorflow.Checkpoint;

public record class CheckpointOptions(
string? experimental_io_device = null,
bool experimental_enable_async_checkpoint = false);
64 changes: 64 additions & 0 deletions src/TensorFlowNET.Core/Checkpoint/ObjectGraphView.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Serilog.Debugging;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Train;

namespace Tensorflow.Checkpoint;

public class ObjectGraphView: TrackableView, ICloneable
{
protected IEnumerable<TrackableReference>? _attached_dependencies;
// TODO: attached_dependencies
public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root)
{
_attached_dependencies = attached_dependencies;
}

public object Clone()
{
// TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__
return new ObjectGraphView(Root, _attached_dependencies);
}

public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
List<TrackableReference> res = base.children(obj, save_type, serialization_cache)
.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
// Check the reference, not value.
if (obj == Root && _attached_dependencies is not null)
{
res.AddRange(_attached_dependencies);
}

return res;
}

public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
{
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer);
}

public IEnumerable<TrackableReference>? AttachedDependencies
{
get => _attached_dependencies;
}

public virtual (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
{
return base._descendants_with_paths();
}

// TODO: complete the implementation
public void serialize_object_graph(object? saveables_cache = null)
{
throw new NotImplementedException();
}

// TODO: complete the implementation
public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null)
{
throw new NotImplementedException();
}
}
Loading
0