diff --git a/src/base/tests/test_util.py b/src/base/tests/test_util.py index b0a15b456..f3356421d 100644 --- a/src/base/tests/test_util.py +++ b/src/base/tests/test_util.py @@ -881,6 +881,115 @@ def test_parallel_execute_retry_on_serialization_failure(self): cr.execute(util.format_query(cr, "SELECT 1 FROM {}", TEST_TABLE_NAME)) self.assertFalse(cr.rowcount) + def test_update_one_col_from_dict(self): + TEST_TABLE_NAME = "_upgrade_bulk_update_one_col_test_table" + N_ROWS = 10 + + cr = self._get_cr() + + cr.execute( + util.format_query( + cr, + """ + DROP TABLE IF EXISTS {table}; + + CREATE TABLE {table} ( + id SERIAL PRIMARY KEY, + col1 INTEGER, + col2 INTEGER + ); + + INSERT INTO {table} (col1, col2) SELECT v, v FROM GENERATE_SERIES(1, %s) as v; + """, + table=TEST_TABLE_NAME, + ), + [N_ROWS], + ) + mapping = {id: id * 2 for id in range(1, N_ROWS + 1, 2)} + util.bulk_update_table(cr, TEST_TABLE_NAME, "col1", mapping) + + cr.execute( + util.format_query( + cr, + "SELECT id FROM {table} WHERE col2 != id", + table=TEST_TABLE_NAME, + ) + ) + self.assertFalse(cr.rowcount, "unintended column 'col2' is affected") + + cr.execute( + util.format_query( + cr, + "SELECT id FROM {table} WHERE col1 != id AND MOD(id, 2) = 0", + table=TEST_TABLE_NAME, + ) + ) + self.assertFalse(cr.rowcount, "unintended rows are affected") + + cr.execute( + util.format_query( + cr, + "SELECT id FROM {table} WHERE col1 != 2 * id AND MOD(id, 2) = 1", + table=TEST_TABLE_NAME, + ) + ) + self.assertFalse(cr.rowcount, "partial/incorrect updates are performed") + + def test_update_multiple_cols_from_dict(self): + TEST_TABLE_NAME = "_upgrade_bulk_update_multiple_cols_test_table" + N_ROWS = 10 + + cr = self._get_cr() + + cr.execute( + util.format_query( + cr, + """ + DROP TABLE IF EXISTS {table}; + + CREATE TABLE {table} ( + id SERIAL PRIMARY KEY, + col1 INTEGER, + col2 INTEGER, + col3 INTEGER + ); + + INSERT INTO {table} (col1, col2, col3) SELECT v, v, v FROM GENERATE_SERIES(1, %s) as v; + """, + table=TEST_TABLE_NAME, + ), + [N_ROWS], + ) + mapping = {id: [id * 2, id * 3] for id in range(1, N_ROWS + 1, 2)} + util.bulk_update_table(cr, TEST_TABLE_NAME, ["col1", "col2"], mapping) + + cr.execute( + util.format_query( + cr, + "SELECT id FROM {table} WHERE col3 != id", + table=TEST_TABLE_NAME, + ) + ) + self.assertFalse(cr.rowcount, "unintended column 'col3' is affected") + + cr.execute( + util.format_query( + cr, + "SELECT id FROM {table} WHERE col1 != id AND MOD(id, 2) = 0", + table=TEST_TABLE_NAME, + ) + ) + self.assertFalse(cr.rowcount, "unintended rows are affected") + + cr.execute( + util.format_query( + cr, + "SELECT id FROM {table} WHERE (col1 != 2 * id OR col2 != 3 * id) AND MOD(id, 2) = 1", + table=TEST_TABLE_NAME, + ) + ) + self.assertFalse(cr.rowcount, "partial/incorrect updates are performed") + def test_create_column_with_fk(self): cr = self.env.cr self.assertFalse(util.column_exists(cr, "res_partner", "_test_lang_id")) diff --git a/src/util/pg.py b/src/util/pg.py index e7522a20e..26c544f02 100644 --- a/src/util/pg.py +++ b/src/util/pg.py @@ -32,6 +32,7 @@ import psycopg2 from psycopg2 import errorcodes, sql from psycopg2.extensions import quote_ident +from psycopg2.extras import Json try: from odoo.modules import module as odoo_module @@ -1621,3 +1622,76 @@ def create_id_sequence(cr, table, set_as_default=True): table=table_sql, ) ) + + +def bulk_update_table(cr, table, columns, mapping, key_col="id"): + """ + Update table based on mapping. + + Each `mapping` entry defines the new values for the specified `columns` for the row(s) + whose `key_col` value matches the key. + + .. example:: + + .. code-block:: python + + # single column update + util.bulk_update_table(cr, "res_users", "active", {42: False, 27: True}) + + # multi-column update + util.bulk_update_table( + cr, + "res_users", + ["active", "password"], + { + "admin": [True, "1234"], + "demo": [True, "5678"], + }, + key_col="login", + ) + + :param str table: table to update. + :param str | list(str) columns: columns spec for the update. It could be a single + column name or a list of column names. The `mapping` + must match the spec. + :param dict mapping: values to set, which must match the spec in `columns`, + following the **same** order + :param str key_col: column used as key to get the values from `mapping` during the + update. + + .. warning:: + + The values in the mapping will be casted to the type of the target column. + This function is designed to update scalar values, avoid setting arrays or json + data via the mapping. + """ + _validate_table(table) + if not columns or not mapping: + return + + assert isinstance(mapping, dict) + if isinstance(columns, str): + columns = [columns] + else: + n_columns = len(columns) + assert all(isinstance(value, (list, tuple)) and len(value) == n_columns for value in mapping.values()) + + query = format_query( + cr, + """ + UPDATE {table} t + SET ({cols}) = ROW({cols_values}) + FROM JSONB_EACH(%s) m + WHERE t.{key_col}::text = m.key + """, + table=table, + cols=ColumnList.from_unquoted(cr, columns), + cols_values=SQLStr( + ", ".join( + "(m.value->>{:d})::{}".format(col_idx, column_type(cr, table, col_name)) + for col_idx, col_name in enumerate(columns) + ) + ), + key_col=key_col, + ) + cr.execute(query, [Json(mapping)])