9
9
# License for the specific language governing permissions and limitations under
10
10
# the License.
11
11
"""Examples of using the Cloud ML Engine's online prediction service."""
12
-
13
12
# [START import_libraries]
14
13
import googleapiclient .discovery
15
14
# [END import_libraries]
@@ -111,57 +110,23 @@ def census_to_example_bytes(json_instance):
111
110
# [END census_to_example_bytes]
112
111
113
112
114
- # [START predict_from_files]
115
- def predict_from_files (project ,
116
- model ,
117
- files ,
118
- version = None ,
119
- force_tfrecord = False ):
113
+ def main (project , model , version = None , force_tfrecord = False ):
120
114
import json
121
- import itertools
122
- instances = (json .loads (line )
123
- for f in files
124
- for line in f .readlines ())
125
-
126
- # Requests to online prediction
127
- # can have at most 100 instances
128
- args = [instances ] * 100
129
- instance_batches = itertools .izip (* args )
130
-
131
- results = []
132
- for batch in instance_batches :
115
+ while True :
116
+ user_input = json .loads (raw_input ())
133
117
if force_tfrecord :
134
- example_bytes_list = [
135
- census_to_example_bytes (instance )
136
- for instance in batch
137
- ]
138
- results .append (predict_tf_records (
139
- project ,
140
- model ,
141
- example_bytes_list ,
142
- version = version
143
- ))
118
+ example_bytes = census_to_example_bytes (user_input )
119
+ result = predict_tf_records (
120
+ project , model , [example_bytes ], version = version )
144
121
else :
145
- results .append (predict_json (
146
- project ,
147
- model ,
148
- batch ,
149
- version = version
150
- ))
151
- return results
152
- # [END predict_from_files]
122
+ result = predict_json (
123
+ project , model , [user_input ], version = version )
124
+ print (result )
153
125
154
126
155
127
if __name__ == '__main__' :
156
128
import argparse
157
- import os
158
129
parser = argparse .ArgumentParser ()
159
- parser .add_argument (
160
- 'input_files' ,
161
- help = 'File paths with examples to predict' ,
162
- nargs = '+' ,
163
- type = os .path .abspath
164
- )
165
130
parser .add_argument (
166
131
'--project' ,
167
132
help = 'Project in which the model is deployed' ,
@@ -186,4 +151,4 @@ def predict_from_files(project,
186
151
default = False
187
152
)
188
153
args = parser .parse_args ()
189
- predict_from_files (** args .__dict__ )
154
+ main (** args .__dict__ )
0 commit comments