diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 9af0db56f..761c1a1f7 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -476,6 +476,44 @@ def update_saml_provider_config( x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, display_name=display_name, enabled=enabled) + def delete_saml_provider_config(self, provider_id): + """Deletes the SAMLProviderConfig with the given ID. + + Args: + provider_id: Provider ID string. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + self._provider_manager.delete_saml_provider_config(provider_id) + + def list_saml_provider_configs( + self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + """Retrieves a page of SAML provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns None. If there are no SAML configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + + Returns: + ListProviderConfigsPage: A ListProviderConfigsPage instance. + + Raises: + ValueError: If max_results or page_token are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ + return self._provider_manager.list_saml_provider_configs(page_token, max_results) + def _check_jwt_revoked(self, verified_claims, exc_type, label): user = self.get_user(verified_claims.get('uid')) if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 6cec0f29a..9bcb7cc4b 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -22,6 +22,9 @@ from firebase_admin import _user_mgt +MAX_LIST_CONFIGS_RESULTS = 100 + + class ProviderConfig: """Parent type for all authentication provider config types.""" @@ -69,6 +72,72 @@ def rp_entity_id(self): return self._data.get('spConfig', {})['spEntityId'] +class ListProviderConfigsPage: + """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. + + Provides methods for traversing the provider configs included in this page, as well as + retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate + through all provider configs in the Firebase project starting from this page. + """ + + def __init__(self, download, page_token, max_results): + self._download = download + self._max_results = max_results + self._current = download(page_token, max_results) + + @property + def provider_configs(self): + """A list of ``AuthProviderConfig`` instances available in this page.""" + raise NotImplementedError + + @property + def next_page_token(self): + """Page token string for the next page (empty string indicates no more pages).""" + return self._current.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of provider configs, if available. + + Returns: + ListProviderConfigsPage: Next page of provider configs, or None if this is the last + page. + """ + if self.has_next_page: + return self.__class__(self._download, self.next_page_token, self._max_results) + return None + + def iterate_all(self): + """Retrieves an iterator for provider configs. + + Returned iterator will iterate through all the provider configs in the Firebase project + starting from this page. The iterator will never buffer more than one page of configs + in memory at a time. + + Returns: + iterator: An iterator of AuthProviderConfig instances. + """ + return _ProviderConfigIterator(self) + + +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): + + @property + def provider_configs(self): + return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] + + +class _ProviderConfigIterator(_auth_utils.PageIterator): + + @property + def items(self): + return self._current_page.provider_configs + + class ProviderConfigClient: """Client for managing Auth provider configurations.""" @@ -151,6 +220,31 @@ def update_saml_provider_config( body = self._make_request('patch', url, json=req, params=params) return SAMLProviderConfig(body) + def delete_saml_provider_config(self, provider_id): + _validate_saml_provider_id(provider_id) + self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) + + def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return _ListSAMLProviderConfigsPage( + self._fetch_saml_provider_configs, page_token, max_results) + + def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + """Fetches a page of SAML provider configs""" + if page_token is not None: + if not isinstance(page_token, str) or not page_token: + raise ValueError('Page token must be a non-empty string.') + if not isinstance(max_results, int): + raise ValueError('Max results must be an integer.') + if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: + raise ValueError( + 'Max results must be a positive integer less than or equal to ' + '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) + + params = 'pageSize={0}'.format(max_results) + if page_token: + params += '&pageToken={0}'.format(page_token) + return self._make_request('get', '/inboundSamlConfigs', params=params) + def _make_request(self, method, path, **kwargs): url = '{0}{1}'.format(self.base_url, path) try: diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index e05793d8f..f1ce97dee 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -30,6 +30,42 @@ VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) +class PageIterator: + """An iterator that allows iterating over a sequence of items, one at a time. + + This implementation loads a page of items into memory, and iterates on them. When the whole + page has been traversed, it loads another page. This class never keeps more than one page + of entries in memory. + """ + + def __init__(self, current_page): + if not current_page: + raise ValueError('Current page must not be None.') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self.items): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self.items): + result = self.items[self._index] + self._index += 1 + return result + raise StopIteration + + @property + def items(self): + raise NotImplementedError + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + def validate_uid(uid, required=False): if uid is None and not required: return None diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 0f3dc1a94..8b0a81adf 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -639,33 +639,8 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -class _UserIterator: - """An iterator that allows iterating over user accounts, one at a time. +class _UserIterator(_auth_utils.PageIterator): - This implementation loads a page of users into memory, and iterates on them. When the whole - page has been traversed, it loads another page. This class never keeps more than one page - of entries in memory. - """ - - def __init__(self, current_page): - if not current_page: - raise ValueError('Current page must not be None.') - self._current_page = current_page - self._index = 0 - - def next(self): - if self._index == len(self._current_page.users): - if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() - self._index = 0 - if self._index < len(self._current_page.users): - result = self._current_page.users[self._index] - self._index += 1 - return result - raise StopIteration - - def __next__(self): - return self.next() - - def __iter__(self): - return self + @property + def items(self): + return self._current_page.users diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index c422c3ab7..7d11bd58c 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -46,6 +46,7 @@ 'InvalidDynamicLinkDomainError', 'InvalidIdTokenError', 'InvalidSessionCookieError', + 'ListProviderConfigsPage', 'ListUsersPage', 'PhoneNumberAlreadyExistsError', 'ProviderConfig', @@ -67,6 +68,7 @@ 'create_saml_provider_config', 'create_session_cookie', 'create_user', + 'delete_saml_provider_config', 'delete_user', 'generate_email_verification_link', 'generate_password_reset_link', @@ -76,6 +78,7 @@ 'get_user_by_email', 'get_user_by_phone_number', 'import_users', + 'list_saml_provider_configs', 'list_users', 'revoke_refresh_tokens', 'set_custom_user_claims', @@ -100,6 +103,7 @@ InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError InvalidSessionCookieError = _token_gen.InvalidSessionCookieError +ListProviderConfigsPage = _auth_providers.ListProviderConfigsPage ListUsersPage = _user_mgt.ListUsersPage PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError ProviderConfig = _auth_providers.ProviderConfigClient @@ -633,3 +637,47 @@ def update_saml_provider_config( provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, display_name=display_name, enabled=enabled) + + +def delete_saml_provider_config(provider_id, app=None): + """Deletes the SAMLProviderConfig with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + client = _get_client(app) + client.delete_saml_provider_config(provider_id) + + +def list_saml_provider_configs( + page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + """Retrieves a page of SAML provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns None. If there are no SAML configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + app: An App instance (optional). + + Returns: + ListProviderConfigsPage: A ListProviderConfigsPage instance. + + Raises: + ValueError: If max_results or page_token are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ + client = _get_client(app) + return client.list_saml_provider_configs(page_token, max_results) diff --git a/tests/data/list_saml_provider_configs.json b/tests/data/list_saml_provider_configs.json new file mode 100644 index 000000000..b568e1e09 --- /dev/null +++ b/tests/data/list_saml_provider_configs.json @@ -0,0 +1,40 @@ +{ + "inboundSamlConfigs": [ + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider0", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + }, + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider1", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + } + ] +} diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py index 9ef59fbff..f5a66a7c5 100644 --- a/tests/test_auth_providers.py +++ b/tests/test_auth_providers.py @@ -26,6 +26,7 @@ USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') CONFIG_NOT_FOUND_RESPONSE = """{ "error": { @@ -33,6 +34,8 @@ } }""" +INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] + @pytest.fixture(scope='module') def user_mgt_app(): @@ -79,10 +82,8 @@ class TestSAMLProviderConfig: } } - @pytest.mark.parametrize('provider_id', [ - None, True, False, 1, 0, list(), tuple(), dict(), '', 'oidc.provider' - ]) - def test_invalid_provider_id(self, user_mgt_app, provider_id): + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_get_invalid_provider_id(self, user_mgt_app, provider_id): with pytest.raises(ValueError) as excinfo: auth.get_saml_provider_config(provider_id, app=user_mgt_app) @@ -239,6 +240,23 @@ def test_update_empty_values(self, user_mgt_app): got = json.loads(req.body.decode()) assert got == {'displayName': None, 'enabled': False} + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_delete_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.delete_saml_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid SAML provider ID') + + def test_delete(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, '{}') + + auth.delete_saml_provider_config('saml.provider', app=user_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + def test_config_not_found(self, user_mgt_app): _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) @@ -251,8 +269,112 @@ def test_config_not_found(self, user_mgt_app): assert excinfo.value.http_response is not None assert excinfo.value.cause is not None - def _assert_provider_config(self, provider_config): - assert provider_config.provider_id == 'saml.provider' + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + def test_invalid_max_results(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(max_results=arg, app=user_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + def test_invalid_page_token(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(page_token=arg, app=user_mgt_app) + + def test_list_single_page(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + page = auth.list_saml_provider_configs(app=user_mgt_app) + + self._assert_page(page) + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs?pageSize=100') + + def test_list_multiple_pages(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = self._create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(max_results=10, app=user_mgt_app) + + self._assert_page(page, next_page_token='token') + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + + self._assert_page(page, count=1, start=2) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + def test_paged_iteration(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = self._create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + iterator = page.iterate_all() + + for index in range(2): + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider{0}'.format(index) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider2' + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + with pytest.raises(StopIteration): + next(iterator) + + def test_list_empty_response(self, user_mgt_app): + response = {'inboundSamlConfigs': []} + _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + assert len(page.provider_configs) == 0 + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 0 + + def test_list_error(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.list_saml_provider_configs(app=user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + + def _assert_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id assert provider_config.display_name == 'samlProviderName' assert provider_config.enabled is True assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' @@ -260,3 +382,27 @@ def _assert_provider_config(self, provider_config): assert provider_config.x509_certificates == ['CERT1', 'CERT2'] assert provider_config.rp_entity_id == 'RP_ENTITY_ID' assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + def _assert_page(self, page, count=2, start=0, next_page_token=''): + assert isinstance(page, auth.ListProviderConfigsPage) + index = start + assert len(page.provider_configs) == count + for provider_config in page.provider_configs: + self._assert_provider_config(provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + if next_page_token: + assert page.next_page_token == next_page_token + assert page.has_next_page is True + else: + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + + def _create_list_response(self, sample_response, count=3): + configs = [] + for idx in range(count): + config = dict(sample_response) + config['name'] += str(idx) + configs.append(config) + return configs diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py index e08eaf8de..7cb8e7bab 100644 --- a/tests/test_tenant_mgt.py +++ b/tests/test_tenant_mgt.py @@ -93,6 +93,8 @@ } } +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') + INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] @@ -761,6 +763,44 @@ def test_update_saml_provider_config(self, tenant_mgt_app): recorder, url, SAML_PROVIDER_CONFIG_REQUEST, method='PATCH', prefix=PROVIDER_MGT_URL_PREFIX) + def test_delete_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + client.delete_saml_provider_config('saml.provider') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_list_saml_provider_configs(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + + page = client.list_saml_provider_configs() + + assert isinstance(page, auth.ListProviderConfigsPage) + index = 0 + assert len(page.provider_configs) == 2 + for provider_config in page.provider_configs: + self._assert_saml_provider_config( + provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format( + PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/inboundSamlConfigs?pageSize=100') + def test_tenant_not_found(self, tenant_mgt_app): client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) _instrument_user_mgt(client, 500, TENANT_NOT_FOUND_RESPONSE) @@ -782,8 +822,9 @@ def _assert_request( body = json.loads(req.body.decode()) assert body == want_body - def _assert_saml_provider_config(self, provider_config): - assert provider_config.provider_id == 'saml.provider' + def _assert_saml_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id assert provider_config.display_name == 'samlProviderName' assert provider_config.enabled is True assert provider_config.idp_entity_id == 'IDP_ENTITY_ID'