-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Add custom prediction routine samples for AI Platform #2121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tswast
merged 2 commits into
GoogleCloudPlatform:master
from
alecglassford:ai-platform-custom-code
Apr 25, 2019
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Custom prediction routines (beta) | ||
|
||
Read the AI Platform documentation about custom prediction routines to learn how | ||
to use these samples: | ||
|
||
* [Custom prediction routines (with a TensorFlow Keras | ||
example)](https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routines) | ||
* [Custom prediction routines (with a scikit-learn | ||
example)](https://cloud.google.com/ml-engine/docs/scikit/custom-prediction-routines) | ||
|
||
If you want to package a predictor directly from this directory, make sure to | ||
edit `setup.py`: replace the reference to `predictor.py` with either | ||
`tensorflow-predictor.py` or `scikit-predictor.py`. | ||
|
||
## What's next | ||
|
||
For a more complete example of how to train and deploy a custom prediction | ||
routine, check out one of the following tutorials: | ||
|
||
* [Creating a custom prediction routine with | ||
Keras](https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routine-keras) | ||
(also available as [a Jupyter | ||
notebook](https://colab.research.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/tensorflow/custom-prediction-routine-keras.ipynb)) | ||
|
||
* [Creating a custom prediction routine with | ||
scikit-learn](https://cloud.google.com/ml-engine/docs/scikit/custom-prediction-routine-scikit-learn) | ||
(also available as [a Jupyter | ||
notebook](https://colab.research.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/scikit-learn/custom-prediction-routine-scikit-learn.ipynb)) |
50 changes: 50 additions & 0 deletions
50
ml_engine/custom-prediction-routines/predictor-interface.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright 2019 Google LLC | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
class Predictor(object): | ||
"""Interface for constructing custom predictors.""" | ||
|
||
def predict(self, instances, **kwargs): | ||
"""Performs custom prediction. | ||
|
||
Instances are the decoded values from the request. They have already | ||
been deserialized from JSON. | ||
|
||
Args: | ||
instances: A list of prediction input instances. | ||
**kwargs: A dictionary of keyword args provided as additional | ||
fields on the predict request body. | ||
|
||
Returns: | ||
A list of outputs containing the prediction results. This list must | ||
be JSON serializable. | ||
""" | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def from_path(cls, model_dir): | ||
"""Creates an instance of Predictor using the given path. | ||
|
||
Loading of the predictor should be done in this method. | ||
|
||
Args: | ||
model_dir: The local directory that contains the exported model | ||
file along with any additional files uploaded when creating the | ||
version resource. | ||
|
||
Returns: | ||
An instance implementing this Predictor class. | ||
""" | ||
raise NotImplementedError() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2019 Google LLC | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
|
||
|
||
class ZeroCenterer(object): | ||
"""Stores means of each column of a matrix and uses them for preprocessing. | ||
""" | ||
|
||
def __init__(self): | ||
"""On initialization, is not tied to any distribution.""" | ||
self._means = None | ||
|
||
def preprocess(self, data): | ||
"""Transforms a matrix. | ||
|
||
The first time this is called, it stores the means of each column of | ||
the input. Then it transforms the input so each column has mean 0. For | ||
subsequent calls, it subtracts the stored means from each column. This | ||
lets you 'center' data at prediction time based on the distribution of | ||
the original training data. | ||
|
||
Args: | ||
data: A NumPy matrix of numerical data. | ||
|
||
Returns: | ||
A transformed matrix with the same dimensions as the input. | ||
""" | ||
if self._means is None: # during training only | ||
self._means = np.mean(data, axis=0) | ||
return data - self._means |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright 2019 Google LLC | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import pickle | ||
|
||
import numpy as np | ||
from sklearn.externals import joblib | ||
|
||
|
||
class MyPredictor(object): | ||
"""An example Predictor for an AI Platform custom prediction routine.""" | ||
|
||
def __init__(self, model, preprocessor): | ||
"""Stores artifacts for prediction. Only initialized via `from_path`. | ||
""" | ||
self._model = model | ||
self._preprocessor = preprocessor | ||
|
||
de 8000 f predict(self, instances, **kwargs): | ||
"""Performs custom prediction. | ||
|
||
Preprocesses inputs, then performs prediction using the trained | ||
scikit-learn model. | ||
|
||
Args: | ||
instances: A list of prediction input instances. | ||
**kwargs: A dictionary of keyword args provided as additional | ||
fields on the predict request body. | ||
|
||
Returns: | ||
A list of outputs containing the prediction results. | ||
""" | ||
inputs = np.asarray(instances) | ||
preprocessed_inputs = self._preprocessor.preprocess(inputs) | ||
outputs = self._model.predict(preprocessed_inputs) | ||
return outputs.tolist() | ||
|
||
@classmethod | ||
def from_path(cls, model_dir): | ||
"""Creates an instance of MyPredictor using the given path. | ||
|
||
This loads artifacts that have been copied from your model directory in | ||
Cloud Storage. MyPredictor uses them during prediction. | ||
|
||
Args: | ||
model_dir: The local directory that contains the trained | ||
scikit-learn model and the pickled preprocessor instance. These | ||
are copied from the Cloud Storage model directory you provide | ||
when you deploy a version resource. | ||
|
||
Returns: | ||
An instance of `MyPredictor`. | ||
""" | ||
model_path = os.path.join(model_dir, 'model.joblib') | ||
model = joblib.load(model_path) | ||
|
||
preprocessor_path = os.path.join(model_dir, 'preprocessor.pkl') | ||
with open(preprocessor_path, 'rb') as f: | ||
preprocessor = pickle.load(f) | ||
|
||
return cls(model, preprocessor) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright 2019 Google LLC | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from setuptools import setup | ||
|
||
setup( | ||
name='my_custom_code', | ||
version='0.1', | ||
scripts=['predictor.py', 'preprocess.py']) |
73 changes: 73 additions & 0 deletions
73
ml_engine/custom-prediction-routines/tensorflow-predictor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright 2019 Google LLC | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import pickle | ||
|
||
import numpy as np | ||
from tensorflow import keras | ||
|
||
|
||
class MyPredictor(object): | ||
"""An example Predictor for an AI Platform custom prediction routine.""" | ||
|
||
def __init__(self, model, preprocessor): | ||
"""Stores artifacts for prediction. Only initialized via `from_path`. | ||
""" | ||
self._model = model | ||
self._preprocessor = preprocessor | ||
|
||
def predict(self, instances, **kwargs): | ||
"""Performs custom prediction. | ||
|
||
Preprocesses inputs, then performs prediction using the trained Keras | ||
model. | ||
|
||
Args: | ||
instances: A list of prediction input instances. | ||
**kwargs: A dictionary of keyword args provided as additional | ||
fields on the predict request body. | ||
|
||
Returns: | ||
A list of outputs containing the prediction results. | ||
""" | ||
inputs = np.asarray(instances) | ||
preprocessed_inputs = self._preprocessor.preprocess(inputs) | ||
outputs = self._model.predict(preprocessed_inputs) | ||
return outputs.tolist() | ||
|
||
@classmethod | ||
def from_path(cls, model_dir): | ||
"""Creates an instance of MyPredictor using the given path. | ||
|
||
This loads artifacts that have been copied from your model directory in | ||
Cloud Storage. MyPredictor uses them during prediction. | ||
|
||
Args: | ||
model_dir: The local directory that contains the trained Keras | ||
model and the pickled preprocessor instance. These are copied | ||
from the Cloud Storage model directory you provide when you | ||
deploy a version resource. | ||
|
||
Returns: | ||
An instance of `MyPredictor`. | ||
""" | ||
model_path = os.path.join(model_dir, 'model.h5') | ||
model = keras.models.load_model(model_path) | ||
|
||
preprocessor_path = os.path.join(model_dir, 'preprocessor.pkl') | ||
with open(preprocessor_path, 'rb') as f: | ||
preprocessor = pickle.load(f) | ||
|
||
return cls(model, preprocessor) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must
2CFB
change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Notebook not found." Is this in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! GoogleCloudPlatform/cloudml-samples#406