diff --git a/docs/reference/esql-query-builder.md b/docs/reference/esql-query-builder.md index 1cdc0c5b3..8390ea983 100644 --- a/docs/reference/esql-query-builder.md +++ b/docs/reference/esql-query-builder.md @@ -203,6 +203,26 @@ query = ( ) ``` +### Preventing injection attacks + +ES|QL, like most query languages, is vulnerable to [code injection attacks](https://en.wikipedia.org/wiki/Code_injection) if untrusted data provided by users is added to a query. To eliminate this risk, ES|QL allows untrusted data to be given separately from the query as parameters. + +Continuing with the example above, let's assume that the application needs a `find_employee_by_name()` function that searches for the name given as an argument. If this argument is received by the application from users, then it is considered untrusted and should not be added to the query directly. Here is how to code the function in a secure manner: + +```python +def find_employee_by_name(name): + query = ( + ESQL.from_("employees") + .keep("first_name", "last_name", "height") + .where(E("first_name") == E("?")) + ) + return client.esql.query(query=str(query), params=[name]) +``` + +Here the part of the query in which the untrusted data needs to be inserted is replaced with a parameter, which in ES|QL is defined by the question mark. When using Python expressions, the parameter must be given as `E("?")` so that it is treated as an expression and not as a literal string. + +The list of values given in the `params` argument to the query endpoint are assigned in order to the parameters defined in the query. + ## Using ES|QL functions The ES|QL language includes a rich set of functions that can be used in expressions and conditionals. These can be included in expressions given as strings, as shown in the example below: @@ -235,6 +255,6 @@ query = ( ) ``` -Note that arguments passed to functions are assumed to be literals. When passing field names, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly. +Note that arguments passed to functions are assumed to be literals. When passing field names, parameters or other ES|QL expressions, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly. You can find the complete list of available functions in the Python client's [ES|QL API reference documentation](https://elasticsearch-py.readthedocs.io/en/stable/esql.html#module-elasticsearch.esql.functions). diff --git a/elasticsearch/esql/__init__.py b/elasticsearch/esql/__init__.py index d872c329a..8da8f852a 100644 --- a/elasticsearch/esql/__init__.py +++ b/elasticsearch/esql/__init__.py @@ -15,4 +15,5 @@ # specific language governing permissions and limitations # under the License. +from ..dsl import E # noqa: F401 from .esql import ESQL, and_, not_, or_ # noqa: F401 diff --git a/elasticsearch/esql/esql.py b/elasticsearch/esql/esql.py index 07ccdf839..05f4e3e3e 100644 --- a/elasticsearch/esql/esql.py +++ b/elasticsearch/esql/esql.py @@ -16,6 +16,7 @@ # under the License. import json +import re from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Tuple, Type, Union @@ -111,6 +112,29 @@ def render(self) -> str: def _render_internal(self) -> str: pass + @staticmethod + def _format_index(index: IndexType) -> str: + return index._index._name if hasattr(index, "_index") else str(index) + + @staticmethod + def _format_id(id: FieldType, allow_patterns: bool = False) -> str: + s = str(id) # in case it is an InstrumentedField + if allow_patterns and "*" in s: + return s # patterns cannot be escaped + if re.fullmatch(r"[a-zA-Z_@][a-zA-Z0-9_\.]*", s): + return s + # this identifier needs to be escaped + s.replace("`", "``") + return f"`{s}`" + + @staticmethod + def _format_expr(expr: ExpressionType) -> str: + return ( + json.dumps(expr) + if not isinstance(expr, (str, InstrumentedExpression)) + else str(expr) + ) + def _is_forked(self) -> bool: if self.__class__.__name__ == "Fork": return True @@ -427,7 +451,7 @@ def sample(self, probability: float) -> "Sample": """ return Sample(self, probability) - def sort(self, *columns: FieldType) -> "Sort": + def sort(self, *columns: ExpressionType) -> "Sort": """The ``SORT`` processing command sorts a table on one or more columns. :param columns: The columns to sort on. @@ -570,15 +594,12 @@ def metadata(self, *fields: FieldType) -> "From": return self def _render_internal(self) -> str: - indices = [ - index if isinstance(index, str) else index._index._name - for index in self._indices - ] + indices = [self._format_index(index) for index in self._indices] s = f'{self.__class__.__name__.upper()} {", ".join(indices)}' if self._metadata_fields: s = ( s - + f' METADATA {", ".join([str(field) for field in self._metadata_fields])}' + + f' METADATA {", ".join([self._format_id(field) for field in self._metadata_fields])}' ) return s @@ -594,7 +615,11 @@ class Row(ESQLBase): def __init__(self, **params: ExpressionType): super().__init__() self._params = { - k: json.dumps(v) if not isinstance(v, InstrumentedExpression) else v + self._format_id(k): ( + json.dumps(v) + if not isinstance(v, InstrumentedExpression) + else self._format_expr(v) + ) for k, v in params.items() } @@ -615,7 +640,7 @@ def __init__(self, item: str): self._item = item def _render_internal(self) -> str: - return f"SHOW {self._item}" + return f"SHOW {self._format_id(self._item)}" class Branch(ESQLBase): @@ -667,11 +692,11 @@ def as_(self, type_name: str, pvalue_name: str) -> "ChangePoint": return self def _render_internal(self) -> str: - key = "" if not self._key else f" ON {self._key}" + key = "" if not self._key else f" ON {self._format_id(self._key)}" names = ( "" if not self._type_name and not self._pvalue_name - else f' AS {self._type_name or "type"}, {self._pvalue_name or "pvalue"}' + else f' AS {self._format_id(self._type_name or "type")}, {self._format_id(self._pvalue_name or "pvalue")}' ) return f"CHANGE_POINT {self._value}{key}{names}" @@ -709,12 +734,13 @@ def with_(self, inference_id: str) -> "Completion": def _render_internal(self) -> str: if self._inference_id is None: raise ValueError("The completion command requires an inference ID") + with_ = {"inference_id": self._inference_id} if self._named_prompt: column = list(self._named_prompt.keys())[0] prompt = list(self._named_prompt.values())[0] - return f"COMPLETION {column} = {prompt} WITH {self._inference_id}" + return f"COMPLETION {self._format_id(column)} = {self._format_id(prompt)} WITH {json.dumps(with_)}" else: - return f"COMPLETION {self._prompt[0]} WITH {self._inference_id}" + return f"COMPLETION {self._format_id(self._prompt[0])} WITH {json.dumps(with_)}" class Dissect(ESQLBase): @@ -742,9 +768,13 @@ def append_separator(self, separator: str) -> "Dissect": def _render_internal(self) -> str: sep = ( - "" if self._separator is None else f' APPEND_SEPARATOR="{self._separator}"' + "" + if self._separator is None + else f" APPEND_SEPARATOR={json.dumps(self._separator)}" + ) + return ( + f"DISSECT {self._format_id(self._input)} {json.dumps(self._pattern)}{sep}" ) - return f"DISSECT {self._input} {json.dumps(self._pattern)}{sep}" class Drop(ESQLBase): @@ -760,7 +790,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType): self._columns = columns def _render_internal(self) -> str: - return f'DROP {", ".join([str(col) for col in self._columns])}' + return f'DROP {", ".join([self._format_id(col, allow_patterns=True) for col in self._columns])}' class Enrich(ESQLBase): @@ -814,12 +844,18 @@ def with_(self, *fields: FieldType, **named_fields: FieldType) -> "Enrich": return self def _render_internal(self) -> str: - on = "" if self._match_field is None else f" ON {self._match_field}" + on = ( + "" + if self._match_field is None + else f" ON {self._format_id(self._match_field)}" + ) with_ = "" if self._named_fields: - with_ = f' WITH {", ".join([f"{name} = {field}" for name, field in self._named_fields.items()])}' + with_ = f' WITH {", ".join([f"{self._format_id(name)} = {self._format_id(field)}" for name, field in self._named_fields.items()])}' elif self._fields is not None: - with_ = f' WITH {", ".join([str(field) for field in self._fields])}' + with_ = ( + f' WITH {", ".join([self._format_id(field) for field in self._fields])}' + ) return f"ENRICH {self._policy}{on}{with_}" @@ -832,7 +868,10 @@ class Eval(ESQLBase): """ def __init__( - self, parent: ESQLBase, *columns: FieldType, **named_columns: FieldType + self, + parent: ESQLBase, + *columns: ExpressionType, + **named_columns: ExpressionType, ): if columns and named_columns: raise ValueError( @@ -844,10 +883,13 @@ def __init__( def _render_internal(self) -> str: if isinstance(self._columns, dict): cols = ", ".join( - [f"{name} = {value}" for name, value in self._columns.items()] + [ + f"{self._format_id(name)} = {self._format_expr(value)}" + for name, value in self._columns.items() + ] ) else: - cols = ", ".join([f"{col}" for col in self._columns]) + cols = ", ".join([f"{self._format_expr(col)}" for col in self._columns]) return f"EVAL {cols}" @@ -900,7 +942,7 @@ def __init__(self, parent: ESQLBase, input: FieldType, pattern: str): self._pattern = pattern def _render_internal(self) -> str: - return f"GROK {self._input} {json.dumps(self._pattern)}" + return f"GROK {self._format_id(self._input)} {json.dumps(self._pattern)}" class Keep(ESQLBase): @@ -916,7 +958,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType): self._columns = columns def _render_internal(self) -> str: - return f'KEEP {", ".join([f"{col}" for col in self._columns])}' + return f'KEEP {", ".join([f"{self._format_id(col, allow_patterns=True)}" for col in self._columns])}' class Limit(ESQLBase): @@ -932,7 +974,7 @@ def __init__(self, parent: ESQLBase, max_number_of_rows: int): self._max_number_of_rows = max_number_of_rows def _render_internal(self) -> str: - return f"LIMIT {self._max_number_of_rows}" + return f"LIMIT {json.dumps(self._max_number_of_rows)}" class LookupJoin(ESQLBase): @@ -967,7 +1009,9 @@ def _render_internal(self) -> str: if isinstance(self._lookup_index, str) else self._lookup_index._index._name ) - return f"LOOKUP JOIN {index} ON {self._field}" + return ( + f"LOOKUP JOIN {self._format_index(index)} ON {self._format_id(self._field)}" + ) class MvExpand(ESQLBase): @@ -983,7 +1027,7 @@ def __init__(self, parent: ESQLBase, column: FieldType): self._column = column def _render_internal(self) -> str: - return f"MV_EXPAND {self._column}" + return f"MV_EXPAND {self._format_id(self._column)}" class Rename(ESQLBase): @@ -999,7 +1043,7 @@ def __init__(self, parent: ESQLBase, **columns: FieldType): self._columns = columns def _render_internal(self) -> str: - return f'RENAME {", ".join([f"{old_name} AS {new_name}" for old_name, new_name in self._columns.items()])}' + return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}' class Sample(ESQLBase): @@ -1015,7 +1059,7 @@ def __init__(self, parent: ESQLBase, probability: float): self._probability = probability def _render_internal(self) -> str: - return f"SAMPLE {self._probability}" + return f"SAMPLE {json.dumps(self._probability)}" class Sort(ESQLBase): @@ -1026,12 +1070,16 @@ class Sort(ESQLBase): in a single expression. """ - def __init__(self, parent: ESQLBase, *columns: FieldType): + def __init__(self, parent: ESQLBase, *columns: ExpressionType): super().__init__(parent) self._columns = columns def _render_internal(self) -> str: - return f'SORT {", ".join([f"{col}" for col in self._columns])}' + sorts = [ + " ".join([self._format_id(term) for term in str(col).split(" ")]) + for col in self._columns + ] + return f'SORT {", ".join([f"{sort}" for sort in sorts])}' class Stats(ESQLBase): @@ -1062,14 +1110,17 @@ def by(self, *grouping_expressions: ExpressionType) -> "Stats": def _render_internal(self) -> str: if isinstance(self._expressions, dict): - exprs = [f"{key} = {value}" for key, value in self._expressions.items()] + exprs = [ + f"{self._format_id(key)} = {self._format_expr(value)}" + for key, value in self._expressions.items() + ] else: - exprs = [f"{expr}" for expr in self._expressions] + exprs = [f"{self._format_expr(expr)}" for expr in self._expressions] expression_separator = ",\n " by = ( "" if self._grouping_expressions is None - else f'\n BY {", ".join([f"{expr}" for expr in self._grouping_expressions])}' + else f'\n BY {", ".join([f"{self._format_expr(expr)}" for expr in self._grouping_expressions])}' ) return f'STATS {expression_separator.join([f"{expr}" for expr in exprs])}{by}' @@ -1087,7 +1138,7 @@ def __init__(self, parent: ESQLBase, *expressions: ExpressionType): self._expressions = expressions def _render_internal(self) -> str: - return f'WHERE {" AND ".join([f"{expr}" for expr in self._expressions])}' + return f'WHERE {" AND ".join([f"{self._format_expr(expr)}" for expr in self._expressions])}' def and_(*expressions: InstrumentedExpression) -> "InstrumentedExpression": diff --git a/elasticsearch/esql/functions.py b/elasticsearch/esql/functions.py index 515e3ddfc..91f18d2d8 100644 --- a/elasticsearch/esql/functions.py +++ b/elasticsearch/esql/functions.py @@ -19,11 +19,15 @@ from typing import Any from elasticsearch.dsl.document_base import InstrumentedExpression -from elasticsearch.esql.esql import ExpressionType +from elasticsearch.esql.esql import ESQLBase, ExpressionType def _render(v: Any) -> str: - return json.dumps(v) if not isinstance(v, InstrumentedExpression) else str(v) + return ( + json.dumps(v) + if not isinstance(v, InstrumentedExpression) + else ESQLBase._format_expr(v) + ) def abs(number: ExpressionType) -> InstrumentedExpression: @@ -69,7 +73,9 @@ def atan2( :param y_coordinate: y coordinate. If `null`, the function returns `null`. :param x_coordinate: x coordinate. If `null`, the function returns `null`. """ - return InstrumentedExpression(f"ATAN2({y_coordinate}, {x_coordinate})") + return InstrumentedExpression( + f"ATAN2({_render(y_coordinate)}, {_render(x_coordinate)})" + ) def avg(number: ExpressionType) -> InstrumentedExpression: @@ -114,7 +120,7 @@ def bucket( :param to: End of the range. Can be a number, a date or a date expressed as a string. """ return InstrumentedExpression( - f"BUCKET({_render(field)}, {_render(buckets)}, {from_}, {_render(to)})" + f"BUCKET({_render(field)}, {_render(buckets)}, {_render(from_)}, {_render(to)})" ) @@ -169,7 +175,7 @@ def cidr_match(ip: ExpressionType, block_x: ExpressionType) -> InstrumentedExpre :param ip: IP address of type `ip` (both IPv4 and IPv6 are supported). :param block_x: CIDR block to test the IP against. """ - return InstrumentedExpression(f"CIDR_MATCH({_render(ip)}, {block_x})") + return InstrumentedExpression(f"CIDR_MATCH({_render(ip)}, {_render(block_x)})") def coalesce(first: ExpressionType, rest: ExpressionType) -> InstrumentedExpression: @@ -264,7 +270,7 @@ def date_diff( :param end_timestamp: A string representing an end timestamp """ return InstrumentedExpression( - f"DATE_DIFF({_render(unit)}, {start_timestamp}, {end_timestamp})" + f"DATE_DIFF({_render(unit)}, {_render(start_timestamp)}, {_render(end_timestamp)})" ) @@ -285,7 +291,9 @@ def date_extract( the function returns `null`. :param date: Date expression. If `null`, the function returns `null`. """ - return InstrumentedExpression(f"DATE_EXTRACT({date_part}, {_render(date)})") + return InstrumentedExpression( + f"DATE_EXTRACT({_render(date_part)}, {_render(date)})" + ) def date_format( @@ -301,7 +309,7 @@ def date_format( """ if date_format is not None: return InstrumentedExpression( - f"DATE_FORMAT({json.dumps(date_format)}, {_render(date)})" + f"DATE_FORMAT({_render(date_format)}, {_render(date)})" ) else: return InstrumentedExpression(f"DATE_FORMAT({_render(date)})") @@ -317,7 +325,9 @@ def date_parse( :param date_string: Date expression as a string. If `null` or an empty string, the function returns `null`. """ - return InstrumentedExpression(f"DATE_PARSE({date_pattern}, {date_string})") + return InstrumentedExpression( + f"DATE_PARSE({_render(date_pattern)}, {_render(date_string)})" + ) def date_trunc( @@ -929,7 +939,7 @@ def replace( :param new_string: Replacement string. """ return InstrumentedExpression( - f"REPLACE({_render(string)}, {_render(regex)}, {new_string})" + f"REPLACE({_render(string)}, {_render(regex)}, {_render(new_string)})" ) @@ -1004,7 +1014,7 @@ def scalb(d: ExpressionType, scale_factor: ExpressionType) -> InstrumentedExpres :param scale_factor: Numeric expression for the scale factor. If `null`, the function returns `null`. """ - return InstrumentedExpression(f"SCALB({_render(d)}, {scale_factor})") + return InstrumentedExpression(f"SCALB({_render(d)}, {_render(scale_factor)})") def sha1(input: ExpressionType) -> InstrumentedExpression: @@ -1116,7 +1126,7 @@ def st_contains( first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_CONTAINS({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_CONTAINS({_render(geom_a)}, {_render(geom_b)})") def st_disjoint( @@ -1135,7 +1145,7 @@ def st_disjoint( first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_DISJOINT({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_DISJOINT({_render(geom_a)}, {_render(geom_b)})") def st_distance( @@ -1153,7 +1163,7 @@ def st_distance( also have the same coordinate system as the first. This means it is not possible to combine `geo_point` and `cartesian_point` parameters. """ - return InstrumentedExpression(f"ST_DISTANCE({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_DISTANCE({_render(geom_a)}, {_render(geom_b)})") def st_envelope(geometry: ExpressionType) -> InstrumentedExpression: @@ -1208,7 +1218,7 @@ def st_geohash_to_long(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geohash grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHASH_TO_LONG({grid_id})") + return InstrumentedExpression(f"ST_GEOHASH_TO_LONG({_render(grid_id)})") def st_geohash_to_string(grid_id: ExpressionType) -> InstrumentedExpression: @@ -1218,7 +1228,7 @@ def st_geohash_to_string(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geohash grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHASH_TO_STRING({grid_id})") + return InstrumentedExpression(f"ST_GEOHASH_TO_STRING({_render(grid_id)})") def st_geohex( @@ -1254,7 +1264,7 @@ def st_geohex_to_long(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geohex grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHEX_TO_LONG({grid_id})") + return InstrumentedExpression(f"ST_GEOHEX_TO_LONG({_render(grid_id)})") def st_geohex_to_string(grid_id: ExpressionType) -> InstrumentedExpression: @@ -1264,7 +1274,7 @@ def st_geohex_to_string(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input Geohex grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOHEX_TO_STRING({grid_id})") + return InstrumentedExpression(f"ST_GEOHEX_TO_STRING({_render(grid_id)})") def st_geotile( @@ -1300,7 +1310,7 @@ def st_geotile_to_long(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geotile grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOTILE_TO_LONG({grid_id})") + return InstrumentedExpression(f"ST_GEOTILE_TO_LONG({_render(grid_id)})") def st_geotile_to_string(grid_id: ExpressionType) -> InstrumentedExpression: @@ -1310,7 +1320,7 @@ def st_geotile_to_string(grid_id: ExpressionType) -> InstrumentedExpression: :param grid_id: Input geotile grid-id. The input can be a single- or multi-valued column or an expression. """ - return InstrumentedExpression(f"ST_GEOTILE_TO_STRING({grid_id})") + return InstrumentedExpression(f"ST_GEOTILE_TO_STRING({_render(grid_id)})") def st_intersects( @@ -1330,7 +1340,9 @@ def st_intersects( first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_INTERSECTS({geom_a}, {geom_b})") + return InstrumentedExpression( + f"ST_INTERSECTS({_render(geom_a)}, {_render(geom_b)})" + ) def st_within(geom_a: ExpressionType, geom_b: ExpressionType) -> InstrumentedExpression: @@ -1346,7 +1358,7 @@ def st_within(geom_a: ExpressionType, geom_b: ExpressionType) -> InstrumentedExp first. This means it is not possible to combine `geo_*` and `cartesian_*` parameters. """ - return InstrumentedExpression(f"ST_WITHIN({geom_a}, {geom_b})") + return InstrumentedExpression(f"ST_WITHIN({_render(geom_a)}, {_render(geom_b)})") def st_x(point: ExpressionType) -> InstrumentedExpression: diff --git a/test_elasticsearch/test_dsl/_async/test_esql.py b/test_elasticsearch/test_dsl/test_integration/_async/test_esql.py similarity index 88% rename from test_elasticsearch/test_dsl/_async/test_esql.py rename to test_elasticsearch/test_dsl/test_integration/_async/test_esql.py index 7aacb833c..27d26ca99 100644 --- a/test_elasticsearch/test_dsl/_async/test_esql.py +++ b/test_elasticsearch/test_dsl/test_integration/_async/test_esql.py @@ -17,7 +17,7 @@ import pytest -from elasticsearch.dsl import AsyncDocument, M +from elasticsearch.dsl import AsyncDocument, E, M from elasticsearch.esql import ESQL, functions @@ -91,3 +91,13 @@ async def test_esql(async_client): ) r = await async_client.esql.query(query=str(query)) assert r.body["values"] == [[1.95]] + + # find employees by name using a parameter + query = ( + ESQL.from_(Employee) + .where(Employee.first_name == E("?")) + .keep(Employee.last_name) + .sort(Employee.last_name.desc()) + ) + r = await async_client.esql.query(query=str(query), params=["Maria"]) + assert r.body["values"] == [["Luna"], ["Cannon"]] diff --git a/test_elasticsearch/test_dsl/_sync/test_esql.py b/test_elasticsearch/test_dsl/test_integration/_sync/test_esql.py similarity index 88% rename from test_elasticsearch/test_dsl/_sync/test_esql.py rename to test_elasticsearch/test_dsl/test_integration/_sync/test_esql.py index 1c4084fc7..85ceee5ae 100644 --- a/test_elasticsearch/test_dsl/_sync/test_esql.py +++ b/test_elasticsearch/test_dsl/test_integration/_sync/test_esql.py @@ -17,7 +17,7 @@ import pytest -from elasticsearch.dsl import Document, M +from elasticsearch.dsl import Document, E, M from elasticsearch.esql import ESQL, functions @@ -91,3 +91,13 @@ def test_esql(client): ) r = client.esql.query(query=str(query)) assert r.body["values"] == [[1.95]] + + # find employees by name using a parameter + query = ( + ESQL.from_(Employee) + .where(Employee.first_name == E("?")) + .keep(Employee.last_name) + .sort(Employee.last_name.desc()) + ) + r = client.esql.query(query=str(query), params=["Maria"]) + assert r.body["values"] == [["Luna"], ["Cannon"]] diff --git a/test_elasticsearch/test_esql.py b/test_elasticsearch/test_esql.py index 70c9ec679..35b026fb5 100644 --- a/test_elasticsearch/test_esql.py +++ b/test_elasticsearch/test_esql.py @@ -84,7 +84,7 @@ def test_completion(): assert ( query.render() == """ROW question = "What is Elasticsearch?" -| COMPLETION question WITH test_completion_model +| COMPLETION question WITH {"inference_id": "test_completion_model"} | KEEP question, completion""" ) @@ -97,7 +97,7 @@ def test_completion(): assert ( query.render() == """ROW question = "What is Elasticsearch?" -| COMPLETION answer = question WITH test_completion_model +| COMPLETION answer = question WITH {"inference_id": "test_completion_model"} | KEEP question, answer""" ) @@ -128,7 +128,7 @@ def test_completion(): "Synopsis: ", synopsis, "\\n", "Actors: ", MV_CONCAT(actors, ", "), "\\n", ) -| COMPLETION summary = prompt WITH test_completion_model +| COMPLETION summary = prompt WITH {"inference_id": "test_completion_model"} | KEEP title, summary, rating""" ) @@ -160,7 +160,7 @@ def test_completion(): | SORT rating DESC | LIMIT 10 | EVAL prompt = CONCAT("Summarize this movie using the following information: \\n", "Title: ", title, "\\n", "Synopsis: ", synopsis, "\\n", "Actors: ", MV_CONCAT(actors, ", "), "\\n") -| COMPLETION summary = prompt WITH test_completion_model +| COMPLETION summary = prompt WITH {"inference_id": "test_completion_model"} | KEEP title, summary, rating""" ) @@ -713,3 +713,11 @@ def test_match_operator(): == """FROM books | WHERE author:"Faulkner\"""" ) + + +def test_parameters(): + query = ESQL.from_("employees").where("name == ?") + assert query.render() == "FROM employees\n| WHERE name == ?" + + query = ESQL.from_("employees").where(E("name") == E("?")) + assert query.render() == "FROM employees\n| WHERE name == ?"