Skip to content

ES|QL query builder robustness fixes #3017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion docs/reference/esql-query-builder.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,26 @@ query = (
)
```

### Preventing injection attacks

ES|QL, like most query languages, is vulnerable to [code injection attacks](https://en.wikipedia.org/wiki/Code_injection) if untrusted data provided by users is added to a query. To eliminate this risk, ES|QL allows untrusted data to be given separately from the query as parameters.

Continuing with the example above, let's assume that the application needs a `find_employee_by_name()` function that searches for the name given as an argument. If this argument is received by the application from users, then it is considered untrusted and should not be added to the query directly. Here is how to code the function in a secure manner:

```python
def find_employee_by_name(name):
query = (
ESQL.from_("employees")
.keep("first_name", "last_name", "height")
.where(E("first_name") == E("?"))
)
return client.esql.query(query=str(query), params=[name])
```

Here the part of the query in which the untrusted data needs to be inserted is replaced with a parameter, which in ES|QL is defined by the question mark. When using Python expressions, the parameter must be given as `E("?")` so that it is treated as an expression and not as a literal string.

The list of values given in the `params` argument to the query endpoint are assigned in order to the parameters defined in the query.

## Using ES|QL functions

The ES|QL language includes a rich set of functions that can be used in expressions and conditionals. These can be included in expressions given as strings, as shown in the example below:
Expand Down Expand Up @@ -235,6 +255,6 @@ query = (
)
```

Note that arguments passed to functions are assumed to be literals. When passing field names, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly.
Note that arguments passed to functions are assumed to be literals. When passing field names, parameters or other ES|QL expressions, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly.

You can find the complete list of available functions in the Python client's [ES|QL API reference documentation](https://elasticsearch-py.readthedocs.io/en/stable/esql.html#module-elasticsearch.esql.functions).
1 change: 1 addition & 0 deletions elasticsearch/esql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# specific language governing permissions and limitations
# under the License.

from ..dsl import E # noqa: F401
from .esql import ESQL, and_, not_, or_ # noqa: F401
119 changes: 85 additions & 34 deletions elasticsearch/esql/esql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import json
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -111,6 +112,29 @@ def render(self) -> str:
def _render_internal(self) -> str:
pass

@staticmethod
def _format_index(index: IndexType) -> str:
return index._index._name if hasattr(index, "_index") else str(index)

@staticmethod
def _format_id(id: FieldType, allow_patterns: bool = False) -> str:
s = str(id) # in case it is an InstrumentedField
if allow_patterns and "*" in s:
return s # patterns cannot be escaped
if re.fullmatch(r"[a-zA-Z_@][a-zA-Z0-9_\.]*", s):
return s
# this identifier needs to be escaped
s.replace("`", "``")
return f"`{s}`"

@staticmethod
def _format_expr(expr: ExpressionType) -> str:
return (
json.dumps(expr)
if not isinstance(expr, (str, InstrumentedExpression))
else str(expr)
)

def _is_forked(self) -> bool:
if self.__class__.__name__ == "Fork":
return True
Expand Down Expand Up @@ -427,7 +451,7 @@ def sample(self, probability: float) -> "Sample":
"""
return Sample(self, probability)

def sort(self, *columns: FieldType) -> "Sort":
def sort(self, *columns: ExpressionType) -> "Sort":
"""The ``SORT`` processing command sorts a table on one or more columns.

:param columns: The columns to sort on.
Expand Down Expand Up @@ -570,15 +594,12 @@ def metadata(self, *fields: FieldType) -> "From":
return self

def _render_internal(self) -> str:
indices = [
index if isinstance(index, str) else index._index._name
for index in self._indices
]
indices = [self._format_index(index) for index in self._indices]
s = f'{self.__class__.__name__.upper()} {", ".join(indices)}'
if self._metadata_fields:
s = (
s
+ f' METADATA {", ".join([str(field) for field in self._metadata_fields])}'
+ f' METADATA {", ".join([self._format_id(field) for field in self._metadata_fields])}'
)
return s

Expand All @@ -594,7 +615,11 @@ class Row(ESQLBase):
def __init__(self, **params: ExpressionType):
super().__init__()
self._params = {
k: json.dumps(v) if not isinstance(v, InstrumentedExpression) else v
self._format_id(k): (
json.dumps(v)
if not isinstance(v, InstrumentedExpression)
else self._format_expr(v)
)
for k, v in params.items()
}

