Skip to content

Commit 58ceb25

Browse files
ES|QL query builder robustness fixes (#3017) (#3025)
* Add note on how to prevent ES|QL injection attacks * Various additional query builder fixes * linter fixes (cherry picked from commit e3e85ed) Co-authored-by: Miguel Grinberg <[email protected]>
1 parent 41b2064 commit 58ceb25

File tree

7 files changed

+175
-63
lines changed

7 files changed

+175
-63
lines changed

docs/reference/esql-query-builder.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,26 @@ query = (
203203
)
204204
```
205205

206+
### Preventing injection attacks
207+
208+
ES|QL, like most query languages, is vulnerable to [code injection attacks](https://en.wikipedia.org/wiki/Code_injection) if untrusted data provided by users is added to a query. To eliminate this risk, ES|QL allows untrusted data to be given separately from the query as parameters.
209+
210+
Continuing with the example above, let's assume that the application needs a `find_employee_by_name()` function that searches for the name given as an argument. If this argument is received by the application from users, then it is considered untrusted and should not be added to the query directly. Here is how to code the function in a secure manner:
211+
212+
```python
213+
def find_employee_by_name(name):
214+
query = (
215+
ESQL.from_("employees")
216+
.keep("first_name", "last_name", "height")
217+
.where(E("first_name") == E("?"))
218+
)
219+
return client.esql.query(query=str(query), params=[name])
220+
```
221+
222+
Here the part of the query in which the untrusted data needs to be inserted is replaced with a parameter, which in ES|QL is defined by the question mark. When using Python expressions, the parameter must be given as `E("?")` so that it is treated as an expression and not as a literal string.
223+
224+
The list of values given in the `params` argument to the query endpoint are assigned in order to the parameters defined in the query.
225+
206226
## Using ES|QL functions
207227

208228
The ES|QL language includes a rich set of functions that can be used in expressions and conditionals. These can be included in expressions given as strings, as shown in the example below:
@@ -235,6 +255,6 @@ query = (
235255
)
236256
```
237257

238-
Note that arguments passed to functions are assumed to be literals. When passing field names, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly.
258+
Note that arguments passed to functions are assumed to be literals. When passing field names, parameters or other ES|QL expressions, it is necessary to wrap them with the `E()` helper function so that they are interpreted correctly.
239259

240260
You can find the complete list of available functions in the Python client's [ES|QL API reference documentation](https://elasticsearch-py.readthedocs.io/en/stable/esql.html#module-elasticsearch.esql.functions).

elasticsearch/esql/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from ..dsl import E # noqa: F401
1819
from .esql import ESQL, and_, not_, or_ # noqa: F401

elasticsearch/esql/esql.py

Lines changed: 85 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import json
19+
import re
1920
from abc import ABC, abstractmethod
2021
from typing import Any, Dict, Optional, Tuple, Type, Union
2122

@@ -111,6 +112,29 @@ def render(self) -> str:
111112
def _render_internal(self) -> str:
112113
pass
113114

115+
@staticmethod
116+
def _format_index(index: IndexType) -> str:
117+
return index._index._name if hasattr(index, "_index") else str(index)
118+
119+
@staticmethod
120+
def _format_id(id: FieldType, allow_patterns: bool = False) -> str:
121+
s = str(id) # in case it is an InstrumentedField
122+
if allow_patterns and "*" in s:
123+
return s # patterns cannot be escaped
124+
if re.fullmatch(r"[a-zA-Z_@][a-zA-Z0-9_\.]*", s):
125+
return s
126+
# this identifier needs to be escaped
127+
s.replace("`", "``")
128+
return f"`{s}`"
129+
130+
@staticmethod
131+
def _format_expr(expr: ExpressionType) -> str:
132+
return (
133+
json.dumps(expr)
134+
if not isinstance(expr, (str, InstrumentedExpression))
135+
else str(expr)
136+
)
137+
114138
def _is_forked(self) -> bool:
115139
if self.__class__.__name__ == "Fork":
116140
return True
@@ -427,7 +451,7 @@ def sample(self, probability: float) -> "Sample":
427451
"""
428452
return Sample(self, probability)
429453

430-
def sort(self, *columns: FieldType) -> "Sort":
454+
def sort(self, *columns: ExpressionType) -> "Sort":
431455
"""The ``SORT`` processing command sorts a table on one or more columns.
432456
433457
:param columns: The columns to sort on.
@@ -570,15 +594,12 @@ def metadata(self, *fields: FieldType) -> "From":
570594
return self
571595

572596
def _render_internal(self) -> str:
573-
indices = [
574-
index if isinstance(index, str) else index._index._name
575-
for index in self._indices
576-
]
597+
indices = [self._format_index(index) for index in self._indices]
577598
s = f'{self.__class__.__name__.upper()} {", ".join(indices)}'
578599
if self._metadata_fields:
579600
s = (
580601
s
581-
+ f' METADATA {", ".join([str(field) for field in self._metadata_fields])}'
602+
+ f' METADATA {", ".join([self._format_id(field) for field in self._metadata_fields])}'
582603
)
583604
return s
584605

@@ -594,7 +615,11 @@ class Row(ESQLBase):
594615
def __init__(self, **params: ExpressionType):
595616
super().__init__()
596617
self._params = {
597-
k: json.dumps(v) if not isinstance(v, InstrumentedExpression) else v
618+
self._format_id(k): (
619+
json.dumps(v)
620+
if not isinstance(v, InstrumentedExpression)
621+
else self._format_expr(v)
622+
)
598623
for k, v in params.items()
599624
}
600625

@@ -615,7 +640,7 @@ def __init__(self, item: str):
615640
self._item = item
616641

617642
def _render_internal(self) -> str:
618-
return f"SHOW {self._item}"
643+
return f"SHOW {self._format_id(self._item)}"
619644

620645

621646
class Branch(ESQLBase):
@@ -667,11 +692,11 @@ def as_(self, type_name: str, pvalue_name: str) -> "ChangePoint":
667692
return self
668693

669694
def _render_internal(self) -> str:
670-
key = "" if not self._key else f" ON {self._key}"
695+
key = "" if not self._key else f" ON {self._format_id(self._key)}"
671696
names = (
672697
""
673698
if not self._type_name and not self._pvalue_name
674-
else f' AS {self._type_name or "type"}, {self._pvalue_name or "pvalue"}'
699+
else f' AS {self._format_id(self._type_name or "type")}, {self._format_id(self._pvalue_name or "pvalue")}'
675700
)
676701
return f"CHANGE_POINT {self._value}{key}{names}"
677702

@@ -709,12 +734,13 @@ def with_(self, inference_id: str) -> "Completion":
709734
def _render_internal(self) -> str:
710735
if self._inference_id is None:
711736
raise ValueError("The completion command requires an inference ID")
737+
with_ = {"inference_id": self._inference_id}
712738
if self._named_prompt:
713739
column = list(self._named_prompt.keys())[0]
714740
prompt = list(self._named_prompt.values())[0]
715-
return f"COMPLETION {column} = {prompt} WITH {self._inference_id}"
741+
return f"COMPLETION {self._format_id(column)} = {self._format_id(prompt)} WITH {json.dumps(with_)}"
716742
else:
717-
return f"COMPLETION {self._prompt[0]} WITH {self._inference_id}"
743+
return f"COMPLETION {self._format_id(self._prompt[0])} WITH {json.dumps(with_)}"
718744

719745

720746
class Dissect(ESQLBase):
@@ -742,9 +768,13 @@ def append_separator(self, separator: str) -> "Dissect":
742768

743769
def _render_internal(self) -> str:
744770
sep = (
745-
"" if self._separator is None else f' APPEND_SEPARATOR="{self._separator}"'
771+
""
772+
if self._separator is None
773+
else f" APPEND_SEPARATOR={json.dumps(self._separator)}"
774+
)
775+
return (
776+
f"DISSECT {self._format_id(self._input)} {json.dumps(self._pattern)}{sep}"
746777
)
747-
return f"DISSECT {self._input} {json.dumps(self._pattern)}{sep}"
748778

749779

750780
class Drop(ESQLBase):
@@ -760,7 +790,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType):
760790
self._columns = columns
761791

