Skip to content

added support for automl-models #428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions firebase_admin/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_GCS_TFLITE_URI_PATTERN = re.compile(
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
_AUTO_ML_MODEL_PATTERN = re.compile(
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/locations/(?P<location_id>[^/]+)/' +
r'models/(?P<model_id>[A-Za-z0-9]+)$')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[a-z0-9-]{6,30})/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
Expand Down Expand Up @@ -362,15 +365,10 @@ def __init__(self, model_source=None):
def from_dict(cls, data):
"""Create an instance of the object from a dict."""
data_copy = dict(data)
model_source = None
gcs_tflite_uri = data_copy.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
model_source = TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
tflite_format = TFLiteFormat(model_source=model_source)
tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy))
tflite_format._data = data_copy # pylint: disable=protected-access
return tflite_format


def __eq__(self, other):
if isinstance(other, self.__class__):
# pylint: disable=protected-access
Expand All @@ -380,6 +378,16 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@staticmethod
def _init_model_source(data):
gcs_tflite_uri = data.pop('gcsTfliteUri', None)
if gcs_tflite_uri:
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_tflite_uri)
auto_ml_model = data.pop('automlModel', None)
if auto_ml_model:
return TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
return None

@property
def model_source(self):
"""The TF Lite model's location."""
Expand Down Expand Up @@ -592,6 +600,36 @@ def as_dict(self, for_upload=False):
return {'gcsTfliteUri': self._gcs_tflite_uri}


class TFLiteAutoMlSource(TFLiteModelSource):
"""TFLite model source representing a tflite model created via AutoML."""

def __init__(self, auto_ml_model, app=None):
self._app = app
self.auto_ml_model = auto_ml_model

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.auto_ml_model == other.auto_ml_model
return False

def __ne__(self, other):
return not self.__eq__(other)

@property
def auto_ml_model(self):
"""Resource name of the model created by the AutoML API."""
return self._auto_ml_model

@auto_ml_model.setter
def auto_ml_model(self, auto_ml_model):
self._auto_ml_model = _validate_auto_ml_model(auto_ml_model)

def as_dict(self, for_upload=False):
"""Returns a serializable representation of the object."""
# Upload is irrelevant for auto_ml models
return {'automlModel': self._auto_ml_model}