Expand All @@ -615,7 +640,7 @@ def __init__(self, item: str):
self._item = item

def _render_internal(self) -> str:
return f"SHOW {self._item}"
return f"SHOW {self._format_id(self._item)}"


class Branch(ESQLBase):
Expand Down Expand Up @@ -667,11 +692,11 @@ def as_(self, type_name: str, pvalue_name: str) -> "ChangePoint":
return self

def _render_internal(self) -> str:
key = "" if not self._key else f" ON {self._key}"
key = "" if not self._key else f" ON {self._format_id(self._key)}"
names = (
""
if not self._type_name and not self._pvalue_name
else f' AS {self._type_name or "type"}, {self._pvalue_name or "pvalue"}'
else f' AS {self._format_id(self._type_name or "type")}, {self._format_id(self._pvalue_name or "pvalue")}'
)
return f"CHANGE_POINT {self._value}{key}{names}"

Expand Down Expand Up @@ -709,12 +734,13 @@ def with_(self, inference_id: str) -> "Completion":
def _render_internal(self) -> str:
if self._inference_id is None:
raise ValueError("The completion command requires an inference ID")
with_ = {"inference_id": self._inference_id}
if self._named_prompt:
column = list(self._named_prompt.keys())[0]
prompt = list(self._named_prompt.values())[0]
return f"COMPLETION {column} = {prompt} WITH {self._inference_id}"
return f"COMPLETION {self._format_id(column)} = {self._format_id(prompt)} WITH {json.dumps(with_)}"
else:
return f"COMPLETION {self._prompt[0]} WITH {self._inference_id}"
return f"COMPLETION {self._format_id(self._prompt[0])} WITH {json.dumps(with_)}"


class Dissect(ESQLBase):
Expand Down Expand Up @@ -742,9 +768,13 @@ def append_separator(self, separator: str) -> "Dissect":

def _render_internal(self) -> str:
sep = (
"" if self._separator is None else f' APPEND_SEPARATOR="{self._separator}"'
""
if self._separator is None
else f" APPEND_SEPARATOR={json.dumps(self._separator)}"
)
return (
f"DISSECT {self._format_id(self._input)} {json.dumps(self._pattern)}{sep}"
)
return f"DISSECT {self._input} {json.dumps(self._pattern)}{sep}"


class Drop(ESQLBase):
Expand All @@ -760,7 +790,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType):
self._columns = columns

def _render_internal(self) -> str:
return f'DROP {", ".join([str(col) for col in self._columns])}'
return f'DROP {", ".join([self._format_id(col, allow_patterns=True) for col in self._columns])}'


class Enrich(ESQLBase):
Expand Down Expand Up @@ -814,12 +844,18 @@ def with_(self, *fields: FieldType, **named_fields: FieldType) -> "Enrich":
return self

def _render_internal(self) -> str:
on = "" if self._match_field is None else f" ON {self._match_field}"
on = (
""
if self._match_field is None
else f" ON {self._format_id(self._match_field)}"
)
with_ = ""
if self._named_fields:
with_ = f' WITH {", ".join([f"{name} = {field}" for name, field in self._named_fields.items()])}'
with_ = f' WITH {", ".join([f"{self._format_id(name)} = {self._format_id(field)}" for name, field in self._named_fields.items()])}'
elif self._fields is not None:
with_ = f' WITH {", ".join([str(field) for field in self._fields])}'
with_ = (
f' WITH {", ".join([self._format_id(field) for field in self._fields])}'
)
return f"ENRICH {self._policy}{on}{with_}"