762792
def _render_internal(self) -> str:
763-
return f'DROP {", ".join([str(col) for col in self._columns])}'
793+
return f'DROP {", ".join([self._format_id(col, allow_patterns=True) for col in self._columns])}'
764794

765795

766796
class Enrich(ESQLBase):
@@ -814,12 +844,18 @@ def with_(self, *fields: FieldType, **named_fields: FieldType) -> "Enrich":
814844
return self
815845

816846
def _render_internal(self) -> str:
817-
on = "" if self._match_field is None else f" ON {self._match_field}"
847+
on = (
848+
""
849+
if self._match_field is None
850+
else f" ON {self._format_id(self._match_field)}"
851+
)
818852
with_ = ""
819853
if self._named_fields:
820-
with_ = f' WITH {", ".join([f"{name} = {field}" for name, field in self._named_fields.items()])}'
854+
with_ = f' WITH {", ".join([f"{self._format_id(name)} = {self._format_id(field)}" for name, field in self._named_fields.items()])}'
821855
elif self._fields is not None:
822-
with_ = f' WITH {", ".join([str(field) for field in self._fields])}'
856+
with_ = (
857+
f' WITH {", ".join([self._format_id(field) for field in self._fields])}'
858+
)
823859
return f"ENRICH {self._policy}{on}{with_}"
824860

