From 7f4c6ed7c2867900f943397639088f27c03eebc6 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 3 Sep 2019 16:06:14 -0400 Subject: [PATCH 01/14] create model plus long running operation handling --- firebase_admin/_utils.py | 21 ++++ firebase_admin/mlkit.py | 61 ++++++++++ tests/test_mlkit.py | 237 ++++++++++++++++++++++++++++++++++----- 3 files changed, 294 insertions(+), 25 deletions(-) diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index 95ed2c414..fb6e32932 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_operation_error(error): + """Constructs a ``FirebaseError`` from the given operation error. + + Args: + error: An error returned by a long running operation. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if not isinstance(error, dict): + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) + + status_code = error.get('code') + message = error.get('message') + error_code = _http_status_to_error_code(status_code) + err_type = _error_code_to_exception_type(error_code) + return err_type(message=message) + + def _handle_func_requests(error, message, error_dict): """Constructs a ``FirebaseError`` from the given GCP error. diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 3f1a825f6..fd704a25c 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -23,6 +23,7 @@ import re import requests import six +import time from firebase_admin import _http_client from firebase_admin import _utils @@ -36,6 +37,8 @@ _GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+') _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') +_OPERATION_NAME_PATTERN = re.compile( + r'^operations/project/[^/]+/model/[A-Za-z0-9_-]{1,60}/operation/[^/]+$') def _get_mlkit_service(app): @@ -53,6 +56,11 @@ def _get_mlkit_service(app): return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService) +def create_model(model, app=None): + mlkit_service = _get_mlkit_service(app) + return Model.from_dict(mlkit_service.create_model(model)) + + def get_model(model_id, app=None): mlkit_service = _get_mlkit_service(app) return Model.from_dict(mlkit_service.get_model(model_id)) @@ -390,11 +398,23 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') +def _validate_model(model): + if not isinstance(model, Model): + raise TypeError('Model must be an mlkit.Model.') + if not model.display_name: + raise ValueError('Model must have a display name.') + + def _validate_model_id(model_id): if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') +def _validate_operation_name(op_name): + if not _OPERATION_NAME_PATTERN.match(op_name): + raise ValueError('Operation name format is invalid.') + + def _validate_display_name(display_name): if not _DISPLAY_NAME_PATTERN.match(display_name): raise ValueError('Display name format is invalid.') @@ -448,6 +468,8 @@ class _MLKitService(object): """Firebase MLKit service.""" PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' + OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' + OPERATION_POLL_DELAY_SECONDS = 30 def __init__(self, app): project_id = app.project_id @@ -459,6 +481,45 @@ def __init__(self, app): self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=self._project_url) + self._operation_client = _http_client.JsonHttpClient( + credential=app.credential.get_credential(), + base_url=_MLKitService.OPERATION_URL) + + def get_operation(self, op_name): + _validate_operation_name(op_name) + try: + return self._operation_client.body('get', url=op_name) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) + + def handle_operation(self, operation): + if not isinstance(operation, dict): + raise TypeError('Operation must be a dictionary.') + op_name = operation.get('name') + _validate_operation_name(op_name) + + while True: + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + else: + # A 'done' operation must have either a response or an error. + raise ValueError('Operation is malformed.') + else: + # We just got this operation wait 30s before getting another + # so we don't exceed the GetOperation maximum request rate. + time.sleep(_MLKitService.OPERATION_POLL_DELAY_SECONDS) + operation = self.get_operation(op_name) + + def create_model(self, model): + _validate_model(model) + try: + return self.handle_operation( + self._client.body('post', url='models', json=model.as_dict())) + except requests.exceptions.RequestException as error: + raise _utils.handle_platform_error_from_requests(error) def get_model(self, model_id): _validate_model_id(model_id) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index c20982a2b..320196c7f 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -113,6 +113,46 @@ } TFLITE_FORMAT_2 = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON_2) +CREATED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'state': MODEL_STATE_ERROR_JSON, + 'etag': ETAG, + 'modelHash': MODEL_HASH, + 'tags': TAGS, +} +CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1) + +OPERATION_DONE_MODEL_JSON_1 = { + 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'done': True, + 'response': CREATED_MODEL_JSON_1 +} + +OPERATION_MALFORMED_JSON_1 = { + 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'done': True, + # if done is true then either response or error should be populated +} + +OPERATION_MISSING_NAME = { + 'done': False +} + +OPERATION_ERROR_CODE = 400 +OPERATION_ERROR_MSG = "Invalid argument" +OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' +OPERATION_ERROR_JSON_1 = { + 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'done': True, + 'error': { + 'code': OPERATION_ERROR_CODE, + 'message': OPERATION_ERROR_MSG, + } +} + FULL_MODEL_ERR_STATE_LRO_JSON = { 'name': MODEL_NAME_1, 'displayName': DISPLAY_NAME_1, @@ -137,6 +177,11 @@ } EMPTY_RESPONSE = json.dumps({}) +OPERATION_NOT_DONE_RESPONSE = json.dumps(OPERATION_NOT_DONE_JSON_1) +OPERATION_DONE_RESPONSE = json.dumps(OPERATION_DONE_MODEL_JSON_1) +OPERATION_ERROR_RESPONSE = json.dumps(OPERATION_ERROR_JSON_1) +OPERATION_MALFORMED_RESPONSE = json.dumps(OPERATION_MALFORMED_JSON_1) +OPERATION_MISSING_NAME_RESPONSE = json.dumps(OPERATION_MISSING_NAME) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) NO_MODELS_LIST_RESPONSE = json.dumps({}) DEFAULT_LIST_RESPONSE = json.dumps({ @@ -185,29 +230,47 @@ invalid_string_or_none_args = [0, -1, 4.2, 0x10, False, list(), dict()] +# For validation type errors def check_error(err, err_type, msg=None): - assert isinstance(err, err_type) + err_value = err.value + assert isinstance(err_value, err_type) if msg: - assert str(err) == msg + assert str(err_value) == msg + + +# For errors that are returned in an operation +def check_operation_error(err, code, msg): + err_value = err.value + assert isinstance(err_value, exceptions.FirebaseError) + assert err_value.code == code + assert str(err_value) == msg +# For rpc errors def check_firebase_error(err, code, status, msg): - assert isinstance(err, exceptions.FirebaseError) - assert err.code == code - assert err.http_response is not None - assert err.http_response.status_code == status - assert str(err) == msg + err_value = err.value + assert isinstance(err_value, exceptions.FirebaseError) + assert err_value.code == code + assert err_value.http_response is not None + assert err_value.http_response.status_code == status + assert str(err_value) == msg -def instrument_mlkit_service(app=None, status=200, payload=None): +def instrument_mlkit_service(app=None, status=200, uri=None, payload=None): if not app: app = firebase_admin.get_app() mlkit_service = mlkit._get_mlkit_service(app) recorder = [] - mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com', - testutils.MockAdapter(payload, status, recorder) - ) + if uri is None or uri is 'projects': + mlkit_service._client.session.mount( + 'https://mlkit.googleapis.com/', + testutils.MockAdapter(payload, status, recorder) + ) + elif uri is 'operations': + mlkit_service._operation_client.session.mount( + 'https://mlkit.googleapis.com/', + testutils.MockAdapter(payload, status, recorder) + ) return recorder @@ -299,7 +362,7 @@ def test_model_format_setters(self): def test_model_display_name_validation_errors(self, display_name, exc_type): with pytest.raises(exc_type) as err: mlkit.Model(display_name=display_name) - check_error(err.value, exc_type) + check_error(err, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', [ ('tag1', TypeError, 'Tags must be a list of strings.'), @@ -313,7 +376,7 @@ def test_model_display_name_validation_errors(self, display_name, exc_type): def test_model_tags_validation_errors(self, tags, exc_type, error_message): with pytest.raises(exc_type) as err: mlkit.Model(tags=tags) - check_error(err.value, exc_type, error_message) + check_error(err, exc_type, error_message) @pytest.mark.parametrize('model_format', [ 123, @@ -325,7 +388,7 @@ def test_model_tags_validation_errors(self, tags, exc_type, error_message): def test_model_format_validation_errors(self, model_format): with pytest.raises(TypeError) as err: mlkit.Model(model_format=model_format) - check_error(err.value, TypeError, 'Model format must be a ModelFormat object.') + check_error(err, TypeError, 'Model format must be a ModelFormat object.') @pytest.mark.parametrize('model_source', [ 123, @@ -337,7 +400,7 @@ def test_model_format_validation_errors(self, model_format): def test_model_source_validation_errors(self, model_source): with pytest.raises(TypeError) as err: mlkit.TFLiteFormat(model_source=model_source) - check_error(err.value, TypeError, 'Model source must be a TFLiteModelSource object.') + check_error(err, TypeError, 'Model source must be a TFLiteModelSource object.') @pytest.mark.parametrize('uri, exc_type', [ (123, TypeError), @@ -353,7 +416,131 @@ def test_model_source_validation_errors(self, model_source): def test_gcs_tflite_source_validation_errors(self, uri, exc_type): with pytest.raises(exc_type) as err: mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) - check_error(err.value, exc_type) + check_error(err, exc_type) + + +class TestCreateModel(object): + """Tests mlkit.create_model.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.OPERATION_POLL_DELAY_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _url(project_id): + return BASE_URL + 'projects/{0}/models'.format(project_id) + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + + def test_create_model_immediate_done(self): + recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + model = mlkit.create_model(MODEL_1) + assert model == CREATED_MODEL_1 + + def test_create_model_with_get_operation(self): + create_recorder = instrument_mlkit_service( + status=200, uri='projects', payload=OPERATION_NOT_DONE_RESPONSE) + operation_recorder = instrument_mlkit_service( + status=200, uri='operations', payload=OPERATION_DONE_RESPONSE) + model = mlkit.create_model(MODEL_1) + assert model == CREATED_MODEL_1 + assert len(create_recorder) == 1 + assert create_recorder[0].method == 'POST' + assert create_recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert len(operation_recorder) == 1 + assert operation_recorder[0].method == 'GET' + assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) + + def test_create_model_operation_error(self): + recorder = instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + with pytest.raises(Exception) as err: + mlkit.create_model(MODEL_1) + # The http request succeeded, the operation returned contains a create failure + check_operation_error(err, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + + def test_create_model_malformed_operation(self): + recorder = instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + with pytest.raises(ValueError) as err: + mlkit.create_model(MODEL_1) + check_error(err, ValueError, 'Operation is malformed.') + + def test_create_model_rpc_error_create(self): + create_recorder = instrument_mlkit_service( + status=400, uri='projects', payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as err: + mlkit.create_model(MODEL_1) + check_firebase_error( + err, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + + def test_create_model_rpc_error_operation(self): + create_recorder = instrument_mlkit_service( + status=200, uri='projects', payload=OPERATION_NOT_DONE_RESPONSE) + operation_recorder = instrument_mlkit_service( + status=400, uri='operations', payload=ERROR_RESPONSE_BAD_REQUEST) + with pytest.raises(Exception) as err: + mlkit.create_model(MODEL_1) + check_firebase_error( + err, + ERROR_STATUS_BAD_REQUEST, + ERROR_CODE_BAD_REQUEST, + ERROR_MSG_BAD_REQUEST + ) + assert len(create_recorder) == 1 + assert len(operation_recorder) == 1 + + @pytest.mark.parametrize('model', [ + 'abc', + 4.2, + list(), + dict(), + True, + -1, + 0, + None + ]) + def test_create_model_not_model(self, model): + with pytest.raises(Exception) as err: + mlkit.create_model(model) + check_error(err, TypeError, 'Model must be an mlkit.Model.') + + def test_create_model_missing_display_name(self): + with pytest.raises(Exception) as err: + mlkit.create_model(mlkit.Model.from_dict({})) + check_error(err, ValueError, 'Model must have a display name.') + + def test_create_model_missing_op_name(self): + recorder = instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + with pytest.raises(Exception) as err: + mlkit.create_model(MODEL_1) + check_error(err, TypeError) + + @pytest.mark.parametrize('op_name', [ + 'abc', + '123', + 'projects/operations/project/1234/model/abc/operation/123', + 'operations/project/model/abc/operation/123', + 'operations/project/123/model/$#@/operation/123', + 'operations/project/1234/model/abc/operation/123/extrathing', + ]) + def test_create_model_invalid_op_name(self, op_name): + payload = json.dumps({'name': op_name}) + recorder = instrument_mlkit_service(status=200, payload=payload) + with pytest.raises(Exception) as err: + mlkit.create_model(MODEL_1) + check_error(err, ValueError, 'Operation name format is invalid.') class TestGetModel(object): @@ -385,14 +572,14 @@ def test_get_model(self): def test_get_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.get_model(model_id) - check_error(err.value, exc_type) + check_error(err, exc_type) def test_get_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as err: mlkit.get_model(MODEL_ID_1) check_firebase_error( - err.value, + err, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -435,14 +622,14 @@ def test_delete_model(self): def test_delete_model_validation_errors(self, model_id, exc_type): with pytest.raises(exc_type) as err: mlkit.delete_model(model_id) - check_error(err.value, exc_type) + check_error(err, exc_type) def test_delete_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) with pytest.raises(exceptions.NotFoundError) as err: mlkit.delete_model(MODEL_ID_1) check_firebase_error( - err.value, + err, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -516,7 +703,7 @@ def test_list_models_with_all_args(self): def test_list_models_list_filter_validation(self, list_filter): with pytest.raises(TypeError) as err: mlkit.list_models(list_filter=list_filter) - check_error(err.value, TypeError, 'List filter must be a string or None.') + check_error(err, TypeError, 'List filter must be a string or None.') @pytest.mark.parametrize('page_size, exc_type, error_message', [ ('abc', TypeError, 'Page size must be a number or None.'), @@ -531,20 +718,20 @@ def test_list_models_list_filter_validation(self, list_filter): def test_list_models_page_size_validation(self, page_size, exc_type, error_message): with pytest.raises(exc_type) as err: mlkit.list_models(page_size=page_size) - check_error(err.value, exc_type, error_message) + check_error(err, exc_type, error_message) @pytest.mark.parametrize('page_token', invalid_string_or_none_args) def test_list_models_page_token_validation(self, page_token): with pytest.raises(TypeError) as err: mlkit.list_models(page_token=page_token) - check_error(err.value, TypeError, 'Page token must be a string or None.') + check_error(err, TypeError, 'Page token must be a string or None.') def test_list_models_error(self): recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(exceptions.InvalidArgumentError) as err: mlkit.list_models() check_firebase_error( - err.value, + err, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST From 9afa1758da9213b2b23f18fdf8676842b2551915 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 3 Sep 2019 16:11:48 -0400 Subject: [PATCH 02/14] fixed lint --- firebase_admin/mlkit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index fd704a25c..7aeaa3d96 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -21,9 +21,10 @@ import datetime import numbers import re +import time import requests import six -import time + from firebase_admin import _http_client from firebase_admin import _utils From 53a22bdae6dd3bd72cec15e71f885fef06d7f21b Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 3 Sep 2019 16:18:46 -0400 Subject: [PATCH 03/14] fixed more lint --- tests/test_mlkit.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 320196c7f..2a79b4565 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -444,6 +444,7 @@ def test_create_model_immediate_done(self): recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) model = mlkit.create_model(MODEL_1) assert model == CREATED_MODEL_1 + assert len(recorder) == 1 def test_create_model_with_get_operation(self): create_recorder = instrument_mlkit_service( @@ -460,14 +461,14 @@ def test_create_model_with_get_operation(self): assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) def test_create_model_operation_error(self): - recorder = instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) + instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as err: mlkit.create_model(MODEL_1) # The http request succeeded, the operation returned contains a create failure check_operation_error(err, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_create_model_malformed_operation(self): - recorder = instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) + instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(ValueError) as err: mlkit.create_model(MODEL_1) check_error(err, ValueError, 'Operation is malformed.') @@ -522,7 +523,7 @@ def test_create_model_missing_display_name(self): check_error(err, ValueError, 'Model must have a display name.') def test_create_model_missing_op_name(self): - recorder = instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) + instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as err: mlkit.create_model(MODEL_1) check_error(err, TypeError) @@ -537,7 +538,7 @@ def test_create_model_missing_op_name(self): ]) def test_create_model_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) - recorder = instrument_mlkit_service(status=200, payload=payload) + instrument_mlkit_service(status=200, payload=payload) with pytest.raises(Exception) as err: mlkit.create_model(MODEL_1) check_error(err, ValueError, 'Operation name format is invalid.') From 64d55b8dd531040c344f7f1605c33d4f55355cbc Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 3 Sep 2019 16:19:44 -0400 Subject: [PATCH 04/14] fixed more lint --- tests/test_mlkit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 2a79b4565..ea8932078 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -441,10 +441,9 @@ def _op_url(project_id, model_id): 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_create_model_immediate_done(self): - recorder = instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) + instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) model = mlkit.create_model(MODEL_1) assert model == CREATED_MODEL_1 - assert len(recorder) == 1 def test_create_model_with_get_operation(self): create_recorder = instrument_mlkit_service( From 550e58e53d66979ca8302f6d176c0f7a404ee30d Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Sep 2019 19:45:02 -0400 Subject: [PATCH 05/14] review changes --- firebase_admin/mlkit.py | 53 ++++++++++++++++++++++++++++-- tests/test_mlkit.py | 73 +++++++++++++++++++++-------------------- 2 files changed, 87 insertions(+), 39 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 7aeaa3d96..380d4fdba 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -28,6 +28,7 @@ from firebase_admin import _http_client from firebase_admin import _utils +from firebase_admin import exceptions _MLKIT_ATTRIBUTE = '_mlkit' @@ -58,22 +59,59 @@ def _get_mlkit_service(app): def create_model(model, app=None): + """Creates a model in Firebase ML Kit. + + Args: + model: An mlkit.Model to create. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The model that was created in Firebase ML Kit. + """ mlkit_service = _get_mlkit_service(app) return Model.from_dict(mlkit_service.create_model(model)) def get_model(model_id, app=None): + """Gets a model from Firebase ML Kit. + + Args: + model_id: The id of the model to get. + app: A Firebase app instance (or None to use the default app). + + Returns: + Model: The requested model. + """ mlkit_service = _get_mlkit_service(app) return Model.from_dict(mlkit_service.get_model(model_id)) def list_models(list_filter=None, page_size=None, page_token=None, app=None): + """Lists models from Firebase ML Kit. + + Args: + list_filter: a list filter string such as "tags:'tag_1'". None will return all models. + page_size: A number between 1 and 100 inclusive that specifies the maximum + number of models to return per page. None for default. + page_token: A next page token returned from a previous page of results. None + for first page of results. + app: A Firebase app instance (or None to use the default app). + + Returns: + ListModelsPage: A (filtered) list of models. + """ mlkit_service = _get_mlkit_service(app) return ListModelsPage( mlkit_service.list_models, list_filter, page_size, page_token) def delete_model(model_id, app=None): + """Deletes a model from Firebase ML Kit. + + Args: + model_id: The id of the model you wish to delete. + app: A Firebase app instance (or None to use the default app). + """ mlkit_service = _get_mlkit_service(app) mlkit_service.delete_model(model_id) @@ -471,6 +509,9 @@ class _MLKitService(object): PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' OPERATION_POLL_DELAY_SECONDS = 30 + MAX_POLLING_ATTEMPTS = 10 + POLL_EXPONENTIAL_BACKOFF_FACTOR = 2 + POLL_BASE_WAIT_TIME_SECONDS = 1 def __init__(self, app): project_id = app.project_id @@ -499,7 +540,7 @@ def handle_operation(self, operation): op_name = operation.get('name') _validate_operation_name(op_name) - while True: + for current_attempt in range(_MLKitService.MAX_POLLING_ATTEMPTS): if operation.get('done'): if operation.get('response'): return operation.get('response') @@ -509,10 +550,16 @@ def handle_operation(self, operation): # A 'done' operation must have either a response or an error. raise ValueError('Operation is malformed.') else: - # We just got this operation wait 30s before getting another + # We just got this operation. Wait before getting another # so we don't exceed the GetOperation maximum request rate. - time.sleep(_MLKitService.OPERATION_POLL_DELAY_SECONDS) + delay_factor = pow( + _MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS + time.sleep(wait_time_seconds) operation = self.get_operation(op_name) + # Model validation took too long for the SDK to wait. The backend request + # is still running. Call ListModels with a displayName filter later to find it.s + raise exceptions.DeadlineExceededError('Polling deadline exceeded.') def create_model(self, model): _validate_model(model) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index ea8932078..0fe24d2fe 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -231,46 +231,47 @@ # For validation type errors -def check_error(err, err_type, msg=None): - err_value = err.value - assert isinstance(err_value, err_type) +def check_error(excinfo, err_type, msg=None): + err = excinfo.value + assert isinstance(err, err_type) if msg: - assert str(err_value) == msg + assert str(err) == msg # For errors that are returned in an operation -def check_operation_error(err, code, msg): - err_value = err.value - assert isinstance(err_value, exceptions.FirebaseError) - assert err_value.code == code - assert str(err_value) == msg +def check_operation_error(excinfo, code, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert str(err) == msg # For rpc errors -def check_firebase_error(err, code, status, msg): - err_value = err.value - assert isinstance(err_value, exceptions.FirebaseError) - assert err_value.code == code - assert err_value.http_response is not None - assert err_value.http_response.status_code == status - assert str(err_value) == msg +def check_firebase_error(excinfo, code, status, msg): + err = excinfo.value + assert isinstance(err, exceptions.FirebaseError) + assert err.code == code + assert err.http_response is not None + assert err.http_response.status_code == status + assert str(err) == msg -def instrument_mlkit_service(app=None, status=200, uri=None, payload=None): +def instrument_mlkit_service(app=None, status=200, operations=False, payload=None): if not app: app = firebase_admin.get_app() mlkit_service = mlkit._get_mlkit_service(app) recorder = [] - if uri is None or uri is 'projects': - mlkit_service._client.session.mount( + if operations: + mlkit_service._operation_client.session.mount( 'https://mlkit.googleapis.com/', testutils.MockAdapter(payload, status, recorder) ) - elif uri is 'operations': - mlkit_service._operation_client.session.mount( + else: + mlkit_service._client.session.mount( 'https://mlkit.googleapis.com/', testutils.MockAdapter(payload, status, recorder) ) + return recorder @@ -440,16 +441,16 @@ def _op_url(project_id, model_id): return BASE_URL + \ 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) - def test_create_model_immediate_done(self): + def test_immediate_done(self): instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) model = mlkit.create_model(MODEL_1) assert model == CREATED_MODEL_1 - def test_create_model_with_get_operation(self): + def test_with_get_operation(self): create_recorder = instrument_mlkit_service( - status=200, uri='projects', payload=OPERATION_NOT_DONE_RESPONSE) + status=200, payload=OPERATION_NOT_DONE_RESPONSE) operation_recorder = instrument_mlkit_service( - status=200, uri='operations', payload=OPERATION_DONE_RESPONSE) + status=200, operations=True, payload=OPERATION_DONE_RESPONSE) model = mlkit.create_model(MODEL_1) assert model == CREATED_MODEL_1 assert len(create_recorder) == 1 @@ -459,22 +460,22 @@ def test_create_model_with_get_operation(self): assert operation_recorder[0].method == 'GET' assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) - def test_create_model_operation_error(self): + def test_operation_error(self): instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) with pytest.raises(Exception) as err: mlkit.create_model(MODEL_1) # The http request succeeded, the operation returned contains a create failure check_operation_error(err, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) - def test_create_model_malformed_operation(self): + def test_malformed_operation(self): instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) with pytest.raises(ValueError) as err: mlkit.create_model(MODEL_1) check_error(err, ValueError, 'Operation is malformed.') - def test_create_model_rpc_error_create(self): + def test_rpc_error_create(self): create_recorder = instrument_mlkit_service( - status=400, uri='projects', payload=ERROR_RESPONSE_BAD_REQUEST) + status=400, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as err: mlkit.create_model(MODEL_1) check_firebase_error( @@ -485,11 +486,11 @@ def test_create_model_rpc_error_create(self): ) assert len(create_recorder) == 1 - def test_create_model_rpc_error_operation(self): + def test_rpc_error_operation(self): create_recorder = instrument_mlkit_service( - status=200, uri='projects', payload=OPERATION_NOT_DONE_RESPONSE) + status=200, payload=OPERATION_NOT_DONE_RESPONSE) operation_recorder = instrument_mlkit_service( - status=400, uri='operations', payload=ERROR_RESPONSE_BAD_REQUEST) + status=400, operations=True, payload=ERROR_RESPONSE_BAD_REQUEST) with pytest.raises(Exception) as err: mlkit.create_model(MODEL_1) check_firebase_error( @@ -511,17 +512,17 @@ def test_create_model_rpc_error_operation(self): 0, None ]) - def test_create_model_not_model(self, model): + def test_not_model(self, model): with pytest.raises(Exception) as err: mlkit.create_model(model) check_error(err, TypeError, 'Model must be an mlkit.Model.') - def test_create_model_missing_display_name(self): + def test_missing_display_name(self): with pytest.raises(Exception) as err: mlkit.create_model(mlkit.Model.from_dict({})) check_error(err, ValueError, 'Model must have a display name.') - def test_create_model_missing_op_name(self): + def test_missing_op_name(self): instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) with pytest.raises(Exception) as err: mlkit.create_model(MODEL_1) @@ -535,7 +536,7 @@ def test_create_model_missing_op_name(self): 'operations/project/123/model/$#@/operation/123', 'operations/project/1234/model/abc/operation/123/extrathing', ]) - def test_create_model_invalid_op_name(self, op_name): + def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_mlkit_service(status=200, payload=payload) with pytest.raises(Exception) as err: From ce7f90a6456452f9b408f327b7baa33607e222d7 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Sep 2019 19:54:22 -0400 Subject: [PATCH 06/14] fix lint --- firebase_admin/mlkit.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 380d4fdba..8981452d4 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -535,6 +535,14 @@ def get_operation(self, op_name): raise _utils.handle_platform_error_from_requests(error) def handle_operation(self, operation): + """Handles long running operations. + + Args: + operation: The operation to handle. + + Returns + dict: A dictionary of the returned model properties. + """ if not isinstance(operation, dict): raise TypeError('Operation must be a dictionary.') op_name = operation.get('name') From b14a7b86e870292d235fd4da1d0efd510f46092c Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Sep 2019 20:06:35 -0400 Subject: [PATCH 07/14] fix more lint --- firebase_admin/mlkit.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 8981452d4..e71e38f52 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -540,8 +540,12 @@ def handle_operation(self, operation): Args: operation: The operation to handle. - Returns + Returns: dict: A dictionary of the returned model properties. + + Raises: + TypeError: if the operation is not a dictionary. + ValueError: If the operation is malformed. """ if not isinstance(operation, dict): raise TypeError('Operation must be a dictionary.') @@ -565,8 +569,6 @@ def handle_operation(self, operation): wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS time.sleep(wait_time_seconds) operation = self.get_operation(op_name) - # Model validation took too long for the SDK to wait. The backend request - # is still running. Call ListModels with a displayName filter later to find it.s raise exceptions.DeadlineExceededError('Polling deadline exceeded.') def create_model(self, model): From e9c77fef06e02d221fc0a5761f6c8cdec0bf5fb1 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 9 Sep 2019 19:59:57 -0400 Subject: [PATCH 08/14] review fixes --- firebase_admin/mlkit.py | 91 +++++++++++++---- tests/test_mlkit.py | 215 ++++++++++++++++++++++++++++++---------- 2 files changed, 236 insertions(+), 70 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index e71e38f52..68eb48dc9 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -40,7 +40,8 @@ _RESOURCE_NAME_PATTERN = re.compile( r'^projects/(?P[^/]+)/models/(?P[A-Za-z0-9_-]{1,60})$') _OPERATION_NAME_PATTERN = re.compile( - r'^operations/project/[^/]+/model/[A-Za-z0-9_-]{1,60}/operation/[^/]+$') + r'^operations/project/(?P[^/]+)/model/(?P[A-Za-z0-9_-]{1,60})' + + r'/operation/[^/]+$') def _get_mlkit_service(app): @@ -69,7 +70,7 @@ def create_model(model, app=None): Model: The model that was created in Firebase ML Kit. """ mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.create_model(model)) + return Model.from_dict(mlkit_service.create_model(model), app=app) def get_model(model_id, app=None): @@ -83,7 +84,7 @@ def get_model(model_id, app=None): Model: The requested model. """ mlkit_service = _get_mlkit_service(app) - return Model.from_dict(mlkit_service.get_model(model_id)) + return Model.from_dict(mlkit_service.get_model(model_id), app=app) def list_models(list_filter=None, page_size=None, page_token=None, app=None): @@ -102,7 +103,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): """ mlkit_service = _get_mlkit_service(app) return ListModelsPage( - mlkit_service.list_models, list_filter, page_size, page_token) + mlkit_service.list_models, list_filter, page_size, page_token, app=app) def delete_model(model_id, app=None): @@ -125,6 +126,7 @@ class Model(object): model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. """ def __init__(self, display_name=None, tags=None, model_format=None): + self._app = None # Only needed for wait_for_unlo self._data = {} self._model_format = None @@ -136,7 +138,7 @@ def __init__(self, display_name=None, tags=None, model_format=None): self.model_format = model_format @classmethod - def from_dict(cls, data): + def from_dict(cls, data, app=None): data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) @@ -144,8 +146,18 @@ def from_dict(cls, data): tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) model._data = data_copy # pylint: disable=protected-access + model._app = app # pylint: disable=protected-access return model + def _update_from_dict(self, data): + data_copy = dict(data) + tflite_format = None + tflite_format_data = data_copy.pop('tfliteModel', None) + if tflite_format_data: + tflite_format = TFLiteFormat.from_dict(tflite_format_data) + self.model_format = tflite_format + self._data = data_copy + def __eq__(self, other): if isinstance(other, self.__class__): # pylint: disable=protected-access @@ -220,6 +232,15 @@ def locked(self): return bool(self._data.get('activeOperations') and len(self._data.get('activeOperations')) > 0) + def wait_for_unlocked(self, max_time_seconds=None): + if self.locked: + mlkit_service = _get_mlkit_service(self._app) + op_name = self._data.get('activeOperations')[0].get('name') + model_dict = mlkit_service.handle_operation( + mlkit_service.get_operation(op_name), + max_time_seconds=max_time_seconds) + self._update_from_dict(model_dict) + @property def model_format(self): return self._model_format @@ -343,17 +364,20 @@ class ListModelsPage(object): ``iterate_all()`` can be used to iterate through all the models in the Firebase project starting from this page. """ - def __init__(self, list_models_func, list_filter, page_size, page_token): + def __init__(self, list_models_func, list_filter, page_size, page_token, app): self._list_models_func = list_models_func self._list_filter = list_filter self._page_size = page_size self._page_token = page_token + self._app = app self._list_response = list_models_func(list_filter, page_size, page_token) @property def models(self): """A list of Models from this page.""" - return [Model.from_dict(model) for model in self._list_response.get('models', [])] + return [ + Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + ] @property def list_filter(self): @@ -380,7 +404,8 @@ def get_next_page(self): self._list_models_func, self._list_filter, self._page_size, - self.next_page_token) + self.next_page_token, + self._app) return None def iterate_all(self): @@ -449,9 +474,11 @@ def _validate_model_id(model_id): raise ValueError('Model ID format is invalid.') -def _validate_operation_name(op_name): - if not _OPERATION_NAME_PATTERN.match(op_name): +def _validate_and_parse_operation_name(op_name): + matcher = _OPERATION_NAME_PATTERN.match(op_name) + if not matcher: raise ValueError('Operation name format is invalid.') + return matcher.group('project_id'), matcher.group('model_id') def _validate_display_name(display_name): @@ -476,11 +503,13 @@ def _validate_gcs_tflite_uri(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri + def _validate_model_format(model_format): if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') return model_format + def _validate_list_filter(list_filter): if list_filter is not None: if not isinstance(list_filter, six.string_types): @@ -508,10 +537,8 @@ class _MLKitService(object): PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/' OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/' - OPERATION_POLL_DELAY_SECONDS = 30 - MAX_POLLING_ATTEMPTS = 10 - POLL_EXPONENTIAL_BACKOFF_FACTOR = 2 - POLL_BASE_WAIT_TIME_SECONDS = 1 + POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 + POLL_BASE_WAIT_TIME_SECONDS = 3 def __init__(self, app): project_id = app.project_id @@ -528,17 +555,24 @@ def __init__(self, app): base_url=_MLKitService.OPERATION_URL) def get_operation(self, op_name): - _validate_operation_name(op_name) + _validate_and_parse_operation_name(op_name) try: return self._operation_client.body('get', url=op_name) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def handle_operation(self, operation): + def handle_operation(self, operation, max_polling_attempts=None, max_time_seconds=None, + always_return_model=False): """Handles long running operations. Args: operation: The operation to handle. + max_polling_attempts: The maximum number of polling requests to make. + (None for no limit) + max_time_seconds: The maximum seconds to try polling for operation complete. + (None for no limit) + always_return_model: If true, returns a locked Model instead of raising deadline + exceeded exceptions. Returns: dict: A dictionary of the returned model properties. @@ -550,9 +584,13 @@ def handle_operation(self, operation): if not isinstance(operation, dict): raise TypeError('Operation must be a dictionary.') op_name = operation.get('name') - _validate_operation_name(op_name) + _, model_id = _validate_and_parse_operation_name(op_name) - for current_attempt in range(_MLKitService.MAX_POLLING_ATTEMPTS): + current_attempt = 0 + start_time = datetime.datetime.now() + stop_time = (None if max_time_seconds is None else + start_time + datetime.timedelta(seconds=max_time_seconds)) + while True: if operation.get('done'): if operation.get('response'): return operation.get('response') @@ -564,18 +602,31 @@ def handle_operation(self, operation): else: # We just got this operation. Wait before getting another # so we don't exceed the GetOperation maximum request rate. + if max_polling_attempts is not None and current_attempt >= max_polling_attempts: + if always_return_model: + return get_model(model_id).as_dict() + raise exceptions.DeadlineExceededError('Polling max attempts exceeded.') delay_factor = pow( _MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS + after_sleep_time = (datetime.datetime.now() + + datetime.timedelta(seconds=wait_time_seconds)) + if stop_time is not None and after_sleep_time > stop_time: + if always_return_model: + return get_model(model_id).as_dict() + raise exceptions.DeadlineExceededError('Polling max time exceeded.') time.sleep(wait_time_seconds) operation = self.get_operation(op_name) - raise exceptions.DeadlineExceededError('Polling deadline exceeded.') + current_attempt += 1 + def create_model(self, model): _validate_model(model) try: return self.handle_operation( - self._client.body('post', url='models', json=model.as_dict())) + self._client.body('post', url='models', json=model.as_dict()), + max_polling_attempts=1, + always_return_model=True) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 0fe24d2fe..2c98fc0db 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -29,16 +29,24 @@ PAGE_TOKEN = 'pageToken' NEXT_PAGE_TOKEN = 'nextPageToken' CREATE_TIME_SECONDS = 1566426374 +CREATE_TIME_SECONDS_2 = 1566426385 CREATE_TIME_JSON = { 'seconds': CREATE_TIME_SECONDS } CREATE_TIME_DATETIME = datetime.datetime.fromtimestamp(CREATE_TIME_SECONDS) +CREATE_TIME_JSON_2 = { + 'seconds': CREATE_TIME_SECONDS_2 +} UPDATE_TIME_SECONDS = 1566426678 +UPDATE_TIME_SECONDS_2 = 1566426691 UPDATE_TIME_JSON = { 'seconds': UPDATE_TIME_SECONDS } UPDATE_TIME_DATETIME = datetime.datetime.fromtimestamp(UPDATE_TIME_SECONDS) +UPDATE_TIME_JSON_2 = { + 'seconds': UPDATE_TIME_SECONDS_2 +} ETAG = '33a64df551425fcc55e4d42a148795d9f25f89d4' MODEL_HASH = '987987a98b98798d098098e09809fc0893897' TAG_1 = 'Tag1' @@ -86,8 +94,9 @@ } } +OPERATION_NAME_1 = 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1) OPERATION_NOT_DONE_JSON_1 = { - 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'name': OPERATION_NAME_1, 'metadata': { '@type': 'type.googleapis.com/google.firebase.ml.v1beta1.ModelOperationMetadata', 'name': 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1), @@ -125,14 +134,32 @@ } CREATED_MODEL_1 = mlkit.Model.from_dict(CREATED_MODEL_JSON_1) +LOCKED_MODEL_JSON_1 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_1, + 'createTime': CREATE_TIME_JSON, + 'updateTime': UPDATE_TIME_JSON, + 'tags': TAGS, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + +LOCKED_MODEL_JSON_2 = { + 'name': MODEL_NAME_1, + 'displayName': DISPLAY_NAME_2, + 'createTime': CREATE_TIME_JSON_2, + 'updateTime': UPDATE_TIME_JSON_2, + 'tags': TAGS_2, + 'activeOperations': [OPERATION_NOT_DONE_JSON_1] +} + OPERATION_DONE_MODEL_JSON_1 = { - 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'name': OPERATION_NAME_1, 'done': True, 'response': CREATED_MODEL_JSON_1 } OPERATION_MALFORMED_JSON_1 = { - 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'name': OPERATION_NAME_1, 'done': True, # if done is true then either response or error should be populated } @@ -145,7 +172,7 @@ OPERATION_ERROR_MSG = "Invalid argument" OPERATION_ERROR_EXPECTED_STATUS = 'INVALID_ARGUMENT' OPERATION_ERROR_JSON_1 = { - 'name': 'operations/project/{0}/model/{1}/operation/123'.format(PROJECT_ID, MODEL_ID_1), + 'name': OPERATION_NAME_1, 'done': True, 'error': { 'code': OPERATION_ERROR_CODE, @@ -175,14 +202,22 @@ 'tags': TAGS, 'tfliteModel': TFLITE_FORMAT_JSON } +FULL_MODEL_PUBLISHED = mlkit.Model.from_dict(FULL_MODEL_PUBLISHED_JSON) +OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON = { + 'name': OPERATION_NAME_1, + 'done': True, + 'response': FULL_MODEL_PUBLISHED_JSON +} EMPTY_RESPONSE = json.dumps({}) OPERATION_NOT_DONE_RESPONSE = json.dumps(OPERATION_NOT_DONE_JSON_1) OPERATION_DONE_RESPONSE = json.dumps(OPERATION_DONE_MODEL_JSON_1) +OPERATION_DONE_PUBLISHED_RESPONSE = json.dumps(OPERATION_DONE_FULL_MODEL_PUBLISHED_JSON) OPERATION_ERROR_RESPONSE = json.dumps(OPERATION_ERROR_JSON_1) OPERATION_MALFORMED_RESPONSE = json.dumps(OPERATION_MALFORMED_JSON_1) OPERATION_MISSING_NAME_RESPONSE = json.dumps(OPERATION_MISSING_NAME) DEFAULT_GET_RESPONSE = json.dumps(MODEL_JSON_1) +LOCKED_MODEL_2_RESPONSE = json.dumps(LOCKED_MODEL_JSON_2) NO_MODELS_LIST_RESPONSE = json.dumps({}) DEFAULT_LIST_RESPONSE = json.dumps({ 'models': [MODEL_JSON_1, MODEL_JSON_2], @@ -256,27 +291,58 @@ def check_firebase_error(excinfo, code, status, msg): assert str(err) == msg -def instrument_mlkit_service(app=None, status=200, operations=False, payload=None): +def instrument_mlkit_service(status=200, payload=None, operations=False, app=None): if not app: app = firebase_admin.get_app() mlkit_service = mlkit._get_mlkit_service(app) recorder = [] if operations: mlkit_service._operation_client.session.mount( - 'https://mlkit.googleapis.com/', + 'https://mlkit.googleapis.com/v1beta1/', testutils.MockAdapter(payload, status, recorder) ) else: mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com/', + 'https://mlkit.googleapis.com/v1beta1/', testutils.MockAdapter(payload, status, recorder) ) + return recorder + +def instrument_mlkit_service_multi(statuses, payloads, operations=False, app=None): + if not app: + app = firebase_admin.get_app() + mlkit_service = mlkit._get_mlkit_service(app) + recorder = [] + if operations: + mlkit_service._operation_client.session.mount( + 'https://mlkit.googleapis.com/v1beta1/', + testutils.MockMultiRequestAdapter(payloads, statuses, recorder) + ) + else: + mlkit_service._client.session.mount( + 'https://mlkit.googleapis.com/v1beta1/', + testutils.MockMultiRequestAdapter(payloads, statuses, recorder) + ) return recorder class TestModel(object): """Tests mlkit.Model class.""" + @classmethod + def setup_class(cls): + cred = testutils.MockCredential() + firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test + + @classmethod + def teardown_class(cls): + testutils.cleanup_apps() + + @staticmethod + def _op_url(project_id, model_id): + return BASE_URL + \ + 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) def test_model_success_err_state_lro(self): model = mlkit.Model.from_dict(FULL_MODEL_ERR_STATE_LRO_JSON) @@ -361,9 +427,9 @@ def test_model_format_setters(self): (12345, TypeError) ]) def test_model_display_name_validation_errors(self, display_name, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.Model(display_name=display_name) - check_error(err, exc_type) + check_error(excinfo, exc_type) @pytest.mark.parametrize('tags, exc_type, error_message', [ ('tag1', TypeError, 'Tags must be a list of strings.'), @@ -375,9 +441,9 @@ def test_model_display_name_validation_errors(self, display_name, exc_type): 'tag2'], ValueError, 'Tag format is invalid.') ]) def test_model_tags_validation_errors(self, tags, exc_type, error_message): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.Model(tags=tags) - check_error(err, exc_type, error_message) + check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('model_format', [ 123, @@ -387,9 +453,9 @@ def test_model_tags_validation_errors(self, tags, exc_type, error_message): True ]) def test_model_format_validation_errors(self, model_format): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.Model(model_format=model_format) - check_error(err, TypeError, 'Model format must be a ModelFormat object.') + check_error(excinfo, TypeError, 'Model format must be a ModelFormat object.') @pytest.mark.parametrize('model_source', [ 123, @@ -399,9 +465,9 @@ def test_model_format_validation_errors(self, model_format): True ]) def test_model_source_validation_errors(self, model_source): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.TFLiteFormat(model_source=model_source) - check_error(err, TypeError, 'Model source must be a TFLiteModelSource object.') + check_error(excinfo, TypeError, 'Model source must be a TFLiteModelSource object.') @pytest.mark.parametrize('uri, exc_type', [ (123, TypeError), @@ -415,9 +481,34 @@ def test_model_source_validation_errors(self, model_source): ValueError) ]) def test_gcs_tflite_source_validation_errors(self, uri, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.TFLiteGCSModelSource(gcs_tflite_uri=uri) - check_error(err, exc_type) + check_error(excinfo, exc_type) + + def test_wait_for_unlocked_not_locked(self): + model = mlkit.Model(display_name="not_locked") + model.wait_for_unlocked() + + def test_wait_for_unlocked(self): + recorder = instrument_mlkit_service(status=200, + operations=True, + payload=OPERATION_DONE_PUBLISHED_RESPONSE) + model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + model.wait_for_unlocked() + assert model == FULL_MODEL_PUBLISHED + assert len(recorder) == 1 + assert recorder[0].method == 'GET' + assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) + + def test_wait_for_unlocked_timeout(self): + recorder = instrument_mlkit_service(status=200, + operations=True, + payload=OPERATION_NOT_DONE_RESPONSE) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 5 # longer for timeout + model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) + with pytest.raises(Exception) as excinfo: + model.wait_for_unlocked(max_time_seconds=3) + check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') class TestCreateModel(object): @@ -426,7 +517,7 @@ class TestCreateModel(object): def setup_class(cls): cred = testutils.MockCredential() firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID}) - mlkit._MLKitService.OPERATION_POLL_DELAY_SECONDS = 0.1 # shorter for test + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test @classmethod def teardown_class(cls): @@ -441,6 +532,10 @@ def _op_url(project_id, model_id): return BASE_URL + \ 'operations/project/{0}/model/{1}/operation/123'.format(project_id, model_id) + @staticmethod + def _get_url(project_id, model_id): + return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) + def test_immediate_done(self): instrument_mlkit_service(status=200, payload=OPERATION_DONE_RESPONSE) model = mlkit.create_model(MODEL_1) @@ -460,26 +555,46 @@ def test_with_get_operation(self): assert operation_recorder[0].method == 'GET' assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) + def test_with_get_returns_locked(self): + recorder = instrument_mlkit_service_multi( + statuses=[200, 200], + payloads=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + operation_recorder = instrument_mlkit_service( + status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) + + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.create_model(MODEL_1) + + assert model == expected_model + assert len(recorder) == 2 + assert len(operation_recorder) == 1 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert operation_recorder[0].method == 'GET' + assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) + def test_operation_error(self): instrument_mlkit_service(status=200, payload=OPERATION_ERROR_RESPONSE) - with pytest.raises(Exception) as err: + with pytest.raises(Exception) as excinfo: mlkit.create_model(MODEL_1) # The http request succeeded, the operation returned contains a create failure - check_operation_error(err, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) + check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) - with pytest.raises(ValueError) as err: + with pytest.raises(ValueError) as excinfo: mlkit.create_model(MODEL_1) - check_error(err, ValueError, 'Operation is malformed.') + check_error(excinfo, ValueError, 'Operation is malformed.') def test_rpc_error_create(self): create_recorder = instrument_mlkit_service( status=400, payload=ERROR_RESPONSE_BAD_REQUEST) - with pytest.raises(Exception) as err: + with pytest.raises(Exception) as excinfo: mlkit.create_model(MODEL_1) check_firebase_error( - err, + excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST @@ -491,10 +606,10 @@ def test_rpc_error_operation(self): status=200, payload=OPERATION_NOT_DONE_RESPONSE) operation_recorder = instrument_mlkit_service( status=400, operations=True, payload=ERROR_RESPONSE_BAD_REQUEST) - with pytest.raises(Exception) as err: + with pytest.raises(Exception) as excinfo: mlkit.create_model(MODEL_1) check_firebase_error( - err, + excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST @@ -513,20 +628,20 @@ def test_rpc_error_operation(self): None ]) def test_not_model(self, model): - with pytest.raises(Exception) as err: + with pytest.raises(Exception) as excinfo: mlkit.create_model(model) - check_error(err, TypeError, 'Model must be an mlkit.Model.') + check_error(excinfo, TypeError, 'Model must be an mlkit.Model.') def test_missing_display_name(self): - with pytest.raises(Exception) as err: + with pytest.raises(Exception) as excinfo: mlkit.create_model(mlkit.Model.from_dict({})) - check_error(err, ValueError, 'Model must have a display name.') + check_error(excinfo, ValueError, 'Model must have a display name.') def test_missing_op_name(self): instrument_mlkit_service(status=200, payload=OPERATION_MISSING_NAME_RESPONSE) - with pytest.raises(Exception) as err: + with pytest.raises(Exception) as excinfo: mlkit.create_model(MODEL_1) - check_error(err, TypeError) + check_error(excinfo, TypeError) @pytest.mark.parametrize('op_name', [ 'abc', @@ -539,9 +654,9 @@ def test_missing_op_name(self): def test_invalid_op_name(self, op_name): payload = json.dumps({'name': op_name}) instrument_mlkit_service(status=200, payload=payload) - with pytest.raises(Exception) as err: + with pytest.raises(Exception) as excinfo: mlkit.create_model(MODEL_1) - check_error(err, ValueError, 'Operation name format is invalid.') + check_error(excinfo, ValueError, 'Operation name format is invalid.') class TestGetModel(object): @@ -571,16 +686,16 @@ def test_get_model(self): @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) def test_get_model_validation_errors(self, model_id, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.get_model(model_id) - check_error(err, exc_type) + check_error(excinfo, exc_type) def test_get_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) - with pytest.raises(exceptions.NotFoundError) as err: + with pytest.raises(exceptions.NotFoundError) as excinfo: mlkit.get_model(MODEL_ID_1) check_firebase_error( - err, + excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -621,16 +736,16 @@ def test_delete_model(self): @pytest.mark.parametrize('model_id, exc_type', invalid_model_id_args) def test_delete_model_validation_errors(self, model_id, exc_type): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.delete_model(model_id) - check_error(err, exc_type) + check_error(excinfo, exc_type) def test_delete_model_error(self): recorder = instrument_mlkit_service(status=404, payload=ERROR_RESPONSE_NOT_FOUND) - with pytest.raises(exceptions.NotFoundError) as err: + with pytest.raises(exceptions.NotFoundError) as excinfo: mlkit.delete_model(MODEL_ID_1) check_firebase_error( - err, + excinfo, ERROR_STATUS_NOT_FOUND, ERROR_CODE_NOT_FOUND, ERROR_MSG_NOT_FOUND @@ -702,9 +817,9 @@ def test_list_models_with_all_args(self): @pytest.mark.parametrize('list_filter', invalid_string_or_none_args) def test_list_models_list_filter_validation(self, list_filter): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.list_models(list_filter=list_filter) - check_error(err, TypeError, 'List filter must be a string or None.') + check_error(excinfo, TypeError, 'List filter must be a string or None.') @pytest.mark.parametrize('page_size, exc_type, error_message', [ ('abc', TypeError, 'Page size must be a number or None.'), @@ -717,22 +832,22 @@ def test_list_models_list_filter_validation(self, list_filter): (mlkit._MAX_PAGE_SIZE + 1, ValueError, PAGE_SIZE_VALUE_ERROR_MSG) ]) def test_list_models_page_size_validation(self, page_size, exc_type, error_message): - with pytest.raises(exc_type) as err: + with pytest.raises(exc_type) as excinfo: mlkit.list_models(page_size=page_size) - check_error(err, exc_type, error_message) + check_error(excinfo, exc_type, error_message) @pytest.mark.parametrize('page_token', invalid_string_or_none_args) def test_list_models_page_token_validation(self, page_token): - with pytest.raises(TypeError) as err: + with pytest.raises(TypeError) as excinfo: mlkit.list_models(page_token=page_token) - check_error(err, TypeError, 'Page token must be a string or None.') + check_error(excinfo, TypeError, 'Page token must be a string or None.') def test_list_models_error(self): recorder = instrument_mlkit_service(status=400, payload=ERROR_RESPONSE_BAD_REQUEST) - with pytest.raises(exceptions.InvalidArgumentError) as err: + with pytest.raises(exceptions.InvalidArgumentError) as excinfo: mlkit.list_models() check_firebase_error( - err, + excinfo, ERROR_STATUS_BAD_REQUEST, ERROR_CODE_BAD_REQUEST, ERROR_MSG_BAD_REQUEST From 3dae68ea792763b85d0791bb71cf900e66a38f41 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 9 Sep 2019 20:13:14 -0400 Subject: [PATCH 09/14] fix lint --- tests/test_mlkit.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 2c98fc0db..4295b3c9f 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -501,9 +501,7 @@ def test_wait_for_unlocked(self): assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) def test_wait_for_unlocked_timeout(self): - recorder = instrument_mlkit_service(status=200, - operations=True, - payload=OPERATION_NOT_DONE_RESPONSE) + instrument_mlkit_service(status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 5 # longer for timeout model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) with pytest.raises(Exception) as excinfo: From 5e634068a00cf8b5c1eae251c270369350b38182 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Sep 2019 16:43:04 -0400 Subject: [PATCH 10/14] review comments --- firebase_admin/mlkit.py | 88 +++++++++++++++++++++-------------------- tests/test_mlkit.py | 37 +++++------------ 2 files changed, 56 insertions(+), 69 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 68eb48dc9..ba70c6a3d 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -150,13 +150,9 @@ def from_dict(cls, data, app=None): return model def _update_from_dict(self, data): - data_copy = dict(data) - tflite_format = None - tflite_format_data = data_copy.pop('tfliteModel', None) - if tflite_format_data: - tflite_format = TFLiteFormat.from_dict(tflite_format_data) - self.model_format = tflite_format - self._data = data_copy + copy = Model.from_dict(data) + self.model_format = copy.model_format + self._data = copy._data # pylint: disable=protected-access def __eq__(self, other): if isinstance(other, self.__class__): @@ -233,13 +229,15 @@ def locked(self): len(self._data.get('activeOperations')) > 0) def wait_for_unlocked(self, max_time_seconds=None): - if self.locked: - mlkit_service = _get_mlkit_service(self._app) - op_name = self._data.get('activeOperations')[0].get('name') - model_dict = mlkit_service.handle_operation( - mlkit_service.get_operation(op_name), - max_time_seconds=max_time_seconds) - self._update_from_dict(model_dict) + if not self.locked: + return + + mlkit_service = _get_mlkit_service(self._app) + op_name = self._data.get('activeOperations')[0].get('name') + model_dict = mlkit_service.handle_operation( + mlkit_service.get_operation(op_name), + max_time_seconds=max_time_seconds) + self._update_from_dict(model_dict) @property def model_format(self): @@ -561,6 +559,19 @@ def get_operation(self, op_name): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) + def _exponential_backoff(self, max_polling_attempts, current_attempt, stop_time): + """Sleeps for the appropriate amount of time. Or thows deadline exceeded.""" + if max_polling_attempts is not None and current_attempt >= max_polling_attempts: + raise exceptions.DeadlineExceededError('Polling max attempts exceeded.') + delay_factor = pow( + _MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS + after_sleep_time = (datetime.datetime.now() + + datetime.timedelta(seconds=wait_time_seconds)) + if stop_time is not None and after_sleep_time > stop_time: + raise exceptions.DeadlineExceededError('Polling max time exceeded.') + time.sleep(wait_time_seconds) + def handle_operation(self, operation, max_polling_attempts=None, max_time_seconds=None, always_return_model=False): """Handles long running operations. @@ -580,6 +591,7 @@ def handle_operation(self, operation, max_polling_attempts=None, max_time_second Raises: TypeError: if the operation is not a dictionary. ValueError: If the operation is malformed. + err: If the operation exceeds polling attempts or stop_time """ if not isinstance(operation, dict): raise TypeError('Operation must be a dictionary.') @@ -590,34 +602,26 @@ def handle_operation(self, operation, max_polling_attempts=None, max_time_second start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else start_time + datetime.timedelta(seconds=max_time_seconds)) - while True: - if operation.get('done'): - if operation.get('response'): - return operation.get('response') - elif operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) - else: - # A 'done' operation must have either a response or an error. - raise ValueError('Operation is malformed.') - else: - # We just got this operation. Wait before getting another - # so we don't exceed the GetOperation maximum request rate. - if max_polling_attempts is not None and current_attempt >= max_polling_attempts: - if always_return_model: - return get_model(model_id).as_dict() - raise exceptions.DeadlineExceededError('Polling max attempts exceeded.') - delay_factor = pow( - _MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) - wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS - after_sleep_time = (datetime.datetime.now() + - datetime.timedelta(seconds=wait_time_seconds)) - if stop_time is not None and after_sleep_time > stop_time: - if always_return_model: - return get_model(model_id).as_dict() - raise exceptions.DeadlineExceededError('Polling max time exceeded.') - time.sleep(wait_time_seconds) - operation = self.get_operation(op_name) - current_attempt += 1 + while not operation.get('done'): + # We just got this operation. Wait before getting another + # so we don't exceed the GetOperation maximum request rate. + try: + self._exponential_backoff(max_polling_attempts, current_attempt, stop_time) + except exceptions.DeadlineExceededError as err: + if always_return_model: + return get_model(model_id).as_dict() + raise err + operation = self.get_operation(op_name) + current_attempt += 1 + + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + else: + # A 'done' operation must have either a response or an error. + raise ValueError('Operation is malformed.') + def create_model(self, model): diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 4295b3c9f..35ad0ded2 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -296,34 +296,17 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non app = firebase_admin.get_app() mlkit_service = mlkit._get_mlkit_service(app) recorder = [] - if operations: - mlkit_service._operation_client.session.mount( - 'https://mlkit.googleapis.com/v1beta1/', - testutils.MockAdapter(payload, status, recorder) - ) - else: - mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com/v1beta1/', - testutils.MockAdapter(payload, status, recorder) - ) - return recorder + session_url = 'https://mlkit.googleapis.com/v1beta1/' + if isinstance(status, list): + adapter = testutils.MockMultiRequestAdapter(payload, status, recorder) + else: + adapter = testutils.MockAdapter(payload, status, recorder) -def instrument_mlkit_service_multi(statuses, payloads, operations=False, app=None): - if not app: - app = firebase_admin.get_app() - mlkit_service = mlkit._get_mlkit_service(app) - recorder = [] if operations: - mlkit_service._operation_client.session.mount( - 'https://mlkit.googleapis.com/v1beta1/', - testutils.MockMultiRequestAdapter(payloads, statuses, recorder) - ) + mlkit_service._operation_client.session.mount(session_url, adapter) else: - mlkit_service._client.session.mount( - 'https://mlkit.googleapis.com/v1beta1/', - testutils.MockMultiRequestAdapter(payloads, statuses, recorder) - ) + mlkit_service._client.session.mount(session_url, adapter) return recorder @@ -554,9 +537,9 @@ def test_with_get_operation(self): assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) def test_with_get_returns_locked(self): - recorder = instrument_mlkit_service_multi( - statuses=[200, 200], - payloads=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) operation_recorder = instrument_mlkit_service( status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) From 4faa883fd1e642540f57192ac0ae2e64c9d77b5f Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Sep 2019 16:57:47 -0400 Subject: [PATCH 11/14] fixed lint --- tests/test_mlkit.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 35ad0ded2..323356d08 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -299,14 +299,16 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non session_url = 'https://mlkit.googleapis.com/v1beta1/' if isinstance(status, list): - adapter = testutils.MockMultiRequestAdapter(payload, status, recorder) + adapter = testutils.MockMultiRequestAdapter else: - adapter = testutils.MockAdapter(payload, status, recorder) + adapter = testutils.MockAdapter if operations: - mlkit_service._operation_client.session.mount(session_url, adapter) + mlkit_service._operation_client.session.mount( + session_url, adapter(payload, status, recorder)) else: - mlkit_service._client.session.mount(session_url, adapter) + mlkit_service._client.session.mount( + session_url, adapter(payload, status, recorder)) return recorder From fce637879145d56cac25c06005dd7e9375487bbc Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Sep 2019 19:39:14 -0400 Subject: [PATCH 12/14] review changes --- firebase_admin/mlkit.py | 62 ++++++++++++++++++++++++----------------- tests/test_mlkit.py | 46 ++++-------------------------- 2 files changed, 43 insertions(+), 65 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index ba70c6a3d..5f05b8959 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -229,14 +229,24 @@ def locked(self): len(self._data.get('activeOperations')) > 0) def wait_for_unlocked(self, max_time_seconds=None): + """Waits for the model to be unlocked. (All active operations complete) + + Args: + max_time_seconds: The maximum number of seconds to wait for the model to unlock. + (None for no limit) + + Raises: + exceptions.DeadlineExceeded: If max_time_seconds passed and the model is still locked. + """ if not self.locked: return - mlkit_service = _get_mlkit_service(self._app) op_name = self._data.get('activeOperations')[0].get('name') model_dict = mlkit_service.handle_operation( mlkit_service.get_operation(op_name), - max_time_seconds=max_time_seconds) + polling=True, + max_time_seconds=max_time_seconds, + always_return_model=False) self._update_from_dict(model_dict) @property @@ -559,27 +569,28 @@ def get_operation(self, op_name): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def _exponential_backoff(self, max_polling_attempts, current_attempt, stop_time): + def _exponential_backoff(self, current_attempt, stop_time): """Sleeps for the appropriate amount of time. Or thows deadline exceeded.""" - if max_polling_attempts is not None and current_attempt >= max_polling_attempts: - raise exceptions.DeadlineExceededError('Polling max attempts exceeded.') delay_factor = pow( _MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS - after_sleep_time = (datetime.datetime.now() + - datetime.timedelta(seconds=wait_time_seconds)) - if stop_time is not None and after_sleep_time > stop_time: - raise exceptions.DeadlineExceededError('Polling max time exceeded.') + + if stop_time is not None: + max_seconds_left = (stop_time - datetime.datetime.now()).total_seconds() + if max_seconds_left < 1: # allow a bit of time for rpc + raise exceptions.DeadlineExceededError('Polling max time exceeded.') + else: + wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) - def handle_operation(self, operation, max_polling_attempts=None, max_time_seconds=None, - always_return_model=False): + + def handle_operation(self, operation, polling=False, max_time_seconds=None, + always_return_model=True): """Handles long running operations. Args: operation: The operation to handle. - max_polling_attempts: The maximum number of polling requests to make. - (None for no limit) + polling: Should we allow polling for the operation to complete. max_time_seconds: The maximum seconds to try polling for operation complete. (None for no limit) always_return_model: If true, returns a locked Model instead of raising deadline @@ -602,11 +613,11 @@ def handle_operation(self, operation, max_polling_attempts=None, max_time_second start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else start_time + datetime.timedelta(seconds=max_time_seconds)) - while not operation.get('done'): + while polling and not operation.get('done'): # We just got this operation. Wait before getting another # so we don't exceed the GetOperation maximum request rate. try: - self._exponential_backoff(max_polling_attempts, current_attempt, stop_time) + self._exponential_backoff(current_attempt, stop_time) except exceptions.DeadlineExceededError as err: if always_return_model: return get_model(model_id).as_dict() @@ -614,13 +625,16 @@ def handle_operation(self, operation, max_polling_attempts=None, max_time_second operation = self.get_operation(op_name) current_attempt += 1 - if operation.get('response'): - return operation.get('response') - elif operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) - else: - # A 'done' operation must have either a response or an error. - raise ValueError('Operation is malformed.') + if operation.get('done'): + if operation.get('response'): + return operation.get('response') + elif operation.get('error'): + raise _utils.handle_operation_error(operation.get('error')) + else: + # A 'done' operation must have either a response or an error. + raise ValueError('Operation is malformed.') + elif always_return_model: + return get_model(model_id).as_dict() @@ -628,9 +642,7 @@ def create_model(self, model): _validate_model(model) try: return self.handle_operation( - self._client.body('post', url='models', json=model.as_dict()), - max_polling_attempts=1, - always_return_model=True) + self._client.body('post', url='models', json=model.as_dict())) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 323356d08..22953d95d 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -486,12 +486,14 @@ def test_wait_for_unlocked(self): assert recorder[0].url == TestModel._op_url(PROJECT_ID, MODEL_ID_1) def test_wait_for_unlocked_timeout(self): - instrument_mlkit_service(status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) - mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 5 # longer for timeout + recorder = instrument_mlkit_service( + status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) + mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 3 # longer so timeout applies immediately model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_1) with pytest.raises(Exception) as excinfo: - model.wait_for_unlocked(max_time_seconds=3) + model.wait_for_unlocked(max_time_seconds=0.1) check_error(excinfo, exceptions.DeadlineExceededError, 'Polling max time exceeded.') + assert len(recorder) == 1 class TestCreateModel(object): @@ -524,37 +526,17 @@ def test_immediate_done(self): model = mlkit.create_model(MODEL_1) assert model == CREATED_MODEL_1 - def test_with_get_operation(self): - create_recorder = instrument_mlkit_service( - status=200, payload=OPERATION_NOT_DONE_RESPONSE) - operation_recorder = instrument_mlkit_service( - status=200, operations=True, payload=OPERATION_DONE_RESPONSE) - model = mlkit.create_model(MODEL_1) - assert model == CREATED_MODEL_1 - assert len(create_recorder) == 1 - assert create_recorder[0].method == 'POST' - assert create_recorder[0].url == TestCreateModel._url(PROJECT_ID) - assert len(operation_recorder) == 1 - assert operation_recorder[0].method == 'GET' - assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) - - def test_with_get_returns_locked(self): + def test_returns_locked(self): recorder = instrument_mlkit_service( status=[200, 200], payload=[OPERATION_NOT_DONE_RESPONSE, LOCKED_MODEL_2_RESPONSE]) - operation_recorder = instrument_mlkit_service( - status=200, operations=True, payload=OPERATION_NOT_DONE_RESPONSE) - expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) model = mlkit.create_model(MODEL_1) assert model == expected_model assert len(recorder) == 2 - assert len(operation_recorder) == 1 assert recorder[0].method == 'POST' assert recorder[0].url == TestCreateModel._url(PROJECT_ID) - assert operation_recorder[0].method == 'GET' - assert operation_recorder[0].url == TestCreateModel._op_url(PROJECT_ID, MODEL_ID_1) assert recorder[1].method == 'GET' assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) @@ -584,22 +566,6 @@ def test_rpc_error_create(self): ) assert len(create_recorder) == 1 - def test_rpc_error_operation(self): - create_recorder = instrument_mlkit_service( - status=200, payload=OPERATION_NOT_DONE_RESPONSE) - operation_recorder = instrument_mlkit_service( - status=400, operations=True, payload=ERROR_RESPONSE_BAD_REQUEST) - with pytest.raises(Exception) as excinfo: - mlkit.create_model(MODEL_1) - check_firebase_error( - excinfo, - ERROR_STATUS_BAD_REQUEST, - ERROR_CODE_BAD_REQUEST, - ERROR_MSG_BAD_REQUEST - ) - assert len(create_recorder) == 1 - assert len(operation_recorder) == 1 - @pytest.mark.parametrize('model', [ 'abc', 4.2, From 3b8fe4521136e96ea39d28c012a8c2e9ef6ae3bf Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Sep 2019 20:41:57 -0400 Subject: [PATCH 13/14] review comments --- firebase_admin/mlkit.py | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 5f05b8959..f222f36d5 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -244,9 +244,8 @@ def wait_for_unlocked(self, max_time_seconds=None): op_name = self._data.get('activeOperations')[0].get('name') model_dict = mlkit_service.handle_operation( mlkit_service.get_operation(op_name), - polling=True, - max_time_seconds=max_time_seconds, - always_return_model=False) + wait_for_operation=True, + max_time_seconds=max_time_seconds) self._update_from_dict(model_dict) @property @@ -570,9 +569,8 @@ def get_operation(self, op_name): raise _utils.handle_platform_error_from_requests(error) def _exponential_backoff(self, current_attempt, stop_time): - """Sleeps for the appropriate amount of time. Or thows deadline exceeded.""" - delay_factor = pow( - _MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) + """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" + delay_factor = pow(_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS if stop_time is not None: @@ -584,17 +582,15 @@ def _exponential_backoff(self, current_attempt, stop_time): time.sleep(wait_time_seconds) - def handle_operation(self, operation, polling=False, max_time_seconds=None, - always_return_model=True): + def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): """Handles long running operations. Args: operation: The operation to handle. - polling: Should we allow polling for the operation to complete. + wait_for_operation: Should we allow polling for the operation to complete. + If no polling is requested, a locked model will be returned instead. max_time_seconds: The maximum seconds to try polling for operation complete. (None for no limit) - always_return_model: If true, returns a locked Model instead of raising deadline - exceeded exceptions. Returns: dict: A dictionary of the returned model properties. @@ -613,15 +609,10 @@ def handle_operation(self, operation, polling=False, max_time_seconds=None, start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else start_time + datetime.timedelta(seconds=max_time_seconds)) - while polling and not operation.get('done'): + while wait_for_operation and not operation.get('done'): # We just got this operation. Wait before getting another # so we don't exceed the GetOperation maximum request rate. - try: - self._exponential_backoff(current_attempt, stop_time) - except exceptions.DeadlineExceededError as err: - if always_return_model: - return get_model(model_id).as_dict() - raise err + self._exponential_backoff(current_attempt, stop_time) operation = self.get_operation(op_name) current_attempt += 1 @@ -630,12 +621,11 @@ def handle_operation(self, operation, polling=False, max_time_seconds=None, return operation.get('response') elif operation.get('error'): raise _utils.handle_operation_error(operation.get('error')) - else: - # A 'done' operation must have either a response or an error. - raise ValueError('Operation is malformed.') - elif always_return_model: - return get_model(model_id).as_dict() + # A 'done' operation must have either a response or an error. + raise ValueError('Operation is malformed.') + # If the operation is not complete or timed out, return a locked model instead + return get_model(model_id).as_dict() def create_model(self, model): From 6ce9fa274a6adb7a0c7390d1e5298419dbf6c8e8 Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Sep 2019 20:53:58 -0400 Subject: [PATCH 14/14] review comments --- firebase_admin/mlkit.py | 4 +--- tests/test_mlkit.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index f222f36d5..8cf8d1f7f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -621,10 +621,8 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds return operation.get('response') elif operation.get('error'): raise _utils.handle_operation_error(operation.get('error')) - # A 'done' operation must have either a response or an error. - raise ValueError('Operation is malformed.') - # If the operation is not complete or timed out, return a locked model instead + # If the operation is not complete or timed out, return a (locked) model instead return get_model(model_id).as_dict() diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 22953d95d..78afbdf49 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -548,10 +548,17 @@ def test_operation_error(self): check_operation_error(excinfo, OPERATION_ERROR_EXPECTED_STATUS, OPERATION_ERROR_MSG) def test_malformed_operation(self): - instrument_mlkit_service(status=200, payload=OPERATION_MALFORMED_RESPONSE) - with pytest.raises(ValueError) as excinfo: - mlkit.create_model(MODEL_1) - check_error(excinfo, ValueError, 'Operation is malformed.') + recorder = instrument_mlkit_service( + status=[200, 200], + payload=[OPERATION_MALFORMED_RESPONSE, LOCKED_MODEL_2_RESPONSE]) + expected_model = mlkit.Model.from_dict(LOCKED_MODEL_JSON_2) + model = mlkit.create_model(MODEL_1) + assert model == expected_model + assert len(recorder) == 2 + assert recorder[0].method == 'POST' + assert recorder[0].url == TestCreateModel._url(PROJECT_ID) + assert recorder[1].method == 'GET' + assert recorder[1].url == TestCreateModel._get_url(PROJECT_ID, MODEL_ID_1) def test_rpc_error_create(self): create_recorder = instrument_mlkit_service(