Expand All @@ -832,7 +868,10 @@ class Eval(ESQLBase):
"""

def __init__(
self, parent: ESQLBase, *columns: FieldType, **named_columns: FieldType
self,
parent: ESQLBase,
*columns: ExpressionType,
**named_columns: ExpressionType,
):
if columns and named_columns:
raise ValueError(
Expand All @@ -844,10 +883,13 @@ def __init__(
def _render_internal(self) -> str:
if isinstance(self._columns, dict):
cols = ", ".join(
[f"{name} = {value}" for name, value in self._columns.items()]
[
f"{self._format_id(name)} = {self._format_expr(value)}"
for name, value in self._columns.items()
]
)
else:
cols = ", ".join([f"{col}" for col in self._columns])
cols = ", ".join([f"{self._format_expr(col)}" for col in self._columns])
return f"EVAL {cols}"


Expand Down Expand Up @@ -900,7 +942,7 @@ def __init__(self, parent: ESQLBase, input: FieldType, pattern: str):
self._pattern = pattern

def _render_internal(self) -> str:
return f"GROK {self._input} {json.dumps(self._pattern)}"
return f"GROK {self._format_id(self._input)} {json.dumps(self._pattern)}"


class Keep(ESQLBase):
Expand All @@ -916,7 +958,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType):
self._columns = columns

def _render_internal(self) -> str:
return f'KEEP {", ".join([f"{col}" for col in self._columns])}'
return f'KEEP {", ".join([f"{self._format_id(col, allow_patterns=True)}" for col in self._columns])}'


class Limit(ESQLBase):
Expand All @@ -932,7 +974,7 @@ def __init__(self, parent: ESQLBase, max_number_of_rows: int):
self._max_number_of_rows = max_number_of_rows

def _render_internal(self) -> str:
return f"LIMIT {self._max_number_of_rows}"
return f"LIMIT {json.dumps(self._max_number_of_rows)}"


class LookupJoin(ESQLBase):
Expand Down Expand Up @@ -967,7 +1009,9 @@ def _render_internal(self) -> str:
if isinstance(self._lookup_index, str)
else self._lookup_index._index._name
)
return f"LOOKUP JOIN {index} ON {self._field}"
return (
f"LOOKUP JOIN {self._format_index(index)} ON {self._format_id(self._field)}"
)


class MvExpand(ESQLBase):
Expand All @@ -983,7 +1027,7 @@ def __init__(self, parent: ESQLBase, column: FieldType):
self._column = column

def _render_internal(self) -> str:
return f"MV_EXPAND {self._column}"
return f"MV_EXPAND {self._format_id(self._column)}"


class Rename(ESQLBase):
Expand All @@ -999,7 +1043,7 @@ def __init__(self, parent: ESQLBase, **columns: FieldType):
self._columns = columns

def _render_internal(self) -> str:
return f'RENAME {", ".join([f"{old_name} AS {new_name}" for old_name, new_name in self._columns.items()])}'
return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}'


class Sample(ESQLBase):
Expand All @@ -1015,7 +1059,7 @@ def __init__(self, parent: ESQLBase, probability: float):
self._probability = probability

def _render_internal(self) -> str:
return f"SAMPLE {self._probability}"
return f"SAMPLE {json.dumps(self._probability)}"


class Sort(ESQLBase):
Expand All @@ -1026,12 +1070,16 @@ class Sort(ESQLBase):
in a single expression.
"""

def __init__(self, parent: ESQLBase, *columns: FieldType):
def __init__(self, parent: ESQLBase, *columns: ExpressionType):
super().__init__(parent)
self._columns = columns

def _render_internal(self) -> str:
return f'SORT {", ".join([f"{col}" for col in self._columns])}'
sorts = [
" ".join([self._format_id(term) for term in str(col).split(" ")])
for col in self._columns
]
return f'SORT {", ".join([f"{sort}" for sort in sorts])}'


class Stats(ESQLBase):
Expand Down Expand Up @@ -1062,14 +1110,17 @@ def by(self, *grouping_expressions: ExpressionType) -> "Stats":

def _render_internal(self) -> str:
if isinstance(self._expressions, dict):
exprs = [f"{key} = {value}" for key, value in self._expressions.items()]
exprs = [
f"{self._format_id(key)} = {self._format_expr(value)}"
for key, value in self._expressions.items()
]
else:
exprs = [f"{expr}" for expr in self._expressions]
exprs = [f"{self._format_expr(expr)}" for expr in self._expressions]
expression_separator = ",\n "
by = (
""
if self._grouping_expressions is None
else f'\n BY {", ".join([f"{expr}" for expr in self._grouping_expressions])}'
else f'\n BY {", ".join([f"{self._format_expr(expr)}" for expr in self._grouping_expressions])}'
)
return f'STATS {expression_separator.join([f"{expr}" for expr in exprs])}{by}'

Expand All @@ -1087,7 +1138,7 @@ def __init__(self, parent: ESQLBase, *expressions: ExpressionType):
self._expressions = expressions

def _render_internal(self) -> str:
return f'WHERE {" AND ".join([f"{expr}" for expr in self._expressions])}'
return f'WHERE {" AND ".join([f"{self._format_expr(expr)}" for expr in self._expressions])}'


def and_(*expressions: InstrumentedExpression) -> "InstrumentedExpression":
Expand Down
Loading
Loading