8000 Switch to user input stream · marcusjc/python-docs-samples@52e37a5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 52e37a5

Browse files
committed
Switch to user input stream
1 parent 42645c7 commit 52e37a5

File tree

1 file changed

+10
-45
lines changed

1 file changed

+10
-45
lines changed

ml_engine/online_prediction/predict.py

Lines changed: 10 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# License for the specific language governing permissions and limitations under
1010
# the License.
1111
"""Examples of using the Cloud ML Engine's online prediction service."""
12-
1312
# [START import_libraries]
1413
import googleapiclient.discovery
1514
# [END import_libraries]
@@ -111,57 +110,23 @@ def census_to_example_bytes(json_instance):
111110
# [END census_to_example_bytes]
112111

113112

114-
# [START predict_from_files]
115-
def predict_from_files(project,
116-
model,
117-
files,
118-
version=None,
119-
force_tfrecord=False):
113+
def main(project, model, version=None, force_tfrecord=False):
120114
import json
121-
import itertools
122-
instances = (json.loads(line)
123-
for f in files
124-
for line in f.readlines())
125-
126-
# Requests to online prediction
127-
# can have at most 100 instances
128-
args = [instances] * 100
129-
instance_batches = itertools.izip(*args)
130-
131-
results = []
132-
for batch in instance_batches:
115+
while True:
116+
user_input = json.loads(raw_input())
133117
if force_tfrecord:
134-
example_bytes_list = [
135-
census_to_example_bytes(instance)
136-
for instance in batch
137-
]
138-
results.append(predict_tf_records(
139-
project,
140-
model,
141-
example_bytes_list,
142-
version=version
143-
))
118+
example_bytes = census_to_example_bytes(user_input)
119+
result = predict_tf_records(
120+
project, model, [example_bytes], version=version)
144121
else:
145-
results.append(predict_json(
146-
project,
147-
model,
148-
batch,
149-
version=version
150-
))
151-
return results
152-
# [END predict_from_files]
122+
result = predict_json(
123+
project, model, [user_input], version=version)
124+
print(result)
153125

154126

155127
if __name__ == '__main__':
156128
import argparse
157-
import os
158129
parser = argparse.ArgumentParser()
159-
parser.add_argument(
160-
'input_files',
161-
help='File paths with examples to predict',
162-
nargs='+',
163-
type=os.path.abspath
164-
)
165130
parser.add_argument(
166131
'--project',
167132
help='Project in which the model is deployed',
@@ -186,4 +151,4 @@ def predict_from_files(project,
186151
default=False
187152
)
188153
args = parser.parse_args()
189-
predict_from_files(**args.__dict__)
154+
main(**args.__dict__)

0 commit comments

Comments
 (0)
0