From fb8914cc8fd21844991bb14c0bd522a6a1458759 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 13 Jul 2020 13:05:07 -0400 Subject: [PATCH 1/5] added integration test for AutoML --- integration/test_ml.py | 85 +++++++++++++++++++++++++++++++++++------- 1 file changed, 71 insertions(+), 14 deletions(-) diff --git a/integration/test_ml.py b/integration/test_ml.py index be791d8fa..5dbcf2bf0 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -22,6 +22,7 @@ import pytest +import firebase_admin from firebase_admin import exceptions from firebase_admin import ml from tests import testutils @@ -34,6 +35,11 @@ except ImportError: _TF_ENABLED = False +try: + from google.cloud import automl_v1 + _AUTOML_ENABLED = True +except ImportError: + _AUTOML_ENABLED = False def _random_identifier(prefix): #pylint: disable=unused-variable @@ -42,27 +48,26 @@ def _random_identifier(prefix): NAME_ONLY_ARGS = { - 'display_name': _random_identifier('TestModel123_') + 'display_name': _random_identifier('TestModel_') } NAME_ONLY_ARGS_UPDATED = { - 'display_name': _random_identifier('TestModel123_updated_') + 'display_name': _random_identifier('TestModel_updated_') } NAME_AND_TAGS_ARGS = { - 'display_name': _random_identifier('TestModel123_tags_'), + 'display_name': _random_identifier('TestModel_tags_'), 'tags': ['test_tag123'] } FULL_MODEL_ARGS = { - 'display_name': _random_identifier('TestModel123_full_'), + 'display_name': _random_identifier('TestModel_full_'), 'tags': ['test_tag567'], 'file_name': 'model1.tflite' } INVALID_FULL_MODEL_ARGS = { - 'display_name': _random_identifier('TestModel123_invalid_full_'), + 'display_name': _random_identifier('TestModel_invalid_full_'), 'tags': ['test_tag890'], 'file_name': 'invalid_model.tflite' } - @pytest.fixture def firebase_model(request): args = request.param @@ -101,6 +106,7 @@ def _clean_up_model(model): try: # Try to delete the model. # Some tests delete the model as part of the test. + model.wait_for_unlocked() ml.delete_model(model.model_id) except exceptions.NotFoundError: pass @@ -133,17 +139,20 @@ def check_model(model, args): assert model.etag is not None -def check_model_format(model, has_model_format=False, validation_error=None): +def check_model_format(model, has_model_format=False, validation_error=None, is_automl=False): if has_model_format: assert model.validation_error == validation_error assert model.published is False - assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') - if validation_error: - assert model.model_format.size_bytes is None - assert model.model_hash is None + if is_automl: + assert model.model_format.model_source.auto_ml_model.startswith('projects/') else: - assert model.model_format.size_bytes is not None - assert model.model_hash is not None + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None + assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None else: assert model.model_format is None assert model.validation_error == 'No model file has been uploaded.' @@ -290,7 +299,7 @@ def test_delete_model(firebase_model): # Test tensor flow conversion functions if tensor flow is enabled. #'pip install tensorflow' in the environment if you want _TF_ENABLED = True -#'pip install tensorflow==2.0.0b' for version 2 etc. +#'pip install tensorflow==2.2.0' for version 2.2.0 etc. def _clean_up_directory(save_dir): @@ -334,6 +343,7 @@ def saved_model_dir(keras_model): _clean_up_directory(parent) + @pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.') def test_from_keras_model(keras_model): source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite') @@ -371,3 +381,50 @@ def test_from_saved_model(saved_model_dir): assert created_model.validation_error is None finally: _clean_up_model(created_model) + + +# Test AutoML functionality if AutoML is enabled. +#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True +# You will also need a predefined AutoML model named 'py_sdk_integ_test1' to run the +# successful test. (Test is skipped otherwise) + +@pytest.fixture +def automl_model(): + assert _AUTOML_ENABLED + + # It takes > 20 minutes to train a model, so we expect a predefined AutoMl + # model named 'py_sdk_integ_test1' to exist in the project, or we skip + # the test. + automl_client = automl_v1.AutoMlClient() + project_id = firebase_admin.get_app().project_id + parent = automl_client.location_path(project_id, 'us-central1') + models = automl_client.list_models(parent, filter_="display_name=py_sdk_integ_test1") + # Expecting exactly one. (Ok to use last one if somehow more than 1) + automl_ref = None + for model in models: + automl_ref = model.name + + # Skip if no pre-defined model. (It takes min > 20 minutes to train a model) + if automl_ref is None: + pytest.skip("No pre-existing AutoML model found. Skipping test") + + source = ml.TFLiteAutoMlSource(automl_ref) + tflite_format = ml.TFLiteFormat(model_source=source) + ml_model = ml.Model( + display_name=_random_identifier('TestModel_automl_'), + tags=['test_automl'], + model_format=tflite_format) + model = ml.create_model(model=ml_model) + yield model + _clean_up_model(model) + +@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.') +def test_automl_model(automl_model): + # This test looks for a predefined automl model with display_name = 'py_sdk_integ_test1' + automl_model.wait_for_unlocked() + + check_model(automl_model, { + 'display_name': automl_model.display_name, + 'tags': ['test_automl'], + }) + check_model_format(automl_model, has_model_format=True, validation_error=None, is_automl=True) From 0bf1a474e8e148371af6138a87ff4718d63d29c6 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 13 Jul 2020 14:39:08 -0400 Subject: [PATCH 2/5] added astroid to fix lint errors --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index d7fb6d736..dbeaee3b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +astroid == 2.3.3 pylint == 2.3.1 pytest >= 3.6.0 pytest-cov >= 2.4.0 From dab056ce831bde44a8848d87161e64fabf99e210 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 13 Jul 2020 18:15:53 -0400 Subject: [PATCH 3/5] review changes --- integration/test_ml.py | 67 +++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/integration/test_ml.py b/integration/test_ml.py index 5dbcf2bf0..52cb1bb7e 100644 --- a/integration/test_ml.py +++ b/integration/test_ml.py @@ -138,38 +138,45 @@ def check_model(model, args): assert model.locked is False assert model.etag is not None +# Model Format Checks -def check_model_format(model, has_model_format=False, validation_error=None, is_automl=False): - if has_model_format: - assert model.validation_error == validation_error - assert model.published is False - if is_automl: - assert model.model_format.model_source.auto_ml_model.startswith('projects/') - else: - assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') - if validation_error: - assert model.model_format.size_bytes is None - assert model.model_hash is None - else: - assert model.model_format.size_bytes is not None - assert model.model_hash is not None - else: - assert model.model_format is None - assert model.validation_error == 'No model file has been uploaded.' - assert model.published is False +def check_no_model_format(model): + assert model.model_format is None + assert model.validation_error == 'No model file has been uploaded.' + assert model.published is False + assert model.model_hash is None + + +def check_tflite_gcs_format(model, validation_error=None): + assert model.validation_error == validation_error + assert model.published is False + assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://') + if validation_error: + assert model.model_format.size_bytes is None assert model.model_hash is None + else: + assert model.model_format.size_bytes is not None + assert model.model_hash is not None + + +def check_tflite_automl_format(model): + assert model.validation_error is None + assert model.published is False + assert model.model_format.model_source.auto_ml_model.startswith('projects/') + # Automl models don't have validation errors since they are references + # to valid automl models. @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_create_simple_model(firebase_model): check_model(firebase_model, NAME_AND_TAGS_ARGS) - check_model_format(firebase_model) + check_no_model_format(firebase_model) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) def test_create_full_model(firebase_model): check_model(firebase_model, FULL_MODEL_ARGS) - check_model_format(firebase_model, True) + check_tflite_gcs_format(firebase_model) @pytest.mark.parametrize('firebase_model', [FULL_MODEL_ARGS], indirect=True) @@ -184,14 +191,14 @@ def test_create_already_existing_fails(firebase_model): @pytest.mark.parametrize('firebase_model', [INVALID_FULL_MODEL_ARGS], indirect=True) def test_create_invalid_model(firebase_model): check_model(firebase_model, INVALID_FULL_MODEL_ARGS) - check_model_format(firebase_model, True, 'Invalid flatbuffer format') + check_tflite_gcs_format(firebase_model, 'Invalid flatbuffer format') @pytest.mark.parametrize('firebase_model', [NAME_AND_TAGS_ARGS], indirect=True) def test_get_model(firebase_model): get_model = ml.get_model(firebase_model.model_id) check_model(get_model, NAME_AND_TAGS_ARGS) - check_model_format(get_model) + check_no_model_format(get_model) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -210,12 +217,12 @@ def test_update_model(firebase_model): firebase_model.display_name = new_model_name updated_model = ml.update_model(firebase_model) check_model(updated_model, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model) + check_no_model_format(updated_model) # Second call with same model does not cause error updated_model2 = ml.update_model(updated_model) check_model(updated_model2, NAME_ONLY_ARGS_UPDATED) - check_model_format(updated_model2) + check_no_model_format(updated_model2) @pytest.mark.parametrize('firebase_model', [NAME_ONLY_ARGS], indirect=True) @@ -358,7 +365,7 @@ def test_from_keras_model(keras_model): try: check_model(created_model, {'display_name': model.display_name}) - check_model_format(created_model, True) + check_tflite_gcs_format(created_model) finally: _clean_up_model(created_model) @@ -385,7 +392,7 @@ def test_from_saved_model(saved_model_dir): # Test AutoML functionality if AutoML is enabled. #'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True -# You will also need a predefined AutoML model named 'py_sdk_integ_test1' to run the +# You will also need a predefined AutoML model named 'admin_sdk_integ_test1' to run the # successful test. (Test is skipped otherwise) @pytest.fixture @@ -393,12 +400,12 @@ def automl_model(): assert _AUTOML_ENABLED # It takes > 20 minutes to train a model, so we expect a predefined AutoMl - # model named 'py_sdk_integ_test1' to exist in the project, or we skip + # model named 'admin_sdk_integ_test1' to exist in the project, or we skip # the test. automl_client = automl_v1.AutoMlClient() project_id = firebase_admin.get_app().project_id parent = automl_client.location_path(project_id, 'us-central1') - models = automl_client.list_models(parent, filter_="display_name=py_sdk_integ_test1") + models = automl_client.list_models(parent, filter_="display_name=admin_sdk_integ_test1") # Expecting exactly one. (Ok to use last one if somehow more than 1) automl_ref = None for model in models: @@ -420,11 +427,11 @@ def automl_model(): @pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.') def test_automl_model(automl_model): - # This test looks for a predefined automl model with display_name = 'py_sdk_integ_test1' + # This test looks for a predefined automl model with display_name = 'admin_sdk_integ_test1' automl_model.wait_for_unlocked() check_model(automl_model, { 'display_name': automl_model.display_name, 'tags': ['test_automl'], }) - check_model_format(automl_model, has_model_format=True, validation_error=None, is_automl=True) + check_tflite_automl_format(automl_model) From 3f6c34c50cbd900b1b740749c69652f376d2fbee Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 13 Jul 2020 19:35:23 -0400 Subject: [PATCH 4/5] forcing google-auth at 1.18.0 --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index dbeaee3b6..6b2ad2460 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ pytest-localserver >= 0.4.1 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 +google-auth == 1.18.0 google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.18.0 From cecc967204ddbafeb9d52276ed53c0d2fac135b6 Mon Sep 17 00:00:00 2001 From: ifielker Date: Mon, 13 Jul 2020 19:41:10 -0400 Subject: [PATCH 5/5] adding comment about temporary workaround --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6b2ad2460..1a55482da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,6 @@ pytest-localserver >= 0.4.1 cachecontrol >= 0.12.6 google-api-core[grpc] >= 1.14.0, < 2.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 -google-auth == 1.18.0 +google-auth == 1.18.0 # temporary workaround google-cloud-firestore >= 1.4.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.18.0