diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 548bcfc37..95fc03e04 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -330,9 +330,9 @@ def __init__(self, app): 'X-FIREBASE-CLIENT': 'fire-admin-python/{0}'.format(firebase_admin.__version__), } timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) - self._client = _http_client.JsonHttpClient( - credential=app.credential.get_credential(), timeout=timeout) - self._transport = _auth.authorized_http(app.credential.get_credential()) + self._credential = app.credential.get_credential() + self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) + self._build_transport = _auth.authorized_http @classmethod def encode_message(cls, message): @@ -373,10 +373,11 @@ def batch_callback(_, response, error): batch = http.BatchHttpRequest( callback=batch_callback, batch_uri=_MessagingService.FCM_BATCH_URL) + transport = self._build_transport(self._credential) for message in messages: body = json.dumps(self._message_data(message, dry_run)) req = http.HttpRequest( - http=self._transport, + http=transport, postproc=self._postproc, uri=self._fcm_url, method='POST', diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 8eb24c0a9..3d8740cc1 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -1813,20 +1813,23 @@ def teardown_class(cls): testutils.cleanup_apps() def _instrument_batch_messaging_service(self, app=None, status=200, payload='', exc=None): - if not app: - app = firebase_admin.get_app() + def build_mock_transport(_): + if exc: + return _HttpMockException(exc) - fcm_service = messaging._get_messaging_service(app) - if exc: - fcm_service._transport = _HttpMockException(exc) - else: if status == 200: content_type = 'multipart/mixed; boundary=boundary' else: content_type = 'application/json' - fcm_service._transport = http.HttpMockSequence([ + return http.HttpMockSequence([ ({'status': str(status), 'content-type': content_type}, payload), ]) + + if not app: + app = firebase_admin.get_app() + + fcm_service = messaging._get_messaging_service(app) + fcm_service._build_transport = build_mock_transport return fcm_service def _batch_payload(self, payloads): @@ -2053,6 +2056,29 @@ def test_send_all_runtime_exception(self): assert excinfo.value.cause is exc assert excinfo.value.http_response is None + def test_send_transport_init(self): + def track_call_count(build_transport): + def wrapper(credential): + wrapper.calls += 1 + return build_transport(credential) + wrapper.calls = 0 + return wrapper + + payload = json.dumps({'name': 'message-id'}) + fcm_service = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, payload), (200, payload)])) + build_mock_transport = fcm_service._build_transport + fcm_service._build_transport = track_call_count(build_mock_transport) + msg = messaging.Message(topic='foo') + + batch_response = messaging.send_all([msg, msg], dry_run=True) + assert batch_response.success_count == 2 + assert fcm_service._build_transport.calls == 1 + + batch_response = messaging.send_all([msg, msg], dry_run=True) + assert batch_response.success_count == 2 + assert fcm_service._build_transport.calls == 2 + class TestSendMulticast(TestBatch):