class ListModelsPage:
"""Represents a page of models in a firebase project.

Expand Down Expand Up @@ -739,6 +777,11 @@ def _validate_gcs_tflite_uri(uri):
raise ValueError('GCS TFLite URI format is invalid.')
return uri

def _validate_auto_ml_model(model):
if not _AUTO_ML_MODEL_PATTERN.match(model):
raise ValueError('Model resource name format is invalid.')
return model


def _validate_model_format(model_format):
if not isinstance(model_format, ModelFormat):
Expand Down
67 changes: 65 additions & 2 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@
}
TFLITE_FORMAT_2 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2)

AUTOML_MODEL_NAME = 'projects/111111111111/locations/us-central1/models/ICN7683346839371803263'
AUTOML_MODEL_SOURCE = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
TFLITE_FORMAT_JSON_3 = {
'automlModel': AUTOML_MODEL_NAME,
'sizeBytes': '3456789'
}
TFLITE_FORMAT_3 = ml.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_3)

AUTOML_MODEL_NAME_2 = 'projects/2222222222/locations/us-central1/models/ICN2222222222222222222'
AUTOML_MODEL_NAME_JSON_2 = {'automlModel': AUTOML_MODEL_NAME_2}
AUTOML_MODEL_SOURCE_2 = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME_2)

CREATED_UPDATED_MODEL_JSON_1 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1,
Expand Down Expand Up @@ -403,7 +415,15 @@ def test_model_keyword_based_creation_and_setters(self):
'tfliteModel': TFLITE_FORMAT_JSON_2
}

def test_model_format_source_creation(self):
model.model_format = TFLITE_FORMAT_3
assert model.as_dict() == {
'displayName': DISPLAY_NAME_2,
'tags': TAGS_2,
'tfliteModel': TFLITE_FORMAT_JSON_3
}


def test_gcs_tflite_model_format_source_creation(self):
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
model_format = ml.TFLiteFormat(model_source=model_source)
model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
Expand All @@ -414,19 +434,37 @@ def test_model_format_source_creation(self):
}
}

def test_auto_ml_tflite_model_format_source_creation(self):
model_source = ml.TFLiteAutoMlSource(auto_ml_model=AUTOML_MODEL_NAME)
model_format = ml.TFLiteFormat(model_source=model_source)
model = ml.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
assert model.as_dict() == {
'displayName': DISPLAY_NAME_1,
'tfliteModel': {
'automlModel': AUTOML_MODEL_NAME
}
}

def test_source_creation_from_tflite_file(self):
model_source = ml.TFLiteGCSModelSource.from_tflite_model_file(
"my_model.tflite", "my_bucket")
assert model_source.as_dict() == {
'gcsTfliteUri': 'gs://my_bucket/Firebase/ML/Models/my_model.tflite'
}

def test_model_source_setters(self):
def test_gcs_tflite_model_source_setters(self):
model_source = ml.TFLiteGCSModelSource(GCS_TFLITE_URI)
model_source.gcs_tflite_uri = GCS_TFLITE_URI_2
assert model_source.gcs_tflite_uri == GCS_TFLITE_URI_2
assert model_source.as_dict() == GCS_TFLITE_URI_JSON_2

def test_auto_ml_tflite_model_source_setters(self):
model_source = ml.TFLiteAutoMlSource(AUTOML_MODEL_NAME)
model_source.auto_ml_model = AUTOML_MODEL_NAME_2
assert model_source.auto_ml_model == AUTOML_MODEL_NAME_2
assert model_source.as_dict() == AUTOML_MODEL_NAME_JSON_2


def test_model_format_setters(self):
model_format = ml.TFLiteFormat(model_source=GCS_TFLITE_MODEL_SOURCE)
model_format.model_source = GCS_TFLITE_MODEL_SOURCE_2
Expand All @@ -437,6 +475,14 @@ def test_model_format_setters(self):
}
}

model_format.model_source = AUTOML_MODEL_SOURCE
assert model_format.model_source == AUTOML_MODEL_SOURCE
assert model_format.as_dict() == {
'tfliteModel': {
'automlModel': AUTOML_MODEL_NAME
}
}

def test_model_as_dict_for_upload(self):
model_source = ml.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
model_format = ml.TFLiteFormat(model_source=model_source)
Expand Down Expand Up @@ -522,6 +568,23 @@ def test_gcs_tflite_source_validation_errors(self, uri, exc_type):
ml.TFLiteGCSModelSource(gcs_tflite_uri=uri)
check_error(excinfo, exc_type)

@pytest.mark.parametrize('auto_ml_model, exc_type', [
(123, TypeError),
('abc', ValueError),
('/projects/123456/locations/us-central1/models/noLeadingSlash', ValueError),
('projects/123546/models/ICN123456', ValueError),
('projects//locations/us-central1/models/ICN123456', ValueError),
('projects/123456/locations//models/ICN123456', ValueError),
('projects/123456/locations/us-central1/models/', ValueError),
('projects/ABC/locations/us-central1/models/ICN123456', ValueError),
('projects/123456/locations/us-central1/models/@#$%^&', ValueError),
('projects/123456/locations/us-cent/ral1/models/ICN123456', ValueError),
])
def test_auto_ml_tflite_source_validation_errors(self, auto_ml_model, exc_type):
with pytest.raises(exc_type) as excinfo:
ml.TFLiteAutoMlSource(auto_ml_model=auto_ml_model)
check_error(excinfo, exc_type)

def test_wait_for_unlocked_not_locked(self):
model = ml.Model(display_name="not_locked")
model.wait_for_unlocked()
Expand Down