From 31e8186a48b94861d2bbc5a94887b586ee1aad0c Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Tue, 8 Dec 2020 17:36:23 -0800 Subject: [PATCH 1/2] fix(rtdb): Support parsing non-US RTDB instance URLs --- firebase_admin/db.py | 81 ++++++++++++++++++-------------------------- tests/test_db.py | 66 +++++++++++++++++++++++------------- 2 files changed, 75 insertions(+), 72 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index be2b9c917..e60704f5c 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -768,10 +768,10 @@ def __init__(self, app): self._credential = app.credential db_url = app.options.get('databaseURL') if db_url: - _DatabaseService._parse_db_url(db_url) # Just for validation. self._db_url = db_url else: self._db_url = None + auth_override = _DatabaseService._get_auth_override(app) if auth_override not in (self._DEFAULT_AUTH_OVERRIDE, {}): self._auth_override = json.dumps(auth_override, separators=(',', ':')) @@ -795,15 +795,27 @@ def get_client(self, db_url=None): if db_url is None: db_url = self._db_url - base_url, namespace = _DatabaseService._parse_db_url(db_url, self._emulator_host) - if base_url == 'https://{0}.firebaseio.com'.format(namespace): - # Production base_url. No need to specify namespace in query params. - params = {} - credential = self._credential.get_credential() - else: - # Emulator base_url. Use fake credentials and specify ?ns=foo in query params. + if not db_url or not isinstance(db_url, str): + raise ValueError( + 'Invalid database URL: "{0}". Database URL must be a non-empty ' + 'URL string.'.format(db_url)) + + parsed_url = parse.urlparse(db_url) + if not parsed_url.netloc: + raise ValueError( + 'Invalid database URL: "{0}". Database URL must be a wellformed ' + 'URL string.'.format(db_url)) + + base_url = 'https://{0}'.format(parsed_url.netloc) + params = {} + credential = self._credential.get_credential() + + emulator_config = self._get_emulator_config(parsed_url) + if emulator_config: + base_url = emulator_config.base_url + params['ns'] = emulator_config.namespace credential = _EmulatorAdminCredentials() - params = {'ns': namespace} + if self._auth_override: params['auth_variable_override'] = self._auth_override @@ -813,47 +825,20 @@ def get_client(self, db_url=None): self._clients[client_cache_key] = client return self._clients[client_cache_key] - @classmethod - def _parse_db_url(cls, url, emulator_host=None): - """Parses (base_url, namespace) from a database URL. - - The input can be either a production URL (https://foo-bar.firebaseio.com/) - or an Emulator URL (http://localhost:8080/?ns=foo-bar). In case of Emulator - URL, the namespace is extracted from the query param ns. The resulting - base_url never includes query params. - - If url is a production URL and emulator_host is specified, the result - base URL will use emulator_host instead. emulator_host is ignored - if url is already an emulator URL. - """ - if not url or not isinstance(url, str): - raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a non-empty ' - 'URL string.'.format(url)) - parsed_url = parse.urlparse(url) - if parsed_url.netloc.endswith('.firebaseio.com'): - return cls._parse_production_url(parsed_url, emulator_host) - - return cls._parse_emulator_url(parsed_url) - - @classmethod - def _parse_production_url(cls, parsed_url, emulator_host): - """Parses production URL like https://foo-bar.firebaseio.com/""" + def _get_emulator_config(self, parsed_url): + """Checks whether the SDK should connect to the RTDB emulator.""" + EmulatorConfig = collections.namedtuple('EmulatorConfig', ['base_url', 'namespace']) if parsed_url.scheme != 'https': - raise ValueError( - 'Invalid database URL scheme: "{0}". Database URL must be an HTTPS URL.'.format( - parsed_url.scheme)) - namespace = parsed_url.netloc.split('.')[0] - if not namespace: - raise ValueError( - 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' - 'Firebase Realtime Database instance.'.format(parsed_url.geturl())) + # Emulator mode enabled by passing http URL via AppOptions + base_url, namespace = _DatabaseService._parse_emulator_url(parsed_url) + return EmulatorConfig(base_url, namespace) + if self._emulator_host: + # Emulator mode enabled via environment variable + base_url = 'http://{0}'.format(self._emulator_host) + namespace = parsed_url.netloc.split('.')[0] + return EmulatorConfig(base_url, namespace) - if emulator_host: - base_url = 'http://{0}'.format(emulator_host) - else: - base_url = 'https://{0}'.format(parsed_url.netloc) - return base_url, namespace + return None @classmethod def _parse_emulator_url(cls, parsed_url): diff --git a/tests/test_db.py b/tests/test_db.py index 2989fc030..2ba5b9b29 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -15,6 +15,7 @@ """Tests for firebase_admin.db.""" import collections import json +import os import sys import time @@ -28,6 +29,9 @@ from tests import testutils +_EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' + + class MockAdapter(testutils.MockAdapter): """A mock HTTP adapter that mimics RTDB server behavior.""" @@ -702,52 +706,66 @@ def test_no_db_url(self): 'url,emulator_host,expected_base_url,expected_namespace', [ # Production URLs with no override: - ('https://test.firebaseio.com', None, 'https://test.firebaseio.com', 'test'), - ('https://test.firebaseio.com/', None, 'https://test.firebaseio.com', 'test'), + ('https://test.firebaseio.com', None, 'https://test.firebaseio.com', None), + ('https://test.firebaseio.com/', None, 'https://test.firebaseio.com', None), # Production URLs with emulator_host override: ('https://test.firebaseio.com', 'localhost:9000', 'http://localhost:9000', 'test'), ('https://test.firebaseio.com/', 'localhost:9000', 'http://localhost:9000', 'test'), - # Emulator URLs with no override. + # Emulator URL with no override. ('http://localhost:8000/?ns=test', None, 'http://localhost:8000', 'test'), + # emulator_host is ignored when the original URL is already emulator. ('http://localhost:8000/?ns=test', 'localhost:9999', 'http://localhost:8000', 'test'), ] ) def test_parse_db_url(self, url, emulator_host, expected_base_url, expected_namespace): - base_url, namespace = db._DatabaseService._parse_db_url(url, emulator_host) - assert base_url == expected_base_url - assert namespace == expected_namespace - - @pytest.mark.parametrize('url,emulator_host', [ - ('', None), - (None, None), - (42, None), - ('test.firebaseio.com', None), # Not a URL. - ('http://test.firebaseio.com', None), # Use of non-HTTPs in production URLs. - ('ftp://test.firebaseio.com', None), # Use of non-HTTPs in production URLs. - ('https://example.com', None), # Invalid RTDB URL. - ('http://localhost:9000/', None), # No ns specified. - ('http://localhost:9000/?ns=', None), # No ns specified. - ('http://localhost:9000/?ns=test1&ns=test2', None), # Two ns parameters specified. - ('ftp://localhost:9000/?ns=test', None), # Neither HTTP or HTTPS. + if emulator_host: + os.environ[_EMULATOR_HOST_ENV_VAR] = emulator_host + + try: + firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) + ref = db.reference() + assert ref._client._base_url == expected_base_url + assert ref._client.params.get('ns') == expected_namespace + finally: + if _EMULATOR_HOST_ENV_VAR in os.environ: + del os.environ[_EMULATOR_HOST_ENV_VAR] + + @pytest.mark.parametrize('url', [ + '', + None, + 42, + 'test.firebaseio.com', # Not a URL. + 'http://test.firebaseio.com', # Use of non-HTTPs in production URLs. + 'ftp://test.firebaseio.com', # Use of non-HTTPs in production URLs. + 'http://localhost:9000/', # No ns specified. + 'http://localhost:9000/?ns=', # No ns specified. + 'http://localhost:9000/?ns=test1&ns=test2', # Two ns parameters specified. + 'ftp://localhost:9000/?ns=test', # Neither HTTP or HTTPS. ]) - def test_parse_db_url_errors(self, url, emulator_host): + def test_parse_db_url_errors(self, url): + firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) with pytest.raises(ValueError): - db._DatabaseService._parse_db_url(url, emulator_host) + db.reference() @pytest.mark.parametrize('url', [ - 'https://test.firebaseio.com', 'https://test.firebaseio.com/' + 'https://test.firebaseio.com', 'https://test.firebaseio.com/', + 'https://test.eu-west1.firebasdatabase.app', 'https://test.eu-west1.firebasdatabase.app/' ]) def test_valid_db_url(self, url): firebase_admin.initialize_app(testutils.MockCredential(), {'databaseURL' : url}) ref = db.reference() - assert ref._client.base_url == 'https://test.firebaseio.com' + expected_url = url + if url.endswith('/'): + expected_url = url[:-1] + assert ref._client.base_url == expected_url assert 'auth_variable_override' not in ref._client.params + assert 'ns' not in ref._client.params @pytest.mark.parametrize('url', [ - None, '', 'foo', 'http://test.firebaseio.com', 'https://google.com', + None, '', 'foo', 'http://test.firebaseio.com', 'http://test.firebasedatabase.app', True, False, 1, 0, dict(), list(), tuple(), _Object() ]) def test_invalid_db_url(self, url): From a8259acc2fce0c316c323319d4bf57c833b20990 Mon Sep 17 00:00:00 2001 From: Hiranya Jayathilaka Date: Wed, 9 Dec 2020 13:04:42 -0800 Subject: [PATCH 2/2] fix: Deferred credential loading until emulator URL is determined --- firebase_admin/db.py | 14 ++++++++------ tests/test_db.py | 4 ++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/firebase_admin/db.py b/firebase_admin/db.py index e60704f5c..3384bd440 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -806,15 +806,17 @@ def get_client(self, db_url=None): 'Invalid database URL: "{0}". Database URL must be a wellformed ' 'URL string.'.format(db_url)) - base_url = 'https://{0}'.format(parsed_url.netloc) - params = {} - credential = self._credential.get_credential() - emulator_config = self._get_emulator_config(parsed_url) if emulator_config: - base_url = emulator_config.base_url - params['ns'] = emulator_config.namespace credential = _EmulatorAdminCredentials() + base_url = emulator_config.base_url + params = {'ns': emulator_config.namespace} + else: + # Defer credential lookup until we are certain it's going to be prod connection. + credential = self._credential.get_credential() + base_url = 'https://{0}'.format(parsed_url.netloc) + params = {} + if self._auth_override: params['auth_variable_override'] = self._auth_override diff --git a/tests/test_db.py b/tests/test_db.py index 2ba5b9b29..5f8ba4b51 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -729,6 +729,10 @@ def test_parse_db_url(self, url, emulator_host, expected_base_url, expected_name ref = db.reference() assert ref._client._base_url == expected_base_url assert ref._client.params.get('ns') == expected_namespace + if expected_base_url.startswith('http://localhost'): + assert isinstance(ref._client.credential, db._EmulatorAdminCredentials) + else: + assert isinstance(ref._client.credential, testutils.MockGoogleCredential) finally: if _EMULATOR_HOST_ENV_VAR in os.environ: del os.environ[_EMULATOR_HOST_ENV_VAR]