diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py new file mode 100644 index 000000000..5a9ade9b9 --- /dev/null +++ b/firebase_admin/_auth_providers.py @@ -0,0 +1,100 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase auth providers management sub module.""" + +import requests + +from firebase_admin import _auth_utils + + +class ProviderConfig: + """Parent type for all authentication provider config types.""" + + def __init__(self, data): + self._data = data + + @property + def provider_id(self): + name = self._data['name'] + return name.split('/')[-1] + + @property + def display_name(self): + return self._data.get('displayName') + + @property + def enabled(self): + return self._data['enabled'] + + +class SAMLProviderConfig(ProviderConfig): + """Represents he SAML auth provider configuration. + + See http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html.""" + + @property + def idp_entity_id(self): + return self._data.get('idpConfig', {})['idpEntityId'] + + @property + def sso_url(self): + return self._data.get('idpConfig', {})['ssoUrl'] + + @property + def x509_certificates(self): + certs = self._data.get('idpConfig', {})['idpCertificates'] + return [c['x509Certificate'] for c in certs] + + @property + def request_signing_enabled(self): + return self._data.get('idpConfig', {})['signRequest'] + + @property + def callback_url(self): + return self._data.get('spConfig', {})['callbackUri'] + + @property + def rp_entity_id(self): + return self._data.get('spConfig', {})['spEntityId'] + + +class ProviderConfigClient: + """Client for managing Auth provider configurations.""" + + PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1' + + def __init__(self, http_client, project_id, tenant_id=None): + self.http_client = http_client + self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id) + if tenant_id: + self.base_url += '/tenants/{0}'.format(tenant_id) + + def get_saml_provider_config(self, provider_id): + if not isinstance(provider_id, str): + raise ValueError( + 'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format( + provider_id)) + if not provider_id.startswith('saml.'): + raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id)) + + body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id)) + return SAMLProviderConfig(body) + + def _make_request(self, method, path, body=None): + url = '{0}{1}'.format(self.base_url, path) + try: + return self.http_client.body(method, url, json=body) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 95e7f2718..e05793d8f 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -307,7 +307,17 @@ def __init__(self, message): exceptions.InvalidArgumentError.__init__(self, message) +class ConfigurationNotFoundError(exceptions.NotFoundError): + """No auth provider found for the specified identifier.""" + + default_message = 'No auth provider found for the given identifier' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + _CODE_TO_EXC_TYPE = { + 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, 'EMAIL_EXISTS': EmailAlreadyExistsError, diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 2e40414cd..18a8008c7 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -82,10 +82,13 @@ def from_iam(cls, request, google_cred, service_account): class TokenGenerator: """Generates custom tokens and session cookies.""" - def __init__(self, app, client): + ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' + + def __init__(self, app, http_client): self.app = app - self.client = client + self.http_client = http_client self.request = transport.requests.Request() + self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, app.project_id) self._signing_provider = None def _init_signing_provider(self): @@ -192,13 +195,13 @@ def create_session_cookie(self, id_token, expires_in): raise ValueError('Illegal expiry duration: {0}. Duration must be at most {1} ' 'seconds.'.format(expires_in, MAX_SESSION_COOKIE_DURATION_SECONDS)) + url = '{0}:createSessionCookie'.format(self.base_url) payload = { 'idToken': id_token, 'validDuration': expires_in, } try: - body, http_resp = self.client.body_and_response( - 'post', ':createSessionCookie', json=payload) + body, http_resp = self.http_client.body_and_response('post', url, json=payload) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) else: diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 533259e70..0f3dc1a94 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -454,8 +454,13 @@ def encode_action_code_settings(settings): class UserManager: """Provides methods for interacting with the Google Identity Toolkit.""" - def __init__(self, client): - self._client = client + ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' + + def __init__(self, http_client, project_id, tenant_id=None): + self.http_client = http_client + self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, project_id) + if tenant_id: + self.base_url += '/tenants/{0}'.format(tenant_id) def get_user(self, **kwargs): """Gets the user data corresponding to the provided key.""" @@ -471,17 +476,12 @@ def get_user(self, **kwargs): else: raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:lookup', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('users'): - raise _auth_utils.UserNotFoundError( - 'No user record found for the provided {0}: {1}.'.format(key_type, key), - http_response=http_resp) - return body['users'][0] + body, http_resp = self._make_request('post', '/accounts:lookup', json=payload) + if not body or not body.get('users'): + raise _auth_utils.UserNotFoundError( + 'No user record found for the provided {0}: {1}.'.format(key_type, key), + http_response=http_resp) + return body['users'][0] def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): """Retrieves a batch of users.""" @@ -498,10 +498,8 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): payload = {'maxResults': max_results} if page_token: payload['nextPageToken'] = page_token - try: - return self._client.body('get', '/accounts:batchGet', params=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) + body, _ = self._make_request('get', '/accounts:batchGet', params=payload) + return body def create_user(self, uid=None, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None): @@ -517,15 +515,11 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None 'disabled': bool(disabled) if disabled is not None else None, } payload = {k: v for k, v in payload.items() if v is not None} - try: - body, http_resp = self._client.body_and_response('post', '/accounts', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('localId'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to create new user.', http_response=http_resp) - return body.get('localId') + body, http_resp = self._make_request('post', '/accounts', json=payload) + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create new user.', http_response=http_resp) + return body.get('localId') def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, @@ -568,29 +562,19 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) payload = {k: v for k, v in payload.items() if v is not None} - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:update', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('localId'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to update user: {0}.'.format(uid), http_response=http_resp) - return body.get('localId') + body, http_resp = self._make_request('post', '/accounts:update', json=payload) + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to update user: {0}.'.format(uid), http_response=http_resp) + return body.get('localId') def delete_user(self, uid): """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:delete', json={'localId' : uid}) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('kind'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) + body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) + if not body or not body.get('kind'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) def import_users(self, users, hash_alg=None): """Imports the given list of users to Firebase Auth.""" @@ -609,16 +593,11 @@ def import_users(self, users, hash_alg=None): if not isinstance(hash_alg, _user_import.UserImportHash): raise ValueError('A UserImportHash is required to import users with passwords.') payload.update(hash_alg.to_dict()) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:batchCreate', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not isinstance(body, dict): - raise _auth_utils.UnexpectedResponseError( - 'Failed to import users.', http_response=http_resp) - return body + body, http_resp = self._make_request('post', '/accounts:batchCreate', json=payload) + if not isinstance(body, dict): + raise _auth_utils.UnexpectedResponseError( + 'Failed to import users.', http_response=http_resp) + return body def generate_email_action_link(self, action_type, email, action_code_settings=None): """Fetches the email action links for types @@ -646,16 +625,18 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No if action_code_settings: payload.update(encode_action_code_settings(action_code_settings)) + body, http_resp = self._make_request('post', '/accounts:sendOobCode', json=payload) + if not body or not body.get('oobLink'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to generate email action link.', http_response=http_resp) + return body.get('oobLink') + + def _make_request(self, method, path, **kwargs): + url = '{0}{1}'.format(self.base_url, path) try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:sendOobCode', json=payload) + return self.http_client.body_and_response(method, url, **kwargs) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('oobLink'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to generate email action link.', http_response=http_resp) - return body.get('oobLink') class _UserIterator: diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index f6fb1da43..e6ab15a80 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -22,6 +22,7 @@ import time import firebase_admin +from firebase_admin import _auth_providers from firebase_admin import _auth_utils from firebase_admin import _http_client from firebase_admin import _token_gen @@ -50,8 +51,10 @@ 'InvalidSessionCookieError', 'ListUsersPage', 'PhoneNumberAlreadyExistsError', + 'ProviderConfig', 'RevokedIdTokenError', 'RevokedSessionCookieError', + 'SAMLProviderConfig', 'TokenSignError', 'UidAlreadyExistsError', 'UnexpectedResponseError', @@ -70,6 +73,7 @@ 'generate_email_verification_link', 'generate_password_reset_link', 'generate_sign_in_with_email_link', + 'get_saml_provider_config', 'get_user', 'get_user_by_email', 'get_user_by_phone_number', @@ -84,6 +88,7 @@ ActionCodeSettings = _user_mgt.ActionCodeSettings CertificateFetchError = _token_gen.CertificateFetchError +ConfigurationNotFoundError = _auth_utils.ConfigurationNotFoundError DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE EmailAlreadyExistsError = _auth_utils.EmailAlreadyExistsError ErrorInfo = _user_import.ErrorInfo @@ -97,8 +102,10 @@ InvalidSessionCookieError = _token_gen.InvalidSessionCookieError ListUsersPage = _user_mgt.ListUsersPage PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError +ProviderConfig = _auth_providers.ProviderConfigClient RevokedIdTokenError = _token_gen.RevokedIdTokenError RevokedSessionCookieError = _token_gen.RevokedSessionCookieError +SAMLProviderConfig = _auth_providers.SAMLProviderConfig TokenSignError = _token_gen.TokenSignError UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError UnexpectedResponseError = _auth_utils.UnexpectedResponseError @@ -521,6 +528,7 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): the link is to be handled by a mobile app and the additional state information to be passed in the deep link. app: An App instance (optional). + Returns: link: The email sign-in link created by the API @@ -533,32 +541,46 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): email, action_code_settings=action_code_settings) +def get_saml_provider_config(provider_id, app=None): + """Returns the SAMLProviderConfig with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Returns: + SAMLProviderConfig: A SAMLProviderConfig instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the SAML provider. + """ + client = _get_client(app) + return client.get_saml_provider_config(provider_id) + + class Client: """Firebase Authentication client scoped to a specific tenant.""" - ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1/projects/' - def __init__(self, app, tenant_id=None): - credential = app.credential.get_credential() - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) - if not app.project_id: raise ValueError("""Project ID is required to access the auth service. 1. Use a service account credential, or 2. set the project ID explicitly via Firebase App options, or 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") - url_path = app.project_id - if tenant_id: - url_path += '/tenants/{0}'.format(tenant_id) - + credential = app.credential.get_credential() + version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) http_client = _http_client.JsonHttpClient( - credential=credential, base_url=self.ID_TOOLKIT_URL + url_path, - headers={'X-Client-Version': version_header}) + credential=credential, headers={'X-Client-Version': version_header}) + self._tenant_id = tenant_id self._token_generator = _token_gen.TokenGenerator(app, http_client) self._token_verifier = _token_gen.TokenVerifier(app) - self._user_manager = _user_mgt.UserManager(http_client) + self._user_manager = _user_mgt.UserManager(http_client, app.project_id, tenant_id) + self._provider_manager = _auth_providers.ProviderConfigClient( + http_client, app.project_id, tenant_id) @property def tenant_id(self): @@ -903,6 +925,22 @@ def generate_sign_in_with_email_link(self, email, action_code_settings): return self._user_manager.generate_email_action_link( 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) + def get_saml_provider_config(self, provider_id): + """Returns the SAMLProviderConfig with the given ID. + + Args: + provider_id: Provider ID string. + + Returns: + SAMLProviderConfig: A SAMLProviderConfig instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the SAML provider. + """ + return self._provider_manager.get_saml_provider_config(provider_id) + def _check_jwt_revoked(self, verified_claims, exc_type, label): user = self.get_user(verified_claims.get('uid')) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: diff --git a/tests/data/saml_provider_config.json b/tests/data/saml_provider_config.json new file mode 100644 index 000000000..577340f2a --- /dev/null +++ b/tests/data/saml_provider_config.json @@ -0,0 +1,18 @@ +{ + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true +} \ No newline at end of file diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py new file mode 100644 index 000000000..95d68b40e --- /dev/null +++ b/tests/test_auth_providers.py @@ -0,0 +1,94 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin._auth_providers module.""" + +import pytest + +import firebase_admin +from firebase_admin import auth +from firebase_admin import exceptions +from firebase_admin import _auth_providers +from tests import testutils + +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') + +CONFIG_NOT_FOUND_RESPONSE = """{ + "error": { + "message": "CONFIGURATION_NOT_FOUND" + } +}""" + + +@pytest.fixture(scope='module') +def user_mgt_app(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='providerConfig', + options={'projectId': 'mock-project-id'}) + yield app + firebase_admin.delete_app(app) + + +def _instrument_provider_mgt(app, status, payload): + client = auth._get_client(app) + provider_manager = client._provider_manager + recorder = [] + provider_manager.http_client.session.mount( + _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +class TestSAMLProviderConfig: + + @pytest.mark.parametrize('provider_id', [ + None, True, False, 1, 0, list(), tuple(), dict(), '', 'oidc.provider' + ]) + def test_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.get_saml_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid SAML provider ID') + + def test_get_saml_provider_config(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.get_saml_provider_config('saml.provider', app=user_mgt_app) + + assert provider_config.provider_id == 'saml.provider' + assert provider_config.display_name == 'samlProviderName' + assert provider_config.enabled is True + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.request_signing_enabled is True + assert provider_config.x509_certificates == ['CERT1', 'CERT2'] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + + def test_config_not_found(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) + + with pytest.raises(auth.ConfigurationNotFoundError) as excinfo: + auth.get_saml_provider_config('saml.provider', app=user_mgt_app) + + error_msg = 'No auth provider found for the given identifier (CONFIGURATION_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index 47a647129..c40176b9f 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -24,6 +24,8 @@ from firebase_admin import credentials from firebase_admin import exceptions from firebase_admin import tenant_mgt +from firebase_admin import _auth_providers +from firebase_admin import _user_mgt from tests import testutils from tests import test_token_gen @@ -76,10 +78,13 @@ MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') +SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') + INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' +PROVIDER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' TENANT_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' @@ -103,8 +108,17 @@ def _instrument_tenant_mgt(app, status, payload): def _instrument_user_mgt(client, status, payload): recorder = [] user_manager = client._user_manager - user_manager._client.session.mount( - auth.Client.ID_TOOLKIT_URL, + user_manager.http_client.session.mount( + _user_mgt.UserManager.ID_TOOLKIT_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +def _instrument_provider_mgt(client, status, payload): + recorder = [] + provider_manager = client._provider_manager + provider_manager.http_client.session.mount( + _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, testutils.MockAdapter(payload, status, recorder)) return recorder @@ -686,6 +700,28 @@ def test_generate_sign_in_with_email_link(self, tenant_mgt_app): 'continueUrl': 'http://localhost', }) + def test_get_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.get_saml_provider_config('saml.provider') + + assert provider_config.provider_id == 'saml.provider' + assert provider_config.display_name == 'samlProviderName' + assert provider_config.enabled is True + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.request_signing_enabled is True + assert provider_config.x509_certificates == ['CERT1', 'CERT2'] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( + PROVIDER_MGT_URL_PREFIX) + def test_tenant_not_found(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) _instrument_user_mgt(client, 500, TENANT_NOT_FOUND_RESPONSE) diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index dbb7020cf..f88c87ff4 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -120,8 +120,8 @@ def _instrument_user_manager(app, status, payload): client = auth._get_client(app) user_manager = client._user_manager recorder = [] - user_manager._client.session.mount( - auth.Client.ID_TOOLKIT_URL, + user_manager.http_client.session.mount( + _token_gen.TokenGenerator.ID_TOOLKIT_URL, testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 4f4efbd28..b64a4d1f3 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -50,6 +50,9 @@ } MOCK_ACTION_CODE_SETTINGS = auth.ActionCodeSettings(**MOCK_ACTION_CODE_DATA) +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' + + @pytest.fixture(scope='module') def user_mgt_app(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt', @@ -61,8 +64,8 @@ def _instrument_user_manager(app, status, payload): client = auth._get_client(app) user_manager = client._user_manager recorder = [] - user_manager._client.session.mount( - auth.Client.ID_TOOLKIT_URL, + user_manager.http_client.session.mount( + _user_mgt.UserManager.ID_TOOLKIT_URL, testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder @@ -101,12 +104,22 @@ def _check_user_record(user, expected_uid='testuser'): assert provider.provider_id == 'phone' +def _check_request(recorder, want_url, want_body=None): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, want_url) + if want_body: + body = json.loads(req.body.decode()) + assert body == want_body + + class TestAuthServiceInitialization: def test_default_timeout(self, user_mgt_app): client = auth._get_client(user_mgt_app) user_manager = client._user_manager - assert user_manager._client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS + assert user_manager.http_client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS def test_fail_on_no_project_id(self): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt2') @@ -203,8 +216,9 @@ def test_invalid_get_user(self, arg, user_mgt_app): auth.get_user(arg, app=user_mgt_app) def test_get_user(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user('testuser', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'localId': ['testuser']}) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-an-email']) def test_invalid_get_user_by_email(self, arg, user_mgt_app): @@ -212,8 +226,9 @@ def test_invalid_get_user_by_email(self, arg, user_mgt_app): auth.get_user_by_email(arg, app=user_mgt_app) def test_get_user_by_email(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user_by_email('testuser@example.com', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'email': ['testuser@example.com']}) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-a-phone']) def test_invalid_get_user_by_phone(self, arg, user_mgt_app): @@ -221,8 +236,9 @@ def test_invalid_get_user_by_phone(self, arg, user_mgt_app): auth.get_user_by_phone_number(arg, app=user_mgt_app) def test_get_user_by_phone(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user_by_phone_number('+1234567890', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'phoneNumber': ['+1234567890']}) def test_get_user_non_existing(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') @@ -1050,7 +1066,7 @@ def test_import_users(self, user_mgt_app): assert result.failure_count == 0 assert result.errors == [] expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}]} - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_error(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, """{"error": [ @@ -1073,7 +1089,7 @@ def test_import_users_error(self, user_mgt_app): assert err.index == 2 assert err.reason == 'Another error occured in user3' expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}, {'localId': 'user3'}]} - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_missing_required_hash(self, user_mgt_app): users = [ @@ -1106,7 +1122,7 @@ def test_import_users_with_hash(self, user_mgt_app): 'memoryCost': 14, 'saltSeparator': _user_import.b64_encode(b'sep'), } - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_http_error(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 401, '{"error": {"message": "ERROR_CODE"}}') @@ -1127,11 +1143,6 @@ def test_import_users_unexpected_response(self, user_mgt_app): with pytest.raises(auth.UnexpectedResponseError): auth.import_users(users, app=user_mgt_app) - def _check_rpc_calls(self, recorder, expected): - assert len(recorder) == 1 - request = json.loads(recorder[0].body.decode()) - assert request == expected - class TestRevokeRefreshTokkens: