From 7078cb6a61d1fc2d06a84803d117cc12b15bf56d Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 2 Dec 2019 14:29:12 -0500 Subject: [PATCH 01/12] Adding support for TensorFlow 2.x --- firebase_admin/mlkit.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index da8bda3c8..5d5c47a47 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -495,12 +495,30 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @staticmethod - def _assert_tf_version_1_enabled(): + def _assert_tf_enabled(): if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') - if not tf.VERSION.startswith('1.'): - raise ImportError('Expected tensorflow version 1.x, but found {0}'.format(tf.VERSION)) + if not tf.version.VERSION.startswith('1.') and not tf.Version.startswith('2.'): + raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' + .format(tf.version.VERSION)) + + @staticmethod + def _tf_convert_from_saved_model(saved_model_dir): + # Same for both v1.x and v2.x + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + return converter.convert() + + @staticmethod + def _tf_convert_from_keras_model(keras_model): + if tf.version.VERSION.startswith('1.'): + keras_file = 'firebase_keras_model.h5' + tf.keras.models.save_model(keras_model, keras_file) + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + return converter.convert() + else: + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + return converter.convert() @classmethod def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): @@ -518,9 +536,8 @@ def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_saved_model(saved_model_dir) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) @@ -541,11 +558,8 @@ def from_keras_model(cls, keras_model, bucket_name=None, app=None): Raises: ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed. """ - TFLiteGCSModelSource._assert_tf_version_1_enabled() - keras_file = 'keras_model.h5' - tf.keras.models.save_model(keras_model, keras_file) - converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) - tflite_model = converter.convert() + TFLiteGCSModelSource._assert_tf_enabled() + tflite_model = TFLiteGCSModelSource._tf_convert_from_keras_model(keras_model) open('firebase_mlkit_model.tflite', 'wb').write(tflite_model) return TFLiteGCSModelSource.from_tflite_model_file( 'firebase_mlkit_model.tflite', bucket_name, app) From d88a66b4cb920861200d6258dccd6b830796efd2 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 2 Dec 2019 14:37:34 -0500 Subject: [PATCH 02/12] fix typo --- firebase_admin/mlkit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 5d5c47a47..580f29d4c 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -499,7 +499,7 @@ def _assert_tf_enabled(): if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') - if not tf.version.VERSION.startswith('1.') and not tf.Version.startswith('2.'): + if not tf.version.VERSION.startswith('1.') and not tf.version.VERSION.startswith('2.'): raise ImportError('Expected tensorflow version 1.x or 2.x, but found {0}' .format(tf.version.VERSION)) From 8b5a6b34a01a58b16f800a5c6b116f0d96fd59cb Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 2 Dec 2019 14:44:01 -0500 Subject: [PATCH 03/12] remove extraneous @type from operations --- firebase_admin/mlkit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 580f29d4c..027786400 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -200,6 +200,7 @@ def from_dict(cls, data, app=None): data_copy = dict(data) tflite_format = None tflite_format_data = data_copy.pop('tfliteModel', None) + data_copy.pop('@type', None) # Returned by Operations. (Not needed) if tflite_format_data: tflite_format = TFLiteFormat.from_dict(tflite_format_data) model = Model(model_format=tflite_format) From 7ee369d9d573723f37af10c4a942b4d781110e93 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 14:29:20 -0500 Subject: [PATCH 04/12] send updateMask in query parameter --- firebase_admin/mlkit.py | 6 +++--- tests/test_mlkit.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 027786400..b89fa5fdf 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -867,12 +867,12 @@ def create_model(self, model): def update_model(self, model, update_mask=None): _validate_model(model, update_mask) - data = {'model': model.as_dict(for_upload=True)} + path = 'models/{0}'.format(model.model_id) if update_mask is not None: - data['updateMask'] = update_mask + path = path + '?updateMask={0}'.format(update_mask) try: return self.handle_operation( - self._client.body('patch', url='models/{0}'.format(model.model_id), json=data)) + self._client.body('patch', url=path, json=model.as_dict(for_upload=True))) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index fbe31aec4..8b0816b0a 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -763,7 +763,12 @@ def teardown_class(cls): testutils.cleanup_apps() @staticmethod - def _url(project_id, model_id): + def _update_url(project_id, model_id): + update_url = 'projects/{0}/models/{1}?updateMask=state.published' + return BASE_URL + update_url.format(project_id, model_id) + + @staticmethod + def _get_url(project_id, model_id): return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id) @staticmethod @@ -778,10 +783,9 @@ def test_immediate_done(self, publish_function, published): assert model == CREATED_UPDATED_MODEL_1 assert len(recorder) == 1 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) body = json.loads(recorder[0].body.decode()) - assert body.get('model', {}).get('state', {}).get('published', None) is published - assert body.get('updateMask', {}) == 'state.published' + assert body.get('state', {}).get('published', None) is published @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_returns_locked(self, publish_function): @@ -794,9 +798,9 @@ def test_returns_locked(self, publish_function): assert model == expected_model assert len(recorder) == 2 assert recorder[0].method == 'PATCH' - assert recorder[0].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[0].url == TestPublishUnpublish._update_url(PROJECT_ID, MODEL_ID_1) assert recorder[1].method == 'GET' - assert recorder[1].url == TestPublishUnpublish._url(PROJECT_ID, MODEL_ID_1) + assert recorder[1].url == TestPublishUnpublish._get_url(PROJECT_ID, MODEL_ID_1) @pytest.mark.parametrize('publish_function', PUBLISH_UNPUBLISH_FUNCS) def test_operation_error(self, publish_function): From d1a993335fe8eceb22612487597b389176b69df5 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 15:22:21 -0500 Subject: [PATCH 05/12] send list filters etc in query parameters --- firebase_admin/mlkit.py | 13 ++++++++----- tests/test_mlkit.py | 10 ++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index b89fa5fdf..6e204029d 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -899,15 +899,18 @@ def list_models(self, list_filter, page_size, page_token): _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - payload = {} + path = 'models' + joiner = '?' if list_filter: - payload['list_filter'] = list_filter + path = path + joiner + 'listFilter=\'{0}\''.format(list_filter) + joiner = '&' if page_size: - payload['page_size'] = page_size + path = path + joiner + 'pageSize={0}'.format(page_size) + joiner = '&' if page_token: - payload['page_token'] = page_token + path = path + joiner + 'pageToken={0}'.format(page_token) try: - return self._client.body('get', url='models', json=payload) + return self._client.body('get', url=path) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 8b0816b0a..577e5b420 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -977,12 +977,10 @@ def test_list_models_with_all_args(self): page_token=PAGE_TOKEN) assert len(recorder) == 1 assert recorder[0].method == 'GET' - assert recorder[0].url == TestListModels._url(PROJECT_ID) - assert json.loads(recorder[0].body.decode()) == { - 'list_filter': 'display_name=displayName3', - 'page_size': 10, - 'page_token': PAGE_TOKEN - } + assert recorder[0].url == ( + TestListModels._url(PROJECT_ID) + + '?listFilter=\'display_name=displayName3\'&pageSize=10&pageToken={0}' + .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 assert models_page.models[0] == MODEL_3 From 68ebe7732f1f254ff9356c963cbf1687e5851a06 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 17:16:14 -0500 Subject: [PATCH 06/12] fix typo --- firebase_admin/mlkit.py | 6 +++--- tests/test_mlkit.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 6e204029d..624a4000d 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -902,13 +902,13 @@ def list_models(self, list_filter, page_size, page_token): path = 'models' joiner = '?' if list_filter: - path = path + joiner + 'listFilter=\'{0}\''.format(list_filter) + path = path + joiner + 'list_filter=\'{0}\''.format(list_filter) joiner = '&' if page_size: - path = path + joiner + 'pageSize={0}'.format(page_size) + path = path + joiner + 'page_size={0}'.format(page_size) joiner = '&' if page_token: - path = path + joiner + 'pageToken={0}'.format(page_token) + path = path + joiner + 'page_token={0}'.format(page_token) try: return self._client.body('get', url=path) except requests.exceptions.RequestException as error: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 577e5b420..8926be6dc 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?listFilter=\'display_name=displayName3\'&pageSize=10&pageToken={0}' + '?list_filter=\'display_name=displayName3\'&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From 5633b9f32e9a198519107d095f5df1360a4b7838 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 17:21:26 -0500 Subject: [PATCH 07/12] fix typo --- firebase_admin/mlkit.py | 2 +- tests/test_mlkit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 624a4000d..17d15989f 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -902,7 +902,7 @@ def list_models(self, list_filter, page_size, page_token): path = 'models' joiner = '?' if list_filter: - path = path + joiner + 'list_filter=\'{0}\''.format(list_filter) + path = path + joiner + 'filter=\'{0}\''.format(list_filter) joiner = '&' if page_size: path = path + joiner + 'page_size={0}'.format(page_size) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 8926be6dc..5f12dcb7e 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?list_filter=\'display_name=displayName3\'&page_size=10&page_token={0}' + '?filter=\'display_name=displayName3\'&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From d715628543d73000c1f7f803cdc418c653a52840 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 17:49:10 -0500 Subject: [PATCH 08/12] fix typo --- firebase_admin/mlkit.py | 2 +- tests/test_mlkit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 17d15989f..050a1ffba 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -902,7 +902,7 @@ def list_models(self, list_filter, page_size, page_token): path = 'models' joiner = '?' if list_filter: - path = path + joiner + 'filter=\'{0}\''.format(list_filter) + path = path + joiner + 'filter={0}'.format(list_filter) joiner = '&' if page_size: path = path + joiner + 'page_size={0}'.format(page_size) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 5f12dcb7e..0daec946b 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?filter=\'display_name=displayName3\'&page_size=10&page_token={0}' + '?filter=display_name=displayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From b63a4feca17dd52b0c808c7cbcdcbabbc328219d Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 18:15:35 -0500 Subject: [PATCH 09/12] urlEncode filter string --- firebase_admin/mlkit.py | 16 +++++++++------- tests/test_mlkit.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 050a1ffba..b75dd2233 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -27,6 +27,7 @@ import six +from six.moves import urllib from firebase_admin import _http_client from firebase_admin import _utils from firebase_admin import exceptions @@ -899,16 +900,17 @@ def list_models(self, list_filter, page_size, page_token): _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - path = 'models' - joiner = '?' + params = {} if list_filter: - path = path + joiner + 'filter={0}'.format(list_filter) - joiner = '&' + params['filter'] = list_filter if page_size: - path = path + joiner + 'page_size={0}'.format(page_size) - joiner = '&' + params['page_size'] = page_size if page_token: - path = path + joiner + 'page_token={0}'.format(page_token) + params['page_token'] = page_token + path = 'models' + if params != {}: + param_str = urllib.parse.urlencode(sorted(params.items()), True) + path = path + '?' + param_str try: return self._client.body('get', url=path) except requests.exceptions.RequestException as error: diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 0daec946b..3198eff34 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -979,7 +979,7 @@ def test_list_models_with_all_args(self): assert recorder[0].method == 'GET' assert recorder[0].url == ( TestListModels._url(PROJECT_ID) + - '?filter=display_name=displayName3&page_size=10&page_token={0}' + '?filter=display_name%3DdisplayName3&page_size=10&page_token={0}' .format(PAGE_TOKEN)) assert isinstance(models_page, mlkit.ListModelsPage) assert len(models_page.models) == 1 From 9892054b2c9c96d633e955d9472cd590d2d017e6 Mon Sep 17 00:00:00 2001 From: ifielker Date: Wed, 4 Dec 2019 19:52:20 -0500 Subject: [PATCH 10/12] fix lint --- firebase_admin/mlkit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index b75dd2233..7303fa057 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -909,6 +909,7 @@ def list_models(self, list_filter, page_size, page_token): params['page_token'] = page_token path = 'models' if params != {}: + # pylint: disable=too-many-function-args param_str = urllib.parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str try: From f54e634998d5c272d80826959558d8b386b40042 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 9 Dec 2019 15:48:25 -0500 Subject: [PATCH 11/12] added comments and fix --- firebase_admin/mlkit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firebase_admin/mlkit.py b/firebase_admin/mlkit.py index 7303fa057..bb277abf9 100644 --- a/firebase_admin/mlkit.py +++ b/firebase_admin/mlkit.py @@ -513,6 +513,7 @@ def _tf_convert_from_saved_model(saved_model_dir): @staticmethod def _tf_convert_from_keras_model(keras_model): + # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. if tf.version.VERSION.startswith('1.'): keras_file = 'firebase_keras_model.h5' tf.keras.models.save_model(keras_model, keras_file) @@ -908,7 +909,7 @@ def list_models(self, list_filter, page_size, page_token): if page_token: params['page_token'] = page_token path = 'models' - if params != {}: + if params: # pylint: disable=too-many-function-args param_str = urllib.parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str From 9a148aa6c1db83c7b1dd39c939baa0371c26838a Mon Sep 17 00:00:00 2001 From: ifielker Date: Tue, 10 Dec 2019 12:26:31 -0500 Subject: [PATCH 12/12] applied review suggestion --- tests/test_mlkit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_mlkit.py b/tests/test_mlkit.py index 3198eff34..dbe590673 100644 --- a/tests/test_mlkit.py +++ b/tests/test_mlkit.py @@ -764,8 +764,9 @@ def teardown_class(cls): @staticmethod def _update_url(project_id, model_id): - update_url = 'projects/{0}/models/{1}?updateMask=state.published' - return BASE_URL + update_url.format(project_id, model_id) + update_url = 'projects/{0}/models/{1}?updateMask=state.published'.format( + project_id, model_id) + return BASE_URL + update_url @staticmethod def _get_url(project_id, model_id):