|
9 | 9 | # License for the specific language governing permissions and limitations under
|
10 | 10 | # the License.
|
11 | 11 | """Examples of using the Cloud ML Engine's online prediction service."""
|
| 12 | +from __future__ import print_function |
12 | 13 | # [START import_libraries]
|
13 | 14 | import googleapiclient.discovery
|
14 | 15 | # [END import_libraries]
|
@@ -113,15 +114,28 @@ def census_to_example_bytes(json_instance):
|
113 | 114 | def main(project, model, version=None, force_tfrecord=False):
|
114 | 115 | import json
|
115 | 116 | while True:
|
116 |
| - user_input = json.loads(raw_input()) |
117 |
| - if force_tfrecord: |
118 |
| - example_bytes = census_to_example_bytes(user_input) |
119 |
| - result = predict_tf_records( |
120 |
| - project, model, [example_bytes], version=version) |
| 117 | + try: |
| 118 | + user_input = json.loads(raw_input("Valid JSON >>>")) |
| 119 | + except KeyboardInterrupt: |
| 120 | + return |
| 121 | + |
| 122 | + if not isinstance(user_input, list): |
| 123 | + user_input = [user_input] |
| 124 | + try: |
| 125 | + if force_tfrecord: |
| 126 | + example_bytes_list = [ |
| 127 | + census_to_example_bytes(e) |
| 128 | + for e in user_input |
| 129 | + ] |
| 130 | + result = predict_tf_records( |
| 131 | + project, model, example_bytes_list, version=version) |
| 132 | + else: |
| 133 | + result = predict_json( |
| 134 | + project, model, user_input, version=version) |
| 135 | + except RuntimeError as err: |
| 136 | + print(str(err)) |
121 | 137 | else:
|
122 |
| - result = predict_json( |
123 |
| - project, model, [user_input], version=version) |
124 |
| - print(result) |
| 138 | + print(result) |
125 | 139 |
|
126 | 140 |
|
127 | 141 | if __name__ == '__main__':
|
|
0 commit comments