8000 Initial commit of prediction examples · lauro-cesar/python-docs-samples@42645c7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 42645c7

Browse files
committed
Initial commit of prediction examples
1 parent 5dd9b31 commit 42645c7

File tree

4 files changed

+116
-1546
lines changed

4 files changed

+116
-1546
lines changed
Lines changed: 116 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
1-
import argparse
2-
import json
1+
# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
2+
# License, Version 2.0 (the "License"); you may not use this file except in
3+
# compliance with the License. You may obtain a copy of the License at
4+
# http://www.apache.org/licenses/LICENSE-2.0
5+
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
8+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
9+
# License for the specific language governing permissions and limitations under
10+
# the License.
11+
"""Examples of using the Cloud ML Engine's online prediction service."""
12+
313
# [START import_libraries]
414
import googleapiclient.discovery
515
# [END import_libraries]
616

17+
718
# [START authenticating]
819
def get_ml_engine_service():
9-
return googleapiclient.discovery.build_from_document(
10-
json.load(open('staging_ml.json')))
20+
return googleapiclient.discovery.build('ml', 'v1beta1')
1121
# [END authenticating]
1222

23+
1324
# [START predict_json]
1425
def predict_json(project, model, instances, version=None):
1526
"""Send data instances to a deployed model for prediction
1627
Args:
17-
project: str, project where the Cloud ML Engine Model is deployed
18-
model: str, model name
28+
project: str, project where the Cloud ML Engine Model is deployed.
29+
model: str, model name.
1930
instances: [dict], dictionaries from string keys defined by the model
2031
to data.
2132
version: [optional] str, version of the model to target.
@@ -26,10 +37,10 @@ def predict_json(project, model, instances, version=None):
2637
name = 'projects/{}/models/{}'.format(project, model)
2738
if version is not None:
2839
name += '/versions/{}'.format(version)
29-
40+
3041
response = service.projects().predict(
3142
name=name,
32-
body={"instances": instances}
43+
body={'instances': instances}
3344
).execute()
3445

3546
if 'error' in response:
@@ -38,15 +49,19 @@ def predict_json(project, model, instances, version=None):
3849
return response['predictions']
3950
# [END predict_json]
4051

52+
4153
# [START predict_tf_records]
42-
def predict_tf_records(project, model, example_bytes_list, key='tfrecord', version=None):
54+
def predict_tf_records(project,
55+
model,
56+
example_bytes_list,
57+
key='tfrecord',
58+
version=None):
4359
"""Send data instances to a deployed model for prediction
4460
Args:
4561
project: str, project where the Cloud ML Engine Model is deployed
46-
model: str, model name
47-
feature_dict_list: [dict], dictionaries from string keys to
48-
tf.train.Feature protos.
49-
version: [optional] str, version of the model to target.
62+
model: str, model name.
63+
example_bytes_list: [str], Serialized tf.train.Example protos.
64+
version: str, version of the model to target.
5065
Returns:
5166
A dictionary of prediction results defined by the model.
5267
"""
@@ -58,7 +73,7 @@ def predict_tf_records(project, model, example_bytes_list, key='tfrecord', versi
5873

5974
response = service.projects().predict(
6075
name=name,
61-
body={"instances": [
76+
body={'instances': [
6277
{key: {'b64': base64.b64encode(example_bytes)}}
6378
for example_bytes in example_bytes_list
6479
]}
@@ -67,8 +82,18 @@ def predict_tf_records(project, model, example_bytes_list, key='tfrecord', versi
6782
raise RuntimeError(response['error'])
6883

6984
return response['predictions']
85+
# [END predict_tf_records]
86+
7087

88+
# [START census_to_example_bytes]
7189
def census_to_example_bytes(json_instance):
90+
"""Serialize a JSON example to the bytes of a tf.train.Example.
91+
This method is specific to the signature of the Census example.
92+
Args:
93+
json_instance: dict, representing data to be serialized.
94+
Returns:
95+
A string (as a container for bytes).
96+
"""
7297
import tensorflow as tf
7398
feature_dict = {}
7499
for key, data in json_instance.iteritems():
@@ -83,18 +108,82 @@ def census_to_example_bytes(json_instance):
83108
feature=feature_dict
84109
)
85110
).SerializeToString()
86-
# [END predict_tf_records]
111+
# [END census_to_example_bytes]
87112

88-
if __name__=='__main__':
89-
import sys
90-
import base64
113+
114+
# [START predict_from_files]
115+
def predict_from_files(project,
116+
model,
117+
files,
118+
version=None,
119+
force_tfrecord=False):
91120
import json
92-
with open(sys.argv[1]) as f:
93-
instances = [json.loads(line) for line in f.readlines()]
94-
95-
with open(sys.argv[2], 'w') as f:
96-
for instance in instances:
97-
f.write(json.dumps(
98-
{'tfrecord': {'b64': base64.b64encode(
99-
census_to_example_string(instance)
100-
)}}))
121+
import itertools
122+
instances = (json.loads(line)
123+
for f in files
124+
for line in f.readlines())
125+
126+
# Requests to online prediction
127+
# can have at most 100 instances
128+
args = [instances] * 100
129+
instance_batches = itertools.izip(*args)
130+
131+
results = []
132+
for batch in instance_batches:
133+
if force_tfrecord:
134+
example_bytes_list = [
135+
census_to_example_bytes(instance)
136+
for instance in batch
137+
]
138+
results.append(predict_tf_records(
139+
project,
140+
model,
141+
example_bytes_list,
142+
version=version
143+
))
144+
else:
145+
results.append(predict_json(
146+
project,
147+
model,
148+
batch,
149+
version=version
150+
))
151+
return results
152+
# [END predict_from_files]
153+
154+
155+
if __name__ == '__main__':
156+
import argparse
157+
import os
158+
parser = argparse.ArgumentParser()
159+
parser.add_argument(
160+
'input_files',
161+
help='File paths with examples to predict',
162+
nargs='+',
163+
type=os.path.abspath
164+
)
165+
parser.add_argument(
166+
'--project',
167+
help='Project in which the model is deployed',
168+
type=str,
169+
required=True
170+
)
171+
parser.add_argument(
172+
'--model',
173+
help='Model name',
174+
type=str,
175+
required=True
176+
)
177+
parser.add_argument(
178+
'--version',
179+
help='Name of the version.',
180+
type=str
181+
)
182+
parser.add_argument(
183+
'--force-tfrecord',
184+
help='Send predictions as TFRecords rather than raw JSON',
185+
action='store_true',
186+
default=False
187+
)
188+
args = parser.parse_args()
189+
predict_from_files(**args.__dict__)

0 commit comments

Comments
 (0)
0