8000 Add pb model save (#976) · SciSharp/TensorFlow.NET@197224f · GitHub
[go: up one dir, main page]

Skip to content

Commit 197224f

Browse files
Add pb model save (#976)
* Add check for dims of x and y in model.fit. * Init the serialization of keras pb model. * Add more facilities to the saved model framework. * Add ListWrapper and ITrackable, and revise implmentations. * Add serialized attributes. * Implement layer serializations. * Add lacked implementations (mainly MultiDeviceSaver). * Support autograph.to_graph under graph mode. * Add more implementations to the pb model save. * Add more implementations to the keras part of pb model save. * Refine some code after merge. * Add two simple sequential test case of pb model save. * Implement serializing attributes other keras arg definitions. * Add alexnet pb save test. * Check and refine the code. --------- Co-authored-by: AsakusaRinne <AsakusaRinne@gmail.com>
1 parent 43625ab commit 197224f

File tree

181 files changed

+6968
-567
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

181 files changed

+6968
-567
lines changed

src/TensorFlowNET.Core/APIs/tf.compat.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Text;
18+
1719
namespace Tensorflow
1820
{
1921
public partial class tensorflow
@@ -23,6 +25,26 @@ public partial class tensorflow
2325
public class CompatApi
2426
{
2527
public CompatV1Api v1 { get; } = new CompatV1Api();
28+
29+
internal string as_text(string bytes_or_text, Encoding? encoding = null)
30+
{
31+
if(encoding is null) encoding = Encoding.UTF8;
32+
return bytes_or_text;
33+
}
34+
internal string as_text(byte[] bytes_or_text, Encoding? encoding = null)
35+
{
36+
if(encoding is null) encoding = Encoding.UTF8;
37+
return encoding.GetString(bytes_or_text);
38+
}
39+
40+
internal string as_str(string bytes_or_text, Encoding? encoding = null)
41+
{
42+
return as_text(bytes_or_text, encoding);
43+
}
44+
internal string as_str(byte[] bytes_or_text, Encoding? encoding = null)
45+
{
46+
return as_text(bytes_or_text, encoding);
47+
}
2648
}
2749

2850
public bool executing_eagerly()
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.IO;
5+
using System.Linq;
6+
using Tensorflow.Train;
7+
using Tensorflow.Training;
8+
using pbc = global::Google.Protobuf.Collections;
9+
10+
namespace Tensorflow.Checkpoint;
11+
12+
public static class CheckPointUtils
13+
{
14+
private static string _ESCAPE_CHAR = ".";
15+
public static (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>, IDictionary<Trackable, int>,
16+
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>,
17+
IDictionary<Trackable, string>) objects_ids_and_slot_variables_and_paths(ObjectGraphView graph_view)
18+
{
19+
var (trackable_objects, node_paths) = graph_view.breadth_first_traversal();
20+
Dictionary<Trackable, string> object_names = new();
21+
foreach (var pair in node_paths)
22+
{
23+
object_names[pair.Key] = TrackableUtils.object_path_to_string(pair.Value);
24+
}
25+
26+
Dictionary<Trackable, int> node_ids = new();
27+
for (int i = 0; i < trackable_objects.Count; i++)
28+
{
29+
node_ids[trackable_objects[i]] = i;
30+
}
31+
32+
var slot_variables = serialize_slot_variables(trackable_objects, node_ids, object_names);
33+
return (trackable_objects, node_paths, node_ids, slot_variables, object_names);
34+
}
35+
36+
public static
37+
IDictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
38+
serialize_slot_variables(IEnumerable<Trackable> trackable_objects,
39+
IDictionary<Trackable, int> node_ids, IDictionary<Trackable, string> object_names)
40+
{
41+
var non_slot_objects = trackable_objects.ToList();
42+
Dictionary<Trackable, pbc::RepeatedField<global::Tensorflow.TrackableObjectGraph.Types.TrackableObject.Types.SlotVariableReference>>
43+
slot_variables = new();
44+
foreach (var trackable in non_slot_objects)
45+
{
46+
if (trackable is not Optimizer)
47+
{
48+
continue;
49+
}
50+
51+
var optim = (Optimizer)trackable;
52+
var slot_names = optim.get_slot_names();
53+
foreach (var slot_name in slot_names)
54+
{
55+
for (int original_variable_node_id = 0;
56+
original_variable_node_id < non_slot_objects.Count;
57+
original_variable_node_id++)
58+
{
59+
var original_variable = non_slot_objects[original_variable_node_id];
60+
IVariableV1 slot_variable;
61+
if (original_variable is not IVariableV1)
62+
{
63+
slot_variable = null;
64+
}
65+
slot_variable = optim.get_slot((IVariableV1)original_variable, slot_name);
66+
if(slot_variable is null) continue;
67+
68+
// There're some problems about the inherits of `Variable` and `Trackable`.
69+
throw new NotImplementedException();
70+
}
71+
}
72+
}
73+
74+
return slot_variables;
75+
}
76+
77+
public static Trackable get_mapped_trackable(Trackable trackable, IDictionary<Trackable, Trackable>? object_map)
78+
{
79+
if (object_map is null || !object_map.TryGetValue(trackable, out var possible_res))
80+
{
81+
return trackable;
82+
}
83+
else
84+
{
85+
return possible_res;
86+
}
87+
}
88+
89+
public static string get_full_name(Trackable variable)
90+
{
91+
// TODO: This state is not correct, the whole framework need to be updated in the future.
92+
if (!(variable is IVariableV1 || resource_variable_ops.is_resource_variable(variable)))
93+
{
94+
return "";
95+
}
96+
// skip the check of attribute `_save_slice_info` .
97+
98+
// TODO: Need to be revised!!!
99+
Debug.Assert(variable is BaseResourceVariable);
100+
return ((BaseResourceVariable)variable).Name;
101+
}
102+
103+
public static void add_checkpoint_values_check(TrackableObjectGraph object_graph_proto)
104+
{
105+
HashSet<int> checkpointed_trackables = new();
106+
Dictionary<int, HashSet<int>> parents = new();
107+
for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
108+
{
109+
var object_proto = object_graph_proto.Nodes[i];
110+
// skip the process of registered saver.
111+
if (object_proto.Attributes is not null && object_proto.Attributes.Count > 0 ||
112+
object_proto.SlotVariables is not null && object_proto.SlotVariables.Count > 0)
113+
{
114+
checkpointed_trackables.Add(i);
115+
}
116+
117+
foreach (var child_proto in object_proto.Children)
118+
{
119+
var child = child_proto.NodeId;
120+
if (!parents.ContainsKey(child))
121+
{
122+
parents[child] = new HashSet<int>();
123+
}
124+
125+
parents[child].Add(i);
126+
}
127+
}
128+
129+
Queue<int> to_visit = new(checkpointed_trackables.AsEnumerable());
130+
while (to_visit.Count > 0)
131+
{
132+
var trackable = to_visit.Dequeue();
133+
if (!parents.ContainsKey(trackable)) continue;
134+
var current_parents = parents[trackable];
135+
foreach (var parent in current_parents)
136+
{
137+
checkpointed_trackables.Add(parent);
138+
if (parents.ContainsKey(parent))
139+
{
140+
to_visit.Enqueue(parent);
141+
}
142+
}
143+
parents.Remove(trackable);
144+
}
145+
146+
// TODO: Complete it after supporting checkpoint.
147+
// for (int i = 0; i < object_graph_proto.Nodes.Count; i++)
148+
// {
149+
// object_graph_proto.Nodes[i].has_checkpoint_values.value = checkpointed_trackables.Contains(i);
150+
// }
151+
}
152+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
namespace Tensorflow.Checkpoint;
2+
3+
public record class CheckpointOptions(
4+
string? experimental_io_device = null,
5+
bool experimental_enable_async_checkpoint = false);
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Serilog.Debugging;
5+
using Tensorflow.Keras.Saving.SavedModel;
6+
using Tensorflow.Train;
7+
8+
namespace Tensorflow.Checkpoint;
9+
10+
public class ObjectGraphView: TrackableView, ICloneable
11+
{
12+
protected IEnumerable<TrackableReference>? _attached_dependencies;
13+
// TODO: attached_dependencies
14+
public ObjectGraphView(Trackable root, IEnumerable<TrackableReference>? attached_dependencies = null): base(root)
15+
{
16+
_attached_dependencies = attached_dependencies;
17+
}
18+
19+
public object Clone()
20+
{
21+
// TODO: Implement real deep copy corresponding to tensorflow/python/checkpoint/graph_view.ObjectGraphView.__deepcopy__
22+
return new ObjectGraphView(Root, _attached_dependencies);
23+
}
24+
25+
public virtual List<TrackableReference> list_children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
26+
{
27+
List<TrackableReference> res = base.children(obj, save_type, serialization_cache)
28+
.Select(x => new TrackableReference(x.Key, x.Value)).ToList();
29+
// Check the reference, not value.
30+
if (obj == Root && _attached_dependencies is not null)
31+
{
32+
res.AddRange(_attached_dependencies);
33+
}
34+
35+
return res;
36+
}
37+
38+
public override IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? serialization_cache = null)
39+
{
40+
return list_children(obj, save_type, serialization_cache).ToDictionary(x => x.Name, x => x.Refer);
41+
}
42+
43+
public IEnumerable<TrackableReference>? AttachedDependencies
44+
{
45+
get => _attached_dependencies;
46+
} B610
47+
48+
public virtual (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) breadth_first_traversal()
49+
{
50+
return base._descendants_with_paths();
51+
}
52+
53+
// TODO: complete the implementation
54+
public void serialize_object_graph(object? saveables_cache = null)
55+
{
56+
throw new NotImplementedException();
57+
}
58+
59+
// TODO: complete the implementation
60+
public void frozen_saveable_objects(object? object_map = null, object? to_graph = null, object call_with_mapped_captures = null)
61+
{
62+
throw new NotImplementedException();
63+
}
64+
}

0 commit comments

Comments
 (0)
0