8000 Move _maybe_save_assets and its helper functions to the top level. · staticfloat/tensorflow@69249af · GitHub
[go: up one dir, main page]

Skip to content

Commit 69249af

Browse files
Move _maybe_save_assets and its helper functions to the top level.
They do not depend on any state of class SavedModelBuilder. PiperOrigin-RevId: 155848011
1 parent e094f0b commit 69249af

File tree

1 file changed

+79
-77
lines changed

1 file changed

+79
-77
lines changed

tensorflow/python/saved_model/builder_impl.py

Lines changed: 79 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -96,53 +96,13 @@ def __init__(self, export_dir):
9696
# weights.
9797
self._has_saved_variables = False
9898

99-
def _asset_path_from_tensor(self, path_tensor):
100-
"""Returns the filepath value stored in constant `path_tensor`.
101-
102-
Args:
103-
path_tensor: Tensor of a file-path.
104-
105-
Returns:
106-
The string value i.e. path of the tensor, if valid.
107-
108-
Raises:
109-
TypeError if tensor does not match expected op type, dtype or value.
110-
"""
111-
if not isinstance(path_tensor, ops.Tensor):
112-
raise TypeError("Asset path tensor must be a Tensor.")
113-
if path_tensor.op.type != "Const":
114-
raise TypeError("Asset path tensor must be of type constant.")
115-
if path_tensor.dtype != dtypes.string:
116-
raise TypeError("Asset path tensor must be of dtype string.")
117-
str_values = path_tensor.op.get_attr("value").string_val
118-
if len(str_values) != 1:
119-
raise TypeError("Asset path tensor must be a scalar.")
120-
return str_values[0]
121-
122-
def _add_asset_to_collection(self, asset_filename, asset_tensor):
123-
"""Builds an asset proto and adds it to the asset collection of the graph.
124-
125-
Args:
126-
asset_filename: The filename of the asset to be added.
127-
asset_tensor: The asset tensor used to populate the tensor info of the
128-
asset proto.
129-
"""
130-
asset_proto = meta_graph_pb2.AssetFileDef()
131- 10000
asset_proto.filename = asset_filename
132-
asset_proto.tensor_info.name = asset_tensor.name
133-
134-
asset_any_proto = Any()
135-
asset_any_proto.Pack(asset_proto)
136-
ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
137-
13899
def _save_and_write_assets(self, assets_collection_to_add=None):
139100
"""Saves asset to the meta graph and writes asset files to disk.
140101
141102
Args:
142103
assets_collection_to_add: The collection where the asset paths are setup.
143104
"""
144-
asset_source_filepath_list = self._maybe_save_assets(
145-
assets_collection_to_add)
105+
asset_source_filepath_list = _maybe_save_assets(assets_collection_to_add)
146106

147107
# Return if there are no assets to write.
148108
if len(asset_source_filepath_list) is 0:
@@ -201,42 +161,6 @@ def _add_main_op(self, main_op):
201161
raise TypeError("main_op needs to be an Operation: %r" % main_op)
202162
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
203163

204-
def _maybe_save_assets(self, assets_collection_to_add=None):
205-
"""Saves assets to the meta graph.
206-
207-
Args:
208-
assets_collection_to_add: The collection where the asset paths are setup.
209-
210-
Returns:
211-
The list of filepaths to the assets in the assets collection.
212-
213-
Raises:
214-
ValueError: Indicating an invalid filepath tensor.
215-
"""
216-
asset_source_filepath_list = []
217-
218-
if assets_collection_to_add is None:
219-
tf_logging.info("No assets to save.")
220-
return asset_source_filepath_list
221-
222-
# Iterate over the supplied asset collection, build the `AssetFile` proto
223-
# and add them to the collection with key `constants.ASSETS_KEY`, in the
224-
# graph.
225-
for asset_tensor in assets_collection_to_add:
226-
asset_source_filepath = self._asset_path_from_tensor(asset_tensor)
227-
if not asset_source_filepath:
228-
raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
229-
230-
asset_source_filename = os.path.basename(asset_source_filepath)
231-
232-
# Build `AssetFile` proto and add it to the asset collection in the graph.
233-
self._add_asset_to_collection(asset_source_filename, asset_tensor)
234-
235-
asset_source_filepath_list.append(asset_source_filepath)
236-
237-
tf_logging.info("Assets added to graph.")
238-
return asset_source_filepath_list
239-
240164
def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
241165
"""Tags the meta graph def and adds it to the SavedModel.
242166
@@ -475,3 +399,81 @@ def save(self, as_text=False):
475399
tf_logging.info("SavedModel written to: %s", path)
476400

477401
return path
402+
403+
404+
def _maybe_save_assets(assets_collection_to_add=None):
405+
"""Saves assets to the meta graph.
406+
407+
Args:
408+
assets_collection_to_add: The collection where the asset paths are setup.
409+
410+
Returns:
411+
The list of filepaths to the assets in the assets collection.
412+
413+
Raises:
414+
ValueError: Indicating an invalid filepath tensor.
415+
"""
416+
asset_source_filepath_list = []
417+
418+
if assets_collection_to_add is None:
419+
tf_logging.info("No assets to save.")
420+
return asset_source_filepath_list
421+
422+
# Iterate over the supplied asset collection, build the `AssetFile` proto
423+
# and add them to the collection with key `constants.ASSETS_KEY`, in the
424+
# graph.
425+
for asset_tensor in assets_collection_to_add:
426+
asset_source_filepath = _asset_path_from_tensor(asset_tensor)
427+
if not asset_source_filepath:
428+
raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
429+
430+
asset_source_filename = os.path.basename(asset_source_filepath)
431+
432+
# Build `AssetFile` proto and add it to the asset collection in the graph.
433+
_add_asset_to_collection(asset_source_filename, asset_tensor)
434+
435+
asset_source_filepath_list.append(asset_source_filepath)
436+
437+
tf_logging.info("Assets added to graph.")
438+
return asset_source_filepath_list
439+
440+
441+
def _asset_path_from_tensor(path_tensor):
442+
"""Returns the filepath value stored in constant `path_tensor`.
443+
444+
Args:
445+
path_tensor: Tensor of a file-path.
446+
447+
Returns:
448+
The string value i.e. path of the tensor, if valid.
449+
450+
Raises:
451+
TypeError if tensor does not match expected op type, dtype or value.
452+
"""
453+
if not isinstance(path_tensor, ops.Tensor):
454+
raise TypeError("Asset path tensor must be a Tensor.")
455+
if path_tensor.op.type != "Const":
456+
raise TypeError("Asset path tensor must be of type constant.")
457+
if path_tensor.dtype != dtypes.string:
458+
raise TypeError("Asset path tensor must be of dtype string.")
459+
str_values = path_tensor.op.get_attr("value").string_val
460+
if len(str_values) != 1:
461+
raise TypeError("Asset path tensor must be a scalar.")
462+
return str_values[0]
463+
464+
465+
def _add_asset_to_collection(asset_filename, asset_tensor):
466+
"""Builds an asset proto and adds it to the asset collection of the graph.
467+
468+
Args:
469+
asset_filename: The filename of the asset to be added.
470+
asset_tensor: The asset tensor used to populate the tensor info of the
471+
asset proto.
472+
"""
473+
asset_proto = meta_graph_pb2.AssetFileDef()
474+
asset_proto.filename = asset_filename
475+
asset_proto.tensor_info.name = asset_tensor.name
476+
477+
asset_any_proto = Any()
478+
asset_any_proto.Pack(asset_proto)
479+
ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)

0 commit comments

Comments
 (0)
0