From 3e26073ff129428f2f0a33010bb053f41eddd364 Mon Sep 17 00:00:00 2001 From: Seth Michael Larson Date: Mon, 12 Jul 2021 12:47:49 -0500 Subject: [PATCH] Don't swallow unexpected errors during product check --- elasticsearch/_async/transport.py | 137 +++++++++--------- elasticsearch/transport.py | 131 ++++++++--------- .../test_async/test_transport.py | 55 ++++++- test_elasticsearch/test_transport.py | 61 +++++++- 4 files changed, 245 insertions(+), 139 deletions(-) diff --git a/elasticsearch/_async/transport.py b/elasticsearch/_async/transport.py index 3b969b1ea..e5ed4b876 100644 --- a/elasticsearch/_async/transport.py +++ b/elasticsearch/_async/transport.py @@ -120,7 +120,7 @@ async def _async_init(self): # Set our 'verified_once' implementation to one that # works with 'asyncio' instead of 'threading' - self._verified_once = Once() + self._verify_elasticsearch_lock = asyncio.Lock() # Now that we have a loop we can create all our HTTP connections... self.set_connections(self.hosts) @@ -338,9 +338,7 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non # Before we make the actual API call we verify the Elasticsearch instance. if self._verified_elasticsearch is None: - await self._verified_once.call( - self._do_verify_elasticsearch, headers=headers, timeout=timeout - ) + await self._do_verify_elasticsearch(headers=headers, timeout=timeout) # If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch. if self._verified_elasticsearch is False: @@ -431,74 +429,73 @@ async def _do_verify_elasticsearch(self, headers, timeout): but we're also unable to rule it out due to a permission error we instead emit an 'ElasticsearchWarning'. """ - # Product check has already been done, no need to do again. - if self._verified_elasticsearch: - return - - headers = {header.lower(): value for header, value in (headers or {}).items()} - # We know we definitely want JSON so request it via 'accept' - headers.setdefault("accept", "application/json") + # Ensure that there's only one async exec within this section + # at a time to not emit unnecessary index API calls. + async with self._verify_elasticsearch_lock: - info_headers = {} - info_response = {} - error = None - - for conn in chain(self.connection_pool.connections, self.seed_connections): - try: - _, info_headers, info_response = await conn.perform_request( - "GET", "/", headers=headers, timeout=timeout - ) - - # Lowercase all the header names for consistency in accessing them. - info_headers = { - header.lower(): value for header, value in info_headers.items() - } - - info_response = self.deserializer.loads( - info_response, mimetype="application/json" - ) - break - - # Previous versions of 7.x Elasticsearch required a specific - # permission so if we receive HTTP 401/403 we should warn - # instead of erroring out. - except (AuthenticationException, AuthorizationException): - warnings.warn( - ( - "The client is unable to verify that the server is " - "Elasticsearch due security privileges on the server side" - ), - ElasticsearchWarning, - stacklevel=4, - ) - self._verified_elasticsearch = True + # Product check has already been completed while we were + # waiting our turn, no need to do again. + if self._verified_elasticsearch is not None: return - # This connection didn't work, we'll try another. - except (ConnectionError, SerializationError) as err: - if error is None: - error = err - - # If we received a connection error and weren't successful - # anywhere then we reraise the more appropriate error. - if error and not info_response: - raise error - - # Check the information we got back from the index request. - self._verified_elasticsearch = _verify_elasticsearch( - info_headers, info_response - ) - - -class Once: - """Simple class which forces an async function to only execute once.""" + headers = { + header.lower(): value for header, value in (headers or {}).items() + } + # We know we definitely want JSON so request it via 'accept' + headers.setdefault("accept", "application/json") + + info_headers = {} + info_response = {} + error = None + + attempted_conns = [] + for conn in chain(self.connection_pool.connections, self.seed_connections): + # Only attempt once per connection max. + if conn in attempted_conns: + continue + attempted_conns.append(conn) + + try: + _, info_headers, info_response = await conn.perform_request( + "GET", "/", headers=headers, timeout=timeout + ) - def __init__(self): - self._lock = asyncio.Lock() - self._called = False + # Lowercase all the header names for consistency in accessing them. + info_headers = { + header.lower(): value for header, value in info_headers.items() + } - async def call(self, func, *args, **kwargs): - async with self._lock: - if not self._called: - self._called = True - await func(*args, **kwargs) + info_response = self.deserializer.loads( + info_response, mimetype="application/json" + ) + break + + # Previous versions of 7.x Elasticsearch required a specific + # permission so if we receive HTTP 401/403 we should warn + # instead of erroring out. + except (AuthenticationException, AuthorizationException): + warnings.warn( + ( + "The client is unable to verify that the server is " + "Elasticsearch due security privileges on the server side" + ), + ElasticsearchWarning, + stacklevel=4, + ) + self._verified_elasticsearch = True + return + + # This connection didn't work, we'll try another. + except (ConnectionError, SerializationError, TransportError) as err: + if error is None: + error = err + + # If we received a connection error and weren't successful + # anywhere then we re-raise the more appropriate error. + if error and not info_response: + raise error + + # Check the information we got back from the index request. + self._verified_elasticsearch = _verify_elasticsearch( + info_headers, info_response + ) diff --git a/elasticsearch/transport.py b/elasticsearch/transport.py index 6e5fb0124..0ee68106a 100644 --- a/elasticsearch/transport.py +++ b/elasticsearch/transport.py @@ -220,7 +220,7 @@ def __init__( # Ensures that the ES verification request only fires once and that # all requests block until this request returns back. - self._verified_once = Once() + self._verify_elasticsearch_lock = Lock() def add_connection(self, host): """ @@ -406,9 +406,7 @@ def perform_request(self, method, url, headers=None, params=None, body=None): # Before we make the actual API call we verify the Elasticsearch instance. if self._verified_elasticsearch is None: - self._verified_once.call( - self._do_verify_elasticsearch, headers=headers, timeout=timeout - ) + self._do_verify_elasticsearch(headers=headers, timeout=timeout) # If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch. if self._verified_elasticsearch is False: @@ -536,63 +534,76 @@ def _do_verify_elasticsearch(self, headers, timeout): but we're also unable to rule it out due to a permission error we instead emit an 'ElasticsearchWarning'. """ - # Product check has already been done, no need to do again. - if self._verified_elasticsearch is not None: - return + # Ensure that there's only one thread within this section + # at a time to not emit unnecessary index API calls. + with self._verify_elasticsearch_lock: - headers = {header.lower(): value for header, value in (headers or {}).items()} - # We know we definitely want JSON so request it via 'accept' - headers.setdefault("accept", "application/json") + # Product check has already been completed while we were + # waiting our turn, no need to do again. + if self._verified_elasticsearch is not None: + return - info_headers = {} - info_response = {} - error = None + headers = { + header.lower(): value for header, value in (headers or {}).items() + } + # We know we definitely want JSON so request it via 'accept' + headers.setdefault("accept", "application/json") - for conn in chain(self.connection_pool.connections, self.seed_connections): - try: - _, info_headers, info_response = conn.perform_request( - "GET", "/", headers=headers, timeout=timeout - ) + info_headers = {} + info_response = {} + error = None - # Lowercase all the header names for consistency in accessing them. - info_headers = { - header.lower(): value for header, value in info_headers.items() - } + attempted_conns = [] + for conn in chain(self.connection_pool.connections, self.seed_connections): + # Only attempt once per connection max. + if conn in attempted_conns: + continue + attempted_conns.append(conn) - info_response = self.deserializer.loads( - info_response, mimetype="application/json" - ) - break - - # Previous versions of 7.x Elasticsearch required a specific - # permission so if we receive HTTP 401/403 we should warn - # instead of erroring out. - except (AuthenticationException, AuthorizationException): - warnings.warn( - ( - "The client is unable to verify that the server is " - "Elasticsearch due security privileges on the server side" - ), - ElasticsearchWarning, - stacklevel=5, - ) - self._verified_elasticsearch = True - return + try: + _, info_headers, info_response = conn.perform_request( + "GET", "/", headers=headers, timeout=timeout + ) - # This connection didn't work, we'll try another. - except (ConnectionError, SerializationError) as err: - if error is None: - error = err + # Lowercase all the header names for consistency in accessing them. + info_headers = { + header.lower(): value for header, value in info_headers.items() + } - # If we received a connection error and weren't successful - # anywhere then we reraise the more appropriate error. - if error and not info_response: - raise error + info_response = self.deserializer.loads( + info_response, mimetype="application/json" + ) + break - # Check the information we got back from the index request. - self._verified_elasticsearch = _verify_elasticsearch( - info_headers, info_response - ) + # Previous versions of 7.x Elasticsearch required a specific + # permission so if we receive HTTP 401/403 we should warn + # instead of erroring out. + except (AuthenticationException, AuthorizationException): + warnings.warn( + ( + "The client is unable to verify that the server is " + "Elasticsearch due security privileges on the server side" + ), + ElasticsearchWarning, + stacklevel=5, + ) + self._verified_elasticsearch = True + return + + # This connection didn't work, we'll try another. + except (ConnectionError, SerializationError, TransportError) as err: + if error is None: + error = err + + # If we received a connection error and weren't successful + # anywhere then we re-raise the more appropriate error. + if error and not info_response: + raise error + + # Check the information we got back from the index request. + self._verified_elasticsearch = _verify_elasticsearch( + info_headers, info_response + ) def _verify_elasticsearch(headers, response): @@ -640,17 +651,3 @@ def _verify_elasticsearch(headers, response): return False return True - - -class Once: - """Simple class which forces a function to only execute once.""" - - def __init__(self): - self._lock = Lock() - self._called = False - - def call(self, func, *args, **kwargs): - with self._lock: - if not self._called: - self._called = True - func(*args, **kwargs) diff --git a/test_elasticsearch/test_async/test_transport.py b/test_elasticsearch/test_async/test_transport.py index 30db9401b..fb93fefa1 100644 --- a/test_elasticsearch/test_async/test_transport.py +++ b/test_elasticsearch/test_async/test_transport.py @@ -34,6 +34,7 @@ ConnectionError, ElasticsearchWarning, NotElasticsearchError, + NotFoundError, TransportError, ) @@ -770,7 +771,9 @@ async def request_task(): # The rest of the requests are 'GET /_search' afterwards assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:]) - async def test_multiple_requests_verify_elasticsearch_errors(self, event_loop): + async def test_multiple_requests_verify_elasticsearch_product_error( + self, event_loop + ): t = AsyncTransport( [ { @@ -823,3 +826,53 @@ async def request_task(): # The rest of the requests are 'GET /_search' afterwards assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:]) + + @pytest.mark.parametrize("error_cls", [ConnectionError, NotFoundError]) + async def test_multiple_requests_verify_elasticsearch_retry_on_errors( + self, event_loop, error_cls + ): + t = AsyncTransport( + [ + { + "exception": error_cls(), + "delay": 0.1, + } + ], + connection_class=DummyConnection, + ) + + results = [] + completed_at = [] + + async def request_task(): + try: + results.append(await t.perform_request("GET", "/_search")) + except Exception as e: + results.append(e) + completed_at.append(event_loop.time()) + + # Execute a bunch of requests concurrently. + tasks = [] + start_time = event_loop.time() + for _ in range(5): + tasks.append(event_loop.create_task(request_task())) + await asyncio.gather(*tasks) + end_time = event_loop.time() + + # Exactly 5 results completed + assert len(results) == 5 + + # All results were errors and not wrapped in 'NotElasticsearchError' + assert all(isinstance(result, error_cls) for result in results) + + # Assert that 5 requests were made in total (5 transport requests per x 0.1s/conn request) + duration = end_time - start_time + assert 0.5 <= duration <= 0.6 + + # Assert that the cluster is still in the unknown/unverified stage. + assert t._verified_elasticsearch is None + + # See that the API isn't hit, instead it's the index requests that are failing. + calls = t.connection_pool.connections[0].calls + assert len(calls) == 5 + assert all(call[0] == ("GET", "/") for call in calls) diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index 17c308a81..150a6ebab 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -32,6 +32,7 @@ ConnectionError, ElasticsearchWarning, NotElasticsearchError, + NotFoundError, TransportError, ) from elasticsearch.transport import Transport @@ -748,7 +749,7 @@ def run(self): assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:]) -def test_multiple_requests_verify_elasticsearch_errors(): +def test_multiple_requests_verify_elasticsearch_product_error(): try: import threading except ImportError: @@ -810,3 +811,61 @@ def run(self): # The rest of the requests are 'GET /_search' afterwards assert all(call[0][:2] == ("GET", "/_search") for call in calls[1:]) + + +@pytest.mark.parametrize("error_cls", [ConnectionError, NotFoundError]) +def test_multiple_requests_verify_elasticsearch_retry_on_errors(error_cls): + try: + import threading + except ImportError: + return pytest.skip("Requires the 'threading' module") + + t = Transport( + [ + { + "exception": error_cls(), + "delay": 0.1, + } + ], + connection_class=DummyConnection, + ) + + results = [] + completed_at = [] + + class RequestThread(threading.Thread): + def run(self): + try: + results.append(t.perform_request("GET", "/_search")) + except Exception as e: + results.append(e) + completed_at.append(time.time()) + + # Execute a bunch of requests concurrently. + threads = [] + start_time = time.time() + for _ in range(5): + thread = RequestThread() + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + end_time = time.time() + + # Exactly 5 results completed + assert len(results) == 5 + + # All results were errors and not wrapped in 'NotElasticsearchError' + assert all(isinstance(result, error_cls) for result in results) + + # Assert that 5 requests were made in total (5 transport requests per x 0.1s/conn request) + duration = end_time - start_time + assert 0.5 <= duration <= 0.6 + + # Assert that the cluster is still in the unknown/unverified stage. + assert t._verified_elasticsearch is None + + # See that the API isn't hit, instead it's the index requests that are failing. + calls = t.connection_pool.connections[0].calls + assert len(calls) == 5 + assert all(call[0] == ("GET", "/") for call in calls)