diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 06429b5a1..4f3e63c14 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -52,6 +52,9 @@ _TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$') _GCS_TFLITE_URI_PATTERN = re.compile( r'^gs://(?P[a-z0-9_.-]{3,63})/(?P.+)$') +_AUTO_ML_MODEL_PATTERN = re.compile( + r'^projects/(?P[a-z0-9-]{6,30})/locations/(?P[^/]+)/' + + r'models/(?P[A-Za-z0-9]+)$') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[a-z0-9-]{6,30})/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( @@ -362,15 +365,10 @@ def __init__(self, model_source=None): def from_dict(cls, data): """Create an instance of the object from a dict.""" data_copy = dict(data) - model_source = None - gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None) - if gcs_tflite_uri: - model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) - tflite_format = TFLiteFormat(model_source=model_source) + tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy)) tflite_format._data = data_copy # pylint: disable=protected-access return tflite_format - def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -380,6 +378,16 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @staticmethod + def _init_model_source(data): + gcs_tflite_uri = data.pop('gcsTfliteUri', None) + if gcs_tflite_uri: + return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri) + auto_ml_model = data.pop('automlModel', None) + if auto_ml_model: + return TFLiteAutoMlSource(auto_ml_model=auto_ml_model) + return None + @property def model_source(self): """The TF Lite model's location.""" @@ -592,6 +600,36 @@ def as_dict(self, for_upload=False): return {'gcsTfliteUri': self._gcs_tflite_uri} +class TFLiteAutoMlSource(TFLiteModelSource): + """TFLite model source representing a tflite model created via AutoML.""" + + def __init__(self, auto_ml_model, app=None): + self._app = app + self.auto_ml_model = auto_ml_model + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.auto_ml_model == other.auto_ml_model + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def auto_ml_model(self): + """Resource name of the model created by the AutoML API.""" + return self._auto_ml_model + + @auto_ml_model.setter + def auto_ml_model(self, auto_ml_model): + self._auto_ml_model = _validate_auto_ml_model(auto_ml_model) + + def as_dict(self, for_upload=False): + """Returns a serializable representation of the object.""" + # Upload is irrelevant for auto_ml models + return {'automlModel': self._auto_ml_model} + + class ListModelsPage: """Represents a page of models in a firebase project. @@ -739,6 +777,11 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri +def _validate_auto_ml_model(model): + if not _AUTO_ML_MODEL_PATTERN.match(model): + raise ValueError('Model resource name format is invalid.') + return model + def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): diff --git a/tests/test_ml.py b/tests/test_ml.py index 8813792e6..e8c46c89d 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -120,6 +120,18 @@ } TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263' +AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) +TFLITE_FORMAT_JSON_3 = { + 'automlModel': AUTOML_MODEL_NAME, + 'sizeBytes': '3456789' +} +TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3) + +AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222' +AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2} +AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2) + CREATED_UPDATED_MODEL_JSON_1 = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -403,7 +415,15 @@ def test_model_keyword_based_creation_and_setters(self): 'tfliteModel': TFLITE_FORMAT_JSON_2 } - def test_model_format_source_creation(self): + model.model_format = TFLITE_FORMAT_3 + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_2, + 'tags': TAGS_2, + 'tfliteModel': TFLITE_FORMAT_JSON_3 + } + + + def test_gcs_tflite_model_format_source_creation(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) @@ -414,6 +434,17 @@ def test_model_format_source_creation(self): } } + def test_auto_ml_tflite_model_format_source_creation(self): + model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME) + model_format = ml.TFLiteFormat(model_source=model_source) + model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format) + assert model.as_dict() == { + 'displayName': DISPLAY_NAME_1, + 'tfliteModel': { + 'automlModel': AUTOML_MODEL_NAME + } + } + def test_source_creation_from_tflite_file(self): model_source = ml.TFLiteGCSModelSource.from_tflite_model_file( "my_model.tflite", "my_bucket") @@ -421,12 +452,19 @@ def test_source_creation_from_tflite_file(self): 'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite' } - def test_model_source_setters(self): + def test_gcs_tflite_model_source_setters(self): model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI) model_source.gcs_tflite_uri = GCS_TFLITE_URI_2 assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2 assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2 + def test_auto_ml_tflite_model_source_setters(self): + model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME) + model_source.auto_ml_model = AUTOML_MODEL_NAME_2 + assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2 + assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2 + + def test_model_format_setters(self): model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE) model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2 @@ -437,6 +475,14 @@ def test_model_format_setters(self): } } + model_format.model_source = AUTOML_MODEL_SOURCE + assert model_format.model_source == AUTOML_MODEL_SOURCE + assert model_format.as_dict() == { + 'tfliteModel': { + 'automlModel': AUTOML_MODEL_NAME + } + } + def test_model_as_dict_for_upload(self): model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI) model_format = ml.TFLiteFormat(model_source=model_source) @@ -522,6 +568,23 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type): ml.TFLiteGCSModelSource(gcs_tflite_uri=uri) check_error(excinfo, exc_type) + @pytest.mark.parametrize('auto_ml_model, exc_type', [ + (123, TypeError), + ('abc', ValueError), + ('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError), + ('projects/123546/models/ICN123456', ValueError), + ('projects//locations/us-central1/models/ICN123456', ValueError), + ('projects/123456/locations//models/ICN123456', ValueError), + ('projects/123456/locations/us-central1/models/', ValueError), + ('projects/ABC/locations/us-central1/models/ICN123456', ValueError), + ('projects/123456/locations/us-central1/models/@#$%^&', ValueError), + ('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError), + ]) + def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type): + with pytest.raises(exc_type) as excinfo: + ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model) + check_error(excinfo, exc_type) + def test_wait_for_unlocked_not_locked(self): model = ml.Model(display_name="not_locked") model.wait_for_unlocked()