From 19d0dfe0e69f4fc65907498ad8ad43a5331ffcb8 Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Mon, 24 Jun 2024 17:43:00 -0400 Subject: [PATCH 1/8] Use temp-tables to enable us to pass in a million IDs for a given 'select' statement --- django_bulk_load/bulk_load.py | 29 +++-------------------------- django_bulk_load/queries.py | 27 +++++++++++++++++++-------- 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/django_bulk_load/bulk_load.py b/django_bulk_load/bulk_load.py index 296b34d..408b9e5 100644 --- a/django_bulk_load/bulk_load.py +++ b/django_bulk_load/bulk_load.py @@ -513,10 +513,7 @@ def bulk_select_model_dicts( the filter_field_names keys in addition to any fields in select_field_names :param filter_data: Values (normally tuples) of the filter_field_names. For instance if filter_field_names=["field1", "field2"], filter_data may be [(12, "hello"), (23, "world"), (35, "fun"), ...] - :param skip_filter_transform: Normally the function converts the filter_data into DB specific values. This is useful - for datetimes or other complex values that have different representation in the DB. The downside is the transform - can be slow. If you know your data is simple values (strings, integers, etc.) and don't need - transformation, you can pass True. + :param skip_filter_transform: DEPRECATED. :param select_for_update: Use `FOR UPDATE` clause in select query. This will lock the rows. :return: List of dictionaries that match the model_data. Returns dictionaries for performance reasons @@ -539,32 +536,13 @@ def bulk_select_model_dicts( connection = connections[db_name] with connection.cursor() as cursor: - # Grab all the filter data, so we can know the length - filter_data = list(filter_data) - if not skip_filter_transform: - filter_data_transformed = [] - for filter_vals in filter_data: - filter_data_transformed.append( - [ - django_field_to_query_value(filter_fields[i], value) - for i, value in enumerate(filter_vals) - ] - ) - filter_data = filter_data_transformed - - sql = generate_values_select_query( - table_name=table_name, - select_fields=select_fields, - filter_fields=filter_fields, - select_for_update=select_for_update - ) - sql_string = sql.as_string(cursor.connection) + models = [model_class(**dict(zip(filter_field_names, x))) for x in filter_data] + cursor.execute(generate_select_query(table_name, create_temp_table_and_load(models, connection, cursor, filter_field_names), filter_fields, select_fields, for_update=select_for_update)) logger.info( "Starting selecting models", extra=dict(query_dict_count=len(filter_data), table_name=table_name), ) - execute_values(cursor, sql_string, filter_data, page_size=len(filter_data)) columns = [col[0] for col in cursor.description] # Map columns to fields so we can later correctly interpret column values @@ -590,5 +568,4 @@ def bulk_select_model_dicts( duration=monotonic() - start_time, ), ) - return results diff --git a/django_bulk_load/queries.py b/django_bulk_load/queries.py index 6bd9cbb..dcc683d 100644 --- a/django_bulk_load/queries.py +++ b/django_bulk_load/queries.py @@ -307,6 +307,7 @@ def generate_select_query( loading_table_name: str, join_fields: Sequence[models.Field], select_fields: Sequence[models.Field] = None, + for_update=False ) -> Composable: join_clause = generate_join_condition( source_table_name=loading_table_name, @@ -326,14 +327,24 @@ def generate_select_query( else: fields = SQL("{table_name}.*").format(table_name=Identifier(table_name)) - return SQL( - "SELECT {fields} FROM {table_name} INNER JOIN {loading_table_name} ON {join_clause}" - ).format( - loading_table_name=Identifier(loading_table_name), - join_clause=join_clause, - fields=fields, - table_name=Identifier(table_name), - ) + if for_update: + return SQL( + "SELECT {fields} FROM {table_name} INNER JOIN {loading_table_name} ON {join_clause} FOR UPDATE" + ).format( + loading_table_name=Identifier(loading_table_name), + join_clause=join_clause, + fields=fields, + table_name=Identifier(table_name), + ) + else: + return SQL( + "SELECT {fields} FROM {table_name} INNER JOIN {loading_table_name} ON {join_clause}" + ).format( + loading_table_name=Identifier(loading_table_name), + join_clause=join_clause, + fields=fields, + table_name=Identifier(table_name), + ) def generate_values_select_query( From f847deba0adfff2e8035dea4819f2977fb3b2fd9 Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Mon, 24 Jun 2024 18:50:11 -0400 Subject: [PATCH 2/8] Formatting --- django_bulk_load/bulk_load.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/django_bulk_load/bulk_load.py b/django_bulk_load/bulk_load.py index 408b9e5..637313d 100644 --- a/django_bulk_load/bulk_load.py +++ b/django_bulk_load/bulk_load.py @@ -537,7 +537,15 @@ def bulk_select_model_dicts( with connection.cursor() as cursor: models = [model_class(**dict(zip(filter_field_names, x))) for x in filter_data] - cursor.execute(generate_select_query(table_name, create_temp_table_and_load(models, connection, cursor, filter_field_names), filter_fields, select_fields, for_update=select_for_update)) + cursor.execute( + generate_select_query( + table_name, + create_temp_table_and_load(models, connection, cursor, filter_field_names), + filter_fields, + select_fields, + for_update=select_for_update + ) + ) logger.info( "Starting selecting models", From 97f0270c6f3538dc3e187800a85fcfcfb212b3ec Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Tue, 25 Jun 2024 09:50:21 -0400 Subject: [PATCH 3/8] Joining can be perilous - let's not return extra rows --- django_bulk_load/bulk_load.py | 7 ++++- django_bulk_load/queries.py | 42 +++++++++++++++++++++++++-- tests/test_bulk_select_model_dicts.py | 35 ++++++++++++++++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) diff --git a/django_bulk_load/bulk_load.py b/django_bulk_load/bulk_load.py index 637313d..fc15e16 100644 --- a/django_bulk_load/bulk_load.py +++ b/django_bulk_load/bulk_load.py @@ -26,6 +26,7 @@ generate_insert_for_update_query, generate_select_latest, generate_select_query, + generate_distinct_select_query, generate_update_query, generate_values_select_query, copy_query @@ -96,6 +97,10 @@ def bulk_load_models_with_queries( field_names: Sequence[str] = None, return_models: bool = False, ): + """" + This could be called, "bulk-load-models-into-temp-table AND THEN execute queries" + Or, perhaps "execute_queries_with_temp_table_as_helper". + """ start_time = monotonic() model = models[0] db_name = router.db_for_write(model.__class__) @@ -538,7 +543,7 @@ def bulk_select_model_dicts( with connection.cursor() as cursor: models = [model_class(**dict(zip(filter_field_names, x))) for x in filter_data] cursor.execute( - generate_select_query( + generate_distinct_select_query( table_name, create_temp_table_and_load(models, connection, cursor, filter_field_names), filter_fields, diff --git a/django_bulk_load/queries.py b/django_bulk_load/queries.py index dcc683d..0da7e59 100644 --- a/django_bulk_load/queries.py +++ b/django_bulk_load/queries.py @@ -302,7 +302,7 @@ def generate_update_query( ) -def generate_select_query( +def generate_distinct_select_query( table_name: str, loading_table_name: str, join_fields: Sequence[models.Field], @@ -329,7 +329,7 @@ def generate_select_query( if for_update: return SQL( - "SELECT {fields} FROM {table_name} INNER JOIN {loading_table_name} ON {join_clause} FOR UPDATE" + "SELECT {fields} FROM {table_name} where exists (select 1 from {loading_table_name} where {join_clause}) FOR UPDATE" ).format( loading_table_name=Identifier(loading_table_name), join_clause=join_clause, @@ -338,7 +338,7 @@ def generate_select_query( ) else: return SQL( - "SELECT {fields} FROM {table_name} INNER JOIN {loading_table_name} ON {join_clause}" + "SELECT {fields} FROM {table_name} where exists (select 1 from {loading_table_name} where {join_clause})" ).format( loading_table_name=Identifier(loading_table_name), join_clause=join_clause, @@ -347,6 +347,42 @@ def generate_select_query( ) + + +def generate_select_query( + table_name: str, + loading_table_name: str, + join_fields: Sequence[models.Field], + select_fields: Sequence[models.Field] = None +) -> Composable: + join_clause = generate_join_condition( + source_table_name=loading_table_name, + destination_table_name=table_name, + fields=join_fields, + ) + if select_fields: + fields = SQL(", ").join( + [ + SQL("{table_name}.{column_name}").format( + table_name=Identifier(table_name), + column_name=Identifier(field.column), + ) + for field in select_fields + ] + ) + else: + fields = SQL("{table_name}.*").format(table_name=Identifier(table_name)) + + return SQL( + "SELECT {fields} FROM {table_name} INNER JOIN {loading_table_name} ON {join_clause}" + ).format( + loading_table_name=Identifier(loading_table_name), + join_clause=join_clause, + fields=fields, + table_name=Identifier(table_name), + ) + + def generate_values_select_query( table_name: str, filter_fields: Sequence[models.Field], diff --git a/tests/test_bulk_select_model_dicts.py b/tests/test_bulk_select_model_dicts.py index 343333e..712522f 100644 --- a/tests/test_bulk_select_model_dicts.py +++ b/tests/test_bulk_select_model_dicts.py @@ -21,6 +21,41 @@ def test_empty_get(self): [], ) + def test_ignores_duplicates_in_input(self): + saved_model = TestComplexModel( + integer_field=123, + string_field="hello", + ) + saved_model.save() + result_dicts = bulk_select_model_dicts( + model_class=TestComplexModel, + filter_field_names=["integer_field"], + filter_data=[(123,), (123,)], + select_field_names=["string_field", "integer_field"], + ) + + self.assertEqual(len(result_dicts), 1) + + def test_finds_duplicates_when_they_exist(self): + TestComplexModel( + integer_field=123, + string_field="hello", + ).save() + secod_saved_model = TestComplexModel( + integer_field=123, + string_field="hello", + ).save() + result_dicts = bulk_select_model_dicts( + model_class=TestComplexModel, + filter_field_names=["integer_field"], + filter_data=[(123,), (123,)], + select_field_names=["string_field", "integer_field"], + ) + + self.assertEqual(len(result_dicts), 2) + + + def test_single_select(self): foreign = TestForeignKeyModel() foreign.save() From 0e7fb1e907dee018435ec93ff9d634a6319ae11a Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Tue, 25 Jun 2024 10:44:49 -0400 Subject: [PATCH 4/8] Dead code now --- django_bulk_load/bulk_load.py | 1 - django_bulk_load/queries.py | 27 --------------------------- 2 files changed, 28 deletions(-) diff --git a/django_bulk_load/bulk_load.py b/django_bulk_load/bulk_load.py index fc15e16..275dec6 100644 --- a/django_bulk_load/bulk_load.py +++ b/django_bulk_load/bulk_load.py @@ -28,7 +28,6 @@ generate_select_query, generate_distinct_select_query, generate_update_query, - generate_values_select_query, copy_query ) from .utils import generate_table_name diff --git a/django_bulk_load/queries.py b/django_bulk_load/queries.py index 0da7e59..c9cae35 100644 --- a/django_bulk_load/queries.py +++ b/django_bulk_load/queries.py @@ -382,30 +382,3 @@ def generate_select_query( table_name=Identifier(table_name), ) - -def generate_values_select_query( - table_name: str, - filter_fields: Sequence[models.Field], - select_fields: Sequence[models.Field], - select_for_update: bool -): - select_fields_sql = SQL(", ").join( - [Identifier(field.column) for field in select_fields] - ) - - filter_fields_sql = SQL(", ").join( - [Identifier(field.column) for field in filter_fields] - ) - - for_update = SQL("") - if select_for_update: - for_update = SQL(" FOR UPDATE") - - return SQL( - "SELECT {select_fields_sql} from {table_name} where ({filter_fields_sql}) IN (VALUES %s){for_update}" - ).format( - table_name=Identifier(table_name), - select_fields_sql=select_fields_sql, - filter_fields_sql=filter_fields_sql, - for_update=for_update, - ) From 71907147cd358f94d555498b9c351a8a4741ab02 Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Tue, 25 Jun 2024 10:46:09 -0400 Subject: [PATCH 5/8] More unused code --- django_bulk_load/bulk_load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/django_bulk_load/bulk_load.py b/django_bulk_load/bulk_load.py index 275dec6..07e651e 100644 --- a/django_bulk_load/bulk_load.py +++ b/django_bulk_load/bulk_load.py @@ -6,7 +6,6 @@ from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import CursorWrapper from django.db.models import AutoField, Model, Field -from psycopg2.extras import execute_values from psycopg2.sql import Composable, SQL from .django import ( From f2386a325b4a75fe6597cfb688c2a231ea70bbb0 Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Tue, 25 Jun 2024 10:49:27 -0400 Subject: [PATCH 6/8] More unused code --- django_bulk_load/bulk_load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/django_bulk_load/bulk_load.py b/django_bulk_load/bulk_load.py index 07e651e..b5a6efb 100644 --- a/django_bulk_load/bulk_load.py +++ b/django_bulk_load/bulk_load.py @@ -9,7 +9,6 @@ from psycopg2.sql import Composable, SQL from .django import ( - django_field_to_query_value, get_fields_and_names, get_fields_from_names, get_model_fields, From 6fd7628acd4a271d8491e0e176565a5403dcb260 Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Tue, 25 Jun 2024 11:03:54 -0400 Subject: [PATCH 7/8] Add a transaction to ensure that the temp table gets cleaned up --- django_bulk_load/bulk_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_bulk_load/bulk_load.py b/django_bulk_load/bulk_load.py index b5a6efb..7f63bd1 100644 --- a/django_bulk_load/bulk_load.py +++ b/django_bulk_load/bulk_load.py @@ -537,7 +537,7 @@ def bulk_select_model_dicts( db_name = router.db_for_read(model_class) connection = connections[db_name] - with connection.cursor() as cursor: + with connection.cursor() as cursor, transaction.atomic(using=db_name): models = [model_class(**dict(zip(filter_field_names, x))) for x in filter_data] cursor.execute( generate_distinct_select_query( From ecd48d5fa18168d6fcfc6b0ae6fc6db918d33011 Mon Sep 17 00:00:00 2001 From: Wai Lee Chin Feman Date: Tue, 25 Jun 2024 16:26:04 -0400 Subject: [PATCH 8/8] leverage the Composable concept --- django_bulk_load/queries.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/django_bulk_load/queries.py b/django_bulk_load/queries.py index c9cae35..35b3233 100644 --- a/django_bulk_load/queries.py +++ b/django_bulk_load/queries.py @@ -327,25 +327,18 @@ def generate_distinct_select_query( else: fields = SQL("{table_name}.*").format(table_name=Identifier(table_name)) - if for_update: - return SQL( - "SELECT {fields} FROM {table_name} where exists (select 1 from {loading_table_name} where {join_clause}) FOR UPDATE" - ).format( - loading_table_name=Identifier(loading_table_name), - join_clause=join_clause, - fields=fields, - table_name=Identifier(table_name), - ) - else: - return SQL( - "SELECT {fields} FROM {table_name} where exists (select 1 from {loading_table_name} where {join_clause})" - ).format( - loading_table_name=Identifier(loading_table_name), - join_clause=join_clause, - fields=fields, - table_name=Identifier(table_name), - ) + base_query = SQL( + "SELECT {fields} FROM {table_name} where exists (select 1 from {loading_table_name} where {join_clause})" + ).format( + loading_table_name=Identifier(loading_table_name), + join_clause=join_clause, + fields=fields, + table_name=Identifier(table_name), + ) + if for_update: + return Composed([base_query, SQL("FOR UPDATE")]) + return base_query