825861

@@ -832,7 +868,10 @@ class Eval(ESQLBase):
832868
"""
833869

834870
def __init__(
835-
self, parent: ESQLBase, *columns: FieldType, **named_columns: FieldType
871+
self,
872+
parent: ESQLBase,
873+
*columns: ExpressionType,
874+
**named_columns: ExpressionType,
836875
):
837876
if columns and named_columns:
838877
raise ValueError(
@@ -844,10 +883,13 @@ def __init__(
844883
def _render_internal(self) -> str:
845884
if isinstance(self._columns, dict):
846885
cols = ", ".join(
847-
[f"{name} = {value}" for name, value in self._columns.items()]
886+
[
887+
f"{self._format_id(name)} = {self._format_expr(value)}"
888+
for name, value in self._columns.items()
889+
]
848890
)
849891
else:
850-
cols = ", ".join([f"{col}" for col in self._columns])
892+
cols = ", ".join([f"{self._format_expr(col)}" for col in self._columns])
851893
return f"EVAL {cols}"
852894

853895

@@ -900,7 +942,7 @@ def __init__(self, parent: ESQLBase, input: FieldType, pattern: str):
900942
self._pattern = pattern
901943

902944
def _render_internal(self) -> str:
903-
return f"GROK {self._input} {json.dumps(self._pattern)}"
945+
return f"GROK {self._format_id(self._input)} {json.dumps(self._pattern)}"
904946

905947

906948
class Keep(ESQLBase):
@@ -916,7 +958,7 @@ def __init__(self, parent: ESQLBase, *columns: FieldType):
916958
self._columns = columns
917959

918960
def _render_internal(self) -> str:
919-
return f'KEEP {", ".join([f"{col}" for col in self._columns])}'
961+
return f'KEEP {", ".join([f"{self._format_id(col, allow_patterns=True)}" for col in self._columns])}'
920962

921963

922964
class Limit(ESQLBase):
@@ -932,7 +974,7 @@ def __init__(self, parent: ESQLBase, max_number_of_rows: int):
932974
self._max_number_of_rows = max_number_of_rows
933975

934976
def _render_internal(self) -> str:
935-
return f"LIMIT {self._max_number_of_rows}"
977+
return f"LIMIT {json.dumps(self._max_number_of_rows)}"
936978

937979

938980
class LookupJoin(ESQLBase):
@@ -967,7 +1009,9 @@ def _render_internal(self) -> str:
9671009
if isinstance(self._lookup_index, str)
9681010
else self._lookup_index._index._name
9691011
)
970-
return f"LOOKUP JOIN {index} ON {self._field}"
1012+
return (
1013+
f"LOOKUP JOIN {self._format_index(index)} ON {self._format_id(self._field)}"
1014+
)
9711015

9721016

9731017
class MvExpand(ESQLBase):
@@ -983,7 +1027,7 @@ def __init__(self, parent: ESQLBase, column: FieldType):
9831027
self._column = column
9841028

9851029
def _render_internal(self) -> str:
986-
return f"MV_EXPAND {self._column}"
1030+
return f"MV_EXPAND {self._format_id(self._column)}"
9871031

9881032

9891033
class Rename(ESQLBase):
@@ -999,7 +1043,7 @@ def __init__(self, parent: ESQLBase, **columns: FieldType):
9991043
self._columns = columns
10001044

10011045
def _render_internal(self) -> str:
1002-
return f'RENAME {", ".join([f"{old_name} AS {new_name}" for old_name, new_name in self._columns.items()])}'
1046+
return f'RENAME {", ".join([f"{self._format_id(old_name)} AS {self._format_id(new_name)}" for old_name, new_name in self._columns.items()])}'
10031047

10041048

10051049
class Sample(ESQLBase):
@@ -1015,7 +1059,7 @@ def __init__(self, parent: ESQLBase, probability: float):
10151059
self._probability = probability
10161060

10171061
def _render_internal(self) -> str:
1018-
return f"SAMPLE {self._probability}"
1062+
return f"SAMPLE {json.dumps(self._probability)}"
10191063

10201064

10211065
class Sort(ESQLBase):
@@ -1026,12 +1070,16 @@ class Sort(ESQLBase):
10261070
in a single expression.
10271071
"""
10281072

