diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py new file mode 100644 index 000000000..b7af6ddb6 --- /dev/null +++ b/firebase_admin/_auth_client.py @@ -0,0 +1,625 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase auth client sub module.""" + +import time + +import firebase_admin +from firebase_admin import _auth_providers +from firebase_admin import _auth_utils +from firebase_admin import _http_client +from firebase_admin import _token_gen +from firebase_admin import _user_import +from firebase_admin import _user_mgt + + +class Client: + """Firebase Authentication client scoped to a specific tenant.""" + + def __init__(self, app, tenant_id=None): + if not app.project_id: + raise ValueError("""A project ID is required to access the auth service. + 1. Use a service account credential, or + 2. set the project ID explicitly via Firebase App options, or + 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") + + credential = app.credential.get_credential() + version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + http_client = _http_client.JsonHttpClient( + credential=credential, headers={'X-Client-Version': version_header}) + + self._tenant_id = tenant_id + self._token_generator = _token_gen.TokenGenerator(app, http_client) + self._token_verifier = _token_gen.TokenVerifier(app) + self._user_manager = _user_mgt.UserManager(http_client, app.project_id, tenant_id) + self._provider_manager = _auth_providers.ProviderConfigClient( + http_client, app.project_id, tenant_id) + + @property + def tenant_id(self): + """Tenant ID associated with this client.""" + return self._tenant_id + + def create_custom_token(self, uid, developer_claims=None): + """Builds and signs a Firebase custom auth token. + + Args: + uid: ID of the user for whom the token is created. + developer_claims: A dictionary of claims to be included in the token + (optional). + + Returns: + bytes: A token minted from the input parameters. + + Raises: + ValueError: If input parameters are invalid. + TokenSignError: If an error occurs while signing the token using the remote IAM service. + """ + return self._token_generator.create_custom_token( + uid, developer_claims, tenant_id=self.tenant_id) + + def verify_id_token(self, id_token, check_revoked=False): + """Verifies the signature and data for the provided JWT. + + Accepts a signed token string, verifies that it is current, was issued + to this project, and that it was correctly signed by Google. + + Args: + id_token: A string of the encoded JWT. + check_revoked: Boolean, If true, checks whether the token has been revoked (optional). + + Returns: + dict: A dictionary of key-value pairs parsed from the decoded JWT. + + Raises: + ValueError: If ``id_token`` is a not a string or is empty. + InvalidIdTokenError: If ``id_token`` is not a valid Firebase ID token. + ExpiredIdTokenError: If the specified ID token has expired. + RevokedIdTokenError: If ``check_revoked`` is ``True`` and the ID token has been + revoked. + TenantIdMismatchError: If ``id_token`` belongs to a tenant that is different than + this ``Client`` instance. + CertificateFetchError: If an error occurs while fetching the public key certificates + required to verify the ID token. + """ + if not isinstance(check_revoked, bool): + # guard against accidental wrong assignment. + raise ValueError('Illegal check_revoked argument. Argument must be of type ' + ' bool, but given "{0}".'.format(type(check_revoked))) + + verified_claims = self._token_verifier.verify_id_token(id_token) + if self.tenant_id: + token_tenant_id = verified_claims.get('firebase', {}).get('tenant') + if self.tenant_id != token_tenant_id: + raise _auth_utils.TenantIdMismatchError( + 'Invalid tenant ID: {0}'.format(token_tenant_id)) + + if check_revoked: + self._check_jwt_revoked(verified_claims, _token_gen.RevokedIdTokenError, 'ID token') + return verified_claims + + def revoke_refresh_tokens(self, uid): + """Revokes all refresh tokens for an existing user. + + This method updates the user's ``tokens_valid_after_timestamp`` to the current UTC + in seconds since the epoch. It is important that the server on which this is called has its + clock set correctly and synchronized. + + While this revokes all sessions for a specified user and disables any new ID tokens for + existing sessions from getting minted, existing ID tokens may remain active until their + natural expiration (one hour). To verify that ID tokens are revoked, use + ``verify_id_token(idToken, check_revoked=True)``. + + Args: + uid: A user ID string. + + Raises: + ValueError: If the user ID is None, empty or malformed. + FirebaseError: If an error occurs while revoking the refresh token. + """ + self._user_manager.update_user(uid, valid_since=int(time.time())) + + def get_user(self, uid): + """Gets the user data corresponding to the specified user ID. + + Args: + uid: A user ID string. + + Returns: + UserRecord: A user record instance. + + Raises: + ValueError: If the user ID is None, empty or malformed. + UserNotFoundError: If the specified user ID does not exist. + FirebaseError: If an error occurs while retrieving the user. + """ + response = self._user_manager.get_user(uid=uid) + return _user_mgt.UserRecord(response) + + def get_user_by_email(self, email): + """Gets the user data corresponding to the specified user email. + + Args: + email: A user email address string. + + Returns: + UserRecord: A user record instance. + + Raises: + ValueError: If the email is None, empty or malformed. + UserNotFoundError: If no user exists by the specified email address. + FirebaseError: If an error occurs while retrieving the user. + """ + response = self._user_manager.get_user(email=email) + return _user_mgt.UserRecord(response) + + def get_user_by_phone_number(self, phone_number): + """Gets the user data corresponding to the specified phone number. + + Args: + phone_number: A phone number string. + + Returns: + UserRecord: A user record instance. + + Raises: + ValueError: If the phone number is ``None``, empty or malformed. + UserNotFoundError: If no user exists by the specified phone number. + FirebaseError: If an error occurs while retrieving the user. + """ + response = self._user_manager.get_user(phone_number=phone_number) + return _user_mgt.UserRecord(response) + + def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS): + """Retrieves a page of user accounts from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of user accounts that may be included in the returned + page. This function never returns ``None``. If there are no user accounts in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 1000, which is also the maximum number + allowed. + + Returns: + ListUsersPage: A page of user accounts. + + Raises: + ValueError: If max_results or page_token are invalid. + FirebaseError: If an error occurs while retrieving the user accounts. + """ + def download(page_token, max_results): + return self._user_manager.list_users(page_token, max_results) + return _user_mgt.ListUsersPage(download, page_token, max_results) + + def create_user(self, **kwargs): # pylint: disable=differing-param-doc + """Creates a new user account with the specified properties. + + Args: + kwargs: A series of keyword arguments (optional). + + Keyword Args: + uid: User ID to assign to the newly created user (optional). + display_name: The user's display name (optional). + email: The user's primary email (optional). + email_verified: A boolean indicating whether or not the user's primary email is + verified (optional). + phone_number: The user's primary phone number (optional). + photo_url: The user's photo URL (optional). + password: The user's raw, unhashed password. (optional). + disabled: A boolean indicating whether or not the user account is disabled (optional). + + Returns: + UserRecord: A UserRecord instance for the newly created user. + + Raises: + ValueError: If the specified user properties are invalid. + FirebaseError: If an error occurs while creating the user account. + """ + uid = self._user_manager.create_user(**kwargs) + return self.get_user(uid=uid) + + def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc + """Updates an existing user account with the specified properties. + + Args: + uid: A user ID string. + kwargs: A series of keyword arguments (optional). + + Keyword Args: + display_name: The user's display name (optional). Can be removed by explicitly passing + ``auth.DELETE_ATTRIBUTE``. + email: The user's primary email (optional). + email_verified: A boolean indicating whether or not the user's primary email is + verified (optional). + phone_number: The user's primary phone number (optional). Can be removed by explicitly + passing ``auth.DELETE_ATTRIBUTE``. + photo_url: The user's photo URL (optional). Can be removed by explicitly passing + ``auth.DELETE_ATTRIBUTE``. + password: The user's raw, unhashed password. (optional). + disabled: A boolean indicating whether or not the user account is disabled (optional). + custom_claims: A dictionary or a JSON string contining the custom claims to be set on + the user account (optional). To remove all custom claims, pass + ``auth.DELETE_ATTRIBUTE``. + valid_since: An integer signifying the seconds since the epoch (optional). This field + is set by ``revoke_refresh_tokens`` and it is discouraged to set this field + directly. + + Returns: + UserRecord: An updated UserRecord instance for the user. + + Raises: + ValueError: If the specified user ID or properties are invalid. + FirebaseError: If an error occurs while updating the user account. + """ + self._user_manager.update_user(uid, **kwargs) + return self.get_user(uid=uid) + + def set_custom_user_claims(self, uid, custom_claims): + """Sets additional claims on an existing user account. + + Custom claims set via this function can be used to define user roles and privilege levels. + These claims propagate to all the devices where the user is already signed in (after token + expiration or when token refresh is forced), and next time the user signs in. The claims + can be accessed via the user's ID token JWT. If a reserved OIDC claim is specified (sub, + iat, iss, etc), an error is thrown. Claims payload must also not be larger then 1000 + characters when serialized into a JSON string. + + Args: + uid: A user ID string. + custom_claims: A dictionary or a JSON string of custom claims. Pass None to unset any + claims set previously. + + Raises: + ValueError: If the specified user ID or the custom claims are invalid. + FirebaseError: If an error occurs while updating the user account. + """ + if custom_claims is None: + custom_claims = _user_mgt.DELETE_ATTRIBUTE + self._user_manager.update_user(uid, custom_claims=custom_claims) + + def delete_user(self, uid): + """Deletes the user identified by the specified user ID. + + Args: + uid: A user ID string. + + Raises: + ValueError: If the user ID is None, empty or malformed. + FirebaseError: If an error occurs while deleting the user account. + """ + self._user_manager.delete_user(uid) + + def import_users(self, users, hash_alg=None): + """Imports the specified list of users into Firebase Auth. + + At most 1000 users can be imported at a time. This operation is optimized for bulk imports + and ignores checks on identifier uniqueness, which could result in duplications. The + ``hash_alg`` parameter must be specified when importing users with passwords. Refer to the + ``UserImportHash`` class for supported hash algorithms. + + Args: + users: A list of ``ImportUserRecord`` instances to import. Length of the list must not + exceed 1000. + hash_alg: A ``UserImportHash`` object (optional). Required when importing users with + passwords. + + Returns: + UserImportResult: An object summarizing the result of the import operation. + + Raises: + ValueError: If the provided arguments are invalid. + FirebaseError: If an error occurs while importing users. + """ + result = self._user_manager.import_users(users, hash_alg) + return _user_import.UserImportResult(result, len(users)) + + def generate_password_reset_link(self, email, action_code_settings=None): + """Generates the out-of-band email action link for password reset flows for the specified + email address. + + Args: + email: The email of the user whose password is to be reset. + action_code_settings: ``ActionCodeSettings`` instance (optional). Defines whether + the link is to be handled by a mobile app and the additional state information to + be passed in the deep link. + + Returns: + link: The password reset link created by the API + + Raises: + ValueError: If the provided arguments are invalid + FirebaseError: If an error occurs while generating the link + """ + return self._user_manager.generate_email_action_link( + 'PASSWORD_RESET', email, action_code_settings=action_code_settings) + + def generate_email_verification_link(self, email, action_code_settings=None): + """Generates the out-of-band email action link for email verification flows for the + specified email address. + + Args: + email: The email of the user to be verified. + action_code_settings: ``ActionCodeSettings`` instance (optional). Defines whether + the link is to be handled by a mobile app and the additional state information to + be passed in the deep link. + + Returns: + link: The email verification link created by the API + + Raises: + ValueError: If the provided arguments are invalid + FirebaseError: If an error occurs while generating the link + """ + return self._user_manager.generate_email_action_link( + 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) + + def generate_sign_in_with_email_link(self, email, action_code_settings): + """Generates the out-of-band email action link for email link sign-in flows, using the + action code settings provided. + + Args: + email: The email of the user signing in. + action_code_settings: ``ActionCodeSettings`` instance. Defines whether + the link is to be handled by a mobile app and the additional state information to be + passed in the deep link. + + Returns: + link: The email sign-in link created by the API + + Raises: + ValueError: If the provided arguments are invalid + FirebaseError: If an error occurs while generating the link + """ + return self._user_manager.generate_email_action_link( + 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) + + def get_oidc_provider_config(self, provider_id): + """Returns the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Returns: + SAMLProviderConfig: An OIDC provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the OIDC provider. + """ + return self._provider_manager.get_oidc_provider_config(provider_id) + + def create_oidc_provider_config( + self, provider_id, client_id, issuer, display_name=None, enabled=None): + """Creates a new OIDC provider config from the given parameters. + + OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config. + issuer: Issuer of the new config. Must be a valid URL. + display_name: The user-friendly display name to the current configuration (optional). + This name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + + Returns: + OIDCProviderConfig: The newly created OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new OIDC provider config. + """ + return self._provider_manager.create_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + def update_oidc_provider_config( + self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None): + """Updates an existing OIDC provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config (optional). + issuer: Issuer of the new config (optional). Must be a valid URL. + display_name: The user-friendly display name to the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + + Returns: + OIDCProviderConfig: The updated OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the OIDC provider config. + """ + return self._provider_manager.update_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + def delete_oidc_provider_config(self, provider_id): + """Deletes the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the OIDC provider. + """ + self._provider_manager.delete_oidc_provider_config(provider_id) + + def list_oidc_provider_configs( + self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + """Retrieves a page of OIDC provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no OIDC configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + + Returns: + ListProviderConfigsPage: A page of OIDC provider config instances. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the OIDC provider configs. + """ + return self._provider_manager.list_oidc_provider_configs(page_token, max_results) + + def get_saml_provider_config(self, provider_id): + """Returns the ``SAMLProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Returns: + SAMLProviderConfig: A SAML provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the SAML provider. + """ + return self._provider_manager.get_saml_provider_config(provider_id) + + def create_saml_provider_config( + self, provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, + callback_url, display_name=None, enabled=None): + """Creates a new SAML provider config from the given parameters. + + SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier. + sso_url: The SAML IdP SSO URL. Must be a valid URL. + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this + provider. Multiple certificates are accepted to prevent outages during IdP key + rotation (for example ADFS rotates every 10 days). When the Auth server receives a + SAML response, it will match the SAML response with the certificate on record. + Otherwise the response is rejected. Developers are expected to manage the + certificate updates as keys are rotated. + rp_entity_id: The SAML relying party (service provider) entity ID. This is defined by + the developer but needs to be provided to the SAML IdP. + callback_url: Callback URL string. This is fixed and must always be the same as the + OAuth redirect URL provisioned by Firebase Auth, unless a custom authDomain is + used. + display_name: The user-friendly display name to the current configuration (optional). + This name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + + Returns: + SAMLProviderConfig: The newly created SAML provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new SAML provider config. + """ + return self._provider_manager.create_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, + callback_url=callback_url, display_name=display_name, enabled=enabled) + + def update_saml_provider_config( + self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, + rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + """Updates an existing SAML provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier (optional). + sso_url: The SAML IdP SSO URL. Must be a valid URL (optional). + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this + provider (optional). + rp_entity_id: The SAML relying party entity ID (optional). + callback_url: Callback URL string (optional). + display_name: The user-friendly display name of the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + + Returns: + SAMLProviderConfig: The updated SAML provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the SAML provider config. + """ + return self._provider_manager.update_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, + callback_url=callback_url, display_name=display_name, enabled=enabled) + + def delete_saml_provider_config(self, provider_id): + """Deletes the ``SAMLProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + self._provider_manager.delete_saml_provider_config(provider_id) + + def list_saml_provider_configs( + self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + """Retrieves a page of SAML provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no SAML configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + + Returns: + ListProviderConfigsPage: A page of SAML provider config instances. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ + return self._provider_manager.list_saml_provider_configs(page_token, max_results) + + def _check_jwt_revoked(self, verified_claims, exc_type, label): + user = self.get_user(verified_claims.get('uid')) + if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: + raise exc_type('The Firebase {0} has been revoked.'.format(label)) diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py new file mode 100644 index 000000000..96f1b5348 --- /dev/null +++ b/firebase_admin/_auth_providers.py @@ -0,0 +1,390 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase auth providers management sub module.""" + +from urllib import parse + +import requests + +from firebase_admin import _auth_utils +from firebase_admin import _user_mgt + + +MAX_LIST_CONFIGS_RESULTS = 100 + + +class ProviderConfig: + """Parent type for all authentication provider config types.""" + + def __init__(self, data): + self._data = data + + @property + def provider_id(self): + name = self._data['name'] + return name.split('/')[-1] + + @property + def display_name(self): + return self._data.get('displayName') + + @property + def enabled(self): + return self._data.get('enabled', False) + + +class OIDCProviderConfig(ProviderConfig): + """Represents the OIDC auth provider configuration. + + See https://openid.net/specs/openid-connect-core-1_0-final.html. + """ + + @property + def issuer(self): + return self._data['issuer'] + + @property + def client_id(self): + return self._data['clientId'] + + +class SAMLProviderConfig(ProviderConfig): + """Represents he SAML auth provider configuration. + + See http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-tech-overview-2.0.html. + """ + + @property + def idp_entity_id(self): + return self._data.get('idpConfig', {})['idpEntityId'] + + @property + def sso_url(self): + return self._data.get('idpConfig', {})['ssoUrl'] + + @property + def x509_certificates(self): + certs = self._data.get('idpConfig', {})['idpCertificates'] + return [c['x509Certificate'] for c in certs] + + @property + def callback_url(self): + return self._data.get('spConfig', {})['callbackUri'] + + @property + def rp_entity_id(self): + return self._data.get('spConfig', {})['spEntityId'] + + +class ListProviderConfigsPage: + """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. + + Provides methods for traversing the provider configs included in this page, as well as + retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate + through all provider configs in the Firebase project starting from this page. + """ + + def __init__(self, download, page_token, max_results): + self._download = download + self._max_results = max_results + self._current = download(page_token, max_results) + + @property + def provider_configs(self): + """A list of ``AuthProviderConfig`` instances available in this page.""" + raise NotImplementedError + + @property + def next_page_token(self): + """Page token string for the next page (empty string indicates no more pages).""" + return self._current.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of provider configs, if available. + + Returns: + ListProviderConfigsPage: Next page of provider configs, or None if this is the last + page. + """ + if self.has_next_page: + return self.__class__(self._download, self.next_page_token, self._max_results) + return None + + def iterate_all(self): + """Retrieves an iterator for provider configs. + + Returned iterator will iterate through all the provider configs in the Firebase project + starting from this page. The iterator will never buffer more than one page of configs + in memory at a time. + + Returns: + iterator: An iterator of AuthProviderConfig instances. + """ + return _ProviderConfigIterator(self) + + +class _ListOIDCProviderConfigsPage(ListProviderConfigsPage): + + @property + def provider_configs(self): + return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])] + + +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): + + @property + def provider_configs(self): + return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] + + +class _ProviderConfigIterator(_auth_utils.PageIterator): + + @property + def items(self): + return self._current_page.provider_configs + + +class ProviderConfigClient: + """Client for managing Auth provider configurations.""" + + PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2beta1' + + def __init__(self, http_client, project_id, tenant_id=None): + self.http_client = http_client + self.base_url = '{0}/projects/{1}'.format(self.PROVIDER_CONFIG_URL, project_id) + if tenant_id: + self.base_url += '/tenants/{0}'.format(tenant_id) + + def get_oidc_provider_config(self, provider_id): + _validate_oidc_provider_id(provider_id) + body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id)) + return OIDCProviderConfig(body) + + def create_oidc_provider_config( + self, provider_id, client_id, issuer, display_name=None, enabled=None): + """Creates a new OIDC provider config from the given parameters.""" + _validate_oidc_provider_id(provider_id) + req = { + 'clientId': _validate_non_empty_string(client_id, 'client_id'), + 'issuer': _validate_url(issuer, 'issuer'), + } + if display_name is not None: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + + params = 'oauthIdpConfigId={0}'.format(provider_id) + body = self._make_request('post', '/oauthIdpConfigs', json=req, params=params) + return OIDCProviderConfig(body) + + def update_oidc_provider_config( + self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None): + """Updates an existing OIDC provider config with the given parameters.""" + _validate_oidc_provider_id(provider_id) + req = {} + if display_name is not None: + if display_name == _user_mgt.DELETE_ATTRIBUTE: + req['displayName'] = None + else: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + if client_id: + req['clientId'] = _validate_non_empty_string(client_id, 'client_id') + if issuer: + req['issuer'] = _validate_url(issuer, 'issuer') + + if not req: + raise ValueError('At least one parameter must be specified for update.') + + update_mask = _auth_utils.build_update_mask(req) + params = 'updateMask={0}'.format(','.join(update_mask)) + url = '/oauthIdpConfigs/{0}'.format(provider_id) + body = self._make_request('patch', url, json=req, params=params) + return OIDCProviderConfig(body) + + def delete_oidc_provider_config(self, provider_id): + _validate_oidc_provider_id(provider_id) + self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id)) + + def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return _ListOIDCProviderConfigsPage( + self._fetch_oidc_provider_configs, page_token, max_results) + + def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results) + + def get_saml_provider_config(self, provider_id): + _validate_saml_provider_id(provider_id) + body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id)) + return SAMLProviderConfig(body) + + def create_saml_provider_config( + self, provider_id, idp_entity_id, sso_url, x509_certificates, + rp_entity_id, callback_url, display_name=None, enabled=None): + """Creates a new SAML provider config from the given parameters.""" + _validate_saml_provider_id(provider_id) + req = { + 'idpConfig': { + 'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'), + 'ssoUrl': _validate_url(sso_url, 'sso_url'), + 'idpCertificates': _validate_x509_certificates(x509_certificates), + }, + 'spConfig': { + 'spEntityId': _validate_non_empty_string(rp_entity_id, 'rp_entity_id'), + 'callbackUri': _validate_url(callback_url, 'callback_url'), + }, + } + if display_name is not None: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + + params = 'inboundSamlConfigId={0}'.format(provider_id) + body = self._make_request('post', '/inboundSamlConfigs', json=req, params=params) + return SAMLProviderConfig(body) + + def update_saml_provider_config( + self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, + rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + """Updates an existing SAML provider config with the given parameters.""" + _validate_saml_provider_id(provider_id) + idp_config = {} + if idp_entity_id is not None: + idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id') + if sso_url is not None: + idp_config['ssoUrl'] = _validate_url(sso_url, 'sso_url') + if x509_certificates is not None: + idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates) + + sp_config = {} + if rp_entity_id is not None: + sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id') + if callback_url is not None: + sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url') + + req = {} + if display_name is not None: + if display_name == _user_mgt.DELETE_ATTRIBUTE: + req['displayName'] = None + else: + req['displayName'] = _auth_utils.validate_string(display_name, 'display_name') + if enabled is not None: + req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') + if idp_config: + req['idpConfig'] = idp_config + if sp_config: + req['spConfig'] = sp_config + + if not req: + raise ValueError('At least one parameter must be specified for update.') + + update_mask = _auth_utils.build_update_mask(req) + params = 'updateMask={0}'.format(','.join(update_mask)) + url = '/inboundSamlConfigs/{0}'.format(provider_id) + body = self._make_request('patch', url, json=req, params=params) + return SAMLProviderConfig(body) + + def delete_saml_provider_config(self, provider_id): + _validate_saml_provider_id(provider_id) + self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) + + def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return _ListSAMLProviderConfigsPage( + self._fetch_saml_provider_configs, page_token, max_results) + + def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results) + + def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + """Fetches a page of auth provider configs""" + if page_token is not None: + if not isinstance(page_token, str) or not page_token: + raise ValueError('Page token must be a non-empty string.') + if not isinstance(max_results, int): + raise ValueError('Max results must be an integer.') + if max_results < 1 or max_results > MAX_LIST_CONFIGS_RESULTS: + raise ValueError( + 'Max results must be a positive integer less than or equal to ' + '{0}.'.format(MAX_LIST_CONFIGS_RESULTS)) + + params = 'pageSize={0}'.format(max_results) + if page_token: + params += '&pageToken={0}'.format(page_token) + return self._make_request('get', path, params=params) + + def _make_request(self, method, path, **kwargs): + url = '{0}{1}'.format(self.base_url, path) + try: + return self.http_client.body(method, url, **kwargs) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + + +def _validate_oidc_provider_id(provider_id): + if not isinstance(provider_id, str): + raise ValueError( + 'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format( + provider_id)) + if not provider_id.startswith('oidc.'): + raise ValueError('Invalid OIDC provider ID: {0}.'.format(provider_id)) + return provider_id + + +def _validate_saml_provider_id(provider_id): + if not isinstance(provider_id, str): + raise ValueError( + 'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format( + provider_id)) + if not provider_id.startswith('saml.'): + raise ValueError('Invalid SAML provider ID: {0}.'.format(provider_id)) + return provider_id + + +def _validate_non_empty_string(value, label): + """Validates that the given value is a non-empty string.""" + if not isinstance(value, str): + raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + if not value: + raise ValueError('{0} must not be empty.'.format(label)) + return value + + +def _validate_url(url, label): + """Validates that the given value is a well-formed URL string.""" + if not isinstance(url, str) or not url: + raise ValueError( + 'Invalid photo URL: "{0}". {1} must be a non-empty ' + 'string.'.format(url, label)) + try: + parsed = parse.urlparse(url) + if not parsed.netloc: + raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + return url + except Exception: + raise ValueError('Malformed {0}: "{1}".'.format(label, url)) + + +def _validate_x509_certificates(x509_certificates): + if not isinstance(x509_certificates, list) or not x509_certificates: + raise ValueError('x509_certificates must be a non-empty list.') + if not all([isinstance(cert, str) and cert for cert in x509_certificates]): + raise ValueError('x509_certificates must only contain non-empty strings.') + return [{'x509Certificate': cert} for cert in x509_certificates] diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index 2f7383c0b..f1ce97dee 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -30,6 +30,42 @@ VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) +class PageIterator: + """An iterator that allows iterating over a sequence of items, one at a time. + + This implementation loads a page of items into memory, and iterates on them. When the whole + page has been traversed, it loads another page. This class never keeps more than one page + of entries in memory. + """ + + def __init__(self, current_page): + if not current_page: + raise ValueError('Current page must not be None.') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self.items): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self.items): + result = self.items[self._index] + self._index += 1 + return result + raise StopIteration + + @property + def items(self): + raise NotImplementedError + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + def validate_uid(uid, required=False): if uid is None and not required: return None @@ -157,6 +193,18 @@ def validate_int(value, label, low=None, high=None): raise ValueError('{0} must not be larger than {1}.'.format(label, high)) return val_int +def validate_string(value, label): + """Validates that the given value is a string.""" + if not isinstance(value, str): + raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + return value + +def validate_boolean(value, label): + """Validates that the given value is a boolean.""" + if not isinstance(value, bool): + raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) + return value + def validate_custom_claims(custom_claims, required=False): """Validates the specified custom claims. @@ -192,6 +240,19 @@ def validate_action_type(action_type): Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) return action_type +def build_update_mask(params): + """Creates an update mask list from the given dictionary.""" + mask = [] + for key, value in params.items(): + if isinstance(value, dict): + child_mask = build_update_mask(value) + for child in child_mask: + mask.append('{0}.{1}'.format(key, child)) + else: + mask.append(key) + + return sorted(mask) + class UidAlreadyExistsError(exceptions.AlreadyExistsError): """The user with the provided uid already exists.""" @@ -266,7 +327,33 @@ def __init__(self, message, cause=None, http_response=None): exceptions.NotFoundError.__init__(self, message, cause, http_response) +class TenantNotFoundError(exceptions.NotFoundError): + """No tenant found for the specified identifier.""" + + default_message = 'No tenant found for the given identifier' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + +class TenantIdMismatchError(exceptions.InvalidArgumentError): + """Missing or invalid tenant ID field in the given JWT.""" + + def __init__(self, message): + exceptions.InvalidArgumentError.__init__(self, message) + + +class ConfigurationNotFoundError(exceptions.NotFoundError): + """No auth provider found for the specified identifier.""" + + default_message = 'No auth provider found for the given identifier' + + def __init__(self, message, cause=None, http_response=None): + exceptions.NotFoundError.__init__(self, message, cause, http_response) + + _CODE_TO_EXC_TYPE = { + 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, 'EMAIL_EXISTS': EmailAlreadyExistsError, @@ -274,6 +361,7 @@ def __init__(self, message, cause=None, http_response=None): 'INVALID_DYNAMIC_LINK_DOMAIN': InvalidDynamicLinkDomainError, 'INVALID_ID_TOKEN': InvalidIdTokenError, 'PHONE_NUMBER_EXISTS': PhoneNumberAlreadyExistsError, + 'TENANT_NOT_FOUND': TenantNotFoundError, 'USER_NOT_FOUND': UserNotFoundError, } @@ -281,12 +369,12 @@ def __init__(self, message, cause=None, http_response=None): def handle_auth_backend_error(error): """Converts a requests error received from the Firebase Auth service into a FirebaseError.""" if error.response is None: - raise _utils.handle_requests_error(error) + return _utils.handle_requests_error(error) code, custom_message = _parse_error_body(error.response) if not code: msg = 'Unexpected error response: {0}'.format(error.response.content.decode()) - raise _utils.handle_requests_error(error, message=msg) + return _utils.handle_requests_error(error, message=msg) exc_type = _CODE_TO_EXC_TYPE.get(code) msg = _build_error_message(code, exc_type, custom_message) diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index 4234bfa7b..18a8008c7 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -82,10 +82,13 @@ def from_iam(cls, request, google_cred, service_account): class TokenGenerator: """Generates custom tokens and session cookies.""" - def __init__(self, app, client): + ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' + + def __init__(self, app, http_client): self.app = app - self.client = client + self.http_client = http_client self.request = transport.requests.Request() + self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, app.project_id) self._signing_provider = None def _init_signing_provider(self): @@ -130,7 +133,7 @@ def signing_provider(self): 'details on creating custom tokens.'.format(error, url)) return self._signing_provider - def create_custom_token(self, uid, developer_claims=None): + def create_custom_token(self, uid, developer_claims=None, tenant_id=None): """Builds and signs a Firebase custom auth token.""" if developer_claims is not None: if not isinstance(developer_claims, dict): @@ -161,6 +164,8 @@ def create_custom_token(self, uid, developer_claims=None): 'iat': now, 'exp': now + MAX_TOKEN_LIFETIME_SECONDS, } + if tenant_id: + payload['tenant_id'] = tenant_id if developer_claims is not None: payload['claims'] = developer_claims @@ -190,13 +195,13 @@ def create_session_cookie(self, id_token, expires_in): raise ValueError('Illegal expiry duration: {0}. Duration must be at most {1} ' 'seconds.'.format(expires_in, MAX_SESSION_COOKIE_DURATION_SECONDS)) + url = '{0}:createSessionCookie'.format(self.base_url) payload = { 'idToken': id_token, 'validDuration': expires_in, } try: - body, http_resp = self.client.body_and_response( - 'post', ':createSessionCookie', json=payload) + body, http_resp = self.http_client.body_and_response('post', url, json=payload) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) else: diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 533259e70..0b0c5ddb6 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -244,6 +244,15 @@ def custom_claims(self): return parsed return None + @property + def tenant_id(self): + """Returns the tenant ID of this user. + + Returns: + string: A tenant ID string or None. + """ + return self._data.get('tenantId') + class ExportedUserRecord(UserRecord): """Contains metadata associated with a user including password hash and salt.""" @@ -454,8 +463,13 @@ def encode_action_code_settings(settings): class UserManager: """Provides methods for interacting with the Google Identity Toolkit.""" - def __init__(self, client): - self._client = client + ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' + + def __init__(self, http_client, project_id, tenant_id=None): + self.http_client = http_client + self.base_url = '{0}/projects/{1}'.format(self.ID_TOOLKIT_URL, project_id) + if tenant_id: + self.base_url += '/tenants/{0}'.format(tenant_id) def get_user(self, **kwargs): """Gets the user data corresponding to the provided key.""" @@ -471,17 +485,12 @@ def get_user(self, **kwargs): else: raise TypeError('Unsupported keyword arguments: {0}.'.format(kwargs)) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:lookup', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('users'): - raise _auth_utils.UserNotFoundError( - 'No user record found for the provided {0}: {1}.'.format(key_type, key), - http_response=http_resp) - return body['users'][0] + body, http_resp = self._make_request('post', '/accounts:lookup', json=payload) + if not body or not body.get('users'): + raise _auth_utils.UserNotFoundError( + 'No user record found for the provided {0}: {1}.'.format(key_type, key), + http_response=http_resp) + return body['users'][0] def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): """Retrieves a batch of users.""" @@ -498,10 +507,8 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): payload = {'maxResults': max_results} if page_token: payload['nextPageToken'] = page_token - try: - return self._client.body('get', '/accounts:batchGet', params=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) + body, _ = self._make_request('get', '/accounts:batchGet', params=payload) + return body def create_user(self, uid=None, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None): @@ -517,15 +524,11 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None 'disabled': bool(disabled) if disabled is not None else None, } payload = {k: v for k, v in payload.items() if v is not None} - try: - body, http_resp = self._client.body_and_response('post', '/accounts', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('localId'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to create new user.', http_response=http_resp) - return body.get('localId') + body, http_resp = self._make_request('post', '/accounts', json=payload) + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to create new user.', http_response=http_resp) + return body.get('localId') def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, @@ -568,29 +571,19 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) payload = {k: v for k, v in payload.items() if v is not None} - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:update', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('localId'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to update user: {0}.'.format(uid), http_response=http_resp) - return body.get('localId') + body, http_resp = self._make_request('post', '/accounts:update', json=payload) + if not body or not body.get('localId'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to update user: {0}.'.format(uid), http_response=http_resp) + return body.get('localId') def delete_user(self, uid): """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:delete', json={'localId' : uid}) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('kind'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) + body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) + if not body or not body.get('kind'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) def import_users(self, users, hash_alg=None): """Imports the given list of users to Firebase Auth.""" @@ -609,16 +602,11 @@ def import_users(self, users, hash_alg=None): if not isinstance(hash_alg, _user_import.UserImportHash): raise ValueError('A UserImportHash is required to import users with passwords.') payload.update(hash_alg.to_dict()) - try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:batchCreate', json=payload) - except requests.exceptions.RequestException as error: - raise _auth_utils.handle_auth_backend_error(error) - else: - if not isinstance(body, dict): - raise _auth_utils.UnexpectedResponseError( - 'Failed to import users.', http_response=http_resp) - return body + body, http_resp = self._make_request('post', '/accounts:batchCreate', json=payload) + if not isinstance(body, dict): + raise _auth_utils.UnexpectedResponseError( + 'Failed to import users.', http_response=http_resp) + return body def generate_email_action_link(self, action_type, email, action_code_settings=None): """Fetches the email action links for types @@ -646,45 +634,22 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No if action_code_settings: payload.update(encode_action_code_settings(action_code_settings)) + body, http_resp = self._make_request('post', '/accounts:sendOobCode', json=payload) + if not body or not body.get('oobLink'): + raise _auth_utils.UnexpectedResponseError( + 'Failed to generate email action link.', http_response=http_resp) + return body.get('oobLink') + + def _make_request(self, method, path, **kwargs): + url = '{0}{1}'.format(self.base_url, path) try: - body, http_resp = self._client.body_and_response( - 'post', '/accounts:sendOobCode', json=payload) + return self.http_client.body_and_response(method, url, **kwargs) except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - else: - if not body or not body.get('oobLink'): - raise _auth_utils.UnexpectedResponseError( - 'Failed to generate email action link.', http_response=http_resp) - return body.get('oobLink') - -class _UserIterator: - """An iterator that allows iterating over user accounts, one at a time. - This implementation loads a page of users into memory, and iterates on them. When the whole - page has been traversed, it loads another page. This class never keeps more than one page - of entries in memory. - """ +class _UserIterator(_auth_utils.PageIterator): - def __init__(self, current_page): - if not current_page: - raise ValueError('Current page must not be None.') - self._current_page = current_page - self._index = 0 - - def next(self): - if self._index == len(self._current_page.users): - if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() - self._index = 0 - if self._index < len(self._current_page.users): - result = self._current_page.users[self._index] - self._index += 1 - return result - raise StopIteration - - def __next__(self): - return self.next() - - def __iter__(self): - return self + @property + def items(self): + return self._current_page.users diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index 6f85e622c..cb8782ea7 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -19,11 +19,9 @@ creating and managing user accounts in Firebase projects. """ -import time - -import firebase_admin +from firebase_admin import _auth_client +from firebase_admin import _auth_providers from firebase_admin import _auth_utils -from firebase_admin import _http_client from firebase_admin import _token_gen from firebase_admin import _user_import from firebase_admin import _user_mgt @@ -36,6 +34,7 @@ __all__ = [ 'ActionCodeSettings', 'CertificateFetchError', + 'Client', 'DELETE_ATTRIBUTE', 'EmailAlreadyExistsError', 'ErrorInfo', @@ -47,10 +46,13 @@ 'InvalidDynamicLinkDomainError', 'InvalidIdTokenError', 'InvalidSessionCookieError', + 'ListProviderConfigsPage', 'ListUsersPage', 'PhoneNumberAlreadyExistsError', + 'ProviderConfig', 'RevokedIdTokenError', 'RevokedSessionCookieError', + 'SAMLProviderConfig', 'TokenSignError', 'UidAlreadyExistsError', 'UnexpectedResponseError', @@ -63,19 +65,24 @@ 'UserRecord', 'create_custom_token', + 'create_saml_provider_config', 'create_session_cookie', 'create_user', + 'delete_saml_provider_config', 'delete_user', 'generate_email_verification_link', 'generate_password_reset_link', 'generate_sign_in_with_email_link', + 'get_saml_provider_config', 'get_user', 'get_user_by_email', 'get_user_by_phone_number', 'import_users', + 'list_saml_provider_configs', 'list_users', 'revoke_refresh_tokens', 'set_custom_user_claims', + 'update_saml_provider_config', 'update_user', 'verify_id_token', 'verify_session_cookie', @@ -83,6 +90,8 @@ ActionCodeSettings = _user_mgt.ActionCodeSettings CertificateFetchError = _token_gen.CertificateFetchError +Client = _auth_client.Client +ConfigurationNotFoundError = _auth_utils.ConfigurationNotFoundError DELETE_ATTRIBUTE = _user_mgt.DELETE_ATTRIBUTE EmailAlreadyExistsError = _auth_utils.EmailAlreadyExistsError ErrorInfo = _user_import.ErrorInfo @@ -94,10 +103,14 @@ InvalidDynamicLinkDomainError = _auth_utils.InvalidDynamicLinkDomainError InvalidIdTokenError = _auth_utils.InvalidIdTokenError InvalidSessionCookieError = _token_gen.InvalidSessionCookieError +ListProviderConfigsPage = _auth_providers.ListProviderConfigsPage ListUsersPage = _user_mgt.ListUsersPage +OIDCProviderConfig = _auth_providers.OIDCProviderConfig PhoneNumberAlreadyExistsError = _auth_utils.PhoneNumberAlreadyExistsError +ProviderConfig = _auth_providers.ProviderConfigClient RevokedIdTokenError = _token_gen.RevokedIdTokenError RevokedSessionCookieError = _token_gen.RevokedSessionCookieError +SAMLProviderConfig = _auth_providers.SAMLProviderConfig TokenSignError = _token_gen.TokenSignError UidAlreadyExistsError = _auth_utils.UidAlreadyExistsError UnexpectedResponseError = _auth_utils.UnexpectedResponseError @@ -110,23 +123,23 @@ UserRecord = _user_mgt.UserRecord -def _get_auth_service(app): - """Returns an _AuthService instance for an App. +def _get_client(app): + """Returns a client instance for an App. - If the App already has an _AuthService associated with it, simply returns - it. Otherwise creates a new _AuthService, and adds it to the App before + If the App already has a client associated with it, simply returns + it. Otherwise creates a new client, and adds it to the App before returning it. Args: - app: A Firebase App instance (or None to use the default App). + app: A Firebase App instance (or ``None`` to use the default App). Returns: - _AuthService: An _AuthService for the specified App instance. + Client: A client for the specified App instance. Raises: ValueError: If the app argument is invalid. """ - return _utils.get_app_service(app, _AUTH_ATTRIBUTE, _AuthService) + return _utils.get_app_service(app, _AUTH_ATTRIBUTE, Client) def create_custom_token(uid, developer_claims=None, app=None): @@ -145,8 +158,8 @@ def create_custom_token(uid, developer_claims=None, app=None): ValueError: If input parameters are invalid. TokenSignError: If an error occurs while signing the token using the remote IAM service. """ - token_generator = _get_auth_service(app).token_generator - return token_generator.create_custom_token(uid, developer_claims) + client = _get_client(app) + return client.create_custom_token(uid, developer_claims) def verify_id_token(id_token, app=None, check_revoked=False): @@ -171,15 +184,8 @@ def verify_id_token(id_token, app=None, check_revoked=False): CertificateFetchError: If an error occurs while fetching the public key certificates required to verify the ID token. """ - if not isinstance(check_revoked, bool): - # guard against accidental wrong assignment. - raise ValueError('Illegal check_revoked argument. Argument must be of type ' - ' bool, but given "{0}".'.format(type(check_revoked))) - token_verifier = _get_auth_service(app).token_verifier - verified_claims = token_verifier.verify_id_token(id_token) - if check_revoked: - _check_jwt_revoked(verified_claims, RevokedIdTokenError, 'ID token', app) - return verified_claims + client = _get_client(app) + return client.verify_id_token(id_token, check_revoked=check_revoked) def create_session_cookie(id_token, expires_in, app=None): @@ -200,8 +206,9 @@ def create_session_cookie(id_token, expires_in, app=None): ValueError: If input parameters are invalid. FirebaseError: If an error occurs while creating the cookie. """ - token_generator = _get_auth_service(app).token_generator - return token_generator.create_session_cookie(id_token, expires_in) + client = _get_client(app) + # pylint: disable=protected-access + return client._token_generator.create_session_cookie(id_token, expires_in) def verify_session_cookie(session_cookie, check_revoked=False, app=None): @@ -226,17 +233,18 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None): CertificateFetchError: If an error occurs while fetching the public key certificates required to verify the session cookie. """ - token_verifier = _get_auth_service(app).token_verifier - verified_claims = token_verifier.verify_session_cookie(session_cookie) + client = _get_client(app) + # pylint: disable=protected-access + verified_claims = client._token_verifier.verify_session_cookie(session_cookie) if check_revoked: - _check_jwt_revoked(verified_claims, RevokedSessionCookieError, 'session cookie', app) + client._check_jwt_revoked(verified_claims, RevokedSessionCookieError, 'session cookie') return verified_claims def revoke_refresh_tokens(uid, app=None): """Revokes all refresh tokens for an existing user. - revoke_refresh_tokens updates the user's tokens_valid_after_timestamp to the current UTC + This function updates the user's ``tokens_valid_after_timestamp`` to the current UTC in seconds since the epoch. It is important that the server on which this is called has its clock set correctly and synchronized. @@ -244,9 +252,17 @@ def revoke_refresh_tokens(uid, app=None): existing sessions from getting minted, existing ID tokens may remain active until their natural expiration (one hour). To verify that ID tokens are revoked, use ``verify_id_token(idToken, check_revoked=True)``. + + Args: + uid: A user ID string. + app: An App instance (optional). + + Raises: + ValueError: If the user ID is None, empty or malformed. + FirebaseError: If an error occurs while revoking the refresh token. """ - user_manager = _get_auth_service(app).user_manager - user_manager.update_user(uid, valid_since=int(time.time())) + client = _get_client(app) + client.revoke_refresh_tokens(uid) def get_user(uid, app=None): @@ -257,16 +273,15 @@ def get_user(uid, app=None): app: An App instance (optional). Returns: - UserRecord: A UserRecord instance. + UserRecord: A user record instance. Raises: ValueError: If the user ID is None, empty or malformed. UserNotFoundError: If the specified user ID does not exist. FirebaseError: If an error occurs while retrieving the user. """ - user_manager = _get_auth_service(app).user_manager - response = user_manager.get_user(uid=uid) - return UserRecord(response) + client = _get_client(app) + return client.get_user(uid=uid) def get_user_by_email(email, app=None): @@ -277,16 +292,15 @@ def get_user_by_email(email, app=None): app: An App instance (optional). Returns: - UserRecord: A UserRecord instance. + UserRecord: A user record instance. Raises: ValueError: If the email is None, empty or malformed. UserNotFoundError: If no user exists by the specified email address. FirebaseError: If an error occurs while retrieving the user. """ - user_manager = _get_auth_service(app).user_manager - response = user_manager.get_user(email=email) - return UserRecord(response) + client = _get_client(app) + return client.get_user_by_email(email=email) def get_user_by_phone_number(phone_number, app=None): @@ -297,16 +311,15 @@ def get_user_by_phone_number(phone_number, app=None): app: An App instance (optional). Returns: - UserRecord: A UserRecord instance. + UserRecord: A user record instance. Raises: ValueError: If the phone number is None, empty or malformed. UserNotFoundError: If no user exists by the specified phone number. FirebaseError: If an error occurs while retrieving the user. """ - user_manager = _get_auth_service(app).user_manager - response = user_manager.get_user(phone_number=phone_number) - return UserRecord(response) + client = _get_client(app) + return client.get_user_by_phone_number(phone_number=phone_number) def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): @@ -325,16 +338,14 @@ def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, ap app: An App instance (optional). Returns: - ListUsersPage: A ListUsersPage instance. + ListUsersPage: A page of user accounts. Raises: - ValueError: If max_results or page_token are invalid. + ValueError: If ``max_results`` or ``page_token`` are invalid. FirebaseError: If an error occurs while retrieving the user accounts. """ - user_manager = _get_auth_service(app).user_manager - def download(page_token, max_results): - return user_manager.list_users(page_token, max_results) - return ListUsersPage(download, page_token, max_results) + client = _get_client(app) + return client.list_users(page_token=page_token, max_results=max_results) def create_user(**kwargs): # pylint: disable=differing-param-doc @@ -356,16 +367,15 @@ def create_user(**kwargs): # pylint: disable=differing-param-doc app: An App instance (optional). Returns: - UserRecord: A UserRecord instance for the newly created user. + UserRecord: A user record instance for the newly created user. Raises: ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ app = kwargs.pop('app', None) - user_manager = _get_auth_service(app).user_manager - uid = user_manager.create_user(**kwargs) - return UserRecord(user_manager.get_user(uid=uid)) + client = _get_client(app) + return client.create_user(**kwargs) def update_user(uid, **kwargs): # pylint: disable=differing-param-doc @@ -389,20 +399,20 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc disabled: A boolean indicating whether or not the user account is disabled (optional). custom_claims: A dictionary or a JSON string contining the custom claims to be set on the user account (optional). To remove all custom claims, pass ``auth.DELETE_ATTRIBUTE``. - valid_since: An integer signifying the seconds since the epoch. This field is set by - ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + valid_since: An integer signifying the seconds since the epoch (optional). This field is + set by ``revoke_refresh_tokens`` and it is discouraged to set this field directly. + app: An App instance (optional). Returns: - UserRecord: An updated UserRecord instance for the user. + UserRecord: An updated user record instance for the user. Raises: ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ app = kwargs.pop('app', None) - user_manager = _get_auth_service(app).user_manager - user_manager.update_user(uid, **kwargs) - return UserRecord(user_manager.get_user(uid=uid)) + client = _get_client(app) + return client.update_user(uid, **kwargs) def set_custom_user_claims(uid, custom_claims, app=None): @@ -425,10 +435,8 @@ def set_custom_user_claims(uid, custom_claims, app=None): ValueError: If the specified user ID or the custom claims are invalid. FirebaseError: If an error occurs while updating the user account. """ - user_manager = _get_auth_service(app).user_manager - if custom_claims is None: - custom_claims = DELETE_ATTRIBUTE - user_manager.update_user(uid, custom_claims=custom_claims) + client = _get_client(app) + client.set_custom_user_claims(uid, custom_claims=custom_claims) def delete_user(uid, app=None): @@ -442,8 +450,8 @@ def delete_user(uid, app=None): ValueError: If the user ID is None, empty or malformed. FirebaseError: If an error occurs while deleting the user account. """ - user_manager = _get_auth_service(app).user_manager - user_manager.delete_user(uid) + client = _get_client(app) + client.delete_user(uid) def import_users(users, hash_alg=None, app=None): @@ -468,9 +476,8 @@ def import_users(users, hash_alg=None, app=None): ValueError: If the provided arguments are invalid. FirebaseError: If an error occurs while importing users. """ - user_manager = _get_auth_service(app).user_manager - result = user_manager.import_users(users, hash_alg) - return UserImportResult(result, len(users)) + client = _get_client(app) + return client.import_users(users, hash_alg) def generate_password_reset_link(email, action_code_settings=None, app=None): @@ -490,9 +497,8 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): ValueError: If the provided arguments are invalid FirebaseError: If an error occurs while generating the link """ - user_manager = _get_auth_service(app).user_manager - return user_manager.generate_email_action_link( - 'PASSWORD_RESET', email, action_code_settings=action_code_settings) + client = _get_client(app) + return client.generate_password_reset_link(email, action_code_settings=action_code_settings) def generate_email_verification_link(email, action_code_settings=None, app=None): @@ -512,9 +518,9 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) ValueError: If the provided arguments are invalid FirebaseError: If an error occurs while generating the link """ - user_manager = _get_auth_service(app).user_manager - return user_manager.generate_email_action_link( - 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) + client = _get_client(app) + return client.generate_email_verification_link( + email, action_code_settings=action_code_settings) def generate_sign_in_with_email_link(email, action_code_settings, app=None): @@ -527,6 +533,7 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): the link is to be handled by a mobile app and the additional state information to be passed in the deep link. app: An App instance (optional). + Returns: link: The email sign-in link created by the API @@ -534,47 +541,263 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): ValueError: If the provided arguments are invalid FirebaseError: If an error occurs while generating the link """ - user_manager = _get_auth_service(app).user_manager - return user_manager.generate_email_action_link( - 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) + client = _get_client(app) + return client.generate_sign_in_with_email_link( + email, action_code_settings=action_code_settings) + + +def get_oidc_provider_config(provider_id, app=None): + """Returns the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Returns: + OIDCProviderConfig: An OIDC provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the OIDC provider. + """ + client = _get_client(app) + return client.get_oidc_provider_config(provider_id) + +def create_oidc_provider_config( + provider_id, client_id, issuer, display_name=None, enabled=None, app=None): + """Creates a new OIDC provider config from the given parameters. + + OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config. + issuer: Issuer of the new config. Must be a valid URL. + display_name: The user-friendly display name to the current configuration (optional). + This name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + app: An App instance (optional). + + Returns: + OIDCProviderConfig: The newly created OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new OIDC provider config. + """ + client = _get_client(app) + return client.create_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + +def update_oidc_provider_config( + provider_id, client_id=None, issuer=None, display_name=None, enabled=None, app=None): + """Updates an existing OIDC provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``oidc.``. + client_id: Client ID of the new config (optional). + issuer: Issuer of the new config (optional). Must be a valid URL. + display_name: The user-friendly display name of the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + app: An App instance (optional). + + Returns: + OIDCProviderConfig: The updated OIDC provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the OIDC provider config. + """ + client = _get_client(app) + return client.update_oidc_provider_config( + provider_id, client_id=client_id, issuer=issuer, display_name=display_name, + enabled=enabled) + + +def delete_oidc_provider_config(provider_id, app=None): + """Deletes the ``OIDCProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``oidc.`` prefix. + ConfigurationNotFoundError: If no OIDC provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the OIDC provider. + """ + client = _get_client(app) + client.delete_oidc_provider_config(provider_id) + + +def list_oidc_provider_configs( + page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + """Retrieves a page of OIDC provider configs from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no OIDC configs in the Firebase + project, this returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + app: An App instance (optional). + + Returns: + ListProviderConfigsPage: A page of OIDC provider config instances. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the OIDC provider configs. + """ + client = _get_client(app) + return client.list_oidc_provider_configs(page_token, max_results) + + +def get_saml_provider_config(provider_id, app=None): + """Returns the ``SAMLProviderConfig`` with the given ID. + + Args: + provider_id: Provider ID string. + app: An App instance (optional). + + Returns: + SAMLProviderConfig: A SAML provider config instance. + + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while retrieving the SAML provider. + """ + client = _get_client(app) + return client.get_saml_provider_config(provider_id) + + +def create_saml_provider_config( + provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, callback_url, + display_name=None, enabled=None, app=None): + """Creates a new SAML provider config from the given parameters. + + SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about + GCIP, including pricing and features, see https://cloud.google.com/identity-platform. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier. + sso_url: The SAML IdP SSO URL. Must be a valid URL. + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this provider. + Multiple certificates are accepted to prevent outages during IdP key rotation (for + example ADFS rotates every 10 days). When the Auth server receives a SAML response, it + will match the SAML response with the certificate on record. Otherwise the response is + rejected. Developers are expected to manage the certificate updates as keys are + rotated. + rp_entity_id: The SAML relying party (service provider) entity ID. This is defined by the + developer but needs to be provided to the SAML IdP. + callback_url: Callback URL string. This is fixed and must always be the same as the OAuth + redirect URL provisioned by Firebase Auth, unless a custom authDomain is used. + display_name: The user-friendly display name to the current configuration (optional). This + name is also used as the provider label in the Cloud Console. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). A user cannot sign in using a disabled provider. + app: An App instance (optional). + + Returns: + SAMLProviderConfig: The newly created SAML provider config instance. + + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while creating the new SAML provider config. + """ + client = _get_client(app) + return client.create_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, + display_name=display_name, enabled=enabled) + + +def update_saml_provider_config( + provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, + rp_entity_id=None, callback_url=None, display_name=None, enabled=None, app=None): + """Updates an existing SAML provider config with the given parameters. + + Args: + provider_id: Provider ID string. Must have the prefix ``saml.``. + idp_entity_id: The SAML IdP entity identifier (optional). + sso_url: The SAML IdP SSO URL. Must be a valid URL (optional). + x509_certificates: The list of SAML IdP X.509 certificates issued by CA for this + provider (optional). + rp_entity_id: The SAML relying party entity ID (optional). + callback_url: Callback URL string (optional). + display_name: The user-friendly display name of the current configuration (optional). + Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. + enabled: A boolean indicating whether the provider configuration is enabled or disabled + (optional). + app: An App instance (optional). + + Returns: + SAMLProviderConfig: The updated SAML provider config instance. + Raises: + ValueError: If any of the specified input parameters are invalid. + FirebaseError: If an error occurs while updating the SAML provider config. + """ + client = _get_client(app) + return client.update_saml_provider_config( + provider_id, idp_entity_id=idp_entity_id, sso_url=sso_url, + x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, + callback_url=callback_url, display_name=display_name, enabled=enabled) -def _check_jwt_revoked(verified_claims, exc_type, label, app): - user = get_user(verified_claims.get('uid'), app=app) - if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: - raise exc_type('The Firebase {0} has been revoked.'.format(label)) +def delete_saml_provider_config(provider_id, app=None): + """Deletes the ``SAMLProviderConfig`` with the given ID. -class _AuthService: - """Firebase Authentication service.""" + Args: + provider_id: Provider ID string. + app: An App instance (optional). - ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1/projects/' + Raises: + ValueError: If the provider ID is invalid, empty or does not have ``saml.`` prefix. + ConfigurationNotFoundError: If no SAML provider is available with the given identifier. + FirebaseError: If an error occurs while deleting the SAML provider. + """ + client = _get_client(app) + client.delete_saml_provider_config(provider_id) - def __init__(self, app): - credential = app.credential.get_credential() - version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) - if not app.project_id: - raise ValueError("""Project ID is required to access the auth service. - 1. Use a service account credential, or - 2. set the project ID explicitly via Firebase App options, or - 3. set the project ID via the GOOGLE_CLOUD_PROJECT environment variable.""") +def list_saml_provider_configs( + page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + """Retrieves a page of SAML provider configs from a Firebase project. - client = _http_client.JsonHttpClient( - credential=credential, base_url=self.ID_TOOLKIT_URL + app.project_id, - headers={'X-Client-Version': version_header}) - self._token_generator = _token_gen.TokenGenerator(app, client) - self._token_verifier = _token_gen.TokenVerifier(app) - self._user_manager = _user_mgt.UserManager(client) + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of configs that may be included in the returned + page. This function never returns ``None``. If there are no SAML configs in the Firebase + project, this returns an empty page. - @property - def token_generator(self): - return self._token_generator + Args: + page_token: A non-empty page token string, which indicates the starting point of the + page (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in + the returned page (optional). Defaults to 100, which is also the maximum number + allowed. + app: An App instance (optional). - @property - def token_verifier(self): - return self._token_verifier + Returns: + ListProviderConfigsPage: A page of SAML provider config instances. - @property - def user_manager(self): - return self._user_manager + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the SAML provider configs. + """ + client = _get_client(app) + return client.list_saml_provider_configs(page_token, max_results) diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py new file mode 100644 index 000000000..396a819fb --- /dev/null +++ b/firebase_admin/tenant_mgt.py @@ -0,0 +1,445 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Firebase tenant management module. + +This module contains functions for creating and configuring authentication tenants within a +Google Cloud Identity Platform (GCIP) instance. +""" + +import re +import threading + +import requests + +import firebase_admin +from firebase_admin import auth +from firebase_admin import _auth_utils +from firebase_admin import _http_client +from firebase_admin import _utils + + +_TENANT_MGT_ATTRIBUTE = '_tenant_mgt' +_MAX_LIST_TENANTS_RESULTS = 100 +_DISPLAY_NAME_PATTERN = re.compile('^[a-zA-Z][a-zA-Z0-9-]{3,19}$') + + +__all__ = [ + 'ListTenantsPage', + 'Tenant', + 'TenantIdMismatchError', + 'TenantNotFoundError', + + 'auth_for_tenant', + 'create_tenant', + 'delete_tenant', + 'get_tenant', + 'list_tenants', + 'update_tenant', +] + + +TenantIdMismatchError = _auth_utils.TenantIdMismatchError +TenantNotFoundError = _auth_utils.TenantNotFoundError + + +def auth_for_tenant(tenant_id, app=None): + """Gets an Auth Client instance scoped to the given tenant ID. + + Args: + tenant_id: A tenant ID string. + app: An App instance (optional). + + Returns: + auth.Client: An ``auth.Client`` object. + + Raises: + ValueError: If the tenant ID is None, empty or not a string. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.auth_for_tenant(tenant_id) + + +def get_tenant(tenant_id, app=None): + """Gets the tenant corresponding to the given ``tenant_id``. + + Args: + tenant_id: A tenant ID string. + app: An App instance (optional). + + Returns: + Tenant: A tenant object. + + Raises: + ValueError: If the tenant ID is None, empty or not a string. + TenantNotFoundError: If no tenant exists by the given ID. + FirebaseError: If an error occurs while retrieving the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.get_tenant(tenant_id) + + +def create_tenant( + display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, app=None): + """Creates a new tenant from the given options. + + Args: + display_name: Display name string for the new tenant. Must begin with a letter and contain + only letters, digits and hyphens. Length must be between 4 and 20. + allow_password_sign_up: A boolean indicating whether to enable or disable the email sign-in + provider (optional). + enable_email_link_sign_in: A boolean indicating whether to enable or disable email link + sign-in (optional). Disabling this makes the password required for email sign-in. + app: An App instance (optional). + + Returns: + Tenant: A tenant object. + + Raises: + ValueError: If any of the given arguments are invalid. + FirebaseError: If an error occurs while creating the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.create_tenant( + display_name=display_name, allow_password_sign_up=allow_password_sign_up, + enable_email_link_sign_in=enable_email_link_sign_in) + + +def update_tenant( + tenant_id, display_name=None, allow_password_sign_up=None, enable_email_link_sign_in=None, + app=None): + """Updates an existing tenant with the given options. + + Args: + tenant_id: ID of the tenant to update. + display_name: Updated display name string for the tenant (optional). + allow_password_sign_up: A boolean indicating whether to enable or disable the email sign-in + provider. + enable_email_link_sign_in: A boolean indicating whether to enable or disable email link + sign-in. Disabling this makes the password required for email sign-in. + app: An App instance (optional). + + Returns: + Tenant: The updated tenant object. + + Raises: + ValueError: If any of the given arguments are invalid. + TenantNotFoundError: If no tenant exists by the given ID. + FirebaseError: If an error occurs while creating the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + return tenant_mgt_service.update_tenant( + tenant_id, display_name=display_name, allow_password_sign_up=allow_password_sign_up, + enable_email_link_sign_in=enable_email_link_sign_in) + + +def delete_tenant(tenant_id, app=None): + """Deletes the tenant corresponding to the given ``tenant_id``. + + Args: + tenant_id: A tenant ID string. + app: An App instance (optional). + + Raises: + ValueError: If the tenant ID is None, empty or not a string. + TenantNotFoundError: If no tenant exists by the given ID. + FirebaseError: If an error occurs while retrieving the tenant. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + tenant_mgt_service.delete_tenant(tenant_id) + + +def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=None): + """Retrieves a page of tenants from a Firebase project. + + The ``page_token`` argument governs the starting point of the page. The ``max_results`` + argument governs the maximum number of tenants that may be included in the returned page. + This function never returns None. If there are no user accounts in the Firebase project, this + returns an empty page. + + Args: + page_token: A non-empty page token string, which indicates the starting point of the page + (optional). Defaults to ``None``, which will retrieve the first page of users. + max_results: A positive integer indicating the maximum number of users to include in the + returned page (optional). Defaults to 100, which is also the maximum number allowed. + app: An App instance (optional). + + Returns: + ListTenantsPage: A page of tenants. + + Raises: + ValueError: If ``max_results`` or ``page_token`` are invalid. + FirebaseError: If an error occurs while retrieving the user accounts. + """ + tenant_mgt_service = _get_tenant_mgt_service(app) + def download(page_token, max_results): + return tenant_mgt_service.list_tenants(page_token, max_results) + return ListTenantsPage(download, page_token, max_results) + + +def _get_tenant_mgt_service(app): + return _utils.get_app_service(app, _TENANT_MGT_ATTRIBUTE, _TenantManagementService) + + +class Tenant: + """Represents a tenant in a multi-tenant application. + + Multi-tenancy support requires Google Cloud Identity Platform (GCIP). To learn more about + GCIP including pricing and features, see https://cloud.google.com/identity-platform. + + Before multi-tenancy can be used in a Google Cloud Identity Platform project, tenants must be + enabled in that project via the Cloud Console UI. A Tenant instance provides information + such as the display name, tenant identifier and email authentication configuration. + """ + + def __init__(self, data): + if not isinstance(data, dict): + raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data)) + if not 'name' in data: + raise ValueError('Tenant response missing required keys.') + + self._data = data + + @property + def tenant_id(self): + name = self._data['name'] + return name.split('/')[-1] + + @property + def display_name(self): + return self._data.get('displayName') + + @property + def allow_password_sign_up(self): + return self._data.get('allowPasswordSignup', False) + + @property + def enable_email_link_sign_in(self): + return self._data.get('enableEmailLinkSignin', False) + + +class _TenantManagementService: + """Firebase tenant management service.""" + + TENANT_MGT_URL = 'https://identitytoolkit.googleapis.com/v2beta1' + + def __init__(self, app): + credential = app.credential.get_credential() + version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) + base_url = '{0}/projects/{1}'.format(self.TENANT_MGT_URL, app.project_id) + self.app = app + self.client = _http_client.JsonHttpClient( + credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) + self.tenant_clients = {} + self.lock = threading.RLock() + + def auth_for_tenant(self, tenant_id): + """Gets an Auth Client instance scoped to the given tenant ID.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError( + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + + with self.lock: + if tenant_id in self.tenant_clients: + return self.tenant_clients[tenant_id] + + client = auth.Client(self.app, tenant_id=tenant_id) + self.tenant_clients[tenant_id] = client + return client + + def get_tenant(self, tenant_id): + """Gets the tenant corresponding to the given ``tenant_id``.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError( + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + + try: + body = self.client.body('get', '/tenants/{0}'.format(tenant_id)) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return Tenant(body) + + def create_tenant( + self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): + """Creates a new tenant from the given parameters.""" + + payload = {'displayName': _validate_display_name(display_name)} + if allow_password_sign_up is not None: + payload['allowPasswordSignup'] = _auth_utils.validate_boolean( + allow_password_sign_up, 'allowPasswordSignup') + if enable_email_link_sign_in is not None: + payload['enableEmailLinkSignin'] = _auth_utils.validate_boolean( + enable_email_link_sign_in, 'enableEmailLinkSignin') + + try: + body = self.client.body('post', '/tenants', json=payload) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return Tenant(body) + + def update_tenant( + self, tenant_id, display_name=None, allow_password_sign_up=None, + enable_email_link_sign_in=None): + """Updates the specified tenant with the given parameters.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError('Tenant ID must be a non-empty string.') + + payload = {} + if display_name is not None: + payload['displayName'] = _validate_display_name(display_name) + if allow_password_sign_up is not None: + payload['allowPasswordSignup'] = _auth_utils.validate_boolean( + allow_password_sign_up, 'allowPasswordSignup') + if enable_email_link_sign_in is not None: + payload['enableEmailLinkSignin'] = _auth_utils.validate_boolean( + enable_email_link_sign_in, 'enableEmailLinkSignin') + + if not payload: + raise ValueError('At least one parameter must be specified for update.') + + url = '/tenants/{0}'.format(tenant_id) + update_mask = ','.join(_auth_utils.build_update_mask(payload)) + params = 'updateMask={0}'.format(update_mask) + try: + body = self.client.body('patch', url, json=payload, params=params) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + else: + return Tenant(body) + + def delete_tenant(self, tenant_id): + """Deletes the tenant corresponding to the given ``tenant_id``.""" + if not isinstance(tenant_id, str) or not tenant_id: + raise ValueError( + 'Invalid tenant ID: {0}. Tenant ID must be a non-empty string.'.format(tenant_id)) + + try: + self.client.request('delete', '/tenants/{0}'.format(tenant_id)) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + + def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): + """Retrieves a batch of tenants.""" + if page_token is not None: + if not isinstance(page_token, str) or not page_token: + raise ValueError('Page token must be a non-empty string.') + if not isinstance(max_results, int): + raise ValueError('Max results must be an integer.') + if max_results < 1 or max_results > _MAX_LIST_TENANTS_RESULTS: + raise ValueError( + 'Max results must be a positive integer less than or equal to ' + '{0}.'.format(_MAX_LIST_TENANTS_RESULTS)) + + payload = {'pageSize': max_results} + if page_token: + payload['pageToken'] = page_token + try: + return self.client.body('get', '/tenants', params=payload) + except requests.exceptions.RequestException as error: + raise _auth_utils.handle_auth_backend_error(error) + + +class ListTenantsPage: + """Represents a page of tenants fetched from a Firebase project. + + Provides methods for traversing tenants included in this page, as well as retrieving + subsequent pages of tenants. The iterator returned by ``iterate_all()`` can be used to iterate + through all tenants in the Firebase project starting from this page. + """ + + def __init__(self, download, page_token, max_results): + self._download = download + self._max_results = max_results + self._current = download(page_token, max_results) + + @property + def tenants(self): + """A list of ``ExportedUserRecord`` instances available in this page.""" + return [Tenant(data) for data in self._current.get('tenants', [])] + + @property + def next_page_token(self): + """Page token string for the next page (empty string indicates no more pages).""" + return self._current.get('nextPageToken', '') + + @property + def has_next_page(self): + """A boolean indicating whether more pages are available.""" + return bool(self.next_page_token) + + def get_next_page(self): + """Retrieves the next page of tenants, if available. + + Returns: + ListTenantsPage: Next page of tenants, or None if this is the last page. + """ + if self.has_next_page: + return ListTenantsPage(self._download, self.next_page_token, self._max_results) + return None + + def iterate_all(self): + """Retrieves an iterator for tenants. + + Returned iterator will iterate through all the tenants in the Firebase project + starting from this page. The iterator will never buffer more than one page of tenants + in memory at a time. + + Returns: + iterator: An iterator of Tenant instances. + """ + return _TenantIterator(self) + + +class _TenantIterator: + """An iterator that allows iterating over tenants. + + This implementation loads a page of tenants into memory, and iterates on them. When the whole + page has been traversed, it loads another page. This class never keeps more than one page + of entries in memory. + """ + + def __init__(self, current_page): + if not current_page: + raise ValueError('Current page must not be None.') + self._current_page = current_page + self._index = 0 + + def next(self): + if self._index == len(self._current_page.tenants): + if self._current_page.has_next_page: + self._current_page = self._current_page.get_next_page() + self._index = 0 + if self._index < len(self._current_page.tenants): + result = self._current_page.tenants[self._index] + self._index += 1 + return result + raise StopIteration + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +def _validate_display_name(display_name): + if not isinstance(display_name, str): + raise ValueError('Invalid type for displayName') + if not _DISPLAY_NAME_PATTERN.search(display_name): + raise ValueError( + 'displayName must start with a letter and only consist of letters, digits and ' + 'hyphens with 4-20 characters.') + return display_name diff --git a/integration/test_auth.py b/integration/test_auth.py index 5d26dd9f1..cfd775016 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -16,6 +16,7 @@ import base64 import datetime import random +import string import time from urllib import parse import uuid @@ -38,6 +39,30 @@ ACTION_LINK_CONTINUE_URL = 'http://localhost?a=1&b=5#f=1' +X509_CERTIFICATES = [ + ('-----BEGIN CERTIFICATE-----\nMIICZjCCAc+gAwIBAgIBADANBgkqhkiG9w0BAQ0FADBQMQswCQYDVQQGEwJ1czE' + 'L\nMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAPBgNVBAMMCGFjbWUuY29tMRIw\nEAYDVQQHDAlTdW5ueXZhbGU' + 'wHhcNMTgxMjA2MDc1MTUxWhcNMjgxMjAzMDc1MTUx\nWjBQMQswCQYDVQQGEwJ1czELMAkGA1UECAwCQ0ExDTALBgNVB' + 'AoMBEFjbWUxETAP\nBgNVBAMMCGFjbWUuY29tMRIwEAYDVQQHDAlTdW5ueXZhbGUwgZ8wDQYJKoZIhvcN\nAQEBBQADg' + 'Y0AMIGJAoGBAKphmggjiVgqMLXyzvI7cKphscIIQ+wcv7Dld6MD4aKv\n7Jqr8ltujMxBUeY4LFEKw8Terb01snYpDot' + 'filaG6NxpF/GfVVmMalzwWp0mT8+H\nyzyPj89mRcozu17RwuooR6n1ofXjGcBE86lqC21UhA3WVgjPOLqB42rlE9gPn' + 'ZLB\nAgMBAAGjUDBOMB0GA1UdDgQWBBS0iM7WnbCNOnieOP1HIA+Oz/ML+zAfBgNVHSME\nGDAWgBS0iM7WnbCNOnieO' + 'P1HIA+Oz/ML+zAMBgNVHRMEBTADAQH/MA0GCSqGSIb3\nDQEBDQUAA4GBAF3jBgS+wP+K/jTupEQur6iaqS4UvXd//d4' + 'vo1MV06oTLQMTz+rP\nOSMDNwxzfaOn6vgYLKP/Dcy9dSTnSzgxLAxfKvDQZA0vE3udsw0Bd245MmX4+GOp\nlbrN99X' + 'P1u+lFxCSdMUzvQ/jW4ysw/Nq4JdJ0gPAyPvL6Qi/3mQdIQwx\n-----END CERTIFICATE-----\n'), + ('-----BEGIN CERTIFICATE-----\nMIICZjCCAc+gAwIBAgIBADANBgkqhkiG9w0BAQ0FADBQMQswCQYDVQQGEwJ1czE' + 'L\nMAkGA1UECAwCQ0ExDTALBgNVBAoMBEFjbWUxETAPBgNVBAMMCGFjbWUuY29tMRIw\nEAYDVQQHDAlTdW5ueXZhbGU' + 'wHhcNMTgxMjA2MDc1ODE4WhcNMjgxMjAzMDc1ODE4\nWjBQMQswCQYDVQQGEwJ1czELMAkGA1UECAwCQ0ExDTALBgNVB' + 'AoMBEFjbWUxETAP\nBgNVBAMMCGFjbWUuY29tMRIwEAYDVQQHDAlTdW5ueXZhbGUwgZ8wDQYJKoZIhvcN\nAQEBBQADg' + 'Y0AMIGJAoGBAKuzYKfDZGA6DJgQru3wNUqv+S0hMZfP/jbp8ou/8UKu\nrNeX7cfCgt3yxoGCJYKmF6t5mvo76JY0MWw' + 'A53BxeP/oyXmJ93uHG5mFRAsVAUKs\ncVVb0Xi6ujxZGVdDWFV696L0BNOoHTfXmac6IBoZQzNNK4n1AATqwo+z7a0pf' + 'RrJ\nAgMBAAGjUDBOMB0GA1UdDgQWBBSKmi/ZKMuLN0ES7/jPa7q7jAjPiDAfBgNVHSME\nGDAWgBSKmi/ZKMuLN0ES7' + '/jPa7q7jAjPiDAMBgNVHRMEBTADAQH/MA0GCSqGSIb3\nDQEBDQUAA4GBAAg2a2kSn05NiUOuWOHwPUjW3wQRsGxPXtb' + 'hWMhmNdCfKKteM2+/\nLd/jz5F3qkOgGQ3UDgr3SHEoWhnLaJMF4a2tm6vL2rEIfPEK81KhTTRxSsAgMVbU\nJXBz1md' + '6Ur0HlgQC7d1CHC8/xi2DDwHopLyxhogaZUxy9IaRxUEa2vJW\n-----END CERTIFICATE-----\n'), +] + + def _sign_in(custom_token, api_key): body = {'token' : custom_token.decode(), 'returnSecureToken' : True} params = {'key' : api_key} @@ -52,6 +77,10 @@ def _sign_in_with_password(email, password, api_key): resp.raise_for_status() return resp.json().get('idToken') +def _random_string(length=10): + letters = string.ascii_lowercase + return ''.join(random.choice(letters) for i in range(length)) + def _random_id(): random_id = str(uuid.uuid4()).lower().replace('-', '') email = 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) @@ -477,6 +506,163 @@ def test_email_sign_in_with_settings(new_user_email_unverified, api_key): assert id_token is not None and len(id_token) > 0 assert auth.get_user(new_user_email_unverified.uid).email_verified + +@pytest.fixture(scope='module') +def oidc_provider(): + provider_config = _create_oidc_provider_config() + yield provider_config + auth.delete_oidc_provider_config(provider_config.provider_id) + + +def test_create_oidc_provider_config(oidc_provider): + assert isinstance(oidc_provider, auth.OIDCProviderConfig) + assert oidc_provider.client_id == 'OIDC_CLIENT_ID' + assert oidc_provider.issuer == 'https://oidc.com/issuer' + assert oidc_provider.display_name == 'OIDC_DISPLAY_NAME' + assert oidc_provider.enabled is True + + +def test_get_oidc_provider_config(oidc_provider): + provider_config = auth.get_oidc_provider_config(oidc_provider.provider_id) + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == oidc_provider.provider_id + assert provider_config.client_id == 'OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/issuer' + assert provider_config.display_name == 'OIDC_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_oidc_provider_configs(oidc_provider): + page = auth.list_oidc_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == oidc_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_oidc_provider_config(): + provider_config = _create_oidc_provider_config() + try: + provider_config = auth.update_oidc_provider_config( + provider_config.provider_id, + client_id='UPDATED_OIDC_CLIENT_ID', + issuer='https://oidc.com/updated_issuer', + display_name='UPDATED_OIDC_DISPLAY_NAME', + enabled=False) + assert provider_config.client_id == 'UPDATED_OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/updated_issuer' + assert provider_config.display_name == 'UPDATED_OIDC_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + auth.delete_oidc_provider_config(provider_config.provider_id) + + +def test_delete_oidc_provider_config(): + provider_config = _create_oidc_provider_config() + auth.delete_oidc_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + auth.get_oidc_provider_config(provider_config.provider_id) + + +@pytest.fixture(scope='module') +def saml_provider(): + provider_config = _create_saml_provider_config() + yield provider_config + auth.delete_saml_provider_config(provider_config.provider_id) + + +def test_create_saml_provider_config(saml_provider): + assert isinstance(saml_provider, auth.SAMLProviderConfig) + assert saml_provider.idp_entity_id == 'IDP_ENTITY_ID' + assert saml_provider.sso_url == 'https://example.com/login' + assert saml_provider.x509_certificates == [X509_CERTIFICATES[0]] + assert saml_provider.rp_entity_id == 'RP_ENTITY_ID' + assert saml_provider.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert saml_provider.display_name == 'SAML_DISPLAY_NAME' + assert saml_provider.enabled is True + + +def test_get_saml_provider_config(saml_provider): + provider_config = auth.get_saml_provider_config(saml_provider.provider_id) + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == saml_provider.provider_id + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == [X509_CERTIFICATES[0]] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert provider_config.display_name == 'SAML_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_saml_provider_configs(saml_provider): + page = auth.list_saml_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == saml_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_saml_provider_config(): + provider_config = _create_saml_provider_config() + try: + provider_config = auth.update_saml_provider_config( + provider_config.provider_id, + idp_entity_id='UPDATED_IDP_ENTITY_ID', + sso_url='https://example.com/updated_login', + x509_certificates=[X509_CERTIFICATES[1]], + rp_entity_id='UPDATED_RP_ENTITY_ID', + callback_url='https://updatedProjectId.firebaseapp.com/__/auth/handler', + display_name='UPDATED_SAML_DISPLAY_NAME', + enabled=False) + assert provider_config.idp_entity_id == 'UPDATED_IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/updated_login' + assert provider_config.x509_certificates == [X509_CERTIFICATES[1]] + assert provider_config.rp_entity_id == 'UPDATED_RP_ENTITY_ID' + assert provider_config.callback_url == ('https://updatedProjectId.firebaseapp.com/' + '__/auth/handler') + assert provider_config.display_name == 'UPDATED_SAML_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + auth.delete_saml_provider_config(provider_config.provider_id) + + +def test_delete_saml_provider_config(): + provider_config = _create_saml_provider_config() + auth.delete_saml_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + auth.get_saml_provider_config(provider_config.provider_id) + + +def _create_oidc_provider_config(): + provider_id = 'oidc.{0}'.format(_random_string()) + return auth.create_oidc_provider_config( + provider_id=provider_id, + client_id='OIDC_CLIENT_ID', + issuer='https://oidc.com/issuer', + display_name='OIDC_DISPLAY_NAME', + enabled=True) + + +def _create_saml_provider_config(): + provider_id = 'saml.{0}'.format(_random_string()) + return auth.create_saml_provider_config( + provider_id=provider_id, + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/login', + x509_certificates=[X509_CERTIFICATES[0]], + rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='SAML_DISPLAY_NAME', + enabled=True) + + class CredentialWrapper(credentials.Base): """A custom Firebase credential that wraps an OAuth2 token.""" diff --git a/integration/test_tenant_mgt.py b/integration/test_tenant_mgt.py new file mode 100644 index 000000000..c9eefd96e --- /dev/null +++ b/integration/test_tenant_mgt.py @@ -0,0 +1,417 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for firebase_admin.tenant_mgt module.""" + +import random +import string +import time +from urllib import parse +import uuid + +import requests +import pytest + +from firebase_admin import auth +from firebase_admin import tenant_mgt +from integration import test_auth + + +ACTION_LINK_CONTINUE_URL = 'http://localhost?a=1&b=5#f=1' +ACTION_CODE_SETTINGS = auth.ActionCodeSettings(ACTION_LINK_CONTINUE_URL) +VERIFY_TOKEN_URL = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' + + +@pytest.fixture(scope='module') +def sample_tenant(): + tenant = tenant_mgt.create_tenant( + display_name='admin-python-tenant', + allow_password_sign_up=True, + enable_email_link_sign_in=True) + yield tenant + tenant_mgt.delete_tenant(tenant.tenant_id) + + +@pytest.fixture(scope='module') +def tenant_user(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + email = _random_email() + user = client.create_user(email=email) + yield user + client.delete_user(user.uid) + + +def test_get_tenant(sample_tenant): + tenant = tenant_mgt.get_tenant(sample_tenant.tenant_id) + assert isinstance(tenant, tenant_mgt.Tenant) + assert tenant.tenant_id == sample_tenant.tenant_id + assert tenant.display_name == 'admin-python-tenant' + assert tenant.allow_password_sign_up is True + assert tenant.enable_email_link_sign_in is True + + +def test_list_tenants(sample_tenant): + page = tenant_mgt.list_tenants() + result = None + for tenant in page.iterate_all(): + if tenant.tenant_id == sample_tenant.tenant_id: + result = tenant + break + assert isinstance(result, tenant_mgt.Tenant) + assert result.tenant_id == sample_tenant.tenant_id + assert result.display_name == 'admin-python-tenant' + assert result.allow_password_sign_up is True + assert result.enable_email_link_sign_in is True + + +def test_update_tenant(): + tenant = tenant_mgt.create_tenant( + display_name='py-update-test', allow_password_sign_up=True, enable_email_link_sign_in=True) + try: + tenant = tenant_mgt.update_tenant( + tenant.tenant_id, display_name='updated-py-tenant', allow_password_sign_up=False, + enable_email_link_sign_in=False) + assert isinstance(tenant, tenant_mgt.Tenant) + assert tenant.tenant_id == tenant.tenant_id + assert tenant.display_name == 'updated-py-tenant' + assert tenant.allow_password_sign_up is False + assert tenant.enable_email_link_sign_in is False + finally: + tenant_mgt.delete_tenant(tenant.tenant_id) + + +def test_delete_tenant(): + tenant = tenant_mgt.create_tenant(display_name='py-delete-test') + tenant_mgt.delete_tenant(tenant.tenant_id) + with pytest.raises(tenant_mgt.TenantNotFoundError): + tenant_mgt.get_tenant(tenant.tenant_id) + + +def test_auth_for_client(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + assert isinstance(client, auth.Client) + assert client.tenant_id == sample_tenant.tenant_id + + +def test_custom_token(sample_tenant, api_key): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + custom_token = client.create_custom_token('user1') + id_token = _sign_in(custom_token, sample_tenant.tenant_id, api_key) + claims = client.verify_id_token(id_token) + assert claims['uid'] == 'user1' + assert claims['firebase']['tenant'] == sample_tenant.tenant_id + + +def test_custom_token_with_claims(sample_tenant, api_key): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + custom_token = client.create_custom_token('user1', {'premium': True}) + id_token = _sign_in(custom_token, sample_tenant.tenant_id, api_key) + claims = client.verify_id_token(id_token) + assert claims['uid'] == 'user1' + assert claims['premium'] is True + assert claims['firebase']['tenant'] == sample_tenant.tenant_id + + +def test_create_user(sample_tenant, tenant_user): + assert tenant_user.tenant_id == sample_tenant.tenant_id + + +def test_update_user(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = client.create_user() + try: + email = _random_email() + phone = _random_phone() + user = client.update_user(user.uid, email=email, phone_number=phone) + assert user.tenant_id == sample_tenant.tenant_id + assert user.email == email + assert user.phone_number == phone + finally: + client.delete_user(user.uid) + + +def test_get_user(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = client.get_user(tenant_user.uid) + assert user.uid == tenant_user.uid + assert user.tenant_id == sample_tenant.tenant_id + + +def test_list_users(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + page = client.list_users() + result = None + for user in page.iterate_all(): + if user.uid == tenant_user.uid: + result = user + break + assert result.tenant_id == sample_tenant.tenant_id + + +def test_set_custom_user_claims(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + client.set_custom_user_claims(tenant_user.uid, {'premium': True}) + user = client.get_user(tenant_user.uid) + assert user.custom_claims == {'premium': True} + + +def test_delete_user(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = client.create_user() + client.delete_user(user.uid) + with pytest.raises(auth.UserNotFoundError): + client.get_user(user.uid) + + +def test_revoke_refresh_tokens(sample_tenant, tenant_user): + valid_since = int(time.time()) + time.sleep(1) + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + client.revoke_refresh_tokens(tenant_user.uid) + user = client.get_user(tenant_user.uid) + assert user.tokens_valid_after_timestamp > valid_since + + +def test_password_reset_link(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + link = client.generate_password_reset_link(tenant_user.email, ACTION_CODE_SETTINGS) + assert _tenant_id_from_link(link) == sample_tenant.tenant_id + + +def test_email_verification_link(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + link = client.generate_email_verification_link(tenant_user.email, ACTION_CODE_SETTINGS) + assert _tenant_id_from_link(link) == sample_tenant.tenant_id + + +def test_sign_in_with_email_link(sample_tenant, tenant_user): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + link = client.generate_sign_in_with_email_link(tenant_user.email, ACTION_CODE_SETTINGS) + assert _tenant_id_from_link(link) == sample_tenant.tenant_id + + +def test_import_users(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + user = auth.ImportUserRecord( + uid=_random_uid(), email=_random_email()) + result = client.import_users([user]) + try: + assert result.success_count == 1 + assert result.failure_count == 0 + saved_user = client.get_user(user.uid) + assert saved_user.email == user.email + finally: + client.delete_user(user.uid) + + +@pytest.fixture(scope='module') +def oidc_provider(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_oidc_provider_config(client) + yield provider_config + client.delete_oidc_provider_config(provider_config.provider_id) + + +def test_create_oidc_provider_config(oidc_provider): + assert isinstance(oidc_provider, auth.OIDCProviderConfig) + assert oidc_provider.client_id == 'OIDC_CLIENT_ID' + assert oidc_provider.issuer == 'https://oidc.com/issuer' + assert oidc_provider.display_name == 'OIDC_DISPLAY_NAME' + assert oidc_provider.enabled is True + + +def test_get_oidc_provider_config(sample_tenant, oidc_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = client.get_oidc_provider_config(oidc_provider.provider_id) + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == oidc_provider.provider_id + assert provider_config.client_id == 'OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/issuer' + assert provider_config.display_name == 'OIDC_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_oidc_provider_configs(sample_tenant, oidc_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + page = client.list_oidc_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == oidc_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_oidc_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_oidc_provider_config(client) + try: + provider_config = client.update_oidc_provider_config( + provider_config.provider_id, + client_id='UPDATED_OIDC_CLIENT_ID', + issuer='https://oidc.com/updated_issuer', + display_name='UPDATED_OIDC_DISPLAY_NAME', + enabled=False) + assert provider_config.client_id == 'UPDATED_OIDC_CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/updated_issuer' + assert provider_config.display_name == 'UPDATED_OIDC_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + client.delete_oidc_provider_config(provider_config.provider_id) + + +def test_delete_oidc_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_oidc_provider_config(client) + client.delete_oidc_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + client.get_oidc_provider_config(provider_config.provider_id) + + +@pytest.fixture(scope='module') +def saml_provider(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_saml_provider_config(client) + yield provider_config + client.delete_saml_provider_config(provider_config.provider_id) + + +def test_create_saml_provider_config(saml_provider): + assert isinstance(saml_provider, auth.SAMLProviderConfig) + assert saml_provider.idp_entity_id == 'IDP_ENTITY_ID' + assert saml_provider.sso_url == 'https://example.com/login' + assert saml_provider.x509_certificates == [test_auth.X509_CERTIFICATES[0]] + assert saml_provider.rp_entity_id == 'RP_ENTITY_ID' + assert saml_provider.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert saml_provider.display_name == 'SAML_DISPLAY_NAME' + assert saml_provider.enabled is True + + +def test_get_saml_provider_config(sample_tenant, saml_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = client.get_saml_provider_config(saml_provider.provider_id) + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == saml_provider.provider_id + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == [test_auth.X509_CERTIFICATES[0]] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + assert provider_config.display_name == 'SAML_DISPLAY_NAME' + assert provider_config.enabled is True + + +def test_list_saml_provider_configs(sample_tenant, saml_provider): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + page = client.list_saml_provider_configs() + result = None + for provider_config in page.iterate_all(): + if provider_config.provider_id == saml_provider.provider_id: + result = provider_config + break + + assert result is not None + + +def test_update_saml_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_saml_provider_config(client) + try: + provider_config = client.update_saml_provider_config( + provider_config.provider_id, + idp_entity_id='UPDATED_IDP_ENTITY_ID', + sso_url='https://example.com/updated_login', + x509_certificates=[test_auth.X509_CERTIFICATES[1]], + rp_entity_id='UPDATED_RP_ENTITY_ID', + callback_url='https://updatedProjectId.firebaseapp.com/__/auth/handler', + display_name='UPDATED_SAML_DISPLAY_NAME', + enabled=False) + assert provider_config.idp_entity_id == 'UPDATED_IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/updated_login' + assert provider_config.x509_certificates == [test_auth.X509_CERTIFICATES[1]] + assert provider_config.rp_entity_id == 'UPDATED_RP_ENTITY_ID' + assert provider_config.callback_url == ('https://updatedProjectId.firebaseapp.com/' + '__/auth/handler') + assert provider_config.display_name == 'UPDATED_SAML_DISPLAY_NAME' + assert provider_config.enabled is False + finally: + client.delete_saml_provider_config(provider_config.provider_id) + + +def test_delete_saml_provider_config(sample_tenant): + client = tenant_mgt.auth_for_tenant(sample_tenant.tenant_id) + provider_config = _create_saml_provider_config(client) + client.delete_saml_provider_config(provider_config.provider_id) + with pytest.raises(auth.ConfigurationNotFoundError): + client.get_saml_provider_config(provider_config.provider_id) + + +def _create_oidc_provider_config(client): + provider_id = 'oidc.{0}'.format(_random_string()) + return client.create_oidc_provider_config( + provider_id=provider_id, + client_id='OIDC_CLIENT_ID', + issuer='https://oidc.com/issuer', + display_name='OIDC_DISPLAY_NAME', + enabled=True) + + +def _create_saml_provider_config(client): + provider_id = 'saml.{0}'.format(_random_string()) + return client.create_saml_provider_config( + provider_id=provider_id, + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/login', + x509_certificates=[test_auth.X509_CERTIFICATES[0]], + rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='SAML_DISPLAY_NAME', + enabled=True) + + +def _random_uid(): + return str(uuid.uuid4()).lower().replace('-', '') + + +def _random_email(): + random_id = str(uuid.uuid4()).lower().replace('-', '') + return 'test{0}@example.{1}.com'.format(random_id[:12], random_id[12:]) + + +def _random_phone(): + return '+1' + ''.join([str(random.randint(0, 9)) for _ in range(0, 10)]) + + +def _random_string(length=10): + letters = string.ascii_lowercase + return ''.join(random.choice(letters) for i in range(length)) + + +def _tenant_id_from_link(link): + query = parse.urlparse(link).query + parsed_query = parse.parse_qs(query) + return parsed_query['tenantId'][0] + + +def _sign_in(custom_token, tenant_id, api_key): + body = { + 'token' : custom_token.decode(), + 'returnSecureToken' : True, + 'tenantId': tenant_id, + } + params = {'key' : api_key} + resp = requests.request('post', VERIFY_TOKEN_URL, params=params, json=body) + resp.raise_for_status() + return resp.json().get('idToken') diff --git a/snippets/auth/index.py b/snippets/auth/index.py index b1c091064..428c54e09 100644 --- a/snippets/auth/index.py +++ b/snippets/auth/index.py @@ -25,6 +25,7 @@ from firebase_admin import credentials from firebase_admin import auth from firebase_admin import exceptions +from firebase_admin import tenant_mgt sys.path.append("lib") @@ -634,6 +635,418 @@ def send_custom_email(email, link): del email del link +def create_saml_provider_config(): + # [START create_saml_provider] + saml = auth.create_saml_provider_config( + display_name='SAML provider name', + enabled=True, + provider_id='saml.myProvider', + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/saml/sso/1234/', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT1...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + ], + rp_entity_id='P_ENTITY_ID', + callback_url='https://project-id.firebaseapp.com/__/auth/handler') + + print('Created new SAML provider:', saml.provider_id) + # [END create_saml_provider] + +def update_saml_provider_config(): + # [START update_saml_provider] + saml = auth.update_saml_provider_config( + 'saml.myProvider', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT3...\n-----END CERTIFICATE-----', + ]) + + print('Updated SAML provider:', saml.provider_id) + # [END update_saml_provider] + +def get_saml_provider_config(): + # [START get_saml_provider] + saml = auth.get_saml_provider_config('saml.myProvider') + print(saml.display_name, saml.enabled) + # [END get_saml_provider] + +def delete_saml_provider_config(): + # [START delete_saml_provider] + auth.delete_saml_provider_config('saml.myProvider') + # [END delete_saml_provider] + +def list_saml_provider_configs(): + # [START list_saml_providers] + for saml in auth.list_saml_provider_configs('nextPageToken').iterate_all(): + print(saml.provider_id) + # [END list_saml_providers] + +def create_oidc_provider_config(): + # [START create_oidc_provider] + oidc = auth.create_oidc_provider_config( + display_name='OIDC provider name', + enabled=True, + provider_id='oidc.myProvider', + client_id='CLIENT_ID2', + issuer='https://oidc.com/CLIENT_ID2') + + print('Created new OIDC provider:', oidc.provider_id) + # [END create_oidc_provider] + +def update_oidc_provider_config(): + # [START update_oidc_provider] + oidc = auth.update_oidc_provider_config( + 'oidc.myProvider', + client_id='CLIENT_ID', + issuer='https://oidc.com') + + print('Updated OIDC provider:', oidc.provider_id) + # [END update_oidc_provider] + +def get_oidc_provider_config(): + # [START get_oidc_provider] + oidc = auth.get_oidc_provider_config('oidc.myProvider') + + print(oidc.display_name, oidc.enabled) + # [END get_oidc_provider] + +def delete_oidc_provider_config(): + # [START delete_oidc_provider] + auth.delete_oidc_provider_config('oidc.myProvider') + # [END delete_oidc_provider] + +def list_oidc_provider_configs(): + # [START list_oidc_providers] + for oidc in auth.list_oidc_provider_configs('nextPageToken').iterate_all(): + print(oidc.provider_id) + # [END list_oidc_providers] + +def get_tenant_client(tenant_id): + # [START get_tenant_client] + from firebase_admin import tenant_mgt + + tenant_client = tenant_mgt.auth_for_tenant(tenant_id) + # [END get_tenant_client] + return tenant_client + +def get_tenant(tenant_id): + # [START get_tenant] + tenant = tenant_mgt.get_tenant(tenant_id) + + print('Retreieved tenant:', tenant.tenant_id) + # [END get_tenant] + +def create_tenant(): + # [START create_tenant] + tenant = tenant_mgt.create_tenant( + display_name='myTenant1', + enable_email_link_sign_in=True, + allow_password_sign_up=True) + + print('Created tenant:', tenant.tenant_id) + # [END create_tenant] + +def update_tenant(tenant_id): + # [START update_tenant] + tenant = tenant_mgt.update_tenant( + tenant_id, + display_name='updatedName', + allow_password_sign_up=False) # Disable email provider + + print('Updated tenant:', tenant.tenant_id) + # [END update_tenant] + +def delete_tenant(tenant_id): + # [START delete_tenant] + tenant_mgt.delete_tenant(tenant_id) + # [END delete_tenant] + +def list_tenants(): + # [START list_tenants] + for tenant in tenant_mgt.list_tenants().iterate_all(): + print('Retrieved tenant:', tenant.tenant_id) + # [END list_tenants] + +def create_provider_tenant(): + # [START get_tenant_client_short] + tenant_client = tenant_mgt.auth_for_tenant('TENANT-ID') + # [END get_tenant_client_short] + + # [START create_saml_provider_tenant] + saml = tenant_client.create_saml_provider_config( + display_name='SAML provider name', + enabled=True, + provider_id='saml.myProvider', + idp_entity_id='IDP_ENTITY_ID', + sso_url='https://example.com/saml/sso/1234/', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT1...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + ], + rp_entity_id='P_ENTITY_ID', + callback_url='https://project-id.firebaseapp.com/__/auth/handler') + + print('Created new SAML provider:', saml.provider_id) + # [END create_saml_provider_tenant] + +def update_provider_tenant(tenant_client): + # [START update_saml_provider_tenant] + saml = tenant_client.update_saml_provider_config( + 'saml.myProvider', + x509_certificates=[ + '-----BEGIN CERTIFICATE-----\nCERT2...\n-----END CERTIFICATE-----', + '-----BEGIN CERTIFICATE-----\nCERT3...\n-----END CERTIFICATE-----', + ]) + + print('Updated SAML provider:', saml.provider_id) + # [END update_saml_provider_tenant] + +def get_provider_tenant(tennat_client): + # [START get_saml_provider_tenant] + saml = tennat_client.get_saml_provider_config('saml.myProvider') + print(saml.display_name, saml.enabled) + # [END get_saml_provider_tenant] + +def list_provider_configs_tenant(tenant_client): + # [START list_saml_providers_tenant] + for saml in tenant_client.list_saml_provider_configs('nextPageToken').iterate_all(): + print(saml.provider_id) + # [END list_saml_providers_tenant] + +def delete_provider_config_tenant(tenant_client): + # [START delete_saml_provider_tenant] + tenant_client.delete_saml_provider_config('saml.myProvider') + # [END delete_saml_provider_tenant] + +def get_user_tenant(tenant_client): + uid = 'some_string_uid' + + # [START get_user_tenant] + # Get an auth.Client from tenant_mgt.auth_for_tenant() + user = tenant_client.get_user(uid) + print('Successfully fetched user data:', user.uid) + # [END get_user_tenant] + +def get_user_by_email_tenant(tenant_client): + email = 'some@email.com' + # [START get_user_by_email_tenant] + user = tenant_client.get_user_by_email(email) + print('Successfully fetched user data:', user.uid) + # [END get_user_by_email_tenant] + +def create_user_tenant(tenant_client): + # [START create_user_tenant] + user = tenant_client.create_user( + email='user@example.com', + email_verified=False, + phone_number='+15555550100', + password='secretPassword', + display_name='John Doe', + photo_url='http://www.example.com/12345678/photo.png', + disabled=False) + print('Sucessfully created new user:', user.uid) + # [END create_user_tenant] + +def update_user_tenant(tenant_client, uid): + # [START update_user_tenant] + user = tenant_client.update_user( + uid, + email='user@example.com', + phone_number='+15555550100', + email_verified=True, + password='newPassword', + display_name='John Doe', + photo_url='http://www.example.com/12345678/photo.png', + disabled=True) + print('Sucessfully updated user:', user.uid) + # [END update_user_tenant] + +def delete_user_tenant(tenant_client, uid): + # [START delete_user_tenant] + tenant_client.delete_user(uid) + print('Successfully deleted user') + # [END delete_user_tenant] + +def list_users_tenant(tenant_client): + # [START list_all_users_tenant] + # Note, behind the scenes, the iterator will retrive 1000 users at a time through the API + for user in tenant_client.list_users().iterate_all(): + print('User: ' + user.uid) + + # Iterating by pages of 1000 users at a time. + page = tenant_client.list_users() + while page: + for user in page.users: + print('User: ' + user.uid) + # Get next batch of users. + page = page.get_next_page() + # [END list_all_users_tenant] + +def import_with_hmac_tenant(tenant_client): + # [START import_with_hmac_tenant] + users = [ + auth.ImportUserRecord( + uid='uid1', + email='user1@example.com', + password_hash=b'password_hash_1', + password_salt=b'salt1' + ), + auth.ImportUserRecord( + uid='uid2', + email='user2@example.com', + password_hash=b'password_hash_2', + password_salt=b'salt2' + ), + ] + + hash_alg = auth.UserImportHash.hmac_sha256(key=b'secret') + try: + result = tenant_client.import_users(users, hash_alg=hash_alg) + for err in result.errors: + print('Failed to import user:', err.reason) + except exceptions.FirebaseError as error: + print('Error importing users:', error) + # [END import_with_hmac_tenant] + +def import_without_password_tenant(tenant_client): + # [START import_without_password_tenant] + users = [ + auth.ImportUserRecord( + uid='some-uid', + display_name='John Doe', + email='johndoe@gmail.com', + photo_url='http://www.example.com/12345678/photo.png', + email_verified=True, + phone_number='+11234567890', + custom_claims={'admin': True}, # set this user as admin + provider_data=[ # user with SAML provider + auth.UserProvider( + uid='saml-uid', + email='johndoe@gmail.com', + display_name='John Doe', + photo_url='http://www.example.com/12345678/photo.png', + provider_id='saml.acme' + ) + ], + ), + ] + try: + result = tenant_client.import_users(users) + for err in result.errors: + print('Failed to import user:', err.reason) + except exceptions.FirebaseError as error: + print('Error importing users:', error) + # [END import_without_password_tenant] + +def verify_id_token_tenant(tenant_client, id_token): + # [START verify_id_token_tenant] + # id_token comes from the client app + try: + decoded_token = tenant_client.verify_id_token(id_token) + + # This should be set to TENANT-ID. Otherwise TenantIdMismatchError error raised. + print('Verified ID token from tenant:', decoded_token['firebase']['tenant']) + except tenant_mgt.TenantIdMismatchError: + # Token revoked, inform the user to reauthenticate or signOut(). + pass + # [END verify_id_token_tenant] + +def verify_id_token_access_control_tenant(id_token): + # [START id_token_access_control_tenant] + decoded_token = auth.verify_id_token(id_token) + + tenant = decoded_token['firebase']['tenant'] + if tenant == 'TENANT-ID1': + # Allow appropriate level of access for TENANT-ID1. + pass + elif tenant == 'TENANT-ID2': + # Allow appropriate level of access for TENANT-ID2. + pass + else: + # Access not allowed -- Handle error + pass + # [END id_token_access_control_tenant] + +def revoke_refresh_tokens_tenant(tenant_client, uid): + # [START revoke_tokens_tenant] + # Revoke all refresh tokens for a specified user in a specified tenant for whatever reason. + # Retrieve the timestamp of the revocation, in seconds since the epoch. + tenant_client.revoke_refresh_tokens(uid) + + user = tenant_client.get_user(uid) + # Convert to seconds as the auth_time in the token claims is in seconds. + revocation_second = user.tokens_valid_after_timestamp / 1000 + print('Tokens revoked at: {0}'.format(revocation_second)) + # [END revoke_tokens_tenant] + +def verify_id_token_and_check_revoked_tenant(tenant_client, id_token): + # [START verify_id_token_and_check_revoked_tenant] + # Verify the ID token for a specific tenant while checking if the token is revoked. + try: + # Verify the ID token while checking if the token is revoked by + # passing check_revoked=True. + decoded_token = tenant_client.verify_id_token(id_token, check_revoked=True) + # Token is valid and not revoked. + uid = decoded_token['uid'] + except tenant_mgt.TenantIdMismatchError: + # Token belongs to a different tenant. + pass + except auth.RevokedIdTokenError: + # Token revoked, inform the user to reauthenticate or signOut(). + pass + except auth.InvalidIdTokenError: + # Token is invalid + pass + # [END verify_id_token_and_check_revoked_tenant] + +def custom_claims_set_tenant(tenant_client, uid): + # [START set_custom_user_claims_tenant] + # Set admin privilege on the user corresponding to uid. + tenant_client.set_custom_user_claims(uid, {'admin': True}) + # The new custom claims will propagate to the user's ID token the + # next time a new one is issued. + # [END set_custom_user_claims_tenant] + +def custom_claims_verify_tenant(tenant_client, id_token): + # [START verify_custom_claims_tenant] + # Verify the ID token first. + claims = tenant_client.verify_id_token(id_token) + if claims['admin'] is True: + # Allow access to requested admin resource. + pass + # [END verify_custom_claims_tenant] + +def custom_claims_read_tenant(tenant_client, uid): + # [START read_custom_user_claims_tenant] + # Lookup the user associated with the specified uid. + user = tenant_client.get_user(uid) + + # The claims can be accessed on the user record. + print(user.custom_claims.get('admin')) + # [END read_custom_user_claims_tenant] + +def generate_email_verification_link_tenant(tenant_client): + # [START email_verification_link_tenant] + action_code_settings = auth.ActionCodeSettings( + url='https://www.example.com/checkout?cartId=1234', + handle_code_in_app=True, + ios_bundle_id='com.example.ios', + android_package_name='com.example.android', + android_install_app=True, + android_minimum_version='12', + # FDL custom domain. + dynamic_link_domain='coolapp.page.link', + ) + + email = 'user@example.com' + link = tenant_client.generate_email_verification_link(email, action_code_settings) + # Construct email from a template embedding the link, and send + # using a custom SMTP server. + send_custom_email(email, link) + # [END email_verification_link_tenant] + + initialize_sdk_with_service_account() initialize_sdk_with_application_default() #initialize_sdk_with_refresh_token() diff --git a/tests/data/list_oidc_provider_configs.json b/tests/data/list_oidc_provider_configs.json new file mode 100644 index 000000000..b2b381304 --- /dev/null +++ b/tests/data/list_oidc_provider_configs.json @@ -0,0 +1,18 @@ +{ + "oauthIdpConfigs": [ + { + "name":"projects/mock-project-id/oauthIdpConfigs/oidc.provider0", + "clientId": "CLIENT_ID", + "issuer": "https://oidc.com/issuer", + "displayName": "oidcProviderName", + "enabled": true + }, + { + "name":"projects/mock-project-id/oauthIdpConfigs/oidc.provider1", + "clientId": "CLIENT_ID", + "issuer": "https://oidc.com/issuer", + "displayName": "oidcProviderName", + "enabled": true + } + ] +} diff --git a/tests/data/list_saml_provider_configs.json b/tests/data/list_saml_provider_configs.json new file mode 100644 index 000000000..b568e1e09 --- /dev/null +++ b/tests/data/list_saml_provider_configs.json @@ -0,0 +1,40 @@ +{ + "inboundSamlConfigs": [ + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider0", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + }, + { + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider1", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true + } + ] +} diff --git a/tests/data/oidc_provider_config.json b/tests/data/oidc_provider_config.json new file mode 100644 index 000000000..89cf3eacf --- /dev/null +++ b/tests/data/oidc_provider_config.json @@ -0,0 +1,7 @@ +{ + "name":"projects/mock-project-id/oauthIdpConfigs/oidc.provider", + "clientId": "CLIENT_ID", + "issuer": "https://oidc.com/issuer", + "displayName": "oidcProviderName", + "enabled": true +} diff --git a/tests/data/saml_provider_config.json b/tests/data/saml_provider_config.json new file mode 100644 index 000000000..577340f2a --- /dev/null +++ b/tests/data/saml_provider_config.json @@ -0,0 +1,18 @@ +{ + "name": "projects/mock-project-id/inboundSamlConfigs/saml.provider", + "idpConfig": { + "idpEntityId": "IDP_ENTITY_ID", + "ssoUrl": "https://example.com/login", + "signRequest": true, + "idpCertificates": [ + {"x509Certificate": "CERT1"}, + {"x509Certificate": "CERT2"} + ] + }, + "spConfig": { + "spEntityId": "RP_ENTITY_ID", + "callbackUri": "https://projectId.firebaseapp.com/__/auth/handler" + }, + "displayName": "samlProviderName", + "enabled": true +} \ No newline at end of file diff --git a/tests/test_auth_providers.py b/tests/test_auth_providers.py new file mode 100644 index 000000000..124aea3cc --- /dev/null +++ b/tests/test_auth_providers.py @@ -0,0 +1,732 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin._auth_providers module.""" + +import json + +import pytest + +import firebase_admin +from firebase_admin import auth +from firebase_admin import exceptions +from firebase_admin import _auth_providers +from tests import testutils + +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +OIDC_PROVIDER_CONFIG_RESPONSE = testutils.resource('oidc_provider_config.json') +SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') +LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') + +CONFIG_NOT_FOUND_RESPONSE = """{ + "error": { + "message": "CONFIGURATION_NOT_FOUND" + } +}""" + +INVALID_PROVIDER_IDS = [None, True, False, 1, 0, list(), tuple(), dict(), ''] + + +@pytest.fixture(scope='module') +def user_mgt_app(): + app = firebase_admin.initialize_app(testutils.MockCredential(), name='providerConfig', + options={'projectId': 'mock-project-id'}) + yield app + firebase_admin.delete_app(app) + + +def _instrument_provider_mgt(app, status, payload): + client = auth._get_client(app) + provider_manager = client._provider_manager + recorder = [] + provider_manager.http_client.session.mount( + _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +class TestOIDCProviderConfig: + + VALID_CREATE_OPTIONS = { + 'provider_id': 'oidc.provider', + 'client_id': 'CLIENT_ID', + 'issuer': 'https://oidc.com/issuer', + 'display_name': 'oidcProviderName', + 'enabled': True, + } + + OIDC_CONFIG_REQUEST = { + 'displayName': 'oidcProviderName', + 'enabled': True, + 'clientId': 'CLIENT_ID', + 'issuer': 'https://oidc.com/issuer', + } + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) + def test_get_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.get_oidc_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid OIDC provider ID') + + def test_get(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.get_oidc_provider_config('oidc.provider', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + + @pytest.mark.parametrize('invalid_opts', [ + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, + {'client_id': None}, {'client_id': ''}, + {'issuer': None}, {'issuer': ''}, {'issuer': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_create_invalid_args(self, user_mgt_app, invalid_opts): + options = dict(self.VALID_CREATE_OPTIONS) + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.create_oidc_provider_config(**options, app=user_mgt_app) + + def test_create(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.create_oidc_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == self.OIDC_CONFIG_REQUEST + + def test_create_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + del options['display_name'] + del options['enabled'] + want = dict(self.OIDC_CONFIG_REQUEST) + del want['displayName'] + del want['enabled'] + + provider_config = auth.create_oidc_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + def test_create_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + options['display_name'] = '' + options['enabled'] = False + want = dict(self.OIDC_CONFIG_REQUEST) + want['displayName'] = '' + want['enabled'] = False + + provider_config = auth.create_oidc_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/oauthIdpConfigs?oauthIdpConfigId=oidc.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + @pytest.mark.parametrize('invalid_opts', [ + {}, + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'saml.provider'}, + {'client_id': ''}, + {'issuer': ''}, {'issuer': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_update_invalid_args(self, user_mgt_app, invalid_opts): + options = {'provider_id': 'oidc.provider'} + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.update_oidc_provider_config(**options, app=user_mgt_app) + + def test_update(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_oidc_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = ['clientId', 'displayName', 'enabled', 'issuer'] + assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == self.OIDC_CONFIG_REQUEST + + def test_update_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_oidc_provider_config( + 'oidc.provider', display_name='oidcProviderName', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask=displayName'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == {'displayName': 'oidcProviderName'} + + def test_update_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_oidc_provider_config( + 'oidc.provider', display_name=auth.DELETE_ATTRIBUTE, enabled=False, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = ['displayName', 'enabled'] + assert req.url == '{0}/oauthIdpConfigs/oidc.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == {'displayName': None, 'enabled': False} + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['saml.provider']) + def test_delete_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.delete_oidc_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid OIDC provider ID') + + def test_delete(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, '{}') + + auth.delete_oidc_provider_config('oidc.provider', app=user_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs/oidc.provider') + + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + def test_invalid_max_results(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_oidc_provider_configs(max_results=arg, app=user_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + def test_invalid_page_token(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_oidc_provider_configs(page_token=arg, app=user_mgt_app) + + def test_list_single_page(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, LIST_OIDC_PROVIDER_CONFIGS_RESPONSE) + page = auth.list_oidc_provider_configs(app=user_mgt_app) + + self._assert_page(page) + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/oauthIdpConfigs?pageSize=100') + + def test_list_multiple_pages(self, user_mgt_app): + sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'oauthIdpConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_oidc_provider_configs(max_results=10, app=user_mgt_app) + + self._assert_page(page, next_page_token='token') + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'oauthIdpConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + + self._assert_page(page, count=1, start=2) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=10&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + def test_paged_iteration(self, user_mgt_app): + sample_response = json.loads(OIDC_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'oauthIdpConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_oidc_provider_configs(app=user_mgt_app) + iterator = page.iterate_all() + + for index in range(2): + provider_config = next(iterator) + assert provider_config.provider_id == 'oidc.provider{0}'.format(index) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'oauthIdpConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + + provider_config = next(iterator) + assert provider_config.provider_id == 'oidc.provider2' + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/oauthIdpConfigs?pageSize=100&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + with pytest.raises(StopIteration): + next(iterator) + + def test_list_empty_response(self, user_mgt_app): + response = {'oauthIdpConfigs': []} + _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_oidc_provider_configs(app=user_mgt_app) + assert len(page.provider_configs) == 0 + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 0 + + def test_list_error(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.list_oidc_provider_configs(app=user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + + def test_config_not_found(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) + + with pytest.raises(auth.ConfigurationNotFoundError) as excinfo: + auth.get_oidc_provider_config('oidc.provider', app=user_mgt_app) + + error_msg = 'No auth provider found for the given identifier (CONFIGURATION_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_provider_config(self, provider_config, want_id='oidc.provider'): + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'oidcProviderName' + assert provider_config.enabled is True + assert provider_config.issuer == 'https://oidc.com/issuer' + assert provider_config.client_id == 'CLIENT_ID' + + def _assert_page(self, page, count=2, start=0, next_page_token=''): + assert isinstance(page, auth.ListProviderConfigsPage) + index = start + assert len(page.provider_configs) == count + for provider_config in page.provider_configs: + self._assert_provider_config(provider_config, want_id='oidc.provider{0}'.format(index)) + index += 1 + + if next_page_token: + assert page.next_page_token == next_page_token + assert page.has_next_page is True + else: + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + + +class TestSAMLProviderConfig: + + VALID_CREATE_OPTIONS = { + 'provider_id': 'saml.provider', + 'idp_entity_id': 'IDP_ENTITY_ID', + 'sso_url': 'https://example.com/login', + 'x509_certificates': ['CERT1', 'CERT2'], + 'rp_entity_id': 'RP_ENTITY_ID', + 'callback_url': 'https://projectId.firebaseapp.com/__/auth/handler', + 'display_name': 'samlProviderName', + 'enabled': True, + } + + SAML_CONFIG_REQUEST = { + 'displayName': 'samlProviderName', + 'enabled': True, + 'idpConfig': { + 'idpEntityId': 'IDP_ENTITY_ID', + 'ssoUrl': 'https://example.com/login', + 'idpCertificates': [{'x509Certificate': 'CERT1'}, {'x509Certificate': 'CERT2'}] + }, + 'spConfig': { + 'spEntityId': 'RP_ENTITY_ID', + 'callbackUri': 'https://projectId.firebaseapp.com/__/auth/handler', + } + } + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_get_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.get_saml_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid SAML provider ID') + + def test_get(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.get_saml_provider_config('saml.provider', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + + @pytest.mark.parametrize('invalid_opts', [ + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, + {'idp_entity_id': None}, {'idp_entity_id': ''}, + {'sso_url': None}, {'sso_url': ''}, {'sso_url': 'not a url'}, + {'x509_certificates': None}, {'x509_certificates': []}, {'x509_certificates': 'cert'}, + {'x509_certificates': [None]}, {'x509_certificates': ['foo', {}]}, + {'rp_entity_id': None}, {'rp_entity_id': ''}, + {'callback_url': None}, {'callback_url': ''}, {'callback_url': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_create_invalid_args(self, user_mgt_app, invalid_opts): + options = dict(self.VALID_CREATE_OPTIONS) + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.create_saml_provider_config(**options, app=user_mgt_app) + + def test_create(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.create_saml_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == self.SAML_CONFIG_REQUEST + + def test_create_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + del options['display_name'] + del options['enabled'] + want = dict(self.SAML_CONFIG_REQUEST) + del want['displayName'] + del want['enabled'] + + provider_config = auth.create_saml_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + def test_create_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + options = dict(self.VALID_CREATE_OPTIONS) + options['display_name'] = '' + options['enabled'] = False + want = dict(self.SAML_CONFIG_REQUEST) + want['displayName'] = '' + want['enabled'] = False + + provider_config = auth.create_saml_provider_config(**options, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/inboundSamlConfigs?inboundSamlConfigId=saml.provider'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == want + + @pytest.mark.parametrize('invalid_opts', [ + {}, + {'provider_id': None}, {'provider_id': ''}, {'provider_id': 'oidc.provider'}, + {'idp_entity_id': ''}, + {'sso_url': ''}, {'sso_url': 'not a url'}, + {'x509_certificates': []}, {'x509_certificates': 'cert'}, + {'x509_certificates': [None]}, {'x509_certificates': ['foo', {}]}, + {'rp_entity_id': ''}, + {'callback_url': ''}, {'callback_url': 'not a url'}, + {'display_name': True}, + {'enabled': 'true'}, + ]) + def test_update_invalid_args(self, user_mgt_app, invalid_opts): + options = {'provider_id': 'saml.provider'} + options.update(invalid_opts) + with pytest.raises(ValueError): + auth.update_saml_provider_config(**options, app=user_mgt_app) + + def test_update(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_saml_provider_config( + **self.VALID_CREATE_OPTIONS, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = [ + 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', + 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', + ] + assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == self.SAML_CONFIG_REQUEST + + def test_update_minimal(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_saml_provider_config( + 'saml.provider', display_name='samlProviderName', app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask=displayName'.format( + USER_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == {'displayName': 'samlProviderName'} + + def test_update_empty_values(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = auth.update_saml_provider_config( + 'saml.provider', display_name=auth.DELETE_ATTRIBUTE, enabled=False, app=user_mgt_app) + + self._assert_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + mask = ['displayName', 'enabled'] + assert req.url == '{0}/inboundSamlConfigs/saml.provider?updateMask={1}'.format( + USER_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == {'displayName': None, 'enabled': False} + + @pytest.mark.parametrize('provider_id', INVALID_PROVIDER_IDS + ['oidc.provider']) + def test_delete_invalid_provider_id(self, user_mgt_app, provider_id): + with pytest.raises(ValueError) as excinfo: + auth.delete_saml_provider_config(provider_id, app=user_mgt_app) + + assert str(excinfo.value).startswith('Invalid SAML provider ID') + + def test_delete(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, '{}') + + auth.delete_saml_provider_config('saml.provider', app=user_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs/saml.provider') + + def test_config_not_found(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, CONFIG_NOT_FOUND_RESPONSE) + + with pytest.raises(auth.ConfigurationNotFoundError) as excinfo: + auth.get_saml_provider_config('saml.provider', app=user_mgt_app) + + error_msg = 'No auth provider found for the given identifier (CONFIGURATION_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + def test_invalid_max_results(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(max_results=arg, app=user_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, 101, False]) + def test_invalid_page_token(self, user_mgt_app, arg): + with pytest.raises(ValueError): + auth.list_saml_provider_configs(page_token=arg, app=user_mgt_app) + + def test_list_single_page(self, user_mgt_app): + recorder = _instrument_provider_mgt(user_mgt_app, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + page = auth.list_saml_provider_configs(app=user_mgt_app) + + self._assert_page(page) + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, '/inboundSamlConfigs?pageSize=100') + + def test_list_multiple_pages(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(max_results=10, app=user_mgt_app) + + self._assert_page(page, next_page_token='token') + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + + self._assert_page(page, count=1, start=2) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=10&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + def test_paged_iteration(self, user_mgt_app): + sample_response = json.loads(SAML_PROVIDER_CONFIG_RESPONSE) + configs = _create_list_response(sample_response) + + # Page 1 + response = { + 'inboundSamlConfigs': configs[:2], + 'nextPageToken': 'token' + } + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + iterator = page.iterate_all() + + for index in range(2): + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider{0}'.format(index) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100'.format(USER_MGT_URL_PREFIX) + + # Page 2 (also the last page) + response = {'inboundSamlConfigs': configs[2:]} + recorder = _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + + provider_config = next(iterator) + assert provider_config.provider_id == 'saml.provider2' + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/inboundSamlConfigs?pageSize=100&pageToken=token'.format( + USER_MGT_URL_PREFIX) + + with pytest.raises(StopIteration): + next(iterator) + + def test_list_empty_response(self, user_mgt_app): + response = {'inboundSamlConfigs': []} + _instrument_provider_mgt(user_mgt_app, 200, json.dumps(response)) + page = auth.list_saml_provider_configs(app=user_mgt_app) + assert len(page.provider_configs) == 0 + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 0 + + def test_list_error(self, user_mgt_app): + _instrument_provider_mgt(user_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + auth.list_saml_provider_configs(app=user_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + + def _assert_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'samlProviderName' + assert provider_config.enabled is True + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == ['CERT1', 'CERT2'] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + def _assert_page(self, page, count=2, start=0, next_page_token=''): + assert isinstance(page, auth.ListProviderConfigsPage) + index = start + assert len(page.provider_configs) == count + for provider_config in page.provider_configs: + self._assert_provider_config(provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + if next_page_token: + assert page.next_page_token == next_page_token + assert page.has_next_page is True + else: + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + + +def _create_list_response(sample_response, count=3): + configs = [] + for idx in range(count): + config = dict(sample_response) + config['name'] += str(idx) + configs.append(config) + return configs diff --git a/tests/test_tenant_mgt.py b/tests/test_tenant_mgt.py new file mode 100644 index 000000000..f92dd2a83 --- /dev/null +++ b/tests/test_tenant_mgt.py @@ -0,0 +1,1004 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin.tenant_mgt module.""" + +import json +from urllib import parse + +import pytest + +import firebase_admin +from firebase_admin import auth +from firebase_admin import credentials +from firebase_admin import exceptions +from firebase_admin import tenant_mgt +from firebase_admin import _auth_providers +from firebase_admin import _user_mgt +from tests import testutils +from tests import test_token_gen + + +GET_TENANT_RESPONSE = """{ + "name": "projects/mock-project-id/tenants/tenant-id", + "displayName": "Test Tenant", + "allowPasswordSignup": true, + "enableEmailLinkSignin": true +}""" + +TENANT_NOT_FOUND_RESPONSE = """{ + "error": { + "message": "TENANT_NOT_FOUND" + } +}""" + +LIST_TENANTS_RESPONSE = """{ + "tenants": [ + { + "name": "projects/mock-project-id/tenants/tenant0", + "displayName": "Test Tenant", + "allowPasswordSignup": true, + "enableEmailLinkSignin": true + }, + { + "name": "projects/mock-project-id/tenants/tenant1", + "displayName": "Test Tenant", + "allowPasswordSignup": true, + "enableEmailLinkSignin": true + } + ] +}""" + +LIST_TENANTS_RESPONSE_WITH_TOKEN = """{ + "tenants": [ + { + "name": "projects/mock-project-id/tenants/tenant0" + }, + { + "name": "projects/mock-project-id/tenants/tenant1" + }, + { + "name": "projects/mock-project-id/tenants/tenant2" + } + ], + "nextPageToken": "token" +}""" + +MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') +MOCK_LIST_USERS_RESPONSE = testutils.resource('list_users.json') + +OIDC_PROVIDER_CONFIG_RESPONSE = testutils.resource('oidc_provider_config.json') +OIDC_PROVIDER_CONFIG_REQUEST = { + 'displayName': 'oidcProviderName', + 'enabled': True, + 'clientId': 'CLIENT_ID', + 'issuer': 'https://oidc.com/issuer', +} + +SAML_PROVIDER_CONFIG_RESPONSE = testutils.resource('saml_provider_config.json') +SAML_PROVIDER_CONFIG_REQUEST = body = { + 'displayName': 'samlProviderName', + 'enabled': True, + 'idpConfig': { + 'idpEntityId': 'IDP_ENTITY_ID', + 'ssoUrl': 'https://example.com/login', + 'idpCertificates': [{'x509Certificate': 'CERT1'}, {'x509Certificate': 'CERT2'}] + }, + 'spConfig': { + 'spEntityId': 'RP_ENTITY_ID', + 'callbackUri': 'https://projectId.firebaseapp.com/__/auth/handler', + } +} + +LIST_OIDC_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_oidc_provider_configs.json') +LIST_SAML_PROVIDER_CONFIGS_RESPONSE = testutils.resource('list_saml_provider_configs.json') + +INVALID_TENANT_IDS = [None, '', 0, 1, True, False, list(), tuple(), dict()] +INVALID_BOOLEANS = ['', 1, 0, list(), tuple(), dict()] + +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' +PROVIDER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' +TENANT_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v2beta1/projects/mock-project-id' + + +@pytest.fixture(scope='module') +def tenant_mgt_app(): + app = firebase_admin.initialize_app( + testutils.MockCredential(), name='tenantMgt', options={'projectId': 'mock-project-id'}) + yield app + firebase_admin.delete_app(app) + + +def _instrument_tenant_mgt(app, status, payload): + service = tenant_mgt._get_tenant_mgt_service(app) + recorder = [] + service.client.session.mount( + tenant_mgt._TenantManagementService.TENANT_MGT_URL, + testutils.MockAdapter(payload, status, recorder)) + return service, recorder + + +def _instrument_user_mgt(client, status, payload): + recorder = [] + user_manager = client._user_manager + user_manager.http_client.session.mount( + _user_mgt.UserManager.ID_TOOLKIT_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +def _instrument_provider_mgt(client, status, payload): + recorder = [] + provider_manager = client._provider_manager + provider_manager.http_client.session.mount( + _auth_providers.ProviderConfigClient.PROVIDER_CONFIG_URL, + testutils.MockAdapter(payload, status, recorder)) + return recorder + + +class TestTenant: + + @pytest.mark.parametrize('data', [None, 'foo', 0, 1, True, False, list(), tuple(), dict()]) + def test_invalid_data(self, data): + with pytest.raises(ValueError): + tenant_mgt.Tenant(data) + + def test_tenant(self): + data = { + 'name': 'projects/test-project/tenants/tenant-id', + 'displayName': 'Test Tenant', + 'allowPasswordSignup': True, + 'enableEmailLinkSignin': True, + } + tenant = tenant_mgt.Tenant(data) + assert tenant.tenant_id == 'tenant-id' + assert tenant.display_name == 'Test Tenant' + assert tenant.allow_password_sign_up is True + assert tenant.enable_email_link_sign_in is True + + def test_tenant_optional_params(self): + data = { + 'name': 'projects/test-project/tenants/tenant-id', + } + tenant = tenant_mgt.Tenant(data) + assert tenant.tenant_id == 'tenant-id' + assert tenant.display_name is None + assert tenant.allow_password_sign_up is False + assert tenant.enable_email_link_sign_in is False + + +class TestGetTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.get_tenant(tenant_id, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid tenant ID') + + def test_get_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.get_tenant('tenant-id', app=tenant_mgt_app) + + _assert_tenant(tenant) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + + def test_tenant_not_found(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + tenant_mgt.get_tenant('tenant-id', app=tenant_mgt_app) + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + +class TestCreateTenant: + + @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + def test_invalid_display_name_type(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant(display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for displayName') + + @pytest.mark.parametrize('display_name', ['', 'foo', '1test', 'foo bar', 'a'*21]) + def test_invalid_display_name_value(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant(display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('displayName must start') + + @pytest.mark.parametrize('allow', INVALID_BOOLEANS) + def test_invalid_allow_password_sign_up(self, allow, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant( + display_name='test', allow_password_sign_up=allow, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for allowPasswordSignup') + + @pytest.mark.parametrize('enable', INVALID_BOOLEANS) + def test_invalid_enable_email_link_sign_in(self, enable, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.create_tenant( + display_name='test', enable_email_link_sign_in=enable, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for enableEmailLinkSignin') + + def test_create_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.create_tenant( + display_name='My-Tenant', allow_password_sign_up=True, enable_email_link_sign_in=True, + app=tenant_mgt_app) + + _assert_tenant(tenant) + self._assert_request(recorder, { + 'displayName': 'My-Tenant', + 'allowPasswordSignup': True, + 'enableEmailLinkSignin': True, + }) + + def test_create_tenant_false_values(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.create_tenant( + display_name='test', allow_password_sign_up=False, enable_email_link_sign_in=False, + app=tenant_mgt_app) + + _assert_tenant(tenant) + self._assert_request(recorder, { + 'displayName': 'test', + 'allowPasswordSignup': False, + 'enableEmailLinkSignin': False, + }) + + def test_create_tenant_minimal(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.create_tenant(display_name='test', app=tenant_mgt_app) + + _assert_tenant(tenant) + self._assert_request(recorder, {'displayName': 'test'}) + + def test_error(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, '{}') + with pytest.raises(exceptions.InternalError) as excinfo: + tenant_mgt.create_tenant(display_name='test', app=tenant_mgt_app) + + error_msg = 'Unexpected error response: {}' + assert excinfo.value.code == exceptions.INTERNAL + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_request(self, recorder, body): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/tenants'.format(TENANT_MGT_URL_PREFIX) + got = json.loads(req.body.decode()) + assert got == body + + +class TestUpdateTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant(tenant_id, display_name='My Tenant', app=tenant_mgt_app) + assert str(excinfo.value).startswith('Tenant ID must be a non-empty string') + + @pytest.mark.parametrize('display_name', [True, False, 1, 0, list(), tuple(), dict()]) + def test_invalid_display_name_type(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for displayName') + + @pytest.mark.parametrize('display_name', ['', 'foo', '1test', 'foo bar', 'a'*21]) + def test_invalid_display_name_value(self, display_name, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', display_name=display_name, app=tenant_mgt_app) + assert str(excinfo.value).startswith('displayName must start') + + @pytest.mark.parametrize('allow', INVALID_BOOLEANS) + def test_invalid_allow_password_sign_up(self, allow, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', allow_password_sign_up=allow, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for allowPasswordSignup') + + @pytest.mark.parametrize('enable', INVALID_BOOLEANS) + def test_invalid_enable_email_link_sign_in(self, enable, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant( + 'tenant-id', enable_email_link_sign_in=enable, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid type for enableEmailLinkSignin') + + def test_update_tenant_no_args(self, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.update_tenant('tenant-id', app=tenant_mgt_app) + assert str(excinfo.value).startswith('At least one parameter must be specified for update') + + def test_update_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.update_tenant( + 'tenant-id', display_name='My-Tenant', allow_password_sign_up=True, + enable_email_link_sign_in=True, app=tenant_mgt_app) + + _assert_tenant(tenant) + body = { + 'displayName': 'My-Tenant', + 'allowPasswordSignup': True, + 'enableEmailLinkSignin': True, + } + mask = ['allowPasswordSignup', 'displayName', 'enableEmailLinkSignin'] + self._assert_request(recorder, body, mask) + + def test_update_tenant_false_values(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.update_tenant( + 'tenant-id', allow_password_sign_up=False, + enable_email_link_sign_in=False, app=tenant_mgt_app) + + _assert_tenant(tenant) + body = { + 'allowPasswordSignup': False, + 'enableEmailLinkSignin': False, + } + mask = ['allowPasswordSignup', 'enableEmailLinkSignin'] + self._assert_request(recorder, body, mask) + + def test_update_tenant_minimal(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, GET_TENANT_RESPONSE) + tenant = tenant_mgt.update_tenant( + 'tenant-id', display_name='My-Tenant', app=tenant_mgt_app) + + _assert_tenant(tenant) + body = {'displayName': 'My-Tenant'} + mask = ['displayName'] + self._assert_request(recorder, body, mask) + + def test_tenant_not_found_error(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + tenant_mgt.update_tenant('tenant', display_name='My-Tenant', app=tenant_mgt_app) + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_request(self, recorder, body, mask): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'PATCH' + assert req.url == '{0}/tenants/tenant-id?updateMask={1}'.format( + TENANT_MGT_URL_PREFIX, ','.join(mask)) + got = json.loads(req.body.decode()) + assert got == body + + +class TestDeleteTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError) as excinfo: + tenant_mgt.delete_tenant(tenant_id, app=tenant_mgt_app) + assert str(excinfo.value).startswith('Invalid tenant ID') + + def test_delete_tenant(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, '{}') + tenant_mgt.delete_tenant('tenant-id', app=tenant_mgt_app) + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id'.format(TENANT_MGT_URL_PREFIX) + + def test_tenant_not_found(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + tenant_mgt.delete_tenant('tenant-id', app=tenant_mgt_app) + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + +class TestListTenants: + + @pytest.mark.parametrize('arg', [None, 'foo', list(), dict(), 0, -1, 101, False]) + def test_invalid_max_results(self, tenant_mgt_app, arg): + with pytest.raises(ValueError): + tenant_mgt.list_tenants(max_results=arg, app=tenant_mgt_app) + + @pytest.mark.parametrize('arg', ['', list(), dict(), 0, -1, True, False]) + def test_invalid_page_token(self, tenant_mgt_app, arg): + with pytest.raises(ValueError): + tenant_mgt.list_tenants(page_token=arg, app=tenant_mgt_app) + + def test_list_single_page(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + self._assert_tenants_page(page) + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + tenants = [tenant for tenant in page.iterate_all()] + assert len(tenants) == 2 + self._assert_request(recorder) + + def test_list_multiple_pages(self, tenant_mgt_app): + # Page 1 + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE_WITH_TOKEN) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + assert len(page.tenants) == 3 + assert page.next_page_token == 'token' + assert page.has_next_page is True + self._assert_request(recorder) + + # Page 2 (also the last page) + response = {'tenants': [{'name': 'projects/mock-project-id/tenants/tenant3'}]} + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) + page = page.get_next_page() + assert len(page.tenants) == 1 + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + self._assert_request(recorder, {'pageSize': '100', 'pageToken': 'token'}) + + def test_list_tenants_paged_iteration(self, tenant_mgt_app): + # Page 1 + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE_WITH_TOKEN) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + iterator = page.iterate_all() + for index in range(3): + tenant = next(iterator) + assert tenant.tenant_id == 'tenant{0}'.format(index) + self._assert_request(recorder) + + # Page 2 (also the last page) + response = {'tenants': [{'name': 'projects/mock-project-id/tenants/tenant3'}]} + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) + tenant = next(iterator) + assert tenant.tenant_id == 'tenant3' + + with pytest.raises(StopIteration): + next(iterator) + self._assert_request(recorder, {'pageSize': '100', 'pageToken': 'token'}) + + def test_list_tenants_iterator_state(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + + # Advance iterator. + iterator = page.iterate_all() + tenant = next(iterator) + assert tenant.tenant_id == 'tenant0' + + # Iterator should resume from where left off. + tenant = next(iterator) + assert tenant.tenant_id == 'tenant1' + + with pytest.raises(StopIteration): + next(iterator) + self._assert_request(recorder) + + def test_list_tenants_stop_iteration(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + iterator = page.iterate_all() + tenants = [tenant for tenant in iterator] + assert len(tenants) == 2 + + with pytest.raises(StopIteration): + next(iterator) + self._assert_request(recorder) + + def test_list_tenants_no_tenants_response(self, tenant_mgt_app): + response = {'tenants': []} + _instrument_tenant_mgt(tenant_mgt_app, 200, json.dumps(response)) + page = tenant_mgt.list_tenants(app=tenant_mgt_app) + assert len(page.tenants) == 0 + tenants = [tenant for tenant in page.iterate_all()] + assert len(tenants) == 0 + + def test_list_tenants_with_max_results(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(max_results=50, app=tenant_mgt_app) + self._assert_tenants_page(page) + self._assert_request(recorder, {'pageSize' : '50'}) + + def test_list_tenants_with_all_args(self, tenant_mgt_app): + _, recorder = _instrument_tenant_mgt(tenant_mgt_app, 200, LIST_TENANTS_RESPONSE) + page = tenant_mgt.list_tenants(page_token='foo', max_results=50, app=tenant_mgt_app) + self._assert_tenants_page(page) + self._assert_request(recorder, {'pageToken' : 'foo', 'pageSize' : '50'}) + + def test_list_tenants_error(self, tenant_mgt_app): + _instrument_tenant_mgt(tenant_mgt_app, 500, '{"error":"test"}') + with pytest.raises(exceptions.InternalError) as excinfo: + tenant_mgt.list_tenants(app=tenant_mgt_app) + assert str(excinfo.value) == 'Unexpected error response: {"error":"test"}' + + def _assert_tenants_page(self, page): + assert isinstance(page, tenant_mgt.ListTenantsPage) + assert len(page.tenants) == 2 + for idx, tenant in enumerate(page.tenants): + _assert_tenant(tenant, 'tenant{0}'.format(idx)) + + def _assert_request(self, recorder, expected=None): + if expected is None: + expected = {'pageSize' : '100'} + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + request = dict(parse.parse_qsl(parse.urlsplit(req.url).query)) + assert request == expected + + +class TestAuthForTenant: + + @pytest.mark.parametrize('tenant_id', INVALID_TENANT_IDS) + def test_invalid_tenant_id(self, tenant_id, tenant_mgt_app): + with pytest.raises(ValueError): + tenant_mgt.auth_for_tenant(tenant_id, app=tenant_mgt_app) + + def test_client(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + assert client.tenant_id == 'tenant1' + + def test_client_reuse(self, tenant_mgt_app): + client1 = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + client2 = tenant_mgt.auth_for_tenant('tenant1', app=tenant_mgt_app) + client3 = tenant_mgt.auth_for_tenant('tenant2', app=tenant_mgt_app) + assert client1 is client2 + assert client1 is not client3 + + +class TestTenantAwareUserManagement: + + def test_get_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user('testuser') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'localId': ['testuser']}) + + def test_get_user_by_email(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user_by_email('testuser@example.com') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'email': ['testuser@example.com']}) + + def test_get_user_by_phone_number(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_GET_USER_RESPONSE) + + user = client.get_user_by_phone_number('+1234567890') + + assert isinstance(user, auth.UserRecord) + assert user.uid == 'testuser' + assert user.email == 'testuser@example.com' + self._assert_request(recorder, '/accounts:lookup', {'phoneNumber': ['+1234567890']}) + + def test_create_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + uid = client._user_manager.create_user() + + assert uid == 'testuser' + self._assert_request(recorder, '/accounts', {}) + + def test_update_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + uid = client._user_manager.update_user('testuser', email='testuser@example.com') + + assert uid == 'testuser' + self._assert_request(recorder, '/accounts:update', { + 'localId': 'testuser', + 'email': 'testuser@example.com', + }) + + def test_delete_user(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"kind":"deleteresponse"}') + + client.delete_user('testuser') + + self._assert_request(recorder, '/accounts:delete', {'localId': 'testuser'}) + + def test_set_custom_user_claims(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + claims = {'admin': True} + + client.set_custom_user_claims('testuser', claims) + + self._assert_request(recorder, '/accounts:update', { + 'localId': 'testuser', + 'customAttributes': json.dumps(claims), + }) + + def test_revoke_refresh_tokens(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"localId":"testuser"}') + + client.revoke_refresh_tokens('testuser') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}/tenants/tenant-id/accounts:update'.format( + USER_MGT_URL_PREFIX) + body = json.loads(req.body.decode()) + assert body['localId'] == 'testuser' + assert 'validSince' in body + + def test_list_users(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, MOCK_LIST_USERS_RESPONSE) + + page = client.list_users() + + assert isinstance(page, auth.ListUsersPage) + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + users = list(user for user in page.iterate_all()) + assert len(users) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/accounts:batchGet?maxResults=1000'.format( + USER_MGT_URL_PREFIX) + + def test_import_users(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{}') + users = [ + auth.ImportUserRecord(uid='user1'), + auth.ImportUserRecord(uid='user2'), + ] + + result = client.import_users(users) + + assert isinstance(result, auth.UserImportResult) + assert result.success_count == 2 + assert result.failure_count == 0 + assert result.errors == [] + self._assert_request(recorder, '/accounts:batchCreate', { + 'users': [{'localId': 'user1'}, {'localId': 'user2'}], + }) + + def test_generate_password_reset_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + + link = client.generate_password_reset_link('test@test.com') + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'PASSWORD_RESET', + 'returnOobLink': True, + }) + + def test_generate_email_verification_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + + link = client.generate_email_verification_link('test@test.com') + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'VERIFY_EMAIL', + 'returnOobLink': True, + }) + + def test_generate_sign_in_with_email_link(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_user_mgt(client, 200, '{"oobLink":"https://testlink"}') + settings = auth.ActionCodeSettings(url='http://localhost') + + link = client.generate_sign_in_with_email_link('test@test.com', settings) + + assert link == 'https://testlink' + self._assert_request(recorder, '/accounts:sendOobCode', { + 'email': 'test@test.com', + 'requestType': 'EMAIL_SIGNIN', + 'returnOobLink': True, + 'continueUrl': 'http://localhost', + }) + + def test_get_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.get_oidc_provider_config('oidc.provider') + + self._assert_oidc_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_create_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.create_oidc_provider_config( + 'oidc.provider', client_id='CLIENT_ID', issuer='https://oidc.com/issuer', + display_name='oidcProviderName', enabled=True) + + self._assert_oidc_provider_config(provider_config) + self._assert_request( + recorder, '/oauthIdpConfigs?oauthIdpConfigId=oidc.provider', + OIDC_PROVIDER_CONFIG_REQUEST, prefix=PROVIDER_MGT_URL_PREFIX) + + def test_update_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, OIDC_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.update_oidc_provider_config( + 'oidc.provider', client_id='CLIENT_ID', issuer='https://oidc.com/issuer', + display_name='oidcProviderName', enabled=True) + + self._assert_oidc_provider_config(provider_config) + mask = ['clientId', 'displayName', 'enabled', 'issuer'] + url = '/oauthIdpConfigs/oidc.provider?updateMask={0}'.format(','.join(mask)) + self._assert_request( + recorder, url, OIDC_PROVIDER_CONFIG_REQUEST, method='PATCH', + prefix=PROVIDER_MGT_URL_PREFIX) + + def test_delete_oidc_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, '{}') + + client.delete_oidc_provider_config('oidc.provider') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id/oauthIdpConfigs/oidc.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_list_oidc_provider_configs(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, LIST_OIDC_PROVIDER_CONFIGS_RESPONSE) + + page = client.list_oidc_provider_configs() + + assert isinstance(page, auth.ListProviderConfigsPage) + index = 0 + assert len(page.provider_configs) == 2 + for provider_config in page.provider_configs: + self._assert_oidc_provider_config( + provider_config, want_id='oidc.provider{0}'.format(index)) + index += 1 + + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format( + PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/oauthIdpConfigs?pageSize=100') + + def test_get_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.get_saml_provider_config('saml.provider') + + self._assert_saml_provider_config(provider_config) + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_create_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.create_saml_provider_config( + 'saml.provider', idp_entity_id='IDP_ENTITY_ID', sso_url='https://example.com/login', + x509_certificates=['CERT1', 'CERT2'], rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='samlProviderName', enabled=True) + + self._assert_saml_provider_config(provider_config) + self._assert_request( + recorder, '/inboundSamlConfigs?inboundSamlConfigId=saml.provider', + SAML_PROVIDER_CONFIG_REQUEST, prefix=PROVIDER_MGT_URL_PREFIX) + + def test_update_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, SAML_PROVIDER_CONFIG_RESPONSE) + + provider_config = client.update_saml_provider_config( + 'saml.provider', idp_entity_id='IDP_ENTITY_ID', sso_url='https://example.com/login', + x509_certificates=['CERT1', 'CERT2'], rp_entity_id='RP_ENTITY_ID', + callback_url='https://projectId.firebaseapp.com/__/auth/handler', + display_name='samlProviderName', enabled=True) + + self._assert_saml_provider_config(provider_config) + mask = [ + 'displayName', 'enabled', 'idpConfig.idpCertificates', 'idpConfig.idpEntityId', + 'idpConfig.ssoUrl', 'spConfig.callbackUri', 'spConfig.spEntityId', + ] + url = '/inboundSamlConfigs/saml.provider?updateMask={0}'.format(','.join(mask)) + self._assert_request( + recorder, url, SAML_PROVIDER_CONFIG_REQUEST, method='PATCH', + prefix=PROVIDER_MGT_URL_PREFIX) + + def test_delete_saml_provider_config(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, '{}') + + client.delete_saml_provider_config('saml.provider') + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'DELETE' + assert req.url == '{0}/tenants/tenant-id/inboundSamlConfigs/saml.provider'.format( + PROVIDER_MGT_URL_PREFIX) + + def test_list_saml_provider_configs(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + recorder = _instrument_provider_mgt(client, 200, LIST_SAML_PROVIDER_CONFIGS_RESPONSE) + + page = client.list_saml_provider_configs() + + assert isinstance(page, auth.ListProviderConfigsPage) + index = 0 + assert len(page.provider_configs) == 2 + for provider_config in page.provider_configs: + self._assert_saml_provider_config( + provider_config, want_id='saml.provider{0}'.format(index)) + index += 1 + + assert page.next_page_token == '' + assert page.has_next_page is False + assert page.get_next_page() is None + provider_configs = list(config for config in page.iterate_all()) + assert len(provider_configs) == 2 + + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'GET' + assert req.url == '{0}{1}'.format( + PROVIDER_MGT_URL_PREFIX, '/tenants/tenant-id/inboundSamlConfigs?pageSize=100') + + def test_tenant_not_found(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('tenant-id', app=tenant_mgt_app) + _instrument_user_mgt(client, 500, TENANT_NOT_FOUND_RESPONSE) + with pytest.raises(tenant_mgt.TenantNotFoundError) as excinfo: + client.get_user('testuser') + + error_msg = 'No tenant found for the given identifier (TENANT_NOT_FOUND).' + assert excinfo.value.code == exceptions.NOT_FOUND + assert str(excinfo.value) == error_msg + assert excinfo.value.http_response is not None + assert excinfo.value.cause is not None + + def _assert_request( + self, recorder, want_url, want_body, method='POST', prefix=USER_MGT_URL_PREFIX): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == method + assert req.url == '{0}/tenants/tenant-id{1}'.format(prefix, want_url) + body = json.loads(req.body.decode()) + assert body == want_body + + def _assert_oidc_provider_config(self, provider_config, want_id='oidc.provider'): + assert isinstance(provider_config, auth.OIDCProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'oidcProviderName' + assert provider_config.enabled is True + assert provider_config.client_id == 'CLIENT_ID' + assert provider_config.issuer == 'https://oidc.com/issuer' + + def _assert_saml_provider_config(self, provider_config, want_id='saml.provider'): + assert isinstance(provider_config, auth.SAMLProviderConfig) + assert provider_config.provider_id == want_id + assert provider_config.display_name == 'samlProviderName' + assert provider_config.enabled is True + assert provider_config.idp_entity_id == 'IDP_ENTITY_ID' + assert provider_config.sso_url == 'https://example.com/login' + assert provider_config.x509_certificates == ['CERT1', 'CERT2'] + assert provider_config.rp_entity_id == 'RP_ENTITY_ID' + assert provider_config.callback_url == 'https://projectId.firebaseapp.com/__/auth/handler' + + +class TestVerifyIdToken: + + def test_valid_token(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_mgt_app) + client._token_verifier.request = test_token_gen.MOCK_REQUEST + + claims = client.verify_id_token(test_token_gen.TEST_ID_TOKEN_WITH_TENANT) + + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + assert claims['firebase']['tenant'] == 'test-tenant' + + def test_invalid_tenant_id(self, tenant_mgt_app): + client = tenant_mgt.auth_for_tenant('other-tenant', app=tenant_mgt_app) + client._token_verifier.request = test_token_gen.MOCK_REQUEST + + with pytest.raises(tenant_mgt.TenantIdMismatchError) as excinfo: + client.verify_id_token(test_token_gen.TEST_ID_TOKEN_WITH_TENANT) + + assert 'Invalid tenant ID: test-tenant' in str(excinfo.value) + assert isinstance(excinfo.value, exceptions.InvalidArgumentError) + assert excinfo.value.cause is None + assert excinfo.value.http_response is None + + +@pytest.fixture(scope='module') +def tenant_aware_custom_token_app(): + cred = credentials.Certificate(testutils.resource_filename('service_account.json')) + app = firebase_admin.initialize_app(cred, name='tenantAwareCustomToken') + yield app + firebase_admin.delete_app(app) + + +class TestCreateCustomToken: + + def test_custom_token(self, tenant_aware_custom_token_app): + client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_aware_custom_token_app) + + custom_token = client.create_custom_token('user1') + + test_token_gen.verify_custom_token( + custom_token, expected_claims=None, tenant_id='test-tenant') + + def test_custom_token_with_claims(self, tenant_aware_custom_token_app): + client = tenant_mgt.auth_for_tenant('test-tenant', app=tenant_aware_custom_token_app) + claims = {'admin': True} + + custom_token = client.create_custom_token('user1', claims) + + test_token_gen.verify_custom_token( + custom_token, expected_claims=claims, tenant_id='test-tenant') + + +def _assert_tenant(tenant, tenant_id='tenant-id'): + assert isinstance(tenant, tenant_mgt.Tenant) + assert tenant.tenant_id == tenant_id + assert tenant.display_name == 'Test Tenant' + assert tenant.allow_password_sign_up is True + assert tenant.enable_email_link_sign_in is True diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 439c1ba6e..f88c87ff4 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -66,7 +66,7 @@ def _merge_jwt_claims(defaults, overrides): del defaults[key] return defaults -def _verify_custom_token(custom_token, expected_claims): +def verify_custom_token(custom_token, expected_claims, tenant_id=None): assert isinstance(custom_token, bytes) token = google.oauth2.id_token.verify_token( custom_token, @@ -75,6 +75,11 @@ def _verify_custom_token(custom_token, expected_claims): assert token['uid'] == MOCK_UID assert token['iss'] == MOCK_SERVICE_ACCOUNT_EMAIL assert token['sub'] == MOCK_SERVICE_ACCOUNT_EMAIL + if tenant_id is None: + assert 'tenant_id' not in token + else: + assert token['tenant_id'] == tenant_id + header = jwt.decode_header(custom_token) assert header.get('typ') == 'JWT' assert header.get('alg') == 'RS256' @@ -94,6 +99,9 @@ def _get_id_token(payload_overrides=None, header_overrides=None): 'exp': int(time.time()) + 3600, 'sub': '1234567890', 'admin': True, + 'firebase': { + 'sign_in_provider': 'provider', + }, } if header_overrides: headers = _merge_jwt_claims(headers, header_overrides) @@ -109,21 +117,21 @@ def _get_session_cookie(payload_overrides=None, header_overrides=None): return _get_id_token(payload_overrides, header_overrides) def _instrument_user_manager(app, status, payload): - auth_service = auth._get_auth_service(app) - user_manager = auth_service.user_manager + client = auth._get_client(app) + user_manager = client._user_manager recorder = [] - user_manager._client.session.mount( - auth._AuthService.ID_TOOLKIT_URL, + user_manager.http_client.session.mount( + _token_gen.TokenGenerator.ID_TOOLKIT_URL, testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder def _overwrite_cert_request(app, request): - auth_service = auth._get_auth_service(app) - auth_service.token_verifier.request = request + client = auth._get_client(app) + client._token_verifier.request = request def _overwrite_iam_request(app, request): - auth_service = auth._get_auth_service(app) - auth_service.token_generator.request = request + client = auth._get_client(app) + client._token_generator.request = request @pytest.fixture(scope='module') def auth_app(): @@ -195,7 +203,7 @@ class TestCreateCustomToken: def test_valid_params(self, auth_app, values): user, claims = values custom_token = auth.create_custom_token(user, claims, app=auth_app) - _verify_custom_token(custom_token, claims) + verify_custom_token(custom_token, claims) @pytest.mark.parametrize('values', invalid_args.values(), ids=list(invalid_args)) def test_invalid_params(self, auth_app, values): @@ -245,8 +253,9 @@ def test_sign_with_discovered_service_account(self): try: _overwrite_iam_request(app, request) # Force initialization of the signing provider. This will invoke the Metadata service. - auth_service = auth._get_auth_service(app) - assert auth_service.token_generator.signing_provider is not None + client = auth._get_client(app) + assert client._token_generator.signing_provider is not None + # Now invoke the IAM signer. signature = base64.b64encode(b'test').decode() request.response = testutils.MockResponse( @@ -346,6 +355,11 @@ def test_unexpected_response(self, user_mgt_app): MOCK_GET_USER_RESPONSE = testutils.resource('get_user.json') TEST_ID_TOKEN = _get_id_token() +TEST_ID_TOKEN_WITH_TENANT = _get_id_token({ + 'firebase': { + 'tenant': 'test-tenant', + } +}) TEST_SESSION_COOKIE = _get_session_cookie() @@ -380,6 +394,14 @@ def test_valid_token(self, user_mgt_app, id_token): claims = auth.verify_id_token(id_token, app=user_mgt_app) assert claims['admin'] is True assert claims['uid'] == claims['sub'] + assert claims['firebase']['sign_in_provider'] == 'provider' + + def test_valid_token_with_tenant(self, user_mgt_app): + _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) + claims = auth.verify_id_token(TEST_ID_TOKEN_WITH_TENANT, app=user_mgt_app) + assert claims['admin'] is True + assert claims['uid'] == claims['sub'] + assert claims['firebase']['tenant'] == 'test-tenant' @pytest.mark.parametrize('id_token', valid_tokens.values(), ids=list(valid_tokens)) def test_valid_token_check_revoked(self, user_mgt_app, id_token): diff --git a/tests/test_user_mgt.py b/tests/test_user_mgt.py index 958bbf9c4..c7b2de496 100644 --- a/tests/test_user_mgt.py +++ b/tests/test_user_mgt.py @@ -50,6 +50,9 @@ } MOCK_ACTION_CODE_SETTINGS = auth.ActionCodeSettings(**MOCK_ACTION_CODE_DATA) +USER_MGT_URL_PREFIX = 'https://identitytoolkit.googleapis.com/v1/projects/mock-project-id' + + @pytest.fixture(scope='module') def user_mgt_app(): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt', @@ -58,11 +61,11 @@ def user_mgt_app(): firebase_admin.delete_app(app) def _instrument_user_manager(app, status, payload): - auth_service = auth._get_auth_service(app) - user_manager = auth_service.user_manager + client = auth._get_client(app) + user_manager = client._user_manager recorder = [] - user_manager._client.session.mount( - auth._AuthService.ID_TOOLKIT_URL, + user_manager.http_client.session.mount( + _user_mgt.UserManager.ID_TOOLKIT_URL, testutils.MockAdapter(payload, status, recorder)) return user_manager, recorder @@ -78,6 +81,7 @@ def _check_user_record(user, expected_uid='testuser'): assert user.user_metadata.creation_timestamp == 1234567890000 assert user.user_metadata.last_sign_in_timestamp is None assert user.provider_id == 'firebase' + assert user.tenant_id is None claims = user.custom_claims assert claims['admin'] is True @@ -101,17 +105,27 @@ def _check_user_record(user, expected_uid='testuser'): assert provider.provider_id == 'phone' +def _check_request(recorder, want_url, want_body=None): + assert len(recorder) == 1 + req = recorder[0] + assert req.method == 'POST' + assert req.url == '{0}{1}'.format(USER_MGT_URL_PREFIX, want_url) + if want_body: + body = json.loads(req.body.decode()) + assert body == want_body + + class TestAuthServiceInitialization: def test_default_timeout(self, user_mgt_app): - auth_service = auth._get_auth_service(user_mgt_app) - user_manager = auth_service.user_manager - assert user_manager._client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS + client = auth._get_client(user_mgt_app) + user_manager = client._user_manager + assert user_manager.http_client.timeout == _http_client.DEFAULT_TIMEOUT_SECONDS def test_fail_on_no_project_id(self): app = firebase_admin.initialize_app(testutils.MockCredential(), name='userMgt2') with pytest.raises(ValueError): - auth._get_auth_service(app) + auth._get_client(app) firebase_admin.delete_app(app) @@ -194,6 +208,10 @@ def test_no_tokens_valid_after_time(self): user = auth.UserRecord({'localId' : 'user'}) assert user.tokens_valid_after_timestamp == 0 + def test_tenant_id(self): + user = auth.UserRecord({'localId' : 'user', 'tenantId': 'test-tenant'}) + assert user.tenant_id == 'test-tenant' + class TestGetUser: @@ -203,8 +221,9 @@ def test_invalid_get_user(self, arg, user_mgt_app): auth.get_user(arg, app=user_mgt_app) def test_get_user(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user('testuser', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'localId': ['testuser']}) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-an-email']) def test_invalid_get_user_by_email(self, arg, user_mgt_app): @@ -212,8 +231,9 @@ def test_invalid_get_user_by_email(self, arg, user_mgt_app): auth.get_user_by_email(arg, app=user_mgt_app) def test_get_user_by_email(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user_by_email('testuser@example.com', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'email': ['testuser@example.com']}) @pytest.mark.parametrize('arg', INVALID_STRINGS + ['not-a-phone']) def test_invalid_get_user_by_phone(self, arg, user_mgt_app): @@ -221,8 +241,9 @@ def test_invalid_get_user_by_phone(self, arg, user_mgt_app): auth.get_user_by_phone_number(arg, app=user_mgt_app) def test_get_user_by_phone(self, user_mgt_app): - _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) + _, recorder = _instrument_user_manager(user_mgt_app, 200, MOCK_GET_USER_RESPONSE) _check_user_record(auth.get_user_by_phone_number('+1234567890', user_mgt_app)) + _check_request(recorder, '/accounts:lookup', {'phoneNumber': ['+1234567890']}) def test_get_user_non_existing(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 200, '{"users":[]}') @@ -1050,7 +1071,7 @@ def test_import_users(self, user_mgt_app): assert result.failure_count == 0 assert result.errors == [] expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}]} - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_error(self, user_mgt_app): _, recorder = _instrument_user_manager(user_mgt_app, 200, """{"error": [ @@ -1073,7 +1094,7 @@ def test_import_users_error(self, user_mgt_app): assert err.index == 2 assert err.reason == 'Another error occured in user3' expected = {'users': [{'localId': 'user1'}, {'localId': 'user2'}, {'localId': 'user3'}]} - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_missing_required_hash(self, user_mgt_app): users = [ @@ -1106,7 +1127,7 @@ def test_import_users_with_hash(self, user_mgt_app): 'memoryCost': 14, 'saltSeparator': _user_import.b64_encode(b'sep'), } - self._check_rpc_calls(recorder, expected) + _check_request(recorder, '/accounts:batchCreate', expected) def test_import_users_http_error(self, user_mgt_app): _instrument_user_manager(user_mgt_app, 401, '{"error": {"message": "ERROR_CODE"}}') @@ -1127,11 +1148,6 @@ def test_import_users_unexpected_response(self, user_mgt_app): with pytest.raises(auth.UnexpectedResponseError): auth.import_users(users, app=user_mgt_app) - def _check_rpc_calls(self, recorder, expected): - assert len(recorder) == 1 - request = json.loads(recorder[0].body.decode()) - assert request == expected - class TestRevokeRefreshTokkens: @@ -1301,8 +1317,8 @@ def test_bad_settings_data(self, user_mgt_app, func): def test_bad_action_type(self, user_mgt_app): with pytest.raises(ValueError): - auth._get_auth_service(user_mgt_app) \ - .user_manager \ + auth._get_client(user_mgt_app) \ + ._user_manager \ .generate_email_action_link('BAD_TYPE', 'test@test.com', action_code_settings=MOCK_ACTION_CODE_SETTINGS)