Skip to content

Commit 1e5c9a3

Browse files
committed
[ADD] util/update_table_from_dict
A recurrent challenge in writing upgrade scripts is that of updating values in a table based on some form of already available mapping from the id (or another identifier) to the new value, this is often addressed with an iterative solution in the form: ```python for key, value in mapping.items(): cr.execute( """ UPDATE table SET col = %s WHERE key_col = %s """, [value, key], ) ``` or in a more efficient (only issuing a single query) but hacky way: ```python cr.execute( """ UPDATE table SET col = (%s::jsonb)->>(key_col::text) WHERE key_col = ANY(%s) """, [json.dumps(mapping), list(mapping)], ) ``` With the former being ineffective for big mappings and the latter often requiring some comments at review time to get it right. This commit introduces a util meant to make it easier to efficiently perform such updates.
1 parent f1dd8f8 commit 1e5c9a3

File tree

1 file changed

+88
-1
lines changed

1 file changed

+88
-1
lines changed

src/util/pg.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Utility functions for interacting with PostgreSQL."""
33

44
import collections
5+
import json
56
import logging
67
import os
78
import re
@@ -43,7 +44,7 @@
4344

4445
from .exceptions import MigrationError, SleepyDeveloperError
4546
from .helpers import _validate_table, model_of_table
46-
from .misc import Sentinel, log_progress, version_gte
47+
from .misc import Sentinel, chunks, log_progress, version_gte
4748

4849
_logger = logging.getLogger(__name__)
4950

@@ -1621,3 +1622,89 @@ def create_id_sequence(cr, table, set_as_default=True):
16211622
table=table_sql,
16221623
)
16231624
)
1625+
1626+
1627+
def update_table_from_dict(cr, table, mapping, key_col="id", bucket_size=DEFAULT_BUCKET_SIZE):
1628+
"""
1629+
Update table's rows based on mapping.
1630+
1631+
Efficiently updates rows in a table by mapping an identifier column (`key_col`) value to the new values for the provided set of columns.
1632+
1633+
.. example::
1634+
1635+
.. code-block:: python
1636+
1637+
util.update_table_from_dict(
1638+
cr,
1639+
"account_move",
1640+
{
1641+
1: {"closing_return_id": 2, "always_tax_eligible": True},
1642+
2: {"closing_return_id": 3, "always_tax_eligible": False},
1643+
},
1644+
)
1645+
1646+
:param str table: the table to update
1647+
:param dict[any, dict[str, any]] mapping: mapping of `key_col` identifiers to maps of column names to their new value
1648+
1649+
.. example::
1650+
1651+
.. code-block:: python
1652+
1653+
mapping = {
1654+
1: {"col1": 123, "col2": "foo"},
1655+
2: {"col1": 456, "col2": "bar"},
1656+
}
1657+
1658+
.. warning::
1659+
1660+
All maps should have the exact same set of keys (column names). The following
1661+
example would behave unpredictably:
1662+
1663+
.. code-block:: python
1664+
1665+
# WRONG
1666+
mapping = {
1667+
1: {"col1": 123, "col2": "foo"},
1668+
2: {"col1": 456},
1669+
}
1670+
1671+
Either resulting in `col2` updates being ignored or setting it to NULL for row 2.
1672+
1673+
:param str key_col: The column to match the key against (`id` by default)
1674+
:param int bucket_size: maximum number of rows to update per single query
1675+
"""
1676+
if not mapping:
1677+
return
1678+
1679+
_validate_table(table)
1680+
1681+
column_names = list(next(iter(mapping.values())).keys())
1682+
query = cr.mogrify(
1683+
format_query(
1684+
cr,
1685+
"""
1686+
UPDATE {table} t
1687+
SET ({columns_list}) = ROW({values_list})
1688+
FROM JSONB_EACH(%%s) m
1689+
WHERE t.{key_col}::varchar = m.key
1690+
""",
1691+
table=table,
1692+
columns_list=ColumnList.from_unquoted(cr, column_names),
1693+
values_list=sql.SQL(", ").join(
1694+
sql.SQL("(m.value->>%s)::{}").format(sql.SQL(column_type(cr, table, col))) for col in column_names
1695+
),
1696+
key_col=key_col,
1697+
),
1698+
column_names,
1699+
)
1700+
1701+
if len(mapping) <= 1.1 * bucket_size:
1702+
cr.execute(query, [json.dumps(mapping)])
1703+
else:
1704+
parallel_execute(
1705+
cr,
1706+
[
1707+
cr.mogrify(query, [json.dumps(mapping_chunk)]).decode()
1708+
for mapping_chunk in chunks(mapping.items(), bucket_size, fmt=dict)
1709+
],
1710+
)

0 commit comments

Comments
 (0)