From 32c4a857499ead1b5036eae9781fd9177b47dc3e Mon Sep 17 00:00:00 2001 From: wahid atoui Date: Sun, 15 Oct 2023 20:01:41 +0200 Subject: [PATCH] add type three scd upserts functionality --- README.md | 57 +++++++ mack/__init__.py | 117 +++++++++++++ tests/test_public_interface.py | 303 +++++++++++++++++++++++++++++++++ 3 files changed, 477 insertions(+) diff --git a/README.md b/README.md index 97a5293..bdab427 100644 --- a/README.md +++ b/README.md @@ -578,6 +578,63 @@ This function is designed to rename a Delta table. It can operate either within rename_delta_table(existing_delta_table, "new_table_name") ``` +## Type 3 SCD Upserts + +Perform a type 3 scd upsert on a target delta table. + +Parameters: + +- `delta_table` (`DeltaTable`): An object representing the delta table to be upserted. +- `updates_df` (`DataFrame`): The data to be used in order to upsert the target delta table. +- `primary_key` (`str`): The primary key (i.e. business key) uniquely identifiy each row in the target delta table. +- `curr_prev_col_names` (`dict[str,str]`): A dictionary of column names to store current and previous values. + `Key`: Column name for current value. + `Value`: Column name for previous value. + + +Suppose you have the following delta table: + +``` ++----+----+----+-------+--------+------------+-------------+--------------+ +|pkey|name|job|prev_job| country|prev_country| continent|prev_continent| ++----+----+---+--------+--------+------------+-------------+--------------+ +| 1| A| AA| null| Japan| null| Asia| null| +| 2| B| BB| null| London| null| Europe| null| +| 3| C| CC| null| canada| null|North America| null| ++----+----+---+--------+--------+------------+-------------+--------------+ +``` + +The source data to be upserted on target delta table: + +``` ++----+----+----+-----------+-------------+ +|pkey|name|job| country| continent| ++----+----+---+------------+-------------+ +| 1| A1| AA| Japan| Asia| // update on name +| 2| B1|BBB| Peru|South America| // updates on name,job,country,continent --> storing previous values on prev_job,prev_country,prev_continent +| 3| C| CC| New Zeland| Oceania| // updates on country,continent --> storing previous values on prev_country,prev_continent +| 5| D| DD|South Africa| Africa| // new row ++----+----+---+------------+-------------+ +``` + +Here's how to perform the type 3 scd upsert: + +```scala +mack.type_3_scd_upsert(delta_table, updatesDF, "pkey", {"country":"prev_country", "job":"prev_job", "continent":"prev_continent"}) +``` + +Here's the table after the upsert: + +``` ++----+----+----+-------+------------+------------+-------------+--------------+ +|pkey|name|job|prev_job| country|prev_country| continent|prev_continent| ++----+----+---+--------+------------+------------+-------------+--------------+ +| 1| A1| AA| null| Japan| null| Asia| null| +| 2| B1|BBB| BB| Peru| London|South America| Europe| +| 3| C| CC| null| New Zeland| canada| Oceania| North America| +| 5| D| DD| null|South Africa| null| Africa| null| ++----+----+---+--------+------------+------------+-------------+--------------+ +``` ## Dictionary diff --git a/mack/__init__.py b/mack/__init__.py index f9c9ef7..b9be3cd 100644 --- a/mack/__init__.py +++ b/mack/__init__.py @@ -1,5 +1,6 @@ from itertools import combinations from typing import List, Union, Dict, Optional +from collections import Counter from delta import DeltaTable import pyspark @@ -735,3 +736,119 @@ def rename_delta_table( delta_table.toDF().write.format("delta").mode("overwrite").saveAsTable( new_table_name ) + +def type_3_scd_upsert( + delta_table: DeltaTable, + updates_df: DataFrame, + primary_key: str, + curr_prev_col_names: dict[str,str] +) -> None: + """ + Apply scd type 3 updates on a target delta table. + + :param delta_table: The target delta table. + :type delta_table: DeltaTable + + :param updates_df: The source dataframe that will be used to apply scd type 3 on the target delta table. + :type updates_df: DataFrame + + :param primary_key: The primary key (i.e. business key) uniquely identifiy each row in the target delta table. + :type primary_key: str + + :param curr_prev_col_names: A dictionary of column names to store current and previous values. + -> Key: Column name for current value. + -> Value: Column name for previous value. + :type curr_prev_col_names: dict[str,str] + + :raises TypeError: Raises type error when find a duplication in the items' value of the dictionary 'curr_prev_col_names'. + :raises TypeError: Raises type error when find a key equals to a value in items of the dictionary 'curr_prev_col_names'. + :raises TypeError: Raises type error when required columns are missing in the delta table. + :raises TypeError: Raises type error when required columns are missing in the update dataframe. + """ + + # validate the curr_prev_col_names parameters + ## raise an error in case of dict values duplication + count_dict = Counter(curr_prev_col_names.values()) + prev_col_name_duplicates = [(key,value) for key, value in curr_prev_col_names.items() if count_dict[value] > 1] + + if prev_col_name_duplicates: + raise TypeError( + f"Find duplication in the values of the dictionary curr_prev_col_names: {prev_col_name_duplicates!r}" + ) + ## raise error when find key equals to value + keys_equal_to_values = [(key,value) for key, value in curr_prev_col_names.items() if key == value] + if keys_equal_to_values: + raise TypeError( + f"Keys cannot be equal to values in the dictionary curr_prev_col_names: {keys_equal_to_values!r}" + ) + + # validate the existing Delta table + base_col_names = delta_table.toDF().columns + required_base_col_names = ( + [primary_key] + + [items for item in curr_prev_col_names.items() for items in item] + ) + missing_col_names = [item for item in required_base_col_names if item not in base_col_names] + if missing_col_names: + raise TypeError( + f"Cannot find these columns {missing_col_names!r} in the base table {base_col_names!r}" + ) + + # validate the updates DataFrame + updates_col_names = updates_df.columns + prev_col_names = list(curr_prev_col_names.values()) + required_updates_col_names = [item for item in base_col_names if item not in (prev_col_names)] # filter out all prev_col_names from base_col_names + if sorted(updates_col_names) != sorted(required_updates_col_names): + raise TypeError( + f"The updates DataFrame has these columns {updates_col_names!r}, but these columns are required {required_updates_col_names!r}" + ) + + # merge condition + merge_condition = pyspark.sql.functions.expr(f"trg.{primary_key} = src.{primary_key}") + + # update condition + updates_attr = [attr for attr in base_col_names if attr not in (primary_key,prev_col_names)] + updates_condition = list( + map(lambda attr: f"trg.{attr} <> src.{attr}", updates_attr) + ) + updates_condition = " OR ".join(updates_condition) + + # rows to be inserted + previous_state_for_inserts = list( + map(lambda item: f"NULL as {item}", prev_col_names) + ) + + staged_inserts_df = ( + updates_df.alias('inserts') + .join(delta_table.toDF().alias('trg'),primary_key,'leftanti') + .selectExpr(["inserts.*"] + previous_state_for_inserts) + ) + + # rows to be updated + previous_state_for_updates = list( + map(lambda item: f"coalesce(nullif(trg.{item[0]},updates.{item[0]}),trg.{item[1]}) as {item[1]}" ,curr_prev_col_names.items()) + ) + + staged_updates_df = ( + updates_df.alias('updates') + .join(delta_table.toDF().alias('trg'),primary_key) + .selectExpr(["updates.*"] + previous_state_for_updates) + ) + + # input data = staged_updates_df + staged_inserts_df + staged_inputs_df = staged_updates_df.union(staged_inserts_df) + + # perform the merge + res = ( + delta_table.alias('trg') + .merge( + source=staged_inputs_df.alias('src'), + condition=merge_condition + ) + .whenMatchedUpdateAll( + condition=updates_condition + ) + .whenNotMatchedInsertAll() + .execute() + ) + return res \ No newline at end of file diff --git a/tests/test_public_interface.py b/tests/test_public_interface.py index 8226e5b..d020254 100644 --- a/tests/test_public_interface.py +++ b/tests/test_public_interface.py @@ -1168,3 +1168,306 @@ def test_rename_delta_table(tmp_path): # Clean up: Drop the new table spark.sql(f"DROP TABLE IF EXISTS {new_table_name}") + +def test_type_three_scd_upsert_with_single_prev_column(tmp_path): + path = f"{tmp_path}/tmp/delta-type-three-scd-upsert-with-single-column" + input = [ + (1, "A", "Canada", None), + (2, "B", "Germany", None), + (3, "C", "Japan", None), + (4, "D", "South Africa", None), + ] + schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("country", StringType(), True), + StructField("prev_country", StringType(), True), + ] + ) + df = spark.createDataFrame(data=input, schema=schema) + df.write.format("delta").save(path) + + updates_data = [ + (2, "B", "France"), # value to upsert + (4, "D", "Nigeria"), # value to upsert + (5, "E", "Mexico"), # new value + ] + updates_schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("country", StringType(), True), + ] + ) + updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema) + + delta_table = DeltaTable.forPath(spark, path) + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"prev_country"}) + + actual_df = spark.read.format("delta").load(path) + + expected_df = spark.createDataFrame( + [ + (1, "A", "Canada", None), + (2, "B", "France", "Germany"), + (3, "C", "Japan", None), + (4, "D", "Nigeria", "South Africa"), + (5, "E", "Mexico", None), + ], + schema, + ) + + chispa.assert_df_equality(actual_df, expected_df, ignore_row_order=True) + +def test_type_three_scd_upsert_with_multiple_prev_columns(tmp_path): + path = f"{tmp_path}/tmp/delta-type-three-scd-upsert-with-multiple-columns" + input = [ + (1, "A", "Canada", None,"North America",None), + (2, "B", "Germany", None,"Europe",None), + (3, "C", "Japan", None,"Asia",None), + (4, "D", "Nigeria", None,"Africa",None), + (5, "E", "Argentina", None,"South America",None), + ] + schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("country", StringType(), True), + StructField("prev_country", StringType(), True), + StructField("continent", StringType(), True), + StructField("prev_continent", StringType(), True), + ] + ) + df = spark.createDataFrame(data=input, schema=schema) + df.write.format("delta").save(path) + + updates_data = [ + (2, "B", "Brazil","South America"), # value to upsert + (4, "D", "Italy", "Europe"), # value to upsert + (5, "E", "Canada", "North America"), # value to upsert + (6, "F", "New Zealand","Oceania"), # new value + ] + updates_schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("country", StringType(), True), + StructField("continent", StringType(), True), + ] + ) + updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema) + + delta_table = DeltaTable.forPath(spark, path) + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"prev_country","continent":"prev_continent"}) + + actual_df = spark.read.format("delta").load(path) + + expected_df = spark.createDataFrame( + [ + (1, "A", "Canada", None, "North America", None), + (2, "B", "Brazil", "Germany", "South America","Europe"), + (3, "C", "Japan", None, "Asia", None), + (4, "D", "Italy", "Nigeria", "Europe","Africa"), + (5, "E", "Canada", "Argentina", "North America", "South America"), + (6, "F", "New Zealand", None, "Oceania", None), + ], + schema, + ) + + chispa.assert_df_equality(actual_df, expected_df, ignore_row_order=True) + + +def test_type_three_scd_upsert_apply_multiple_times_using_the_same_scope(tmp_path): + path = f"{tmp_path}/tmp/delta-type-three-scd-upsert-with-single-column" + input = [ + (1, "A", "Canada", None), + (2, "B", "Germany", None), + (3, "C", "Japan", None), + (4, "D", "South Africa", None), + ] + schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("country", StringType(), True), + StructField("prev_country", StringType(), True), + ] + ) + df = spark.createDataFrame(data=input, schema=schema) + df.write.format("delta").save(path) + + updates_data = [ + (2, "B", "France"), # value to upsert + (4, "D", "Nigeria"), # value to upsert + (5, "E", "Mexico"), # new value + ] + updates_schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("country", StringType(), True), + ] + ) + updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema) + + delta_table = DeltaTable.forPath(spark, path) + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"prev_country"}) + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"prev_country"}) + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"prev_country"}) + + actual_df = spark.read.format("delta").load(path) + + expected_df = spark.createDataFrame( + [ + (1, "A", "Canada", None), + (2, "B", "France", "Germany"), + (3, "C", "Japan", None), + (4, "D", "Nigeria", "South Africa"), + (5, "E", "Mexico", None), + ], + schema, + ) + + chispa.assert_df_equality(actual_df, expected_df, ignore_row_order=True) + +def test_errors_out_type_three_scd_duplication_on_dictionary_values(tmp_path): + path = f"{tmp_path}/tmp/delta-type-three-scd-upsert-error-duplication-dict-values" + input = [] + schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("prev_job", StringType(), True), + StructField("country", StringType(), True), + StructField("prev_country", StringType(), True), + StructField("continent", StringType(), True), + StructField("prev_continent", StringType(), True), + ] + ) + df = spark.createDataFrame(data=input, schema=schema) + df.write.format("delta").save(path) + + updates_data = [] + updates_schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("country", StringType(), True), + StructField("continent", StringType(), True), + ] + ) + updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema) + + delta_table = DeltaTable.forPath(spark, path) + + with pytest.raises(TypeError): # duplication on dict value 'prev_job' + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"prev_country", "job":"prev_job", "continent":"prev_job"}) + + +def test_errors_out_type_three_scd_dictionary_keys_equal_to_values(tmp_path): + path = f"{tmp_path}/tmp/delta-type-three-scd-upsert-error-dict-keys-equal-to-value" + input = [] + schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("prev_job", StringType(), True), + StructField("country", StringType(), True), + StructField("prev_country", StringType(), True), + StructField("continent", StringType(), True), + StructField("prev_continent", StringType(), True), + ] + ) + df = spark.createDataFrame(data=input, schema=schema) + df.write.format("delta").save(path) + + updates_data = [] + updates_schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("country", StringType(), True), + StructField("continent", StringType(), True), + ] + ) + updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema) + + delta_table = DeltaTable.forPath(spark, path) + + with pytest.raises(TypeError): # the first and last [key,value] are equal + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"country", "job":"prev_job", "continent":"continent"}) + + +def test_errors_out_type_three_scd_provided_columns_do_not_exist_in_delta_table(tmp_path): + path = f"{tmp_path}/tmp/delta-type-three-scd-upsert-error-columns-not-found-in-delta-table" + input = [] + schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("country", StringType(), True), + StructField("prev_country", StringType(), True), + StructField("continent", StringType(), True), + StructField("prev_continent", StringType(), True), + ] + ) + df = spark.createDataFrame(data=input, schema=schema) + df.write.format("delta").save(path) + + updates_data = [] + updates_schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("country", StringType(), True), + StructField("continent", StringType(), True), + ] + ) + updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema) + + delta_table = DeltaTable.forPath(spark, path) + + with pytest.raises(TypeError): # the prev_job does not exist in the target delta table + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"country", "job":"prev_job", "continent":"continent"}) + + +def test_errors_out_type_three_scd_missing_columns_in_update_dataframe(tmp_path): + path = f"{tmp_path}/tmp/delta-type-three-scd-upsert-error-missing-columns-in-update-dataframe" + input = [] + schema = StructType( + [ + StructField("pkey", IntegerType(), True), + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("country", StringType(), True), + StructField("prev_country", StringType(), True), + StructField("continent", StringType(), True), + StructField("prev_continent", StringType(), True), + ] + ) + df = spark.createDataFrame(data=input, schema=schema) + df.write.format("delta").save(path) + + updates_data = [] + updates_schema = StructType( + [ + # missing column -> pkey + StructField("name", StringType(), True), + StructField("job", StringType(), True), + StructField("country", StringType(), True), + StructField("continent", StringType(), True), + ] + ) + updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema) + + delta_table = DeltaTable.forPath(spark, path) + + with pytest.raises(TypeError): + mack.type_3_scd_upsert(delta_table, updates_df, "pkey", {"country":"prev_country", "job":"prev_job", "continent":"prev_continent"})