8000 Fix user input loop · lauro-cesar/python-docs-samples@71daf1f · GitHub
[go: up one dir, main page]

Skip to content

Commit 71daf1f

Browse files
committed
Fix user input loop
1 parent 52e37a5 commit 71daf1f

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

ml_engine/online_prediction/predict.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# License for the specific language governing permissions and limitations under
1010
# the License.
1111
"""Examples of using the Cloud ML Engine's online prediction service."""
12+
from __future__ import print_function
1213
# [START import_libraries]
1314
import googleapiclient.discovery
1415
# [END import_libraries]
@@ -113,15 +114,28 @@ def census_to_example_bytes(json_instance):
113114
def main(project, model, version=None, force_tfrecord=False):
114115
import json
115116
while True:
116-
user_input = json.loads(raw_input())
117-
if force_tfrecord:
118-
example_bytes = census_to_example_bytes(user_input)
119-
result = predict_tf_records(
120-
project, model, [example_bytes], version=version)
117+
try:
118+
user_input = json.loads(raw_input("Valid JSON >>>"))
119+
except KeyboardInterrupt:
120+
return
121+
122+
if not isinstance(user_input, list):
123+
user_input = [user_input]
124+
try:
125+
if force_tfrecord:
126+
example_bytes_list = [
127+
census_to_example_bytes(e)
128+
for e in user_input
129+
]
130+
result = predict_tf_records(
131+
project, model, example_bytes_list, version=version)
132+
else:
133+
result = predict_json(
134+
project, model, user_input, version=version)
135+
except RuntimeError as err:
136+
print(str(err))
121137
else:
122-
result = predict_json(
123-
project, model, [user_input], version=version)
124-
print(result)
138+
print(result)
125139

126140

127141
if __name__ == '__main__':

ml_engine/online_prediction/resources/test.json

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)
0