diff --git a/docs/conf.py b/docs/conf.py index 38b0b26d4..edd8be947 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,34 +16,34 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.doctest"] autoclass_content = "both" # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Elasticsearch' -copyright = u'2013, Honza Král' +project = u"Elasticsearch" +copyright = u"2013, Honza Král" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -51,6 +51,7 @@ # import elasticsearch + # The short X.Y version. version = elasticsearch.__versionstr__ # The full version, including alpha/beta/rc tags. @@ -58,40 +59,40 @@ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output --------------------------------------------------- @@ -99,11 +100,12 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if not on_rtd: # only import and set the theme if we're building docs locally import sphinx_rtd_theme - html_theme = 'sphinx_rtd_theme' + + html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme @@ -113,116 +115,119 @@ # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'Elasticsearchdoc' +htmlhelp_basename = "Elasticsearchdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'Elasticsearch.tex', u'Elasticsearch Documentation', - u'Honza Král', 'manual'), + ( + "index", + "Elasticsearch.tex", + u"Elasticsearch Documentation", + u"Honza Král", + "manual", + ) ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- @@ -230,12 +235,11 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'elasticsearch-py', u'Elasticsearch Documentation', - [u'Honza Král'], 1) + ("index", "elasticsearch-py", u"Elasticsearch Documentation", [u"Honza Král"], 1) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ @@ -244,19 +248,25 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Elasticsearch', u'Elasticsearch Documentation', - u'Honza Král', 'Elasticsearch', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "Elasticsearch", + u"Elasticsearch Documentation", + u"Honza Král", + "Elasticsearch", + "One line description of project.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/elasticsearch/client/cat.py b/elasticsearch/client/cat.py index 717f99026..f20d8af40 100644 --- a/elasticsearch/client/cat.py +++ b/elasticsearch/client/cat.py @@ -1,7 +1,8 @@ from .utils import NamespacedClient, query_params, _make_path, SKIP_IN_PATH + class CatClient(NamespacedClient): - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "v") def aliases(self, name=None, params=None): """ @@ -19,11 +20,13 @@ def aliases(self, name=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'aliases', name), params=params) + return self.transport.perform_request( + "GET", _make_path("_cat", "aliases", name), params=params + ) - @query_params('bytes', 'size', 'format', 'h', 'help', 'local', 'master_timeout', - 's', 'v') + @query_params( + "bytes", "size", "format", "h", "help", "local", "master_timeout", "s", "v" + ) def allocation(self, node_id=None, params=None): """ Allocation provides a snapshot of how shards have located around the @@ -45,10 +48,11 @@ def allocation(self, node_id=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'allocation', node_id), params=params) + return self.transport.perform_request( + "GET", _make_path("_cat", "allocation", node_id), params=params + ) - @query_params('size', 'format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("size", "format", "h", "help", "local", "master_timeout", "s", "v") def count(self, index=None, params=None): """ Count provides quick access to the document count of the entire cluster, @@ -68,11 +72,11 @@ def count(self, index=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', 'count', - index), params=params) + return self.transport.perform_request( + "GET", _make_path("_cat", "count", index), params=params + ) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'ts', - 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "ts", "v") def health(self, params=None): """ health is a terse, one-line representation of the same information from @@ -91,10 +95,9 @@ def health(self, params=None): :arg ts: Set to false to disable timestamping, default True :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/health', - params=params) + return self.transport.perform_request("GET", "/_cat/health", params=params) - @query_params('help', 's') + @query_params("help", "s") def help(self, params=None): """ A simple help for the cat api. @@ -104,10 +107,22 @@ def help(self, params=None): :arg s: Comma-separated list of column names or column aliases to sort by """ - return self.transport.perform_request('GET', '/_cat', params=params) - - @query_params('bytes', 'time', 'size', 'format', 'h', 'health', 'help', 'local', - 'master_timeout', 'pri', 's', 'v') + return self.transport.perform_request("GET", "/_cat", params=params) + + @query_params( + "bytes", + "time", + "size", + "format", + "h", + "health", + "help", + "local", + "master_timeout", + "pri", + "s", + "v", + ) def indices(self, index=None, params=None): """ The indices command provides a cross-section of each index. @@ -133,10 +148,11 @@ def indices(self, index=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'indices', index), params=params) + return self.transport.perform_request( + "GET", _make_path("_cat", "indices", index), params=params + ) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "v") def master(self, params=None): """ Displays the master's node ID, bound IP address, and node name. @@ -153,11 +169,9 @@ def master(self, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/master', - params=params) + return self.transport.perform_request("GET", "/_cat/master", params=params) - @query_params('format', 'full_id', 'h', 'help', 'local', 'master_timeout', - 's', 'v') + @query_params("format", "full_id", "h", "help", "local", "master_timeout", "s", "v") def nodes(self, params=None): """ The nodes command shows the cluster topology. @@ -176,10 +190,11 @@ def nodes(self, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/nodes', - params=params) + return self.transport.perform_request("GET", "/_cat/nodes", params=params) - @query_params('bytes', 'time', 'size', 'format', 'h', 'help', 'master_timeout', 's', 'v') + @query_params( + "bytes", "time", "size", "format", "h", "help", "master_timeout", "s", "v" + ) def recovery(self, index=None, params=None): """ recovery is a view of shard replication. @@ -198,10 +213,22 @@ def recovery(self, index=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'recovery', index), params=params) - - @query_params('bytes', 'time', 'size', 'format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + return self.transport.perform_request( + "GET", _make_path("_cat", "recovery", index), params=params + ) + + @query_params( + "bytes", + "time", + "size", + "format", + "h", + "help", + "local", + "master_timeout", + "s", + "v", + ) def shards(self, index=None, params=None): """ The shards command is the detailed view of what nodes contain which shards. @@ -222,10 +249,11 @@ def shards(self, index=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'shards', index), params=params) + return self.transport.perform_request( + "GET", _make_path("_cat", "shards", index), params=params + ) - @query_params('bytes', 'size', 'format', 'h', 'help', 's', 'v') + @query_params("bytes", "size", "format", "h", "help", "s", "v") def segments(self, index=None, params=None): """ The segments command is the detailed view of Lucene segments per index. @@ -242,10 +270,11 @@ def segments(self, index=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'segments', index), params=params) + return self.transport.perform_request( + "GET", _make_path("_cat", "segments", index), params=params + ) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "v") def pending_tasks(self, params=None): """ pending_tasks provides the same information as the @@ -264,11 +293,11 @@ def pending_tasks(self, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/pending_tasks', - params=params) + return self.transport.perform_request( + "GET", "/_cat/pending_tasks", params=params + ) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'size', - 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "size", "v") def thread_pool(self, thread_pool_patterns=None, params=None): """ Get information about thread pools. @@ -289,11 +318,13 @@ def thread_pool(self, thread_pool_patterns=None, params=None): '', 'k', 'm', 'g', 't', 'p' :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'thread_pool', thread_pool_patterns), params=params) + return self.transport.perform_request( + "GET", + _make_path("_cat", "thread_pool", thread_pool_patterns), + params=params, + ) - @query_params('bytes', 'format', 'h', 'help', 'local', 'master_timeout', - 's', 'v') + @query_params("bytes", "format", "h", "help", "local", "master_timeout", "s", "v") def fielddata(self, fields=None, params=None): """ Shows information about currently loaded fielddata on a per-node basis. @@ -314,10 +345,11 @@ def fielddata(self, fields=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'fielddata', fields), params=params) + return self.transport.perform_request( + "GET", _make_path("_cat", "fielddata", fields), params=params + ) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "v") def plugins(self, params=None): """ @@ -334,10 +366,9 @@ def plugins(self, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/plugins', - params=params) + return self.transport.perform_request("GET", "/_cat/plugins", params=params) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "v") def nodeattrs(self, params=None): """ @@ -354,10 +385,9 @@ def nodeattrs(self, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/nodeattrs', - params=params) + return self.transport.perform_request("GET", "/_cat/nodeattrs", params=params) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "v") def repositories(self, params=None): """ @@ -374,11 +404,13 @@ def repositories(self, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/repositories', - params=params) + return self.transport.perform_request( + "GET", "/_cat/repositories", params=params + ) - @query_params('format', 'h', 'help', 'ignore_unavailable', 'master_timeout', - 's', 'v') + @query_params( + "format", "h", "help", "ignore_unavailable", "master_timeout", "s", "v" + ) def snapshots(self, repository, params=None): """ @@ -399,11 +431,21 @@ def snapshots(self, repository, params=None): """ if repository in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument 'repository'.") - return self.transport.perform_request('GET', _make_path('_cat', - 'snapshots', repository), params=params) - - @query_params('actions', 'detailed', 'format', 'h', 'help', 'nodes', - 'parent_task_id', 's', 'v') + return self.transport.perform_request( + "GET", _make_path("_cat", "snapshots", repository), params=params + ) + + @query_params( + "actions", + "detailed", + "format", + "h", + "help", + "nodes", + "parent_task_id", + "s", + "v", + ) def tasks(self, params=None): """ @@ -425,10 +467,9 @@ def tasks(self, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', '/_cat/tasks', - params=params) + return self.transport.perform_request("GET", "/_cat/tasks", params=params) - @query_params('format', 'h', 'help', 'local', 'master_timeout', 's', 'v') + @query_params("format", "h", "help", "local", "master_timeout", "s", "v") def templates(self, name=None, params=None): """ ``_ @@ -445,6 +486,6 @@ def templates(self, name=None, params=None): by :arg v: Verbose mode. Display column headers, default False """ - return self.transport.perform_request('GET', _make_path('_cat', - 'templates', name), params=params) - + return self.transport.perform_request( + "GET", _make_path("_cat", "templates", name), params=params + ) diff --git a/elasticsearch/client/cluster.py b/elasticsearch/client/cluster.py index d1da9ea9b..be694e3e8 100644 --- a/elasticsearch/client/cluster.py +++ b/elasticsearch/client/cluster.py @@ -1,10 +1,19 @@ from .utils import NamespacedClient, query_params, _make_path + class ClusterClient(NamespacedClient): - @query_params('level', 'local', 'master_timeout', 'timeout', - 'wait_for_active_shards', 'wait_for_events', - 'wait_for_no_relocating_shards', 'wait_for_nodes', - 'wait_for_status', 'wait_for_no_initializing_shards') + @query_params( + "level", + "local", + "master_timeout", + "timeout", + "wait_for_active_shards", + "wait_for_events", + "wait_for_no_relocating_shards", + "wait_for_nodes", + "wait_for_status", + "wait_for_no_initializing_shards", + ) def health(self, index=None, params=None): """ Get a very simple status on the health of the cluster. @@ -30,10 +39,11 @@ def health(self, index=None, params=None): :arg wait_for_status: Wait until cluster is in a specific state, default None, valid choices are: 'green', 'yellow', 'red' """ - return self.transport.perform_request('GET', _make_path('_cluster', - 'health', index), params=params) + return self.transport.perform_request( + "GET", _make_path("_cluster", "health", index), params=params + ) - @query_params('local', 'master_timeout') + @query_params("local", "master_timeout") def pending_tasks(self, params=None): """ The pending cluster tasks API returns a list of any cluster-level @@ -45,11 +55,18 @@ def pending_tasks(self, params=None): master node (default: false) :arg master_timeout: Specify timeout for connection to master """ - return self.transport.perform_request('GET', - '/_cluster/pending_tasks', params=params) - - @query_params('allow_no_indices', 'expand_wildcards', 'flat_settings', - 'ignore_unavailable', 'local', 'master_timeout') + return self.transport.perform_request( + "GET", "/_cluster/pending_tasks", params=params + ) + + @query_params( + "allow_no_indices", + "expand_wildcards", + "flat_settings", + "ignore_unavailable", + "local", + "master_timeout", + ) def state(self, metric=None, index=None, params=None): """ Get a comprehensive state information of the whole cluster. @@ -72,11 +89,12 @@ def state(self, metric=None, index=None, params=None): :arg master_timeout: Specify timeout for connection to master """ if index and not metric: - metric = '_all' - return self.transport.perform_request('GET', _make_path('_cluster', - 'state', metric, index), params=params) + metric = "_all" + return self.transport.perform_request( + "GET", _make_path("_cluster", "state", metric, index), params=params + ) - @query_params('flat_settings', 'timeout') + @query_params("flat_settings", "timeout") def stats(self, node_id=None, params=None): """ The Cluster Stats API allows to retrieve statistics from a cluster wide @@ -91,13 +109,14 @@ def stats(self, node_id=None, params=None): :arg flat_settings: Return settings in flat format (default: false) :arg timeout: Explicit operation timeout """ - url = '/_cluster/stats' + url = "/_cluster/stats" if node_id: - url = _make_path('_cluster/stats/nodes', node_id) - return self.transport.perform_request('GET', url, params=params) + url = _make_path("_cluster/stats/nodes", node_id) + return self.transport.perform_request("GET", url, params=params) - @query_params('dry_run', 'explain', 'master_timeout', 'metric', - 'retry_failed', 'timeout') + @query_params( + "dry_run", "explain", "master_timeout", "metric", "retry_failed", "timeout" + ) def reroute(self, body=None, params=None): """ Explicitly execute a cluster reroute allocation command including specific commands. @@ -117,11 +136,11 @@ def reroute(self, body=None, params=None): too many subsequent allocation failures :arg timeout: Explicit operation timeout """ - return self.transport.perform_request('POST', '/_cluster/reroute', - params=params, body=body) + return self.transport.perform_request( + "POST", "/_cluster/reroute", params=params, body=body + ) - @query_params('flat_settings', 'include_defaults', 'master_timeout', - 'timeout') + @query_params("flat_settings", "include_defaults", "master_timeout", "timeout") def get_settings(self, params=None): """ Get cluster settings. @@ -134,10 +153,11 @@ def get_settings(self, params=None): node :arg timeout: Explicit operation timeout """ - return self.transport.perform_request('GET', '/_cluster/settings', - params=params) + return self.transport.perform_request( + "GET", "/_cluster/settings", params=params + ) - @query_params('flat_settings', 'master_timeout', 'timeout') + @query_params("flat_settings", "master_timeout", "timeout") def put_settings(self, body=None, params=None): """ Update cluster wide specific settings. @@ -150,10 +170,11 @@ def put_settings(self, body=None, params=None): node :arg timeout: Explicit operation timeout """ - return self.transport.perform_request('PUT', '/_cluster/settings', - params=params, body=body) + return self.transport.perform_request( + "PUT", "/_cluster/settings", params=params, body=body + ) - @query_params('include_disk_info', 'include_yes_decisions') + @query_params("include_disk_info", "include_yes_decisions") def allocation_explain(self, body=None, params=None): """ ``_ @@ -165,6 +186,6 @@ def allocation_explain(self, body=None, params=None): :arg include_yes_decisions: Return 'YES' decisions in explanation (default: false) """ - return self.transport.perform_request('GET', - '/_cluster/allocation/explain', params=params, body=body) - + return self.transport.perform_request( + "GET", "/_cluster/allocation/explain", params=params, body=body + ) diff --git a/elasticsearch/client/indices.py b/elasticsearch/client/indices.py index 8d33fd38d..d371dde4d 100644 --- a/elasticsearch/client/indices.py +++ b/elasticsearch/client/indices.py @@ -79,7 +79,10 @@ def flush(self, index=None, params=None): ) @query_params( - "master_timeout", "request_timeout", "wait_for_active_shards", "include_type_name" + "master_timeout", + "request_timeout", + "wait_for_active_shards", + "include_type_name", ) def create(self, index, body=None, params=None): """ diff --git a/elasticsearch/client/ingest.py b/elasticsearch/client/ingest.py index 4a2def47e..c3144e6ed 100644 --- a/elasticsearch/client/ingest.py +++ b/elasticsearch/client/ingest.py @@ -1,7 +1,8 @@ from .utils import NamespacedClient, query_params, _make_path, SKIP_IN_PATH + class IngestClient(NamespacedClient): - @query_params('master_timeout') + @query_params("master_timeout") def get_pipeline(self, id=None, params=None): """ ``_ @@ -10,10 +11,11 @@ def get_pipeline(self, id=None, params=None): :arg master_timeout: Explicit operation timeout for connection to master node """ - return self.transport.perform_request('GET', _make_path('_ingest', - 'pipeline', id), params=params) + return self.transport.perform_request( + "GET", _make_path("_ingest", "pipeline", id), params=params + ) - @query_params('master_timeout', 'timeout') + @query_params("master_timeout", "timeout") def put_pipeline(self, id, body, params=None): """ ``_ @@ -27,10 +29,11 @@ def put_pipeline(self, id, body, params=None): for param in (id, body): if param in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument.") - return self.transport.perform_request('PUT', _make_path('_ingest', - 'pipeline', id), params=params, body=body) + return self.transport.perform_request( + "PUT", _make_path("_ingest", "pipeline", id), params=params, body=body + ) - @query_params('master_timeout', 'timeout') + @query_params("master_timeout", "timeout") def delete_pipeline(self, id, params=None): """ ``_ @@ -42,10 +45,11 @@ def delete_pipeline(self, id, params=None): """ if id in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument 'id'.") - return self.transport.perform_request('DELETE', _make_path('_ingest', - 'pipeline', id), params=params) + return self.transport.perform_request( + "DELETE", _make_path("_ingest", "pipeline", id), params=params + ) - @query_params('verbose') + @query_params("verbose") def simulate(self, body, id=None, params=None): """ ``_ @@ -57,5 +61,9 @@ def simulate(self, body, id=None, params=None): """ if body in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument 'body'.") - return self.transport.perform_request('GET', _make_path('_ingest', - 'pipeline', id, '_simulate'), params=params, body=body) + return self.transport.perform_request( + "GET", + _make_path("_ingest", "pipeline", id, "_simulate"), + params=params, + body=body, + ) diff --git a/elasticsearch/client/remote.py b/elasticsearch/client/remote.py index 9cbb26070..7f859823e 100644 --- a/elasticsearch/client/remote.py +++ b/elasticsearch/client/remote.py @@ -1,11 +1,10 @@ from .utils import NamespacedClient, query_params, _make_path, SKIP_IN_PATH + class RemoteClient(NamespacedClient): @query_params() def info(self, params=None): """ ``_ """ - return self.transport.perform_request('GET', '/_remote/info', - params=params) - + return self.transport.perform_request("GET", "/_remote/info", params=params) diff --git a/elasticsearch/client/snapshot.py b/elasticsearch/client/snapshot.py index a57cebedf..88aed52c3 100644 --- a/elasticsearch/client/snapshot.py +++ b/elasticsearch/client/snapshot.py @@ -1,7 +1,8 @@ from .utils import NamespacedClient, query_params, _make_path, SKIP_IN_PATH + class SnapshotClient(NamespacedClient): - @query_params('master_timeout', 'wait_for_completion') + @query_params("master_timeout", "wait_for_completion") def create(self, repository, snapshot, body=None, params=None): """ Create a snapshot in repository @@ -18,10 +19,14 @@ def create(self, repository, snapshot, body=None, params=None): for param in (repository, snapshot): if param in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument.") - return self.transport.perform_request('PUT', _make_path('_snapshot', - repository, snapshot), params=params, body=body) - - @query_params('master_timeout') + return self.transport.perform_request( + "PUT", + _make_path("_snapshot", repository, snapshot), + params=params, + body=body, + ) + + @query_params("master_timeout") def delete(self, repository, snapshot, params=None): """ Deletes a snapshot from a repository. @@ -35,10 +40,11 @@ def delete(self, repository, snapshot, params=None): for param in (repository, snapshot): if param in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument.") - return self.transport.perform_request('DELETE', - _make_path('_snapshot', repository, snapshot), params=params) + return self.transport.perform_request( + "DELETE", _make_path("_snapshot", repository, snapshot), params=params + ) - @query_params('ignore_unavailable', 'master_timeout', 'verbose') + @query_params("ignore_unavailable", "master_timeout", "verbose") def get(self, repository, snapshot, params=None): """ Retrieve information about a snapshot. @@ -56,10 +62,11 @@ def get(self, repository, snapshot, params=None): for param in (repository, snapshot): if param in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument.") - return self.transport.perform_request('GET', _make_path('_snapshot', - repository, snapshot), params=params) + return self.transport.perform_request( + "GET", _make_path("_snapshot", repository, snapshot), params=params + ) - @query_params('master_timeout', 'timeout') + @query_params("master_timeout", "timeout") def delete_repository(self, repository, params=None): """ Removes a shared file system repository. @@ -72,10 +79,11 @@ def delete_repository(self, repository, params=None): """ if repository in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument 'repository'.") - return self.transport.perform_request('DELETE', - _make_path('_snapshot', repository), params=params) + return self.transport.perform_request( + "DELETE", _make_path("_snapshot", repository), params=params + ) - @query_params('local', 'master_timeout') + @query_params("local", "master_timeout") def get_repository(self, repository=None, params=None): """ Return information about registered repositories. @@ -87,10 +95,11 @@ def get_repository(self, repository=None, params=None): :arg master_timeout: Explicit operation timeout for connection to master node """ - return self.transport.perform_request('GET', _make_path('_snapshot', - repository), params=params) + return self.transport.perform_request( + "GET", _make_path("_snapshot", repository), params=params + ) - @query_params('master_timeout', 'timeout', 'verify') + @query_params("master_timeout", "timeout", "verify") def create_repository(self, repository, body, params=None): """ Registers a shared file system repository. @@ -106,10 +115,11 @@ def create_repository(self, repository, body, params=None): for param in (repository, body): if param in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument.") - return self.transport.perform_request('PUT', _make_path('_snapshot', - repository), params=params, body=body) + return self.transport.perform_request( + "PUT", _make_path("_snapshot", repository), params=params, body=body + ) - @query_params('master_timeout', 'wait_for_completion') + @query_params("master_timeout", "wait_for_completion") def restore(self, repository, snapshot, body=None, params=None): """ Restore a snapshot. @@ -126,10 +136,14 @@ def restore(self, repository, snapshot, body=None, params=None): for param in (repository, snapshot): if param in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument.") - return self.transport.perform_request('POST', _make_path('_snapshot', - repository, snapshot, '_restore'), params=params, body=body) - - @query_params('ignore_unavailable', 'master_timeout') + return self.transport.perform_request( + "POST", + _make_path("_snapshot", repository, snapshot, "_restore"), + params=params, + body=body, + ) + + @query_params("ignore_unavailable", "master_timeout") def status(self, repository=None, snapshot=None, params=None): """ Return information about all currently running snapshots. By specifying @@ -144,10 +158,13 @@ def status(self, repository=None, snapshot=None, params=None): :arg master_timeout: Explicit operation timeout for connection to master node """ - return self.transport.perform_request('GET', _make_path('_snapshot', - repository, snapshot, '_status'), params=params) + return self.transport.perform_request( + "GET", + _make_path("_snapshot", repository, snapshot, "_status"), + params=params, + ) - @query_params('master_timeout', 'timeout') + @query_params("master_timeout", "timeout") def verify_repository(self, repository, params=None): """ Returns a list of nodes where repository was successfully verified or @@ -161,5 +178,6 @@ def verify_repository(self, repository, params=None): """ if repository in SKIP_IN_PATH: raise ValueError("Empty value passed for a required argument 'repository'.") - return self.transport.perform_request('POST', _make_path('_snapshot', - repository, '_verify'), params=params) + return self.transport.perform_request( + "POST", _make_path("_snapshot", repository, "_verify"), params=params + ) diff --git a/elasticsearch/client/tasks.py b/elasticsearch/client/tasks.py index 8056bfdd9..d98d3bbdc 100644 --- a/elasticsearch/client/tasks.py +++ b/elasticsearch/client/tasks.py @@ -1,8 +1,16 @@ from .utils import NamespacedClient, query_params, _make_path, SKIP_IN_PATH + class TasksClient(NamespacedClient): - @query_params('actions', 'detailed', 'group_by', 'nodes', - 'parent_task_id', 'wait_for_completion', 'timeout') + @query_params( + "actions", + "detailed", + "group_by", + "nodes", + "parent_task_id", + "wait_for_completion", + "timeout", + ) def list(self, params=None): """ ``_ @@ -22,9 +30,9 @@ def list(self, params=None): (default: false) :arg timeout: Maximum waiting time for `wait_for_completion` """ - return self.transport.perform_request('GET', '/_tasks', params=params) + return self.transport.perform_request("GET", "/_tasks", params=params) - @query_params('actions', 'nodes', 'parent_task_id') + @query_params("actions", "nodes", "parent_task_id") def cancel(self, task_id=None, params=None): """ @@ -41,10 +49,11 @@ def cancel(self, task_id=None, params=None): :arg parent_task_id: Cancel tasks with specified parent task id (node_id:task_number). Set to -1 to cancel all. """ - return self.transport.perform_request('POST', _make_path('_tasks', - task_id, '_cancel'), params=params) + return self.transport.perform_request( + "POST", _make_path("_tasks", task_id, "_cancel"), params=params + ) - @query_params('wait_for_completion', 'timeout') + @query_params("wait_for_completion", "timeout") def get(self, task_id=None, params=None): """ Retrieve information for a particular task. @@ -55,5 +64,6 @@ def get(self, task_id=None, params=None): (default: false) :arg timeout: Maximum waiting time for `wait_for_completion` """ - return self.transport.perform_request('GET', _make_path('_tasks', - task_id), params=params) + return self.transport.perform_request( + "GET", _make_path("_tasks", task_id), params=params + ) diff --git a/elasticsearch/compat.py b/elasticsearch/compat.py index c3d5fa287..247c5d6e0 100644 --- a/elasticsearch/compat.py +++ b/elasticsearch/compat.py @@ -3,13 +3,14 @@ PY2 = sys.version_info[0] == 2 if PY2: - string_types = basestring, + string_types = (basestring,) from urllib import quote_plus, urlencode, unquote - from urlparse import urlparse + from urlparse import urlparse from itertools import imap as map from Queue import Queue else: string_types = str, bytes from urllib.parse import quote_plus, urlencode, urlparse, unquote + map = map from queue import Queue diff --git a/elasticsearch/connection/__init__.py b/elasticsearch/connection/__init__.py index 6eb68967d..2470bd39b 100644 --- a/elasticsearch/connection/__init__.py +++ b/elasticsearch/connection/__init__.py @@ -1,4 +1,3 @@ from .base import Connection from .http_requests import RequestsHttpConnection from .http_urllib3 import Urllib3HttpConnection, create_ssl_context - diff --git a/elasticsearch/connection/base.py b/elasticsearch/connection/base.py index d4fbe33de..4404ed7f1 100644 --- a/elasticsearch/connection/base.py +++ b/elasticsearch/connection/base.py @@ -1,4 +1,5 @@ import logging + try: import simplejson as json except ImportError: @@ -6,12 +7,12 @@ from ..exceptions import TransportError, HTTP_EXCEPTIONS -logger = logging.getLogger('elasticsearch') +logger = logging.getLogger("elasticsearch") # create the elasticsearch.trace logger, but only set propagate to False if the # logger hasn't already been configured -_tracer_already_configured = 'elasticsearch.trace' in logging.Logger.manager.loggerDict -tracer = logging.getLogger('elasticsearch.trace') +_tracer_already_configured = "elasticsearch.trace" in logging.Logger.manager.loggerDict +tracer = logging.getLogger("elasticsearch.trace") if not _tracer_already_configured: tracer.propagate = False @@ -24,32 +25,43 @@ class Connection(object): Also responsible for logging. """ - def __init__(self, host='localhost', port=9200, use_ssl=False, url_prefix='', timeout=10, **kwargs): + + def __init__( + self, + host="localhost", + port=9200, + use_ssl=False, + url_prefix="", + timeout=10, + **kwargs + ): """ :arg host: hostname of the node (default: localhost) :arg port: port to use (integer, default: 9200) :arg url_prefix: optional url prefix for elasticsearch :arg timeout: default timeout in seconds (float, default: 10) """ - scheme = kwargs.get('scheme', 'http') - if use_ssl or scheme == 'https': - scheme = 'https' + scheme = kwargs.get("scheme", "http") + if use_ssl or scheme == "https": + scheme = "https" use_ssl = True self.use_ssl = use_ssl - self.host = '%s://%s:%s' % (scheme, host, port) + self.host = "%s://%s:%s" % (scheme, host, port) if url_prefix: - url_prefix = '/' + url_prefix.strip('/') + url_prefix = "/" + url_prefix.strip("/") self.url_prefix = url_prefix self.timeout = timeout def __repr__(self): - return '<%s: %s>' % (self.__class__.__name__, self.host) + return "<%s: %s>" % (self.__class__.__name__, self.host) def _pretty_json(self, data): # pretty JSON in tracer curl logs try: - return json.dumps(json.loads(data), sort_keys=True, indent=2, separators=(',', ': ')).replace("'", r'\u0027') + return json.dumps( + json.loads(data), sort_keys=True, indent=2, separators=(",", ": ") + ).replace("'", r"\u0027") except (ValueError, TypeError): # non-json data or a bulk request return data @@ -59,17 +71,28 @@ def _log_trace(self, method, path, body, status_code, response, duration): return # include pretty in trace curls - path = path.replace('?', '?pretty&', 1) if '?' in path else path + '?pretty' + path = path.replace("?", "?pretty&", 1) if "?" in path else path + "?pretty" if self.url_prefix: - path = path.replace(self.url_prefix, '', 1) - tracer.info("curl %s-X%s 'http://localhost:9200%s' -d '%s'", - "-H 'Content-Type: application/json' " if body else '', - method, path, self._pretty_json(body) if body else '') + path = path.replace(self.url_prefix, "", 1) + tracer.info( + "curl %s-X%s 'http://localhost:9200%s' -d '%s'", + "-H 'Content-Type: application/json' " if body else "", + method, + path, + self._pretty_json(body) if body else "", + ) if tracer.isEnabledFor(logging.DEBUG): - tracer.debug('#[%s] (%.3fs)\n#%s', status_code, duration, self._pretty_json(response).replace('\n', '\n#') if response else '') - - def log_request_success(self, method, full_url, path, body, status_code, response, duration): + tracer.debug( + "#[%s] (%.3fs)\n#%s", + status_code, + duration, + self._pretty_json(response).replace("\n", "\n#") if response else "", + ) + + def log_request_success( + self, method, full_url, path, body, status_code, response, duration + ): """ Log a successful API call. """ # TODO: optionally pass in params instead of full_url and do urlencode only when needed @@ -77,43 +100,56 @@ def log_request_success(self, method, full_url, path, body, status_code, respons # TODO: find a better way to avoid (de)encoding the body back and forth if body: try: - body = body.decode('utf-8', 'ignore') + body = body.decode("utf-8", "ignore") except AttributeError: pass logger.info( - '%s %s [status:%s request:%.3fs]', method, full_url, - status_code, duration + "%s %s [status:%s request:%.3fs]", method, full_url, status_code, duration ) - logger.debug('> %s', body) - logger.debug('< %s', response) + logger.debug("> %s", body) + logger.debug("< %s", response) self._log_trace(method, path, body, status_code, response, duration) - def log_request_fail(self, method, full_url, path, body, duration, status_code=None, response=None, exception=None): + def log_request_fail( + self, + method, + full_url, + path, + body, + duration, + status_code=None, + response=None, + exception=None, + ): """ Log an unsuccessful API call. """ # do not log 404s on HEAD requests - if method == 'HEAD' and status_code == 404: + if method == "HEAD" and status_code == 404: return logger.warning( - '%s %s [status:%s request:%.3fs]', method, full_url, - status_code or 'N/A', duration, exc_info=exception is not None + "%s %s [status:%s request:%.3fs]", + method, + full_url, + status_code or "N/A", + duration, + exc_info=exception is not None, ) # body has already been serialized to utf-8, deserialize it for logging # TODO: find a better way to avoid (de)encoding the body back and forth if body: try: - body = body.decode('utf-8', 'ignore') + body = body.decode("utf-8", "ignore") except AttributeError: pass - logger.debug('> %s', body) + logger.debug("> %s", body) self._log_trace(method, path, body, status_code, response, duration) if response is not None: - logger.debug('< %s', response) + logger.debug("< %s", response) def _raise_error(self, status_code, raw_data): """ Locate appropriate exception and raise it. """ @@ -122,12 +158,12 @@ def _raise_error(self, status_code, raw_data): try: if raw_data: additional_info = json.loads(raw_data) - error_message = additional_info.get('error', error_message) - if isinstance(error_message, dict) and 'type' in error_message: - error_message = error_message['type'] + error_message = additional_info.get("error", error_message) + if isinstance(error_message, dict) and "type" in error_message: + error_message = error_message["type"] except (ValueError, TypeError) as err: - logger.warning('Undecodable raw error response from server: %s', err) - - raise HTTP_EXCEPTIONS.get(status_code, TransportError)(status_code, error_message, additional_info) - + logger.warning("Undecodable raw error response from server: %s", err) + raise HTTP_EXCEPTIONS.get(status_code, TransportError)( + status_code, error_message, additional_info + ) diff --git a/elasticsearch/connection/http_requests.py b/elasticsearch/connection/http_requests.py index 1e590fbc3..3a0a3c00f 100644 --- a/elasticsearch/connection/http_requests.py +++ b/elasticsearch/connection/http_requests.py @@ -1,15 +1,23 @@ import time import warnings + try: import requests + REQUESTS_AVAILABLE = True except ImportError: REQUESTS_AVAILABLE = False from .base import Connection -from ..exceptions import ConnectionError, ImproperlyConfigured, ConnectionTimeout, SSLError +from ..exceptions import ( + ConnectionError, + ImproperlyConfigured, + ConnectionTimeout, + SSLError, +) from ..compat import urlencode, string_types + class RequestsHttpConnection(Connection): """ Connection using the `requests` library. @@ -27,25 +35,43 @@ class RequestsHttpConnection(Connection): separate cert and key files (client_cert will contain only the cert) :arg headers: any custom http headers to be add to requests """ - def __init__(self, host='localhost', port=9200, http_auth=None, - use_ssl=False, verify_certs=True, ssl_show_warn=True, ca_certs=None, client_cert=None, - client_key=None, headers=None, **kwargs): + + def __init__( + self, + host="localhost", + port=9200, + http_auth=None, + use_ssl=False, + verify_certs=True, + ssl_show_warn=True, + ca_certs=None, + client_cert=None, + client_key=None, + headers=None, + **kwargs + ): if not REQUESTS_AVAILABLE: - raise ImproperlyConfigured("Please install requests to use RequestsHttpConnection.") + raise ImproperlyConfigured( + "Please install requests to use RequestsHttpConnection." + ) - super(RequestsHttpConnection, self).__init__(host=host, port=port, use_ssl=use_ssl, **kwargs) + super(RequestsHttpConnection, self).__init__( + host=host, port=port, use_ssl=use_ssl, **kwargs + ) self.session = requests.Session() self.session.headers = headers or {} - self.session.headers.setdefault('content-type', 'application/json') + self.session.headers.setdefault("content-type", "application/json") if http_auth is not None: if isinstance(http_auth, (tuple, list)): http_auth = tuple(http_auth) elif isinstance(http_auth, string_types): - http_auth = tuple(http_auth.split(':', 1)) + http_auth = tuple(http_auth.split(":", 1)) self.session.auth = http_auth - self.base_url = 'http%s://%s:%d%s' % ( - 's' if self.use_ssl else '', - host, port, self.url_prefix + self.base_url = "http%s://%s:%d%s" % ( + "s" if self.use_ssl else "", + host, + port, + self.url_prefix, ) self.session.verify = verify_certs if not client_key: @@ -55,42 +81,76 @@ def __init__(self, host='localhost', port=9200, http_auth=None, self.session.cert = (client_cert, client_key) if ca_certs: if not verify_certs: - raise ImproperlyConfigured("You cannot pass CA certificates when verify SSL is off.") + raise ImproperlyConfigured( + "You cannot pass CA certificates when verify SSL is off." + ) self.session.verify = ca_certs if self.use_ssl and not verify_certs and ssl_show_warn: warnings.warn( - 'Connecting to %s using SSL with verify_certs=False is insecure.' % self.base_url) + "Connecting to %s using SSL with verify_certs=False is insecure." + % self.base_url + ) - def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None): + def perform_request( + self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None + ): url = self.base_url + url if params: - url = '%s?%s' % (url, urlencode(params or {})) + url = "%s?%s" % (url, urlencode(params or {})) start = time.time() request = requests.Request(method=method, headers=headers, url=url, data=body) prepared_request = self.session.prepare_request(request) - settings = self.session.merge_environment_settings(prepared_request.url, {}, None, None, None) - send_kwargs = {'timeout': timeout or self.timeout} + settings = self.session.merge_environment_settings( + prepared_request.url, {}, None, None, None + ) + send_kwargs = {"timeout": timeout or self.timeout} send_kwargs.update(settings) try: response = self.session.send(prepared_request, **send_kwargs) duration = time.time() - start raw_data = response.text except Exception as e: - self.log_request_fail(method, url, prepared_request.path_url, body, time.time() - start, exception=e) + self.log_request_fail( + method, + url, + prepared_request.path_url, + body, + time.time() - start, + exception=e, + ) if isinstance(e, requests.exceptions.SSLError): - raise SSLError('N/A', str(e), e) + raise SSLError("N/A", str(e), e) if isinstance(e, requests.Timeout): - raise ConnectionTimeout('TIMEOUT', str(e), e) - raise ConnectionError('N/A', str(e), e) + raise ConnectionTimeout("TIMEOUT", str(e), e) + raise ConnectionError("N/A", str(e), e) # raise errors based on http status codes, let the client handle those if needed - if not (200 <= response.status_code < 300) and response.status_code not in ignore: - self.log_request_fail(method, url, response.request.path_url, body, duration, response.status_code, raw_data) + if ( + not (200 <= response.status_code < 300) + and response.status_code not in ignore + ): + self.log_request_fail( + method, + url, + response.request.path_url, + body, + duration, + response.status_code, + raw_data, + ) self._raise_error(response.status_code, raw_data) - self.log_request_success(method, url, response.request.path_url, body, response.status_code, raw_data, duration) + self.log_request_success( + method, + url, + response.request.path_url, + body, + response.status_code, + raw_data, + duration, + ) return response.status_code, response.headers, raw_data diff --git a/elasticsearch/connection/pooling.py b/elasticsearch/connection/pooling.py index 3115e7c2c..dd5431e15 100644 --- a/elasticsearch/connection/pooling.py +++ b/elasticsearch/connection/pooling.py @@ -12,6 +12,7 @@ class PoolingConnection(Connection): ``_make_connection`` method that constructs a new connection and returns it. """ + def __init__(self, *args, **kwargs): self._free_connections = queue.Queue() super(PoolingConnection, self).__init__(*args, **kwargs) @@ -30,4 +31,3 @@ def close(self): Explicitly close connection """ pass - diff --git a/elasticsearch/connection_pool.py b/elasticsearch/connection_pool.py index c28bb3ccc..05277632c 100644 --- a/elasticsearch/connection_pool.py +++ b/elasticsearch/connection_pool.py @@ -10,7 +10,8 @@ from .exceptions import ImproperlyConfigured -logger = logging.getLogger('elasticsearch') +logger = logging.getLogger("elasticsearch") + class ConnectionSelector(object): """ @@ -30,6 +31,7 @@ class ConnectionSelector(object): only select connections from it's own zones and only fall back to other connections where there would be none in it's zones. """ + def __init__(self, opts): """ :arg opts: dictionary of connection instances and their options @@ -49,6 +51,7 @@ class RandomSelector(ConnectionSelector): """ Select a connection at random """ + def select(self, connections): return random.choice(connections) @@ -57,15 +60,17 @@ class RoundRobinSelector(ConnectionSelector): """ Selector using round-robin. """ + def __init__(self, opts): super(RoundRobinSelector, self).__init__(opts) self.data = threading.local() def select(self, connections): - self.data.rr = getattr(self.data, 'rr', -1) + 1 + self.data.rr = getattr(self.data, "rr", -1) + 1 self.data.rr %= len(connections) return connections[self.data.rr] + class ConnectionPool(object): """ Container holding the :class:`~elasticsearch.Connection` instances, @@ -88,8 +93,16 @@ class ConnectionPool(object): live pool. A connection that has been previously marked as dead and succeeds will be marked as live (its fail count will be deleted). """ - def __init__(self, connections, dead_timeout=60, timeout_cutoff=5, - selector_class=RoundRobinSelector, randomize_hosts=True, **kwargs): + + def __init__( + self, + connections, + dead_timeout=60, + timeout_cutoff=5, + selector_class=RoundRobinSelector, + randomize_hosts=True, + **kwargs + ): """ :arg connections: list of tuples containing the :class:`~elasticsearch.Connection` instance and it's options @@ -103,8 +116,9 @@ def __init__(self, connections, dead_timeout=60, timeout_cutoff=5, avoid dog piling effect across processes """ if not connections: - raise ImproperlyConfigured("No defined connections, you need to " - "specify at least one host.") + raise ImproperlyConfigured( + "No defined connections, you need to " "specify at least one host." + ) self.connection_opts = connections self.connections = [c for (c, opts) in connections] # remember original connection list for resurrect(force=True) @@ -144,8 +158,10 @@ def mark_dead(self, connection, now=None): timeout = self.dead_timeout * 2 ** min(dead_count - 1, self.timeout_cutoff) self.dead.put((now + timeout, connection)) logger.warning( - 'Connection %r has failed for %i times in a row, putting on %i second timeout.', - connection, dead_count, timeout + "Connection %r has failed for %i times in a row, putting on %i second timeout.", + connection, + dead_count, + timeout, ) def mark_live(self, connection): @@ -200,7 +216,7 @@ def resurrect(self, force=False): # either we were forced or the connection is elligible to be retried self.connections.append(connection) - logger.info('Resurrecting connection %r (force=%s).', connection, force) + logger.info("Resurrecting connection %r (force=%s).", connection, force) return connection def get_connection(self): @@ -235,15 +251,17 @@ def close(self): for conn in self.orig_connections: conn.close() + class DummyConnectionPool(ConnectionPool): def __init__(self, connections, **kwargs): if len(connections) != 1: - raise ImproperlyConfigured("DummyConnectionPool needs exactly one " - "connection defined.") + raise ImproperlyConfigured( + "DummyConnectionPool needs exactly one " "connection defined." + ) # we need connection opts for sniffing logic self.connection_opts = connections self.connection = connections[0][0] - self.connections = (self.connection, ) + self.connections = (self.connection,) def get_connection(self): return self.connection @@ -256,6 +274,5 @@ def close(self): def _noop(self, *args, **kwargs): pass - mark_dead = mark_live = resurrect = _noop - + mark_dead = mark_live = resurrect = _noop diff --git a/elasticsearch/exceptions.py b/elasticsearch/exceptions.py index 00f1d01ab..c7d83b314 100644 --- a/elasticsearch/exceptions.py +++ b/elasticsearch/exceptions.py @@ -1,7 +1,16 @@ __all__ = [ - 'ImproperlyConfigured', 'ElasticsearchException', 'SerializationError', - 'TransportError', 'NotFoundError', 'ConflictError', 'RequestError', 'ConnectionError', - 'SSLError', 'ConnectionTimeout', 'AuthenticationException', 'AuthorizationException' + "ImproperlyConfigured", + "ElasticsearchException", + "SerializationError", + "TransportError", + "NotFoundError", + "ConflictError", + "RequestError", + "ConnectionError", + "SSLError", + "ConnectionTimeout", + "AuthenticationException", + "AuthorizationException", ] @@ -31,6 +40,7 @@ class TransportError(ElasticsearchException): an actual connection error happens; in that case the ``status_code`` will be set to ``'N/A'``. """ + @property def status_code(self): """ @@ -53,20 +63,28 @@ def info(self): return self.args[2] def __str__(self): - cause = '' + cause = "" try: - if self.info and 'error' in self.info: - if isinstance(self.info['error'], dict): - root_cause = self.info['error']['root_cause'][0] - cause = ', '.join(filter(None, [repr(root_cause['reason']), root_cause.get('resource.id'), - root_cause.get('resource.type')])) + if self.info and "error" in self.info: + if isinstance(self.info["error"], dict): + root_cause = self.info["error"]["root_cause"][0] + cause = ", ".join( + filter( + None, + [ + repr(root_cause["reason"]), + root_cause.get("resource.id"), + root_cause.get("resource.type"), + ], + ) + ) else: - cause = repr(self.info['error']) + cause = repr(self.info["error"]) except LookupError: pass - msg = ', '.join(filter(None, [str(self.status_code), repr(self.error), cause])) - return '%s(%s)' % (self.__class__.__name__, msg) + msg = ", ".join(filter(None, [str(self.status_code), repr(self.error), cause])) + return "%s(%s)" % (self.__class__.__name__, msg) class ConnectionError(TransportError): @@ -77,8 +95,11 @@ class ConnectionError(TransportError): """ def __str__(self): - return 'ConnectionError(%s) caused by: %s(%s)' % ( - self.error, self.info.__class__.__name__, self.info) + return "ConnectionError(%s) caused by: %s(%s)" % ( + self.error, + self.info.__class__.__name__, + self.info, + ) class SSLError(ConnectionError): @@ -89,8 +110,10 @@ class ConnectionTimeout(ConnectionError): """ A network timeout. Doesn't cause a node retry by default. """ def __str__(self): - return 'ConnectionTimeout caused by - %s(%s)' % ( - self.info.__class__.__name__, self.info) + return "ConnectionTimeout caused by - %s(%s)" % ( + self.info.__class__.__name__, + self.info, + ) class NotFoundError(TransportError): diff --git a/elasticsearch/helpers/__init__.py b/elasticsearch/helpers/__init__.py index c5c81be66..848138bf0 100644 --- a/elasticsearch/helpers/__init__.py +++ b/elasticsearch/helpers/__init__.py @@ -1,8 +1,4 @@ - - from .errors import BulkIndexError, ScanError from .actions import expand_action, streaming_bulk, bulk, parallel_bulk from .actions import scan, reindex from .actions import _chunk_actions, _process_bulk_chunk - - diff --git a/elasticsearch/helpers/actions.py b/elasticsearch/helpers/actions.py index bfdf64187..1d233b464 100644 --- a/elasticsearch/helpers/actions.py +++ b/elasticsearch/helpers/actions.py @@ -347,7 +347,7 @@ def parallel_bulk( class BlockingPool(ThreadPool): def _setup_queues(self): super(BlockingPool, self)._setup_queues() - # The queue must be at least the size of the number of threads to + # The queue must be at least the size of the number of threads to # prevent hanging when inserting sentinel values during teardown. self._inqueue = Queue(max(queue_size, thread_count)) self._quick_put = self._inqueue.put @@ -437,7 +437,7 @@ def scan( scroll_id = resp.get("_scroll_id") try: - while scroll_id and resp['hits']['hits']: + while scroll_id and resp["hits"]["hits"]: for hit in resp["hits"]["hits"]: yield hit diff --git a/elasticsearch/helpers/errors.py b/elasticsearch/helpers/errors.py index 5f4dc7485..6261822e5 100644 --- a/elasticsearch/helpers/errors.py +++ b/elasticsearch/helpers/errors.py @@ -1,5 +1,3 @@ - - from ..exceptions import ElasticsearchException diff --git a/elasticsearch/helpers/test.py b/elasticsearch/helpers/test.py index 6fdd2ec05..cdc7a981c 100644 --- a/elasticsearch/helpers/test.py +++ b/elasticsearch/helpers/test.py @@ -1,5 +1,6 @@ import time import os + try: # python 2.6 from unittest2 import TestCase, SkipTest @@ -9,33 +10,37 @@ from elasticsearch import Elasticsearch from elasticsearch.exceptions import ConnectionError + def get_test_client(nowait=False, **kwargs): # construct kwargs from the environment - kw = {'timeout': 30} - if 'TEST_ES_CONNECTION' in os.environ: + kw = {"timeout": 30} + if "TEST_ES_CONNECTION" in os.environ: from elasticsearch import connection - kw['connection_class'] = getattr(connection, os.environ['TEST_ES_CONNECTION']) + + kw["connection_class"] = getattr(connection, os.environ["TEST_ES_CONNECTION"]) kw.update(kwargs) - client = Elasticsearch([os.environ.get('TEST_ES_SERVER', {})], **kw) + client = Elasticsearch([os.environ.get("TEST_ES_SERVER", {})], **kw) # wait for yellow status for _ in range(1 if nowait else 100): try: - client.cluster.health(wait_for_status='yellow') + client.cluster.health(wait_for_status="yellow") return client except ConnectionError: - time.sleep(.1) + time.sleep(0.1) else: # timeout raise SkipTest("Elasticsearch failed to start.") + def _get_version(version_string): - if '.' not in version_string: + if "." not in version_string: return () - version = version_string.strip().split('.') + version = version_string.strip().split(".") return tuple(int(v) if v.isdigit() else 999 for v in version) + class ElasticsearchTestCase(TestCase): @staticmethod def _get_client(): @@ -48,13 +53,12 @@ def setUpClass(cls): def tearDown(self): super(ElasticsearchTestCase, self).tearDown() - self.client.indices.delete(index='*', ignore=404) - self.client.indices.delete_template(name='*', ignore=404) + self.client.indices.delete(index="*", ignore=404) + self.client.indices.delete_template(name="*", ignore=404) @property def es_version(self): - if not hasattr(self, '_es_version'): - version_string = self.client.info()['version']['number'] + if not hasattr(self, "_es_version"): + version_string = self.client.info()["version"]["number"] self._es_version = _get_version(version_string) return self._es_version - diff --git a/elasticsearch/serializer.py b/elasticsearch/serializer.py index 12f63f2a3..dd7d0dc87 100644 --- a/elasticsearch/serializer.py +++ b/elasticsearch/serializer.py @@ -9,8 +9,9 @@ from .exceptions import SerializationError, ImproperlyConfigured from .compat import string_types + class TextSerializer(object): - mimetype = 'text/plain' + mimetype = "text/plain" def loads(self, s): return s @@ -19,10 +20,11 @@ def dumps(self, data): if isinstance(data, string_types): return data - raise SerializationError('Cannot serialize %r into text.' % data) + raise SerializationError("Cannot serialize %r into text." % data) + class JSONSerializer(object): - mimetype = 'application/json' + mimetype = "application/json" def default(self, data): if isinstance(data, (date, datetime)): @@ -46,25 +48,26 @@ def dumps(self, data): try: return json.dumps( - data, - default=self.default, - ensure_ascii=False, - separators=(',', ':'), + data, default=self.default, ensure_ascii=False, separators=(",", ":") ) except (ValueError, TypeError) as e: raise SerializationError(data, e) + DEFAULT_SERIALIZERS = { JSONSerializer.mimetype: JSONSerializer(), TextSerializer.mimetype: TextSerializer(), } + class Deserializer(object): - def __init__(self, serializers, default_mimetype='application/json'): + def __init__(self, serializers, default_mimetype="application/json"): try: self.default = serializers[default_mimetype] except KeyError: - raise ImproperlyConfigured('Cannot find default serializer (%s)' % default_mimetype) + raise ImproperlyConfigured( + "Cannot find default serializer (%s)" % default_mimetype + ) self.serializers = serializers def loads(self, s, mimetype=None): @@ -72,11 +75,12 @@ def loads(self, s, mimetype=None): deserializer = self.default else: # split out charset - mimetype, _, _ = mimetype.partition(';') + mimetype, _, _ = mimetype.partition(";") try: deserializer = self.serializers[mimetype] except KeyError: - raise SerializationError('Unknown mimetype, unable to deserialize: %s' % mimetype) + raise SerializationError( + "Unknown mimetype, unable to deserialize: %s" % mimetype + ) return deserializer.loads(s) - diff --git a/elasticsearch/transport.py b/elasticsearch/transport.py index b7801bc56..37b2c8333 100644 --- a/elasticsearch/transport.py +++ b/elasticsearch/transport.py @@ -4,8 +4,12 @@ from .connection import Urllib3HttpConnection from .connection_pool import ConnectionPool, DummyConnectionPool from .serializer import JSONSerializer, Deserializer, DEFAULT_SERIALIZERS -from .exceptions import ConnectionError, TransportError, SerializationError, \ - ConnectionTimeout +from .exceptions import ( + ConnectionError, + TransportError, + SerializationError, + ConnectionTimeout, +) def get_host_info(node_info, host): @@ -23,10 +27,11 @@ def get_host_info(node_info, host): :arg host: connection information (host, port) extracted from the node info """ # ignore master only nodes - if node_info.get('roles', []) == ['master']: + if node_info.get("roles", []) == ["master"]: return None return host + class Transport(object): """ Encapsulation of transport-related to logic. Handles instantiation of the @@ -34,12 +39,26 @@ class Transport(object): Main interface is the `perform_request` method. """ - def __init__(self, hosts, connection_class=Urllib3HttpConnection, - connection_pool_class=ConnectionPool, host_info_callback=get_host_info, - sniff_on_start=False, sniffer_timeout=None, sniff_timeout=.1, - sniff_on_connection_fail=False, serializer=JSONSerializer(), serializers=None, - default_mimetype='application/json', max_retries=3, retry_on_status=(502, 503, 504, ), - retry_on_timeout=False, send_get_body_as='GET', **kwargs): + + def __init__( + self, + hosts, + connection_class=Urllib3HttpConnection, + connection_pool_class=ConnectionPool, + host_info_callback=get_host_info, + sniff_on_start=False, + sniffer_timeout=None, + sniff_timeout=0.1, + sniff_on_connection_fail=False, + serializer=JSONSerializer(), + serializers=None, + default_mimetype="application/json", + max_retries=3, + retry_on_status=(502, 503, 504), + retry_on_timeout=False, + send_get_body_as="GET", + **kwargs + ): """ :arg hosts: list of dictionaries, each containing keyword arguments to create a `connection_class` instance @@ -143,7 +162,7 @@ def _create_connection(host): # if this is not the initial setup look at the existing connection # options and identify connections that haven't changed and can be # kept around. - if hasattr(self, 'connection_pool'): + if hasattr(self, "connection_pool"): for (connection, old_host) in self.connection_pool.connection_opts: if old_host == host: return connection @@ -152,6 +171,7 @@ def _create_connection(host): kwargs = self.kwargs.copy() kwargs.update(host) return self.connection_class(**kwargs) + connections = map(_create_connection, hosts) connections = list(zip(connections, hosts)) @@ -159,7 +179,9 @@ def _create_connection(host): self.connection_pool = DummyConnectionPool(connections) else: # pass the hosts dicts to the connection pool to optionally extract parameters from - self.connection_pool = self.connection_pool_class(connections, **self.kwargs) + self.connection_pool = self.connection_pool_class( + connections, **self.kwargs + ) def get_connection(self): """ @@ -194,9 +216,13 @@ def _get_sniff_data(self, initial=False): try: # use small timeout for the sniffing request, should be a fast api call _, headers, node_info = c.perform_request( - 'GET', '/_nodes/_all/http', - timeout=self.sniff_timeout if not initial else None) - node_info = self.deserializer.loads(node_info, headers.get('content-type')) + "GET", + "/_nodes/_all/http", + timeout=self.sniff_timeout if not initial else None, + ) + node_info = self.deserializer.loads( + node_info, headers.get("content-type") + ) break except (ConnectionError, SerializationError): pass @@ -207,18 +233,18 @@ def _get_sniff_data(self, initial=False): self.last_sniff = previous_sniff raise - return list(node_info['nodes'].values()) + return list(node_info["nodes"].values()) def _get_host_info(self, host_info): host = {} - address = host_info.get('http', {}).get('publish_address') + address = host_info.get("http", {}).get("publish_address") # malformed or no address given - if not address or ':' not in address: + if not address or ":" not in address: return None - host['host'], host['port'] = address.rsplit(':', 1) - host['port'] = int(host['port']) + host["host"], host["port"] = address.rsplit(":", 1) + host["port"] = int(host["port"]) return self.host_info_callback(host_info, host) @@ -239,7 +265,9 @@ def sniff_hosts(self, initial=False): # we weren't able to get any nodes or host_info_callback blocked all - # raise error. if not hosts: - raise TransportError("N/A", "Unable to sniff hosts - no viable hosts found.") + raise TransportError( + "N/A", "Unable to sniff hosts - no viable hosts found." + ) self.set_connections(hosts) @@ -280,21 +308,21 @@ def perform_request(self, method, url, headers=None, params=None, body=None): body = self.serializer.dumps(body) # some clients or environments don't support sending GET with body - if method in ('HEAD', 'GET') and self.send_get_body_as != 'GET': + if method in ("HEAD", "GET") and self.send_get_body_as != "GET": # send it as post instead - if self.send_get_body_as == 'POST': - method = 'POST' + if self.send_get_body_as == "POST": + method = "POST" # or as source parameter - elif self.send_get_body_as == 'source': + elif self.send_get_body_as == "source": if params is None: params = {} - params['source'] = body + params["source"] = body body = None if body is not None: try: - body = body.encode('utf-8', 'surrogatepass') + body = body.encode("utf-8", "surrogatepass") except (UnicodeDecodeError, AttributeError): # bytes/str - no need to re-encode pass @@ -302,10 +330,10 @@ def perform_request(self, method, url, headers=None, params=None, body=None): ignore = () timeout = None if params: - timeout = params.pop('request_timeout', None) - ignore = params.pop('ignore', ()) + timeout = params.pop("request_timeout", None) + ignore = params.pop("ignore", ()) if isinstance(ignore, int): - ignore = (ignore, ) + ignore = (ignore,) for attempt in range(self.max_retries + 1): connection = self.get_connection() @@ -313,12 +341,20 @@ def perform_request(self, method, url, headers=None, params=None, body=None): try: # add a delay before attempting the next retry # 0, 1, 3, 7, etc... - delay = 2**attempt - 1 + delay = 2 ** attempt - 1 time.sleep(delay) - status, headers_response, data = connection.perform_request(method, url, params, body, headers=headers, ignore=ignore, timeout=timeout) + status, headers_response, data = connection.perform_request( + method, + url, + params, + body, + headers=headers, + ignore=ignore, + timeout=timeout, + ) except TransportError as e: - if method == 'HEAD' and e.status_code == 404: + if method == "HEAD" and e.status_code == 404: return False retry = False @@ -342,11 +378,13 @@ def perform_request(self, method, url, headers=None, params=None, body=None): # connection didn't fail, confirm it's live status self.connection_pool.mark_live(connection) - if method == 'HEAD': + if method == "HEAD": return 200 <= status < 300 if data: - data = self.deserializer.loads(data, headers_response.get('content-type')) + data = self.deserializer.loads( + data, headers_response.get("content-type") + ) return data def close(self): diff --git a/example/load.py b/example/load.py index b500f9449..06abe903f 100644 --- a/example/load.py +++ b/example/load.py @@ -14,65 +14,62 @@ from elasticsearch.exceptions import TransportError from elasticsearch.helpers import bulk, streaming_bulk + def create_git_index(client, index): # we will use user on several places user_mapping = { - 'properties': { - 'name': { - 'type': 'text', - 'fields': { - 'keyword': {'type': 'keyword'}, - } + "properties": { + "name": {"type": "text", "fields": {"keyword": {"type": "keyword"}}} } - } } create_index_body = { - 'settings': { - # just one shard, no replicas for testing - 'number_of_shards': 1, - 'number_of_replicas': 0, - - # custom analyzer for analyzing file paths - 'analysis': { - 'analyzer': { - 'file_path': { - 'type': 'custom', - 'tokenizer': 'path_hierarchy', - 'filter': ['lowercase'] + "settings": { + # just one shard, no replicas for testing + "number_of_shards": 1, + "number_of_replicas": 0, + # custom analyzer for analyzing file paths + "analysis": { + "analyzer": { + "file_path": { + "type": "custom", + "tokenizer": "path_hierarchy", + "filter": ["lowercase"], + } + } + }, + }, + "mappings": { + "doc": { + "properties": { + "repository": {"type": "keyword"}, + "author": user_mapping, + "authored_date": {"type": "date"}, + "committer": user_mapping, + "committed_date": {"type": "date"}, + "parent_shas": {"type": "keyword"}, + "description": {"type": "text", "analyzer": "snowball"}, + "files": { + "type": "text", + "analyzer": "file_path", + "fielddata": True, + }, + } } - } - } - }, - 'mappings': { - 'doc': { - 'properties': { - 'repository': {'type': 'keyword'}, - 'author': user_mapping, - 'authored_date': {'type': 'date'}, - 'committer': user_mapping, - 'committed_date': {'type': 'date'}, - 'parent_shas': {'type': 'keyword'}, - 'description': {'type': 'text', 'analyzer': 'snowball'}, - 'files': {'type': 'text', 'analyzer': 'file_path', "fielddata": True} - } - } - } + }, } # create empty index try: - client.indices.create( - index=index, - body=create_index_body, - ) + client.indices.create(index=index, body=create_index_body) except TransportError as e: # ignore already existing index - if e.error == 'index_already_exists_exception': + if e.error == "index_already_exists_exception": pass else: raise + def parse_commits(head, name): """ Go through the git repository log and generate a document per commit @@ -80,26 +77,24 @@ def parse_commits(head, name): """ for commit in head.traverse(): yield { - '_id': commit.hexsha, - 'repository': name, - 'committed_date': datetime.fromtimestamp(commit.committed_date), - 'committer': { - 'name': commit.committer.name, - 'email': commit.committer.email, - }, - 'authored_date': datetime.fromtimestamp(commit.authored_date), - 'author': { - 'name': commit.author.name, - 'email': commit.author.email, + "_id": commit.hexsha, + "repository": name, + "committed_date": datetime.fromtimestamp(commit.committed_date), + "committer": { + "name": commit.committer.name, + "email": commit.committer.email, }, - 'description': commit.message, - 'parent_shas': [p.hexsha for p in commit.parents], + "authored_date": datetime.fromtimestamp(commit.authored_date), + "author": {"name": commit.author.name, "email": commit.author.email}, + "description": commit.message, + "parent_shas": [p.hexsha for p in commit.parents], # we only care about the filenames, not the per-file stats - 'files': list(commit.stats.files), - 'stats': commit.stats.total, + "files": list(commit.stats.files), + "stats": commit.stats.total, } -def load_repo(client, path=None, index='git'): + +def load_repo(client, path=None, index="git"): """ Parse a git repository with all it's commits and load it into elasticsearch using `client`. If the index doesn't exist it will be created. @@ -114,18 +109,18 @@ def load_repo(client, path=None, index='git'): # in - since the `parse_commits` function is a generator this will avoid # loading all the commits into memory for ok, result in streaming_bulk( - client, - parse_commits(repo.refs.master.commit, repo_name), - index=index, - doc_type='doc', - chunk_size=50 # keep the batch sizes small for appearances only - ): + client, + parse_commits(repo.refs.master.commit, repo_name), + index=index, + doc_type="doc", + chunk_size=50, # keep the batch sizes small for appearances only + ): action, result = result.popitem() - doc_id = '/%s/doc/%s' % (index, result['_id']) + doc_id = "/%s/doc/%s" % (index, result["_id"]) # process the information from ES whether the document has been # successfully indexed if not ok: - print('Failed to %s document %s: %r' % (action, doc_id, result)) + print("Failed to %s document %s: %r" % (action, doc_id, result)) else: print(doc_id) @@ -133,36 +128,40 @@ def load_repo(client, path=None, index='git'): # we manually update some documents to add additional information UPDATES = [ { - '_type': 'doc', - '_id': '20fbba1230cabbc0f4644f917c6c2be52b8a63e8', - '_op_type': 'update', - 'doc': {'initial_commit': True} + "_type": "doc", + "_id": "20fbba1230cabbc0f4644f917c6c2be52b8a63e8", + "_op_type": "update", + "doc": {"initial_commit": True}, }, { - '_type': 'doc', - '_id': 'ae0073c8ca7e24d237ffd56fba495ed409081bf4', - '_op_type': 'update', - 'doc': {'release': '5.0.0'} + "_type": "doc", + "_id": "ae0073c8ca7e24d237ffd56fba495ed409081bf4", + "_op_type": "update", + "doc": {"release": "5.0.0"}, }, ] -if __name__ == '__main__': +if __name__ == "__main__": # get trace logger and set level - tracer = logging.getLogger('elasticsearch.trace') + tracer = logging.getLogger("elasticsearch.trace") tracer.setLevel(logging.INFO) - tracer.addHandler(logging.FileHandler('/tmp/es_trace.log')) + tracer.addHandler(logging.FileHandler("/tmp/es_trace.log")) parser = argparse.ArgumentParser() parser.add_argument( - "-H", "--host", + "-H", + "--host", action="store", default="localhost:9200", - help="The elasticsearch host you wish to connect to. (Default: localhost:9200)") + help="The elasticsearch host you wish to connect to. (Default: localhost:9200)", + ) parser.add_argument( - "-p", "--path", + "-p", + "--path", action="store", default=None, - help="Path to git repo. Commits used as data to load into Elasticsearch. (Default: None") + help="Path to git repo. Commits used as data to load into Elasticsearch. (Default: None", + ) args = parser.parse_args() @@ -173,15 +172,19 @@ def load_repo(client, path=None, index='git'): load_repo(es, path=args.path) # run the bulk operations - success, _ = bulk(es, UPDATES, index='git') - print('Performed %d actions' % success) + success, _ = bulk(es, UPDATES, index="git") + print("Performed %d actions" % success) # we can now make docs visible for searching - es.indices.refresh(index='git') + es.indices.refresh(index="git") # now we can retrieve the documents - initial_commit = es.get(index='git', doc_type='doc', id='20fbba1230cabbc0f4644f917c6c2be52b8a63e8') - print('%s: %s' % (initial_commit['_id'], initial_commit['_source']['committed_date'])) + initial_commit = es.get( + index="git", doc_type="doc", id="20fbba1230cabbc0f4644f917c6c2be52b8a63e8" + ) + print( + "%s: %s" % (initial_commit["_id"], initial_commit["_source"]["committed_date"]) + ) # and now we can count the documents - print(es.count(index='git')['count'], 'documents in index') + print(es.count(index="git")["count"], "documents in index") diff --git a/example/queries.py b/example/queries.py index 4d61c2623..a08070662 100644 --- a/example/queries.py +++ b/example/queries.py @@ -6,95 +6,92 @@ from elasticsearch import Elasticsearch + def print_search_stats(results): - print('=' * 80) - print('Total %d found in %dms' % (results['hits']['total'], results['took'])) - print('-' * 80) + print("=" * 80) + print("Total %d found in %dms" % (results["hits"]["total"], results["took"])) + print("-" * 80) + def print_hits(results): " Simple utility function to print results of a search query. " print_search_stats(results) - for hit in results['hits']['hits']: + for hit in results["hits"]["hits"]: # get created date for a repo and fallback to authored_date for a commit - created_at = parse_date(hit['_source'].get('created_at', hit['_source']['authored_date'])) - print('/%s/%s/%s (%s): %s' % ( - hit['_index'], hit['_type'], hit['_id'], - created_at.strftime('%Y-%m-%d'), - hit['_source']['description'].split('\n')[0])) + created_at = parse_date( + hit["_source"].get("created_at", hit["_source"]["authored_date"]) + ) + print( + "/%s/%s/%s (%s): %s" + % ( + hit["_index"], + hit["_type"], + hit["_id"], + created_at.strftime("%Y-%m-%d"), + hit["_source"]["description"].split("\n")[0], + ) + ) - print('=' * 80) + print("=" * 80) print() + # get trace logger and set level -tracer = logging.getLogger('elasticsearch.trace') +tracer = logging.getLogger("elasticsearch.trace") tracer.setLevel(logging.INFO) -tracer.addHandler(logging.FileHandler('/tmp/es_trace.log')) +tracer.addHandler(logging.FileHandler("/tmp/es_trace.log")) # instantiate es client, connects to localhost:9200 by default es = Elasticsearch() -print('Empty search:') -print_hits(es.search(index='git')) +print("Empty search:") +print_hits(es.search(index="git")) print('Find commits that says "fix" without touching tests:') result = es.search( - index='git', - doc_type='doc', + index="git", + doc_type="doc", body={ - 'query': { - 'bool': { - 'must': { - 'match': {'description': 'fix'} - }, - 'must_not': { - 'term': {'files': 'test_elasticsearch'} - } + "query": { + "bool": { + "must": {"match": {"description": "fix"}}, + "must_not": {"term": {"files": "test_elasticsearch"}}, + } } - } - } + }, ) print_hits(result) -print('Last 8 Commits for elasticsearch-py:') +print("Last 8 Commits for elasticsearch-py:") result = es.search( - index='git', - doc_type='doc', + index="git", + doc_type="doc", body={ - 'query': { - 'term': { - 'repository': 'elasticsearch-py' - } - }, - 'sort': [ - {'committed_date': {'order': 'desc'}} - ], - 'size': 8 - } + "query": {"term": {"repository": "elasticsearch-py"}}, + "sort": [{"committed_date": {"order": "desc"}}], + "size": 8, + }, ) print_hits(result) -print('Stats for top 10 committers:') +print("Stats for top 10 committers:") result = es.search( - index='git', - doc_type='doc', + index="git", + doc_type="doc", body={ - 'size': 0, - 'aggs': { - 'committers': { - 'terms': { - 'field': 'committer.name.keyword', - }, - 'aggs': { - 'line_stats': { - 'stats': {'field': 'stats.lines'} + "size": 0, + "aggs": { + "committers": { + "terms": {"field": "committer.name.keyword"}, + "aggs": {"line_stats": {"stats": {"field": "stats.lines"}}}, } - } - } - } - } + }, + }, ) print_search_stats(result) -for committer in result['aggregations']['committers']['buckets']: - print('%15s: %3d commits changing %6d lines' % ( - committer['key'], committer['doc_count'], committer['line_stats']['sum'])) -print('=' * 80) +for committer in result["aggregations"]["committers"]["buckets"]: + print( + "%15s: %3d commits changing %6d lines" + % (committer["key"], committer["doc_count"], committer["line_stats"]["sum"]) + ) +print("=" * 80) diff --git a/test_elasticsearch/test_cases.py b/test_elasticsearch/test_cases.py index 43550e5cc..6dec2b3c1 100644 --- a/test_elasticsearch/test_cases.py +++ b/test_elasticsearch/test_cases.py @@ -1,4 +1,5 @@ from collections import defaultdict + try: # python 2.6 from unittest2 import TestCase, SkipTest @@ -7,6 +8,7 @@ from elasticsearch import Elasticsearch + class DummyTransport(object): def __init__(self, hosts, responses=None, **kwargs): self.hosts = hosts @@ -46,7 +48,7 @@ def test_start_with_0_call(self): self.assert_call_count_equals(0) def test_each_call_is_recorded(self): - self.client.transport.perform_request('GET', '/') - self.client.transport.perform_request('DELETE', '/42', params={}, body='body') + self.client.transport.perform_request("GET", "/") + self.client.transport.perform_request("DELETE", "/42", params={}, body="body") self.assert_call_count_equals(2) - self.assertEquals([({}, 'body')], self.assert_url_called('DELETE', '/42', 1)) + self.assertEquals([({}, "body")], self.assert_url_called("DELETE", "/42", 1)) diff --git a/test_elasticsearch/test_client/test_indices.py b/test_elasticsearch/test_client/test_indices.py index fbb8c23c6..7d80562a7 100644 --- a/test_elasticsearch/test_client/test_indices.py +++ b/test_elasticsearch/test_client/test_indices.py @@ -1,19 +1,20 @@ from test_elasticsearch.test_cases import ElasticsearchTestCase + class TestIndices(ElasticsearchTestCase): def test_create_one_index(self): - self.client.indices.create('test-index') - self.assert_url_called('PUT', '/test-index') + self.client.indices.create("test-index") + self.assert_url_called("PUT", "/test-index") def test_delete_multiple_indices(self): - self.client.indices.delete(['test-index', 'second.index', 'third/index']) - self.assert_url_called('DELETE', '/test-index,second.index,third%2Findex') + self.client.indices.delete(["test-index", "second.index", "third/index"]) + self.assert_url_called("DELETE", "/test-index,second.index,third%2Findex") def test_exists_index(self): - self.client.indices.exists('second.index,third/index') - self.assert_url_called('HEAD', '/second.index,third%2Findex') + self.client.indices.exists("second.index,third/index") + self.assert_url_called("HEAD", "/second.index,third%2Findex") def test_passing_empty_value_for_required_param_raises_exception(self): self.assertRaises(ValueError, self.client.indices.exists, index=None) self.assertRaises(ValueError, self.client.indices.exists, index=[]) - self.assertRaises(ValueError, self.client.indices.exists, index='') + self.assertRaises(ValueError, self.client.indices.exists, index="") diff --git a/test_elasticsearch/test_client/test_utils.py b/test_elasticsearch/test_client/test_utils.py index 66a832930..e30d8b14c 100644 --- a/test_elasticsearch/test_client/test_utils.py +++ b/test_elasticsearch/test_client/test_utils.py @@ -6,35 +6,32 @@ from ..test_cases import TestCase, SkipTest + class TestMakePath(TestCase): def test_handles_unicode(self): id = "中文" - self.assertEquals('/some-index/type/%E4%B8%AD%E6%96%87', _make_path('some-index', 'type', id)) + self.assertEquals( + "/some-index/type/%E4%B8%AD%E6%96%87", _make_path("some-index", "type", id) + ) def test_handles_utf_encoded_string(self): if not PY2: - raise SkipTest('Only relevant for py2') - id = "中文".encode('utf-8') - self.assertEquals('/some-index/type/%E4%B8%AD%E6%96%87', _make_path('some-index', 'type', id)) + raise SkipTest("Only relevant for py2") + id = "中文".encode("utf-8") + self.assertEquals( + "/some-index/type/%E4%B8%AD%E6%96%87", _make_path("some-index", "type", id) + ) class TestEscape(TestCase): def test_handles_ascii(self): string = "abc123" - self.assertEquals( - b'abc123', - _escape(string) - ) + self.assertEquals(b"abc123", _escape(string)) + def test_handles_unicode(self): string = "中文" - self.assertEquals( - b'\xe4\xb8\xad\xe6\x96\x87', - _escape(string) - ) + self.assertEquals(b"\xe4\xb8\xad\xe6\x96\x87", _escape(string)) def test_handles_bytestring(self): - string = b'celery-task-meta-c4f1201f-eb7b-41d5-9318-a75a8cfbdaa0' - self.assertEquals( - string, - _escape(string) - ) + string = b"celery-task-meta-c4f1201f-eb7b-41d5-9318-a75a8cfbdaa0" + self.assertEquals(string, _escape(string)) diff --git a/test_elasticsearch/test_connection.py b/test_elasticsearch/test_connection.py index d35387542..478897c2c 100644 --- a/test_elasticsearch/test_connection.py +++ b/test_elasticsearch/test_connection.py @@ -6,9 +6,13 @@ import warnings from requests.auth import AuthBase -from elasticsearch.exceptions import TransportError, ConflictError, RequestError, NotFoundError -from elasticsearch.connection import RequestsHttpConnection, \ - Urllib3HttpConnection +from elasticsearch.exceptions import ( + TransportError, + ConflictError, + RequestError, + NotFoundError, +) +from elasticsearch.connection import RequestsHttpConnection, Urllib3HttpConnection from elasticsearch.exceptions import ImproperlyConfigured from .test_cases import TestCase, SkipTest @@ -22,20 +26,18 @@ def test_ssl_context(self): # it means SSLContext is not available for that version of python # and we should skip this test. raise SkipTest( - "Test test_ssl_context is skipped cause SSLContext is not available for this version of ptyhon") + "Test test_ssl_context is skipped cause SSLContext is not available for this version of ptyhon" + ) con = Urllib3HttpConnection(use_ssl=True, ssl_context=context) self.assertEqual(len(con.pool.conn_kw.keys()), 1) - self.assertIsInstance( - con.pool.conn_kw['ssl_context'], - ssl.SSLContext - ) + self.assertIsInstance(con.pool.conn_kw["ssl_context"], ssl.SSLContext) self.assertTrue(con.use_ssl) def test_http_compression(self): con = Urllib3HttpConnection(http_compress=True) self.assertTrue(con.http_compress) - self.assertEquals(con.headers['content-encoding'], 'gzip') + self.assertEquals(con.headers["content-encoding"], "gzip") def test_timeout_set(self): con = Urllib3HttpConnection(timeout=42) @@ -43,40 +45,60 @@ def test_timeout_set(self): def test_keep_alive_is_on_by_default(self): con = Urllib3HttpConnection() - self.assertEquals({'connection': 'keep-alive', - 'content-type': 'application/json'}, con.headers) + self.assertEquals( + {"connection": "keep-alive", "content-type": "application/json"}, + con.headers, + ) def test_http_auth(self): - con = Urllib3HttpConnection(http_auth='username:secret') - self.assertEquals({ - 'authorization': 'Basic dXNlcm5hbWU6c2VjcmV0', - 'connection': 'keep-alive', - 'content-type': 'application/json' - }, con.headers) + con = Urllib3HttpConnection(http_auth="username:secret") + self.assertEquals( + { + "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", + "connection": "keep-alive", + "content-type": "application/json", + }, + con.headers, + ) def test_http_auth_tuple(self): - con = Urllib3HttpConnection(http_auth=('username', 'secret')) - self.assertEquals({'authorization': 'Basic dXNlcm5hbWU6c2VjcmV0', - 'content-type': 'application/json', - 'connection': 'keep-alive'}, con.headers) + con = Urllib3HttpConnection(http_auth=("username", "secret")) + self.assertEquals( + { + "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", + "content-type": "application/json", + "connection": "keep-alive", + }, + con.headers, + ) def test_http_auth_list(self): - con = Urllib3HttpConnection(http_auth=['username', 'secret']) - self.assertEquals({'authorization': 'Basic dXNlcm5hbWU6c2VjcmV0', - 'content-type': 'application/json', - 'connection': 'keep-alive'}, con.headers) + con = Urllib3HttpConnection(http_auth=["username", "secret"]) + self.assertEquals( + { + "authorization": "Basic dXNlcm5hbWU6c2VjcmV0", + "content-type": "application/json", + "connection": "keep-alive", + }, + con.headers, + ) def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: con = Urllib3HttpConnection(use_ssl=True, verify_certs=False) self.assertEquals(1, len(w)) - self.assertEquals('Connecting to localhost using SSL with verify_certs=False is insecure.', str(w[0].message)) + self.assertEquals( + "Connecting to localhost using SSL with verify_certs=False is insecure.", + str(w[0].message), + ) self.assertIsInstance(con.pool, urllib3.HTTPSConnectionPool) def nowarn_when_test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = Urllib3HttpConnection(use_ssl=True, verify_certs=False, ssl_show_warn=False) + con = Urllib3HttpConnection( + use_ssl=True, verify_certs=False, ssl_show_warn=False + ) self.assertEquals(0, len(w)) self.assertIsInstance(con.pool, urllib3.HTTPSConnectionPool) @@ -85,9 +107,13 @@ def test_doesnt_use_https_if_not_specified(self): con = Urllib3HttpConnection() self.assertIsInstance(con.pool, urllib3.HTTPConnectionPool) + class TestRequestsConnection(TestCase): - def _get_mock_connection(self, connection_params={}, status_code=200, response_body='{}'): + def _get_mock_connection( + self, connection_params={}, status_code=200, response_body="{}" + ): con = RequestsHttpConnection(**connection_params) + def _dummy_send(*args, **kwargs): dummy_response = Mock() dummy_response.headers = {} @@ -97,20 +123,21 @@ def _dummy_send(*args, **kwargs): dummy_response.cookies = {} _dummy_send.call_args = (args, kwargs) return dummy_response + con.session.send = _dummy_send return con def _get_request(self, connection, *args, **kwargs): - if 'body' in kwargs: - kwargs['body'] = kwargs['body'].encode('utf-8') + if "body" in kwargs: + kwargs["body"] = kwargs["body"].encode("utf-8") status, headers, data = connection.perform_request(*args, **kwargs) self.assertEquals(200, status) - self.assertEquals('{}', data) + self.assertEquals("{}", data) - timeout = kwargs.pop('timeout', connection.timeout) + timeout = kwargs.pop("timeout", connection.timeout) args, kwargs = connection.session.send.call_args - self.assertEquals(timeout, kwargs['timeout']) + self.assertEquals(timeout, kwargs["timeout"]) self.assertEquals(1, len(args)) return args[0] @@ -126,73 +153,96 @@ def test_timeout_set(self): def test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = self._get_mock_connection({'use_ssl': True, 'url_prefix': 'url', 'verify_certs': False}) + con = self._get_mock_connection( + {"use_ssl": True, "url_prefix": "url", "verify_certs": False} + ) self.assertEquals(1, len(w)) - self.assertEquals('Connecting to https://localhost:9200/url using SSL with verify_certs=False is insecure.', str(w[0].message)) + self.assertEquals( + "Connecting to https://localhost:9200/url using SSL with verify_certs=False is insecure.", + str(w[0].message), + ) - request = self._get_request(con, 'GET', '/') + request = self._get_request(con, "GET", "/") - self.assertEquals('https://localhost:9200/url/', request.url) - self.assertEquals('GET', request.method) + self.assertEquals("https://localhost:9200/url/", request.url) + self.assertEquals("GET", request.method) self.assertEquals(None, request.body) def nowarn_when_test_uses_https_if_verify_certs_is_off(self): with warnings.catch_warnings(record=True) as w: - con = self._get_mock_connection({'use_ssl': True, 'url_prefix': 'url', 'verify_certs': False, 'ssl_show_warn': False}) + con = self._get_mock_connection( + { + "use_ssl": True, + "url_prefix": "url", + "verify_certs": False, + "ssl_show_warn": False, + } + ) self.assertEquals(0, len(w)) - request = self._get_request(con, 'GET', '/') + request = self._get_request(con, "GET", "/") - self.assertEquals('https://localhost:9200/url/', request.url) - self.assertEquals('GET', request.method) + self.assertEquals("https://localhost:9200/url/", request.url) + self.assertEquals("GET", request.method) self.assertEquals(None, request.body) def test_merge_headers(self): - con = self._get_mock_connection(connection_params={'headers': {'h1': 'v1', 'h2': 'v2'}}) - req = self._get_request(con, 'GET', '/', headers={'h2': 'v2p', 'h3': 'v3'}) - self.assertEquals(req.headers['h1'], 'v1') - self.assertEquals(req.headers['h2'], 'v2p') - self.assertEquals(req.headers['h3'], 'v3') + con = self._get_mock_connection( + connection_params={"headers": {"h1": "v1", "h2": "v2"}} + ) + req = self._get_request(con, "GET", "/", headers={"h2": "v2p", "h3": "v3"}) + self.assertEquals(req.headers["h1"], "v1") + self.assertEquals(req.headers["h2"], "v2p") + self.assertEquals(req.headers["h3"], "v3") def test_http_auth(self): - con = RequestsHttpConnection(http_auth='username:secret') - self.assertEquals(('username', 'secret'), con.session.auth) + con = RequestsHttpConnection(http_auth="username:secret") + self.assertEquals(("username", "secret"), con.session.auth) def test_http_auth_tuple(self): - con = RequestsHttpConnection(http_auth=('username', 'secret')) - self.assertEquals(('username', 'secret'), con.session.auth) + con = RequestsHttpConnection(http_auth=("username", "secret")) + self.assertEquals(("username", "secret"), con.session.auth) def test_http_auth_list(self): - con = RequestsHttpConnection(http_auth=['username', 'secret']) - self.assertEquals(('username', 'secret'), con.session.auth) + con = RequestsHttpConnection(http_auth=["username", "secret"]) + self.assertEquals(("username", "secret"), con.session.auth) def test_repr(self): con = self._get_mock_connection({"host": "elasticsearch.com", "port": 443}) - self.assertEquals('', repr(con)) + self.assertEquals( + "", repr(con) + ) def test_conflict_error_is_returned_on_409(self): con = self._get_mock_connection(status_code=409) - self.assertRaises(ConflictError, con.perform_request, 'GET', '/', {}, '') + self.assertRaises(ConflictError, con.perform_request, "GET", "/", {}, "") def test_not_found_error_is_returned_on_404(self): con = self._get_mock_connection(status_code=404) - self.assertRaises(NotFoundError, con.perform_request, 'GET', '/', {}, '') + self.assertRaises(NotFoundError, con.perform_request, "GET", "/", {}, "") def test_request_error_is_returned_on_400(self): con = self._get_mock_connection(status_code=400) - self.assertRaises(RequestError, con.perform_request, 'GET', '/', {}, '') + self.assertRaises(RequestError, con.perform_request, "GET", "/", {}, "") - @patch('elasticsearch.connection.base.logger') + @patch("elasticsearch.connection.base.logger") def test_head_with_404_doesnt_get_logged(self, logger): con = self._get_mock_connection(status_code=404) - self.assertRaises(NotFoundError, con.perform_request, 'HEAD', '/', {}, '') + self.assertRaises(NotFoundError, con.perform_request, "HEAD", "/", {}, "") self.assertEquals(0, logger.warning.call_count) - @patch('elasticsearch.connection.base.tracer') - @patch('elasticsearch.connection.base.logger') + @patch("elasticsearch.connection.base.tracer") + @patch("elasticsearch.connection.base.logger") def test_failed_request_logs_and_traces(self, logger, tracer): con = self._get_mock_connection(response_body='{"answer": 42}', status_code=500) - self.assertRaises(TransportError, con.perform_request, 'GET', '/', {'param': 42}, '{}'.encode('utf-8')) + self.assertRaises( + TransportError, + con.perform_request, + "GET", + "/", + {"param": 42}, + "{}".encode("utf-8"), + ) # trace request self.assertEquals(1, tracer.info.call_count) @@ -200,90 +250,101 @@ def test_failed_request_logs_and_traces(self, logger, tracer): self.assertEquals(1, tracer.debug.call_count) # log url and duration self.assertEquals(1, logger.warning.call_count) - self.assertTrue(re.match( - '^GET http://localhost:9200/\?param=42 \[status:500 request:0.[0-9]{3}s\]', - logger.warning.call_args[0][0] % logger.warning.call_args[0][1:] - )) + self.assertTrue( + re.match( + "^GET http://localhost:9200/\?param=42 \[status:500 request:0.[0-9]{3}s\]", + logger.warning.call_args[0][0] % logger.warning.call_args[0][1:], + ) + ) - @patch('elasticsearch.connection.base.tracer') - @patch('elasticsearch.connection.base.logger') + @patch("elasticsearch.connection.base.tracer") + @patch("elasticsearch.connection.base.logger") def test_success_logs_and_traces(self, logger, tracer): - con = self._get_mock_connection(response_body='''{"answer": "that's it!"}''') - status, headers, data = con.perform_request('GET', '/', {'param': 42}, '''{"question": "what's that?"}'''.encode('utf-8')) + con = self._get_mock_connection(response_body="""{"answer": "that's it!"}""") + status, headers, data = con.perform_request( + "GET", + "/", + {"param": 42}, + """{"question": "what's that?"}""".encode("utf-8"), + ) # trace request self.assertEquals(1, tracer.info.call_count) self.assertEquals( """curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/?pretty¶m=42' -d '{\n "question": "what\\u0027s that?"\n}'""", - tracer.info.call_args[0][0] % tracer.info.call_args[0][1:] + tracer.info.call_args[0][0] % tracer.info.call_args[0][1:], ) # trace response self.assertEquals(1, tracer.debug.call_count) - self.assertTrue(re.match( - '#\[200\] \(0.[0-9]{3}s\)\n#\{\n# "answer": "that\\\\u0027s it!"\n#\}', - tracer.debug.call_args[0][0] % tracer.debug.call_args[0][1:] - )) + self.assertTrue( + re.match( + '#\[200\] \(0.[0-9]{3}s\)\n#\{\n# "answer": "that\\\\u0027s it!"\n#\}', + tracer.debug.call_args[0][0] % tracer.debug.call_args[0][1:], + ) + ) # log url and duration self.assertEquals(1, logger.info.call_count) - self.assertTrue(re.match( - 'GET http://localhost:9200/\?param=42 \[status:200 request:0.[0-9]{3}s\]', - logger.info.call_args[0][0] % logger.info.call_args[0][1:] - )) + self.assertTrue( + re.match( + "GET http://localhost:9200/\?param=42 \[status:200 request:0.[0-9]{3}s\]", + logger.info.call_args[0][0] % logger.info.call_args[0][1:], + ) + ) # log request body and response self.assertEquals(2, logger.debug.call_count) req, resp = logger.debug.call_args_list - self.assertEquals( - '> {"question": "what\'s that?"}', - req[0][0] % req[0][1:] - ) - self.assertEquals( - '< {"answer": "that\'s it!"}', - resp[0][0] % resp[0][1:] - ) + self.assertEquals('> {"question": "what\'s that?"}', req[0][0] % req[0][1:]) + self.assertEquals('< {"answer": "that\'s it!"}', resp[0][0] % resp[0][1:]) def test_defaults(self): con = self._get_mock_connection() - request = self._get_request(con, 'GET', '/') + request = self._get_request(con, "GET", "/") - self.assertEquals('http://localhost:9200/', request.url) - self.assertEquals('GET', request.method) + self.assertEquals("http://localhost:9200/", request.url) + self.assertEquals("GET", request.method) self.assertEquals(None, request.body) def test_params_properly_encoded(self): con = self._get_mock_connection() - request = self._get_request(con, 'GET', '/', params={'param': 'value with spaces'}) + request = self._get_request( + con, "GET", "/", params={"param": "value with spaces"} + ) - self.assertEquals('http://localhost:9200/?param=value+with+spaces', request.url) - self.assertEquals('GET', request.method) + self.assertEquals("http://localhost:9200/?param=value+with+spaces", request.url) + self.assertEquals("GET", request.method) self.assertEquals(None, request.body) def test_body_attached(self): con = self._get_mock_connection() - request = self._get_request(con, 'GET', '/', body='{"answer": 42}') + request = self._get_request(con, "GET", "/", body='{"answer": 42}') - self.assertEquals('http://localhost:9200/', request.url) - self.assertEquals('GET', request.method) - self.assertEquals('{"answer": 42}'.encode('utf-8'), request.body) + self.assertEquals("http://localhost:9200/", request.url) + self.assertEquals("GET", request.method) + self.assertEquals('{"answer": 42}'.encode("utf-8"), request.body) def test_http_auth_attached(self): - con = self._get_mock_connection({'http_auth': 'username:secret'}) - request = self._get_request(con, 'GET', '/') + con = self._get_mock_connection({"http_auth": "username:secret"}) + request = self._get_request(con, "GET", "/") - self.assertEquals(request.headers['authorization'], 'Basic dXNlcm5hbWU6c2VjcmV0') + self.assertEquals( + request.headers["authorization"], "Basic dXNlcm5hbWU6c2VjcmV0" + ) - @patch('elasticsearch.connection.base.tracer') + @patch("elasticsearch.connection.base.tracer") def test_url_prefix(self, tracer): con = self._get_mock_connection({"url_prefix": "/some-prefix/"}) - request = self._get_request(con, 'GET', '/_search', body='{"answer": 42}', timeout=0.1) + request = self._get_request( + con, "GET", "/_search", body='{"answer": 42}', timeout=0.1 + ) - self.assertEquals('http://localhost:9200/some-prefix/_search', request.url) - self.assertEquals('GET', request.method) - self.assertEquals('{"answer": 42}'.encode('utf-8'), request.body) + self.assertEquals("http://localhost:9200/some-prefix/_search", request.url) + self.assertEquals("GET", request.method) + self.assertEquals('{"answer": 42}'.encode("utf-8"), request.body) # trace request self.assertEquals(1, tracer.info.call_count) self.assertEquals( "curl -H 'Content-Type: application/json' -XGET 'http://localhost:9200/_search?pretty' -d '{\n \"answer\": 42\n}'", - tracer.info.call_args[0][0] % tracer.info.call_args[0][1:] + tracer.info.call_args[0][0] % tracer.info.call_args[0][1:], ) diff --git a/test_elasticsearch/test_connection_pool.py b/test_elasticsearch/test_connection_pool.py index fdbe16fdd..923497f13 100644 --- a/test_elasticsearch/test_connection_pool.py +++ b/test_elasticsearch/test_connection_pool.py @@ -1,14 +1,21 @@ import time -from elasticsearch.connection_pool import ConnectionPool, RoundRobinSelector, DummyConnectionPool +from elasticsearch.connection_pool import ( + ConnectionPool, + RoundRobinSelector, + DummyConnectionPool, +) from elasticsearch.exceptions import ImproperlyConfigured from .test_cases import TestCase + class TestConnectionPool(TestCase): def test_dummy_cp_raises_exception_on_more_connections(self): self.assertRaises(ImproperlyConfigured, DummyConnectionPool, []) - self.assertRaises(ImproperlyConfigured, DummyConnectionPool, [object(), object()]) + self.assertRaises( + ImproperlyConfigured, DummyConnectionPool, [object(), object()] + ) def test_raises_exception_when_no_connections_defined(self): self.assertRaises(ImproperlyConfigured, ConnectionPool, []) @@ -32,13 +39,20 @@ def test_disable_shuffling(self): def test_selectors_have_access_to_connection_opts(self): class MySelector(RoundRobinSelector): def select(self, connections): - return self.connection_opts[super(MySelector, self).select(connections)]["actual"] - pool = ConnectionPool([(x, {"actual": x*x}) for x in range(100)], selector_class=MySelector, randomize_hosts=False) + return self.connection_opts[ + super(MySelector, self).select(connections) + ]["actual"] + + pool = ConnectionPool( + [(x, {"actual": x * x}) for x in range(100)], + selector_class=MySelector, + randomize_hosts=False, + ) connections = [] for _ in range(100): connections.append(pool.get_connection()) - self.assertEquals(connections, [x*x for x in range(100)]) + self.assertEquals(connections, [x * x for x in range(100)]) def test_dead_nodes_are_removed_from_active_connections(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -53,23 +67,26 @@ def test_connection_is_skipped_when_dead(self): pool = ConnectionPool([(x, {}) for x in range(2)]) pool.mark_dead(0) - self.assertEquals([1, 1, 1], [pool.get_connection(), pool.get_connection(), pool.get_connection(), ]) + self.assertEquals( + [1, 1, 1], + [pool.get_connection(), pool.get_connection(), pool.get_connection()], + ) def test_connection_is_forcibly_resurrected_when_no_live_ones_are_availible(self): pool = ConnectionPool([(x, {}) for x in range(2)]) pool.dead_count[0] = 1 - pool.mark_dead(0) # failed twice, longer timeout - pool.mark_dead(1) # failed the first time, first to be resurrected + pool.mark_dead(0) # failed twice, longer timeout + pool.mark_dead(1) # failed the first time, first to be resurrected self.assertEquals([], pool.connections) self.assertEquals(1, pool.get_connection()) - self.assertEquals([1,], pool.connections) + self.assertEquals([1], pool.connections) def test_connection_is_resurrected_after_its_timeout(self): pool = ConnectionPool([(x, {}) for x in range(100)]) now = time.time() - pool.mark_dead(42, now=now-61) + pool.mark_dead(42, now=now - 61) pool.get_connection() self.assertEquals(42, pool.connections[-1]) self.assertEquals(100, len(pool.connections)) @@ -89,7 +106,7 @@ def test_already_failed_connection_has_longer_timeout(self): pool.mark_dead(42, now=now) self.assertEquals(3, pool.dead_count[42]) - self.assertEquals((now + 4*60, 42), pool.dead.get()) + self.assertEquals((now + 4 * 60, 42), pool.dead.get()) def test_timeout_for_failed_connections_is_limitted(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -98,7 +115,7 @@ def test_timeout_for_failed_connections_is_limitted(self): pool.mark_dead(42, now=now) self.assertEquals(246, pool.dead_count[42]) - self.assertEquals((now + 32*60, 42), pool.dead.get()) + self.assertEquals((now + 32 * 60, 42), pool.dead.get()) def test_dead_count_is_wiped_clean_for_connection_if_marked_live(self): pool = ConnectionPool([(x, {}) for x in range(100)]) @@ -109,4 +126,3 @@ def test_dead_count_is_wiped_clean_for_connection_if_marked_live(self): self.assertEquals(3, pool.dead_count[42]) pool.mark_live(42) self.assertNotIn(42, pool.dead_count) - diff --git a/test_elasticsearch/test_exceptions.py b/test_elasticsearch/test_exceptions.py index 985518de9..0bb0fe8b2 100644 --- a/test_elasticsearch/test_exceptions.py +++ b/test_elasticsearch/test_exceptions.py @@ -5,19 +5,22 @@ class TestTransformError(TestCase): def test_transform_error_parse_with_error_reason(self): - e = TransportError(500, 'InternalServerError', { - 'error': { - 'root_cause': [ - {"type": "error", "reason": "error reason"} - ] - } - }) + e = TransportError( + 500, + "InternalServerError", + {"error": {"root_cause": [{"type": "error", "reason": "error reason"}]}}, + ) - self.assertEqual(str(e), "TransportError(500, 'InternalServerError', 'error reason')") + self.assertEqual( + str(e), "TransportError(500, 'InternalServerError', 'error reason')" + ) def test_transform_error_parse_with_error_string(self): - e = TransportError(500, 'InternalServerError', { - 'error': 'something error message' - }) + e = TransportError( + 500, "InternalServerError", {"error": "something error message"} + ) - self.assertEqual(str(e), "TransportError(500, 'InternalServerError', 'something error message')") + self.assertEqual( + str(e), + "TransportError(500, 'InternalServerError', 'something error message')", + ) diff --git a/test_elasticsearch/test_serializer.py b/test_elasticsearch/test_serializer.py index 3c7925e51..0309c5f1d 100644 --- a/test_elasticsearch/test_serializer.py +++ b/test_elasticsearch/test_serializer.py @@ -5,30 +5,44 @@ from datetime import datetime from decimal import Decimal -from elasticsearch.serializer import JSONSerializer, Deserializer, DEFAULT_SERIALIZERS, TextSerializer +from elasticsearch.serializer import ( + JSONSerializer, + Deserializer, + DEFAULT_SERIALIZERS, + TextSerializer, +) from elasticsearch.exceptions import SerializationError, ImproperlyConfigured from .test_cases import TestCase, SkipTest + class TestJSONSerializer(TestCase): def test_datetime_serialization(self): - self.assertEquals('{"d":"2010-10-01T02:30:00"}', JSONSerializer().dumps({'d': datetime(2010, 10, 1, 2, 30)})) + self.assertEquals( + '{"d":"2010-10-01T02:30:00"}', + JSONSerializer().dumps({"d": datetime(2010, 10, 1, 2, 30)}), + ) def test_decimal_serialization(self): if sys.version_info[:2] == (2, 6): raise SkipTest("Float rounding is broken in 2.6.") - self.assertEquals('{"d":3.8}', JSONSerializer().dumps({'d': Decimal('3.8')})) + self.assertEquals('{"d":3.8}', JSONSerializer().dumps({"d": Decimal("3.8")})) def test_uuid_serialization(self): - self.assertEquals('{"d":"00000000-0000-0000-0000-000000000003"}', JSONSerializer().dumps({'d': uuid.UUID('00000000-0000-0000-0000-000000000003')})) + self.assertEquals( + '{"d":"00000000-0000-0000-0000-000000000003"}', + JSONSerializer().dumps( + {"d": uuid.UUID("00000000-0000-0000-0000-000000000003")} + ), + ) def test_raises_serialization_error_on_dump_error(self): self.assertRaises(SerializationError, JSONSerializer().dumps, object()) def test_raises_serialization_error_on_load_error(self): self.assertRaises(SerializationError, JSONSerializer().loads, object()) - self.assertRaises(SerializationError, JSONSerializer().loads, '') - self.assertRaises(SerializationError, JSONSerializer().loads, '{{') + self.assertRaises(SerializationError, JSONSerializer().loads, "") + self.assertRaises(SerializationError, JSONSerializer().loads, "{{") def test_strings_are_left_untouched(self): self.assertEquals("你好", JSONSerializer().dumps("你好")) @@ -51,11 +65,18 @@ def test_deserializes_json_by_default(self): self.assertEquals({"some": "data"}, self.de.loads('{"some":"data"}')) def test_deserializes_text_with_correct_ct(self): - self.assertEquals('{"some":"data"}', self.de.loads('{"some":"data"}', 'text/plain')) - self.assertEquals('{"some":"data"}', self.de.loads('{"some":"data"}', 'text/plain; charset=whatever')) + self.assertEquals( + '{"some":"data"}', self.de.loads('{"some":"data"}', "text/plain") + ) + self.assertEquals( + '{"some":"data"}', + self.de.loads('{"some":"data"}', "text/plain; charset=whatever"), + ) def test_raises_serialization_error_on_unknown_mimetype(self): - self.assertRaises(SerializationError, self.de.loads, '{}', 'text/html') + self.assertRaises(SerializationError, self.de.loads, "{}", "text/html") - def test_raises_improperly_configured_when_default_mimetype_cannot_be_deserialized(self): + def test_raises_improperly_configured_when_default_mimetype_cannot_be_deserialized( + self + ): self.assertRaises(ImproperlyConfigured, Deserializer, {}) diff --git a/test_elasticsearch/test_server/__init__.py b/test_elasticsearch/test_server/__init__.py index 9fb387895..53adbe0fb 100644 --- a/test_elasticsearch/test_server/__init__.py +++ b/test_elasticsearch/test_server/__init__.py @@ -1,7 +1,11 @@ -from elasticsearch.helpers.test import get_test_client, ElasticsearchTestCase as BaseTestCase +from elasticsearch.helpers.test import ( + get_test_client, + ElasticsearchTestCase as BaseTestCase, +) client = None + def get_client(**kwargs): global client if client is not None and not kwargs: @@ -10,6 +14,7 @@ def get_client(**kwargs): # try and locate manual override in the local environment try: from test_elasticsearch.local import get_client as local_get_client + new_client = local_get_client(**kwargs) except ImportError: # fallback to using vanilla client @@ -24,6 +29,7 @@ def get_client(**kwargs): def setup(): get_client() + class ElasticsearchTestCase(BaseTestCase): @staticmethod def _get_client(**kwargs): diff --git a/test_elasticsearch/test_server/test_client.py b/test_elasticsearch/test_server/test_client.py index 70393fd66..ed79cb4ee 100644 --- a/test_elasticsearch/test_server/test_client.py +++ b/test_elasticsearch/test_server/test_client.py @@ -3,6 +3,7 @@ from . import ElasticsearchTestCase + class TestUnicode(ElasticsearchTestCase): def test_indices_analyze(self): self.client.indices.analyze(body='{"text": "привет"}') diff --git a/test_elasticsearch/test_server/test_helpers.py b/test_elasticsearch/test_server/test_helpers.py index 81fc4f220..bb679e9e3 100644 --- a/test_elasticsearch/test_server/test_helpers.py +++ b/test_elasticsearch/test_server/test_helpers.py @@ -310,20 +310,20 @@ def test_errors_are_collected_properly(self): class TestScan(ElasticsearchTestCase): mock_scroll_responses = [ { - '_scroll_id': 'dummy_id', - '_shards': {'successful': 4, 'total': 5}, - 'hits': {'hits': [{'scroll_data': 42}]}, + "_scroll_id": "dummy_id", + "_shards": {"successful": 4, "total": 5}, + "hits": {"hits": [{"scroll_data": 42}]}, }, { - '_scroll_id': 'dummy_id', - '_shards': {'successful': 4, 'total': 5}, - 'hits': {'hits': []}, + "_scroll_id": "dummy_id", + "_shards": {"successful": 4, "total": 5}, + "hits": {"hits": []}, }, ] @classmethod def tearDownClass(cls): - cls.client.transport.perform_request('DELETE', '/_search/scroll/_all') + cls.client.transport.perform_request("DELETE", "/_search/scroll/_all") super(TestScan, cls).tearDownClass() def test_order_can_be_preserved(self): @@ -366,87 +366,101 @@ def test_scroll_error(self): bulk.append({"value": x}) self.client.bulk(bulk, refresh=True) - with patch.object(self.client, 'scroll') as scroll_mock: + with patch.object(self.client, "scroll") as scroll_mock: scroll_mock.side_effect = self.mock_scroll_responses - data = list(helpers.scan( - self.client, - index='test_index', - size=2, - raise_on_error=False, - clear_scroll=False - )) + data = list( + helpers.scan( + self.client, + index="test_index", + size=2, + raise_on_error=False, + clear_scroll=False, + ) + ) self.assertEqual(len(data), 3) - self.assertEqual(data[-1], {'scroll_data': 42}) + self.assertEqual(data[-1], {"scroll_data": 42}) scroll_mock.side_effect = self.mock_scroll_responses with self.assertRaises(ScanError): - data = list(helpers.scan( - self.client, - index='test_index', - size=2, - raise_on_error=True, - clear_scroll=False - )) + data = list( + helpers.scan( + self.client, + index="test_index", + size=2, + raise_on_error=True, + clear_scroll=False, + ) + ) self.assertEqual(len(data), 3) - self.assertEqual(data[-1], {'scroll_data': 42}) + self.assertEqual(data[-1], {"scroll_data": 42}) def test_initial_search_error(self): - with patch.object(self, 'client') as client_mock: + with patch.object(self, "client") as client_mock: client_mock.search.return_value = { - '_scroll_id': 'dummy_id', - '_shards': {'successful': 4, 'total': 5}, - 'hits': {'hits': [{'search_data': 1}]}, + "_scroll_id": "dummy_id", + "_shards": {"successful": 4, "total": 5}, + "hits": {"hits": [{"search_data": 1}]}, } client_mock.scroll.side_effect = self.mock_scroll_responses - data = list(helpers.scan(self.client, index='test_index', size=2, raise_on_error=False)) - self.assertEqual(data, [{'search_data': 1}, {'scroll_data': 42}]) + data = list( + helpers.scan( + self.client, index="test_index", size=2, raise_on_error=False + ) + ) + self.assertEqual(data, [{"search_data": 1}, {"scroll_data": 42}]) client_mock.scroll.side_effect = self.mock_scroll_responses with self.assertRaises(ScanError): data = list( - helpers.scan(self.client, index='test_index', size=2, raise_on_error=True) + helpers.scan( + self.client, index="test_index", size=2, raise_on_error=True + ) ) - self.assertEqual(data, [{'search_data': 1}]) + self.assertEqual(data, [{"search_data": 1}]) client_mock.scroll.assert_not_called() def test_no_scroll_id_fast_route(self): - with patch.object(self, 'client') as client_mock: - client_mock.search.return_value = {'no': '_scroll_id'} - data = list(helpers.scan(self.client, index='test_index')) + with patch.object(self, "client") as client_mock: + client_mock.search.return_value = {"no": "_scroll_id"} + data = list(helpers.scan(self.client, index="test_index")) self.assertEqual(data, []) client_mock.scroll.assert_not_called() client_mock.clear_scroll.assert_not_called() - @patch('elasticsearch.helpers.actions.logger') + @patch("elasticsearch.helpers.actions.logger") def test_logger(self, logger_mock): bulk = [] for x in range(4): - bulk.append({'index': {'_index': 'test_index', '_type': '_doc'}}) - bulk.append({'value': x}) + bulk.append({"index": {"_index": "test_index", "_type": "_doc"}}) + bulk.append({"value": x}) self.client.bulk(bulk, refresh=True) - with patch.object(self.client, 'scroll') as scroll_mock: + with patch.object(self.client, "scroll") as scroll_mock: scroll_mock.side_effect = self.mock_scroll_responses - list(helpers.scan( - self.client, - index='test_index', - size=2, - raise_on_error=False, - clear_scroll=False - )) + list( + helpers.scan( + self.client, + index="test_index", + size=2, + raise_on_error=False, + clear_scroll=False, + ) + ) logger_mock.warning.assert_called() scroll_mock.side_effect = self.mock_scroll_responses try: - list(helpers.scan( - self.client, - index='test_index', - size=2, - raise_on_error=True, - clear_scroll=False - )) + list( + helpers.scan( + self.client, + index="test_index", + size=2, + raise_on_error=True, + clear_scroll=False, + ) + ) except ScanError: pass logger_mock.warning.assert_called() @@ -454,20 +468,28 @@ def test_logger(self, logger_mock): def test_clear_scroll(self): bulk = [] for x in range(4): - bulk.append({'index': {'_index': 'test_index', '_type': '_doc'}}) - bulk.append({'value': x}) + bulk.append({"index": {"_index": "test_index", "_type": "_doc"}}) + bulk.append({"value": x}) self.client.bulk(bulk, refresh=True) - with patch.object(self.client, 'clear_scroll', wraps=self.client.clear_scroll) as spy: - list(helpers.scan(self.client, index='test_index', size=2)) + with patch.object( + self.client, "clear_scroll", wraps=self.client.clear_scroll + ) as spy: + list(helpers.scan(self.client, index="test_index", size=2)) spy.assert_called_once() spy.reset_mock() - list(helpers.scan(self.client, index='test_index', size=2, clear_scroll=True)) + list( + helpers.scan(self.client, index="test_index", size=2, clear_scroll=True) + ) spy.assert_called_once() spy.reset_mock() - list(helpers.scan(self.client, index='test_index', size=2, clear_scroll=False)) + list( + helpers.scan( + self.client, index="test_index", size=2, clear_scroll=False + ) + ) spy.assert_not_called() diff --git a/test_elasticsearch/test_transport.py b/test_elasticsearch/test_transport.py index 50acb7fae..c2b7a10a8 100644 --- a/test_elasticsearch/test_transport.py +++ b/test_elasticsearch/test_transport.py @@ -9,11 +9,12 @@ from .test_cases import TestCase + class DummyConnection(Connection): def __init__(self, **kwargs): - self.exception = kwargs.pop('exception', None) - self.status, self.data = kwargs.pop('status', 200), kwargs.pop('data', '{}') - self.headers = kwargs.pop('headers', {}) + self.exception = kwargs.pop("exception", None) + self.status, self.data = kwargs.pop("status", 200), kwargs.pop("data", "{}") + self.headers = kwargs.pop("headers", {}) self.calls = [] super(DummyConnection, self).__init__(**kwargs) @@ -23,7 +24,8 @@ def perform_request(self, *args, **kwargs): raise self.exception return self.status, self.headers, self.data -CLUSTER_NODES = '''{ + +CLUSTER_NODES = """{ "_nodes" : { "total" : 1, "successful" : 1, @@ -46,18 +48,23 @@ def perform_request(self, *args, **kwargs): } } } -}''' +}""" + class TestHostsInfoCallback(TestCase): def test_master_only_nodes_are_ignored(self): nodes = [ - {'roles': [ "master"]}, - {'roles': [ "master", "data", "ingest"]}, - {'roles': [ "data", "ingest"]}, - {'roles': [ ]}, - {} + {"roles": ["master"]}, + {"roles": ["master", "data", "ingest"]}, + {"roles": ["data", "ingest"]}, + {"roles": []}, + {}, + ] + chosen = [ + i + for i, node_info in enumerate(nodes) + if get_host_info(node_info, i) is not None ] - chosen = [i for i, node_info in enumerate(nodes) if get_host_info(node_info, i) is not None] self.assertEquals([1, 2, 3, 4], chosen) @@ -65,57 +72,70 @@ class TestTransport(TestCase): def test_single_connection_uses_dummy_connection_pool(self): t = Transport([{}]) self.assertIsInstance(t.connection_pool, DummyConnectionPool) - t = Transport([{'host': 'localhost'}]) + t = Transport([{"host": "localhost"}]) self.assertIsInstance(t.connection_pool, DummyConnectionPool) def test_request_timeout_extracted_from_params_and_passed(self): t = Transport([{}], connection_class=DummyConnection) - t.perform_request('GET', '/', params={'request_timeout': 42}) + t.perform_request("GET", "/", params={"request_timeout": 42}) self.assertEquals(1, len(t.get_connection().calls)) - self.assertEquals(('GET', '/', {}, None), t.get_connection().calls[0][0]) - self.assertEquals({'timeout': 42, 'ignore': (), 'headers': None}, t.get_connection().calls[0][1]) + self.assertEquals(("GET", "/", {}, None), t.get_connection().calls[0][0]) + self.assertEquals( + {"timeout": 42, "ignore": (), "headers": None}, + t.get_connection().calls[0][1], + ) def test_send_get_body_as_source(self): - t = Transport([{}], send_get_body_as='source', connection_class=DummyConnection) + t = Transport([{}], send_get_body_as="source", connection_class=DummyConnection) - t.perform_request('GET', '/', body={}) + t.perform_request("GET", "/", body={}) self.assertEquals(1, len(t.get_connection().calls)) - self.assertEquals(('GET', '/', {'source': '{}'}, None), t.get_connection().calls[0][0]) + self.assertEquals( + ("GET", "/", {"source": "{}"}, None), t.get_connection().calls[0][0] + ) def test_send_get_body_as_post(self): - t = Transport([{}], send_get_body_as='POST', connection_class=DummyConnection) + t = Transport([{}], send_get_body_as="POST", connection_class=DummyConnection) - t.perform_request('GET', '/', body={}) + t.perform_request("GET", "/", body={}) self.assertEquals(1, len(t.get_connection().calls)) - self.assertEquals(('POST', '/', None, b'{}'), t.get_connection().calls[0][0]) + self.assertEquals(("POST", "/", None, b"{}"), t.get_connection().calls[0][0]) def test_body_gets_encoded_into_bytes(self): t = Transport([{}], connection_class=DummyConnection) - t.perform_request('GET', '/', body='你好') + t.perform_request("GET", "/", body="你好") self.assertEquals(1, len(t.get_connection().calls)) - self.assertEquals(('GET', '/', None, b'\xe4\xbd\xa0\xe5\xa5\xbd'), t.get_connection().calls[0][0]) + self.assertEquals( + ("GET", "/", None, b"\xe4\xbd\xa0\xe5\xa5\xbd"), + t.get_connection().calls[0][0], + ) def test_body_bytes_get_passed_untouched(self): t = Transport([{}], connection_class=DummyConnection) - body = b'\xe4\xbd\xa0\xe5\xa5\xbd' - t.perform_request('GET', '/', body=body) + body = b"\xe4\xbd\xa0\xe5\xa5\xbd" + t.perform_request("GET", "/", body=body) self.assertEquals(1, len(t.get_connection().calls)) - self.assertEquals(('GET', '/', None, body), t.get_connection().calls[0][0]) + self.assertEquals(("GET", "/", None, body), t.get_connection().calls[0][0]) def test_body_surrogates_replaced_encoded_into_bytes(self): t = Transport([{}], connection_class=DummyConnection) - t.perform_request('GET', '/', body='你好\uda6a') + t.perform_request("GET", "/", body="你好\uda6a") self.assertEquals(1, len(t.get_connection().calls)) - self.assertEquals(('GET', '/', None, b'\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa'), t.get_connection().calls[0][0]) - + self.assertEquals( + ("GET", "/", None, b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"), + t.get_connection().calls[0][0], + ) + def test_kwargs_passed_on_to_connections(self): - t = Transport([{'host': 'google.com'}], port=123) + t = Transport([{"host": "google.com"}], port=123) self.assertEquals(1, len(t.connection_pool.connections)) - self.assertEquals('http://google.com:123', t.connection_pool.connections[0].host) + self.assertEquals( + "http://google.com:123", t.connection_pool.connections[0].host + ) def test_kwargs_passed_on_to_connection_pool(self): dt = object() @@ -126,6 +146,7 @@ def test_custom_connection_class(self): class MyConnection(object): def __init__(self, **kwargs): self.kwargs = kwargs + t = Transport([{}], connection_class=MyConnection) self.assertEquals(1, len(t.connection_pool.connections)) self.assertIsInstance(t.connection_pool.connections[0], MyConnection) @@ -135,18 +156,26 @@ def test_add_connection(self): t.add_connection({"host": "google.com", "port": 1234}) self.assertEquals(2, len(t.connection_pool.connections)) - self.assertEquals('http://google.com:1234', t.connection_pool.connections[1].host) + self.assertEquals( + "http://google.com:1234", t.connection_pool.connections[1].host + ) def test_request_will_fail_after_X_retries(self): - t = Transport([{'exception': ConnectionError('abandon ship')}], connection_class=DummyConnection) + t = Transport( + [{"exception": ConnectionError("abandon ship")}], + connection_class=DummyConnection, + ) - self.assertRaises(ConnectionError, t.perform_request, 'GET', '/') + self.assertRaises(ConnectionError, t.perform_request, "GET", "/") self.assertEquals(4, len(t.get_connection().calls)) def test_failed_connection_will_be_marked_as_dead(self): - t = Transport([{'exception': ConnectionError('abandon ship')}] * 2, connection_class=DummyConnection) + t = Transport( + [{"exception": ConnectionError("abandon ship")}] * 2, + connection_class=DummyConnection, + ) - self.assertRaises(ConnectionError, t.perform_request, 'GET', '/') + self.assertRaises(ConnectionError, t.perform_request, "GET", "/") self.assertEquals(0, len(t.connection_pool.connections)) def test_resurrected_connection_will_be_marked_as_live_on_success(self): @@ -156,35 +185,57 @@ def test_resurrected_connection_will_be_marked_as_live_on_success(self): t.connection_pool.mark_dead(con1) t.connection_pool.mark_dead(con2) - t.perform_request('GET', '/') + t.perform_request("GET", "/") self.assertEquals(1, len(t.connection_pool.connections)) self.assertEquals(1, len(t.connection_pool.dead_count)) def test_sniff_will_use_seed_connections(self): - t = Transport([{'data': CLUSTER_NODES}], connection_class=DummyConnection) - t.set_connections([{'data': 'invalid'}]) + t = Transport([{"data": CLUSTER_NODES}], connection_class=DummyConnection) + t.set_connections([{"data": "invalid"}]) t.sniff_hosts() self.assertEquals(1, len(t.connection_pool.connections)) - self.assertEquals('http://1.1.1.1:123', t.get_connection().host) + self.assertEquals("http://1.1.1.1:123", t.get_connection().host) def test_sniff_on_start_fetches_and_uses_nodes_list(self): - t = Transport([{'data': CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_start=True) + t = Transport( + [{"data": CLUSTER_NODES}], + connection_class=DummyConnection, + sniff_on_start=True, + ) self.assertEquals(1, len(t.connection_pool.connections)) - self.assertEquals('http://1.1.1.1:123', t.get_connection().host) + self.assertEquals("http://1.1.1.1:123", t.get_connection().host) def test_sniff_on_start_ignores_sniff_timeout(self): - t = Transport([{'data': CLUSTER_NODES}], connection_class=DummyConnection, sniff_on_start=True, sniff_timeout=12) - self.assertEquals((('GET', '/_nodes/_all/http'), {'timeout': None}), t.seed_connections[0].calls[0]) + t = Transport( + [{"data": CLUSTER_NODES}], + connection_class=DummyConnection, + sniff_on_start=True, + sniff_timeout=12, + ) + self.assertEquals( + (("GET", "/_nodes/_all/http"), {"timeout": None}), + t.seed_connections[0].calls[0], + ) def test_sniff_uses_sniff_timeout(self): - t = Transport([{'data': CLUSTER_NODES}], connection_class=DummyConnection, sniff_timeout=42) + t = Transport( + [{"data": CLUSTER_NODES}], + connection_class=DummyConnection, + sniff_timeout=42, + ) t.sniff_hosts() - self.assertEquals((('GET', '/_nodes/_all/http'), {'timeout': 42}), t.seed_connections[0].calls[0]) - + self.assertEquals( + (("GET", "/_nodes/_all/http"), {"timeout": 42}), + t.seed_connections[0].calls[0], + ) def test_sniff_reuses_connection_instances_if_possible(self): - t = Transport([{'data': CLUSTER_NODES}, {"host": "1.1.1.1", "port": 123}], connection_class=DummyConnection, randomize_hosts=False) + t = Transport( + [{"data": CLUSTER_NODES}, {"host": "1.1.1.1", "port": 123}], + connection_class=DummyConnection, + randomize_hosts=False, + ) connection = t.connection_pool.connections[1] t.sniff_hosts() @@ -192,25 +243,32 @@ def test_sniff_reuses_connection_instances_if_possible(self): self.assertIs(connection, t.get_connection()) def test_sniff_on_fail_triggers_sniffing_on_fail(self): - t = Transport([{'exception': ConnectionError('abandon ship')}, {"data": CLUSTER_NODES}], - connection_class=DummyConnection, sniff_on_connection_fail=True, max_retries=0, randomize_hosts=False) - - self.assertRaises(ConnectionError, t.perform_request, 'GET', '/') + t = Transport( + [{"exception": ConnectionError("abandon ship")}, {"data": CLUSTER_NODES}], + connection_class=DummyConnection, + sniff_on_connection_fail=True, + max_retries=0, + randomize_hosts=False, + ) + + self.assertRaises(ConnectionError, t.perform_request, "GET", "/") self.assertEquals(1, len(t.connection_pool.connections)) - self.assertEquals('http://1.1.1.1:123', t.get_connection().host) + self.assertEquals("http://1.1.1.1:123", t.get_connection().host) def test_sniff_after_n_seconds(self): - t = Transport([{"data": CLUSTER_NODES}], - connection_class=DummyConnection, sniffer_timeout=5) + t = Transport( + [{"data": CLUSTER_NODES}], + connection_class=DummyConnection, + sniffer_timeout=5, + ) for _ in range(4): - t.perform_request('GET', '/') + t.perform_request("GET", "/") self.assertEquals(1, len(t.connection_pool.connections)) self.assertIsInstance(t.get_connection(), DummyConnection) t.last_sniff = time.time() - 5.1 - t.perform_request('GET', '/') + t.perform_request("GET", "/") self.assertEquals(1, len(t.connection_pool.connections)) - self.assertEquals('http://1.1.1.1:123', t.get_connection().host) - self.assertTrue(time.time() - 1 < t.last_sniff < time.time() + 0.01 ) - + self.assertEquals("http://1.1.1.1:123", t.get_connection().host) + self.assertTrue(time.time() - 1 < t.last_sniff < time.time() + 0.01)