Skip to content

Commit

Permalink
Change SCD functions to take DeltaTable instead of Path (#67)
Browse files Browse the repository at this point in the history
* Change SCD functions to take DeltaTable instead of path
  • Loading branch information
PadenZach authored Jan 21, 2023
1 parent 4c81240 commit 4505af3
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ You'd like to perform an upsert with this data:
Here's how to perform the upsert:

```scala
mack.type_2_scd_upsert(path, updatesDF, "pkey", ["attr1", "attr2"])
mack.type_2_scd_upsert(delta_table, updatesDF, "pkey", ["attr1", "attr2"])
```

Here's the table after the upsert:
Expand Down
20 changes: 11 additions & 9 deletions mack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@


def type_2_scd_upsert(
path: str, updates_df: DataFrame, primary_key: str, attr_col_names: List[str]
delta_table: DeltaTable,
updates_df: DataFrame,
primary_key: str,
attr_col_names: List[str],
) -> None:
"""
<description>
:param path: <description>
:type path: str
:type path: DeltaTable
:param updates_df: <description>
:type updates_df: DataFrame
:param primary_key: <description>
Expand All @@ -27,7 +30,7 @@ def type_2_scd_upsert(
:rtype: None
"""
return type_2_scd_generic_upsert(
path,
delta_table,
updates_df,
primary_key,
attr_col_names,
Expand All @@ -38,7 +41,7 @@ def type_2_scd_upsert(


def type_2_scd_generic_upsert(
path: str,
delta_table: DeltaTable,
updates_df: DataFrame,
primary_key: str,
attr_col_names: List[str],
Expand All @@ -49,7 +52,7 @@ def type_2_scd_generic_upsert(
"""
<description>
:param path: <description>
:param delta_table: DeltaTable
:type path: str
:param updates_df: <description>
:type updates_df: DataFrame
Expand All @@ -70,10 +73,9 @@ def type_2_scd_generic_upsert(
:returns: <description>
:rtype: None
"""
base_table = DeltaTable.forPath(pyspark.sql.SparkSession.getActiveSession(), path)

# validate the existing Delta table
base_col_names = base_table.toDF().columns
base_col_names = delta_table.toDF().columns
required_base_col_names = (
[primary_key]
+ attr_col_names
Expand Down Expand Up @@ -104,7 +106,7 @@ def type_2_scd_generic_upsert(
staged_updates_attrs = " OR ".join(staged_updates_attrs)
staged_part_1 = (
updates_df.alias("updates")
.join(base_table.toDF().alias("base"), primary_key)
.join(delta_table.toDF().alias("base"), primary_key)
.where(f"base.{is_current_col_name} = true AND ({updates_attrs})")
.selectExpr("NULL as mergeKey", "updates.*")
)
Expand All @@ -121,7 +123,7 @@ def type_2_scd_generic_upsert(
}
res_thing = {**thing, **thing2}
res = (
base_table.alias("base")
delta_table.alias("base")
.merge(
source=staged_updates.alias("staged_updates"),
condition=pyspark.sql.functions.expr(
Expand Down
24 changes: 18 additions & 6 deletions tests/test_public_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_upserts_with_single_attribute(tmp_path):
)
updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema)

mack.type_2_scd_upsert(path, updates_df, "pkey", ["attr"])
delta_table = DeltaTable.forPath(spark, path)
mack.type_2_scd_upsert(delta_table, updates_df, "pkey", ["attr"])

actual_df = spark.read.format("delta").load(path)

Expand Down Expand Up @@ -110,8 +111,9 @@ def test_errors_out_if_base_df_does_not_have_all_required_columns(tmp_path):
)
updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema)

delta_table = DeltaTable.forPath(spark, path)
with pytest.raises(TypeError):
mack.type_2_scd_upsert(path, updates_df, "pkey", ["attr"])
mack.type_2_scd_upsert(delta_table, updates_df, "pkey", ["attr"])


def test_errors_out_if_updates_table_does_not_contain_all_required_columns(tmp_path):
Expand Down Expand Up @@ -146,8 +148,9 @@ def test_errors_out_if_updates_table_does_not_contain_all_required_columns(tmp_p
)
updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema)

delta_table = DeltaTable.forPath(spark, path)
with pytest.raises(TypeError):
mack.type_2_scd_upsert(path, updates_df, "pkey", ["attr"])
mack.type_2_scd_upsert(delta_table, updates_df, "pkey", ["attr"])


def test_upserts_based_on_multiple_attributes(tmp_path):
Expand Down Expand Up @@ -184,7 +187,8 @@ def test_upserts_based_on_multiple_attributes(tmp_path):
)
updates_df = spark.createDataFrame(data=updates_data, schema=updates_schema)

mack.type_2_scd_upsert(path, updates_df, "pkey", ["attr1", "attr2"])
delta_table = DeltaTable.forPath(spark, path)
mack.type_2_scd_upsert(delta_table, updates_df, "pkey", ["attr1", "attr2"])

actual_df = spark.read.format("delta").load(path)

Expand Down Expand Up @@ -235,8 +239,9 @@ def test_upserts_based_on_date_columns(tmp_path):
).toDF("pkey", "attr", "effective_date")

# perform upsert
delta_table = DeltaTable.forPath(spark, path)
mack.type_2_scd_generic_upsert(
path, updates_df, "pkey", ["attr"], "cur", "effective_date", "end_date"
delta_table, updates_df, "pkey", ["attr"], "cur", "effective_date", "end_date"
)

actual_df = spark.read.format("delta").load(path)
Expand Down Expand Up @@ -287,8 +292,15 @@ def test_upserts_based_on_version_number(tmp_path):
).toDF("pkey", "attr", "effective_ver")

# perform upsert
delta_table = DeltaTable.forPath(spark, path)
mack.type_2_scd_generic_upsert(
path, updates_df, "pkey", ["attr"], "is_current", "effective_ver", "end_ver"
delta_table,
updates_df,
"pkey",
["attr"],
"is_current",
"effective_ver",
"end_ver",
)

# show result
Expand Down

0 comments on commit 4505af3

Please sign in to comment.