1029-
def __init__(self, parent: ESQLBase, *columns: FieldType):
1073+
def __init__(self, parent: ESQLBase, *columns: ExpressionType):
10301074
super().__init__(parent)
10311075
self._columns = columns
10321076

10331077
def _render_internal(self) -> str:
1034-
return f'SORT {", ".join([f"{col}" for col in self._columns])}'
1078+
sorts = [
1079+
" ".join([self._format_id(term) for term in str(col).split(" ")])
1080+
for col in self._columns
1081+
]
1082+
return f'SORT {", ".join([f"{sort}" for sort in sorts])}'
10351083

10361084

10371085
class Stats(ESQLBase):
@@ -1062,14 +1110,17 @@ def by(self, *grouping_expressions: ExpressionType) -> "Stats":
10621110

10631111
def _render_internal(self) -> str:
10641112
if isinstance(self._expressions, dict):
1065-
exprs = [f"{key} = {value}" for key, value in self._expressions.items()]
1113+
exprs = [
1114+
f"{self._format_id(key)} = {self._format_expr(value)}"
1115+
for key, value in self._expressions.items()
1116+
]
10661117
else:
1067-
exprs = [f"{expr}" for expr in self._expressions]
1118+
exprs = [f"{self._format_expr(expr)}" for expr in self._expressions]
10681119
expression_separator = ",\n "
10691120
by = (
10701121
""
10711122
if self._grouping_expressions is None
1072-
else f'\n BY {", ".join([f"{expr}" for expr in self._grouping_expressions])}'
1123+
else f'\n BY {", ".join([f"{self._format_expr(expr)}" for expr in self._grouping_expressions])}'
10731124
)
10741125
return f'STATS {expression_separator.join([f"{expr}" for expr in exprs])}{by}'
10751126

@@ -1087,7 +1138,7 @@ def __init__(self, parent: ESQLBase, *expressions: ExpressionType):
10871138
self._expressions = expressions
10881139

10891140
def _render_internal(self) -> str:
1090-
return f'WHERE {" AND ".join([f"{expr}" for expr in self._expressions])}'
1141+
return f'WHERE {" AND ".join([f"{self._format_expr(expr)}" for expr in self._expressions])}'
10911142

10921143

10931144
def and_(*expressions: InstrumentedExpression) -> "InstrumentedExpression":

0 commit comments

Comments
 (0)