1
- import argparse
2
- import json
1
+ # Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache
2
+ # License, Version 2.0 (the "License"); you may not use this file except in
3
+ # compliance with the License. You may obtain a copy of the License at
4
+ # http://www.apache.org/licenses/LICENSE-2.0
5
+
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
8
+ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
9
+ # License for the specific language governing permissions and limitations under
10
+ # the License.
11
+ """Examples of using the Cloud ML Engine's online prediction service."""
12
+
3
13
# [START import_libraries]
4
14
import googleapiclient .discovery
5
15
# [END import_libraries]
6
16
17
+
7
18
# [START authenticating]
8
19
def get_ml_engine_service ():
9
- return googleapiclient .discovery .build_from_document (
10
- json .load (open ('staging_ml.json' )))
20
+ return googleapiclient .discovery .build ('ml' , 'v1beta1' )
11
21
# [END authenticating]
12
22
23
+
13
24
# [START predict_json]
14
25
def predict_json (project , model , instances , version = None ):
15
26
"""Send data instances to a deployed model for prediction
16
27
Args:
17
- project: str, project where the Cloud ML Engine Model is deployed
18
- model: str, model name
28
+ project: str, project where the Cloud ML Engine Model is deployed.
29
+ model: str, model name.
19
30
instances: [dict], dictionaries from string keys defined by the model
20
31
to data.
21
32
version: [optional] str, version of the model to target.
@@ -26,10 +37,10 @@ def predict_json(project, model, instances, version=None):
26
37
name = 'projects/{}/models/{}' .format (project , model )
27
38
if version is not None :
28
39
name += '/versions/{}' .format (version )
29
-
40
+
30
41
response = service .projects ().predict (
31
42
name = name ,
32
- body = {" instances" : instances }
43
+ body = {' instances' : instances }
33
44
).execute ()
34
45
35
46
if 'error' in response :
@@ -38,15 +49,19 @@ def predict_json(project, model, instances, version=None):
38
49
return response ['predictions' ]
39
50
# [END predict_json]
40
51
52
+
41
53
# [START predict_tf_records]
42
- def predict_tf_records (project , model , example_bytes_list , key = 'tfrecord' , version = None ):
54
+ def predict_tf_records (project ,
55
+ model ,
56
+ example_bytes_list ,
57
+ key = 'tfrecord' ,
58
+ version = None ):
43
59
"""Send data instances to a deployed model for prediction
44
60
Args:
45
61
project: str, project where the Cloud ML Engine Model is deployed
46
- model: str, model name
47
- feature_dict_list: [dict], dictionaries from string keys to
48
- tf.train.Feature protos.
49
- version: [optional] str, version of the model to target.
62
+ model: str, model name.
63
+ example_bytes_list: [str], Serialized tf.train.Example protos.
64
+ version: str, version of the model to target.
50
65
Returns:
51
66
A dictionary of prediction results defined by the model.
52
67
"""
@@ -58,7 +73,7 @@ def predict_tf_records(project, model, example_bytes_list, key='tfrecord', versi
58
73
59
74
response = service .projects ().predict (
60
75
name = name ,
61
- body = {" instances" : [
76
+ body = {' instances' : [
62
77
{key : {'b64' : base64 .b64encode (example_bytes )}}
63
78
for example_bytes in example_bytes_list
64
79
]}
@@ -67,8 +82,18 @@ def predict_tf_records(project, model, example_bytes_list, key='tfrecord', versi
67
82
raise RuntimeError (response ['error' ])
68
83
69
84
return response ['predictions' ]
85
+ # [END predict_tf_records]
86
+
70
87
88
+ # [START census_to_example_bytes]
71
89
def census_to_example_bytes (json_instance ):
90
+ """Serialize a JSON example to the bytes of a tf.train.Example.
91
+ This method is specific to the signature of the Census example.
92
+ Args:
93
+ json_instance: dict, representing data to be serialized.
94
+ Returns:
95
+ A string (as a container for bytes).
96
+ """
72
97
import tensorflow as tf
73
98
feature_dict = {}
74
99
for key , data in json_instance .iteritems ():
@@ -83,18 +108,82 @@ def census_to_example_bytes(json_instance):
83
108
feature = feature_dict
84
109
)
85
110
).SerializeToString ()
86
- # [END predict_tf_records ]
111
+ # [END census_to_example_bytes ]
87
112
88
- if __name__ == '__main__' :
89
- import sys
90
- import base64
113
+
114
+ # [START predict_from_files]
115
+ def predict_from_files (project ,
116
+ model ,
117
+ files ,
118
+ version = None ,
119
+ force_tfrecord = False ):
91
120
import json
92
- with open (sys .argv [1 ]) as f :
93
- instances = [json .loads (line ) for line in f .readlines ()]
94
-
95
- with open (sys .argv [2 ], 'w' ) as f :
96
- for instance in instances :
97
- f .write (json .dumps (
98
- {'tfrecord' : {'b64' : base64 .b64encode (
99
- census_to_example_string (instance )
100
- )}}))
121
+ import itertools
122
+ instances = (json .loads (line )
123
+ for f in files
124
+ for line in f .readlines ())
125
+
126
+ # Requests to online prediction
127
+ # can have at most 100 instances
128
+ args = [instances ] * 100
129
+ instance_batches = itertools .izip (* args )
130
+
131
+ results = []
132
+ for batch in instance_batches :
133
+ if force_tfrecord :
134
+ example_bytes_list = [
135
+ census_to_example_bytes (instance )
136
+ for instance in batch
137
+ ]
138
+ results .append (predict_tf_records (
139
+ project ,
140
+ model ,
141
+ example_bytes_list ,
142
+ version = version
143
+ ))
144
+ else :
145
+ results .append (predict_json (
146
+ project ,
147
+ model ,
148
+ batch ,
149
+ version = version
150
+ ))
151
+ return results
152
+ # [END predict_from_files]
153
+
154
+
155
+ if __name__ == '__main__' :
156
+ import argparse
157
+ import os
158
+ parser = argparse .ArgumentParser ()
159
+ parser .add_argument (
160
+ 'input_files' ,
161
+ help = 'File paths with examples to predict' ,
162
+ nargs = '+' ,
163
+ type = os .path .abspath
164
+ )
165
+ parser .add_argument (
166
+ '--project' ,
167
+ help = 'Project in which the model is deployed' ,
168
+ type = str ,
169
+ required = True
170
+ )
171
+ parser .add_argument (
172
+ '--model' ,
173
+ help = 'Model name' ,
174
+ type = str ,
175
+ required = True
176
+ )
177
+ parser .add_argument (
178
+ '--version' ,
179
+ help = 'Name of the version.' ,
180
+ type = str
181
+ )
182
+ parser .add_argument (
183
+ '--force-tfrecord' ,
184
+ help = 'Send predictions as TFRecords rather than raw JSON' ,
185
+ action = 'store_true' ,
186
+ default = False
187
+ )
188
+ args = parser .parse_args ()
189
+ predict_from_files (** args .__dict__ )
0 commit comments