diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index 788875048..e4e223091 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -372,7 +372,8 @@ def batch_callback(_, response, error): send_response = SendResponse(response, exception) responses.append(send_response) - batch = http.BatchHttpRequest(batch_callback, _MessagingService.FCM_BATCH_URL) + batch = http.BatchHttpRequest( + callback=batch_callback, batch_uri=_MessagingService.FCM_BATCH_URL) for message in messages: body = json.dumps(self._message_data(message, dry_run)) req = http.HttpRequest( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index f2ef47cf8..6e776cc5f 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -17,7 +17,8 @@ import json import numbers -from googleapiclient.http import HttpMockSequence +from googleapiclient import http +from googleapiclient import _helpers import pytest import firebase_admin @@ -1810,7 +1811,7 @@ def _instrument_batch_messaging_service(self, app=None, status=200, payload=''): content_type = 'multipart/mixed; boundary=boundary' else: content_type = 'application/json' - fcm_service._transport = HttpMockSequence([ + fcm_service._transport = http.HttpMockSequence([ ({'status': str(status), 'content-type': content_type}, payload), ]) return fcm_service @@ -1867,6 +1868,20 @@ def test_send_all(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) + def test_send_all_with_positional_param_enforcement(self): + payload = json.dumps({'name': 'message-id'}) + _ = self._instrument_batch_messaging_service( + payload=self._batch_payload([(200, payload), (200, payload)])) + msg = messaging.Message(topic='foo') + + enforcement = _helpers.positional_parameters_enforcement + _helpers.positional_parameters_enforcement = _helpers.POSITIONAL_EXCEPTION + try: + batch_response = messaging.send_all([msg, msg], dry_run=True) + assert batch_response.success_count == 2 + finally: + _helpers.positional_parameters_enforcement = enforcement + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_all_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'})