diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index b9b56c8f4..91cedbedc 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -87,6 +87,34 @@ def update_model(model, app=None): return Model.from_dict(mlkit_service.update_model(model), app=app) +def publish_model(model_id, app=None): + """Publishes a model in Firebase ML Kit. + + Args: + model_id: The id of the model to publish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The published model. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.set_published(model_id, publish=True), app=app) + + +def unpublish_model(model_id, app=None): + """Unpublishes a model in Firebase ML Kit. + + Args: + model_id: The id of the model to unpublish. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The unpublished model. + """ + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.set_published(model_id, publish=False), app=app) + + def get_model(model_id, app=None): """Gets a model from Firebase ML Kit. @@ -562,12 +590,12 @@ class _MLKitService(object): POLL_BASE_WAIT_TIME_SECONDS = 3 def __init__(self, app): - project_id = app.project_id - if not project_id: + self._project_id = app.project_id + if not self._project_id: raise ValueError( 'Project ID is required to access MLKit service. Either set the ' 'projectId option, or use service account credentials.') - self._project_url = _MLKitService.PROJECT_URL.format(project_id) + self._project_url = _MLKitService.PROJECT_URL.format(self._project_id) self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=self._project_url) @@ -595,7 +623,6 @@ def _exponential_backoff(self, current_attempt, stop_time): wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) - def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): """Handles long running operations. @@ -659,6 +686,17 @@ def update_model(self, model, update_mask=None): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + def set_published(self, model_id, publish): + _validate_model_id(model_id) + model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id) + model = Model.from_dict({ + 'name': model_name, + 'state': { + 'published': publish + } + }) + return self.update_model(model, update_mask='state.published') + def get_model(self, model_id): _validate_model_id(model_id) try: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index e93bbd7e9..50fed4e1b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -657,7 +657,7 @@ def test_operation_error(self): instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as excinfo: mlkit.update_model(MODEL_1) - # The http request succeeded, the operation returned contains a create failure + # The http request succeeded, the operation returned contains an update failure check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): @@ -673,7 +673,7 @@ def test_malformed_operation(self): assert recorder[1].method == 'GET' assert recorder[1].url == TestUpdateModel._url(PROJECT_ID, MODEL_ID_1) - def test_rpc_error_create(self): + def test_rpc_error(self): create_recorder = instrument_mlkit_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as excinfo: @@ -712,6 +712,97 @@ def test_invalid_op_name(self, op_name): check_error(excinfo, ValueError, 'Operation name format is invalid.') +class TestPublishUnpublish(object): + """Tests mlkit.publish_model and mlkit.unpublish_model.""" + + PUBLISH_UNPUBLISH_WITH_ARGS = [ + (mlkit.publish_model, True), + (mlkit.unpublish_model, False) + ] + PUBLISH_UNPUBLISH_FUNCS = [item[0] for item in PUBLISH_UNPUBLISH_WITH_ARGS] + + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + + @pytest.mark.parametrize('publish_function, published', PUBLISH_UNPUBLISH_WITH_ARGS) + def test_immediate_done(self, publish_function, published): + recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = publish_function(MODEL_ID_1) + assert model == CREATED_UPDATED_MODEL_1 + assert len(recorder) == 1 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + body = json.loads(recorder[0].body.decode()) + assert body.get('model', {}).get('state', {}).get('published', None) is published + assert body.get('updateMask', {}) == 'state.published' + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_returns_locked(self, publish_function): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = publish_function(MODEL_ID_1) + + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_operation_error(self, publish_function): + instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + # The http request succeeded, the operation returned contains an update failure + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_malformed_operation(self, publish_function): + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = publish_function(MODEL_ID_1) + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'PATCH' + assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + + @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) + def test_rpc_error(self, publish_function): + create_recorder = instrument_mlkit_service( + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as excinfo: + publish_function(MODEL_ID_1) + check_firebase_error( + excinfo, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + class TestGetModel(object): """Tests mlkit.get_model.""" @classmethod