8000 added support for automl-models (#428) · ClarkDing/firebase-admin-python@c4275be · GitHub
[go: up one dir, main page]

Skip to content

Commit c4275be

Browse files
authored
added support for automl-models (firebase#428)
* added support for automl-models
1 parent e49add8 co
8000
mmit c4275be

File tree

2 files changed

+114
-8
lines changed

2 files changed

+114
-8
lines changed

firebase_admin/ml.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
5353
_GCS_TFLITE_URI_PATTERN = re.compile(
5454
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
55+
_AUTO_ML_MODEL_PATTERN = re.compile(
56+
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/locations/(?P<location_id>[^/]+)/' +
57+
r'models/(?P<model_id>[A-Za-z0-9]+)$')
5558
_RESOURCE_NAME_PATTERN = re.compile(
5659
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
5760
_OPERATION_NAME_PATTERN = re.compile(
@@ -362,15 +365,10 @@ def __init__(self, model_source=None):
362365
def from_dict(cls, data):
363366
"""Create an instance of the object from a dict."""
364367
data_copy = dict(data)
365-
model_source = None
366-
gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None)
367-
if gcs_tflite_uri:
368-
model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
369-
tflite_format = TFLiteFormat(model_source=model_source)
368+
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
370369
tflite_format._data = data_copy # pylint: disable=protected-access
371370
return tflite_format
372371

373-
374372
def __eq__(self, other):
375373
if isinstance(other, self.__class__):
376374
# pylint: disable=protected-access
@@ -380,6 +378,16 @@ def __eq__(self, other):
380378
def __ne__(self, other):
381379
return not self.__eq__(other)
382380

381+
@staticmethod
382+
def _init_model_source(data):
383+
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
384+
if gcs_tflite_uri:
385+
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
386+
auto_ml_model = data.pop('automlModel', None)
387+
if auto_ml_model:
388+
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
389+
return None
390+
383391
@property
384392
def model_source(self):
385393
"""The TF Lite model's location."""
@@ -592,6 +600,36 @@ def as_dict(self, for_upload=False):
592600
return {'gcsTfliteUri': self._gcs_tflite_uri}
593601

594602

603+
class TFLiteAutoMlSource(TFLiteModelSource):
604+
"""TFLite model source representing a tflite model created via AutoML."""
605+
606+
def __init__(self, auto_ml_model, app=None):
607+
self._app = app
608+
self.auto_ml_model = auto_ml_model
609+
610+
def __eq__(self, other):
611+
if isinstance(other, self.__class__):
612+
return self.auto_ml_model == other.auto_ml_model
613+
return False
614+
615+
def __ne__(self, other):
616+
return not self.__eq__(other)
617+
618+
@property
619+
def auto_ml_model(self):
620+
"""Resource name of the model created by the AutoML API."""
621+
return self._auto_ml_model
622+
623+
@auto_ml_model.setter
624+
def auto_ml_model(self, auto_ml_model):
625+
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)
626+
627+
def as_dict(self, for_upload=False):
628+
"""Returns a serializable representation of the object."""
629+
# Upload is irrelevant for auto_ml models
630+
return {'automlModel': self._auto_ml_model}
631+
632+
595633
class ListModelsPage:
596634
"""Represents a page of models in a firebase project.
597635
@@ -739,6 +777,11 @@ def _validate_gcs_tflite_uri(uri):
739777
raise ValueError('GCS TFLite URI format is invalid.')
740778
return uri
741779

780+
def _validate_auto_ml_model(model):
781+
if not _AUTO_ML_MODEL_PATTERN.match(model):
782+
raise ValueError('Model resource name format is invalid.')
783+
return model
784+
742785

743786
def _validate_model_format(model_format):
744787
if not isinstance(model_format, ModelFormat):

tests/test_ml.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,18 @@
120120
}
121121
TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2)
122122

123+
AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263'
124+
AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
125+
TFLITE_FORMAT_JSON_3 = {
126+
'automlModel': AUTOML_MODEL_NAME,
127+
'sizeBytes': '3456789'
128+
}
129+
TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3)
130+
131+
AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222'
132+
AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2}
133+
AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2)
134+
123135
CREATED_UPDATED_MODEL_JSON_1 = {
124136
'name': MODEL_NAME_1,
125137
'displayName': DISPLAY_NAME_1,
@@ -403,7 +415,15 @@ def test_model_keyword_based_creation_and_setters(self):
403415
'tfliteModel': TFLITE_FORMAT_JSON_2
404416
}
405417

406-
def test_model_format_source_creation(self):
418+
model.model_format = TFLITE_FORMAT_3
419+
assert model.as_dict() == {
420+
'displayName': DISPLAY_NAME_2,
421+
'tags': TAGS_2,
422+
'tfliteModel': TFLITE_FORMAT_JSON_3
423+
}
424+
425+
426+
def test_gcs_tflite_model_format_source_creation(self):
407427
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
408428
model_format = ml.TFLiteFormat(model_source=model_source)
409429
model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
@@ -414,19 +434,37 @@ def test_model_format_source_creation(self):
414434
}
415435
}
416436

437+
def test_auto_ml_tflite_model_format_source_creation(self):
438+
model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME)
439+
model_format = ml.TFLiteFormat(model_source=model_source)
440+
model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
441+
assert model.as_dict() == {
442+
'displayName': DISPLAY_NAME_1,
443+
'tfliteModel': {
444+
'automlModel': AUTOML_MODEL_NAME
445+
}
446+
}
447+
417448
def test_source_creation_from_tflite_file(self):
418449
model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(
419450
"my_model.tflite", "my_bucket")
420451
assert model_source.as_dict() == {
421452
'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite'
422453
}
423454

424-
def test_model_source_setters(self):
455+
def test_gcs_tflite_model_source_setters(self):
425456
model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI)
426457
model_source.gcs_tflite_uri = GCS_TFLITE_URI_2
427458
assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2
428459
assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2
429460

461+
def test_auto_ml_tflite_model_source_setters(self):
462+
model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
463+
model_source.auto_ml_model = AUTOML_MODEL_NAME_2
464+
assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2
465+
assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2
466+
467+
430468
def test_model_format_setters(self):
431469
model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE)
432470
model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2
@@ -437,6 +475,14 @@ def test_model_format_setters(self):
437475
}
438476
}
439477

478+
model_format.model_source = AUTOML_MODEL_SOURCE
479+
assert model_format.model_source == AUTOML_MODEL_SOURCE
480+
assert model_format.as_dict() == {
481+
'tfliteModel': {
482+
'automlModel': AUTOML_MODEL_NAME
483+
}
484+
}
485+
440486
def test_model_as_dict_for_upload(self):
441487
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
442488
model_format = ml.TFLiteFormat(model_source=model_source)
@@ -522,6 +568,23 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type):
522568
ml.TFLiteGCSModelSource(gcs_tflite_uri=uri)
523569
check_error(excinfo, exc_type)
524570

571+
@pytest.mark.parametrize('auto_ml_model, exc_type', [
572+
(123, TypeError),
573+
('abc', ValueError),
574+
('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError),
575+
('projects/123546/models/ICN123456', ValueError),
576+
('projects//locations/us-central1/models/ICN123456', ValueError),
577+
('projects/123456/locations//models/ICN123456', ValueError),
578+
('projects/123456/locations/us-central1/models/', ValueError),
579+
('projects/ABC/locations/us-central1/models/ICN123456', ValueError),
580+
('projects/123456/locations/us-central1/models/@#$%^&', ValueError),
581+
('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError),
582+
])
583+
def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type):
584+
with pytest.raises(exc_type) as excinfo:
585+
ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
586+
check_error(excinfo, exc_type)
587+
525588
def test_wait_for_unlocked_not_locked(self):
526589
model = ml.Model(display_name="not_locked")
527590
model.wait_for_unlocked()

0 commit comments

Comments
 (0)
0