|
| 1 | +# Copyright 2016 Google Inc. All Rights Reserved. Licensed under the Apache |
| 2 | +# License, Version 2.0 (the "License"); you may not use this file except in |
| 3 | +# compliance with the License. You may obtain a copy of the License at |
| 4 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 5 | + |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 8 | +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 9 | +# License for the specific language governing permissions and limitations under |
| 10 | +# the License. |
| 11 | +"""Tests for predict.py .""" |
| 12 | +import base64 |
| 13 | +import pytest |
| 14 | +from predict import predict_json, predict_tf_records, census_to_example_bytes |
| 15 | + |
| 16 | + |
| 17 | +MODEL = 'census' |
| 18 | +VERSION = 'v1' |
| 19 | +PROJECT = 'python-docs-samples-test' |
| 20 | +JSON = {'age': 25, 'workclass': ' Private', 'education': ' 11th', 'education_num': 7, 'marital_status': ' Never-married', 'occupation': ' Machine-op-inspct', 'relationship': ' Own-child', 'race': ' Black', 'gender': ' Male', 'capital_gain': 0, 'capital_loss': 0, 'hours_per_week': 40, 'native_country': ' United-States'} |
| 21 | +EXAMPLE_BYTE_STRING = 'CuoCChoKDmhvdXJzX3Blcl93ZWVrEggSBgoEAAAgQgoZCgl3b3JrY2xhc3MSDAoKCgggUHJpdmF0ZQoeCgxyZWxhdGlvbnNoaXASDgoMCgogT3duLWNoaWxkChMKBmdlbmRlchIJCgcKBSBNYWxlCg8KA2FnZRIIEgYKBAAAyEEKJAoObWFyaXRhbF9zdGF0dXMSEgoQCg4gTmV2ZXItbWFycmllZAoSCgRyYWNlEgoKCAoGIEJsYWNrChkKDWVkdWNhdGlvbl9udW0SCBIGCgQAAOBACiQKDm5hdGl2ZV9jb3VudHJ5EhIKEAoOIFVuaXRlZC1TdGF0ZXMKGAoMY2FwaXRhbF9sb3NzEggSBgoEAAAAAAoWCgllZHVjYXRpb24SCQoHCgUgMTF0aAoYCgxjYXBpdGFsX2dhaW4SCBIGCgQAAAAACiQKCm9jY3VwYXRpb24SFgoUChIgTWFjaGluZS1vcC1pbnNwY3Q=' |
| 22 | + |
| 23 | +EXPECTED_OUTPUT = {u'probabilities': [0.9942260384559631, 0.005774002522230148], u'logits': [-5.148599147796631], u'classes': 0, u'logistic': [0.005774001590907574]} |
| 24 | + |
| 25 | + |
| 26 | +def test_predict_json(): |
| 27 | + result = predict_json(PROJECT, MODEL, [JSON, JSON], version=VERSION) |
| 28 | + assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result |
| 29 | + |
| 30 | +def test_predict_json_error(): |
| 31 | + with pytest.raises(RuntimeError): |
| 32 | + predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION) |
| 33 | + |
| 34 | +def test_census_example_to_bytes(): |
| 35 | + b = census_to_example_bytes(JSON) |
| 36 | + assert EXAMPLE_BYTE_STRING == base64.b64encode(b) |
0 commit comments