Skip to content

Commit

Permalink
[SPARK-45279][PYTHON][CONNECT] Attach plan_id for all logical plans
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Attach plan_id for all logical plans, except `CachedRelation`

### Why are the changes needed?
1, all logical plans should contain its plan id in protos
2, catalog plans also contain the plan id in scala client, e.g.

https://github.com/apache/spark/blob/05f5dccbd34218c7d399228529853bdb1595f3a2/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala#L63-L67

`newDataset` method will set the plan id

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#43055 from zhengruifeng/connect_plan_id.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Sep 25, 2023
1 parent a881438 commit 609552e
Showing 1 changed file with 40 additions and 39 deletions.
79 changes: 40 additions & 39 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,9 +1190,7 @@ def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> pro

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None

plan = proto.Relation()
plan.common.plan_id = self._child._plan_id
plan = self._create_proto_relation()
plan.collect_metrics.input.CopyFrom(self._child.plan(session))
plan.collect_metrics.name = self._name
plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for x in self._exprs])
Expand Down Expand Up @@ -1689,7 +1687,9 @@ def __init__(self) -> None:
super().__init__(None)

def plan(self, session: "SparkConnectClient") -> proto.Relation:
return proto.Relation(catalog=proto.Catalog(current_database=proto.CurrentDatabase()))
plan = self._create_proto_relation()
plan.catalog.current_database.SetInParent()
return plan


class SetCurrentDatabase(LogicalPlan):
Expand All @@ -1698,7 +1698,7 @@ def __init__(self, db_name: str) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation()
plan = self._create_proto_relation()
plan.catalog.set_current_database.db_name = self._db_name
return plan

Expand All @@ -1709,7 +1709,8 @@ def __init__(self, pattern: Optional[str] = None) -> None:
self._pattern = pattern

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_databases=proto.ListDatabases()))
plan = self._create_proto_relation()
plan.catalog.list_databases.SetInParent()
if self._pattern is not None:
plan.catalog.list_databases.pattern = self._pattern
return plan
Expand All @@ -1722,7 +1723,8 @@ def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None)
self._pattern = pattern

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_tables=proto.ListTables()))
plan = self._create_proto_relation()
plan.catalog.list_tables.SetInParent()
if self._db_name is not None:
plan.catalog.list_tables.db_name = self._db_name
if self._pattern is not None:
Expand All @@ -1737,7 +1739,8 @@ def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None)
self._pattern = pattern

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_functions=proto.ListFunctions()))
plan = self._create_proto_relation()
plan.catalog.list_functions.SetInParent()
if self._db_name is not None:
plan.catalog.list_functions.db_name = self._db_name
if self._pattern is not None:
Expand All @@ -1752,7 +1755,7 @@ def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_columns=proto.ListColumns()))
plan = self._create_proto_relation()
plan.catalog.list_columns.table_name = self._table_name
if self._db_name is not None:
plan.catalog.list_columns.db_name = self._db_name
Expand All @@ -1765,7 +1768,7 @@ def __init__(self, db_name: str) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(get_database=proto.GetDatabase()))
plan = self._create_proto_relation()
plan.catalog.get_database.db_name = self._db_name
return plan

Expand All @@ -1777,7 +1780,7 @@ def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(get_table=proto.GetTable()))
plan = self._create_proto_relation()
plan.catalog.get_table.table_name = self._table_name
if self._db_name is not None:
plan.catalog.get_table.db_name = self._db_name
Expand All @@ -1791,7 +1794,7 @@ def __init__(self, function_name: str, db_name: Optional[str] = None) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(get_function=proto.GetFunction()))
plan = self._create_proto_relation()
plan.catalog.get_function.function_name = self._function_name
if self._db_name is not None:
plan.catalog.get_function.db_name = self._db_name
Expand All @@ -1804,7 +1807,7 @@ def __init__(self, db_name: str) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(database_exists=proto.DatabaseExists()))
plan = self._create_proto_relation()
plan.catalog.database_exists.db_name = self._db_name
return plan

Expand All @@ -1816,7 +1819,7 @@ def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(table_exists=proto.TableExists()))
plan = self._create_proto_relation()
plan.catalog.table_exists.table_name = self._table_name
if self._db_name is not None:
plan.catalog.table_exists.db_name = self._db_name
Expand All @@ -1830,7 +1833,7 @@ def __init__(self, function_name: str, db_name: Optional[str] = None) -> None:
self._db_name = db_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(function_exists=proto.FunctionExists()))
plan = self._create_proto_relation()
plan.catalog.function_exists.function_name = self._function_name
if self._db_name is not None:
plan.catalog.function_exists.db_name = self._db_name
Expand All @@ -1854,9 +1857,7 @@ def __init__(
self._options = options

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(
catalog=proto.Catalog(create_external_table=proto.CreateExternalTable())
)
plan = self._create_proto_relation()
plan.catalog.create_external_table.table_name = self._table_name
if self._path is not None:
plan.catalog.create_external_table.path = self._path
Expand Down Expand Up @@ -1892,7 +1893,7 @@ def __init__(
self._options = options

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(create_table=proto.CreateTable()))
plan = self._create_proto_relation()
plan.catalog.create_table.table_name = self._table_name
if self._path is not None:
plan.catalog.create_table.path = self._path
Expand All @@ -1915,7 +1916,7 @@ def __init__(self, view_name: str) -> None:
self._view_name = view_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(drop_temp_view=proto.DropTempView()))
plan = self._create_proto_relation()
plan.catalog.drop_temp_view.view_name = self._view_name
return plan

Expand All @@ -1926,9 +1927,7 @@ def __init__(self, view_name: str) -> None:
self._view_name = view_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(
catalog=proto.Catalog(drop_global_temp_view=proto.DropGlobalTempView())
)
plan = self._create_proto_relation()
plan.catalog.drop_global_temp_view.view_name = self._view_name
return plan

Expand All @@ -1939,11 +1938,8 @@ def __init__(self, table_name: str) -> None:
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(
catalog=proto.Catalog(
recover_partitions=proto.RecoverPartitions(table_name=self._table_name)
)
)
plan = self._create_proto_relation()
plan.catalog.recover_partitions.table_name = self._table_name
return plan


Expand All @@ -1953,9 +1949,8 @@ def __init__(self, table_name: str) -> None:
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(
catalog=proto.Catalog(is_cached=proto.IsCached(table_name=self._table_name))
)
plan = self._create_proto_relation()
plan.catalog.is_cached.table_name = self._table_name
return plan


Expand All @@ -1966,10 +1961,11 @@ def __init__(self, table_name: str, storage_level: Optional[StorageLevel] = None
self._storage_level = storage_level

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
_cache_table = proto.CacheTable(table_name=self._table_name)
if self._storage_level:
_cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
plan = proto.Relation(catalog=proto.Catalog(cache_table=_cache_table))
plan.catalog.cache_table.CopyFrom(_cache_table)
return plan


Expand All @@ -1979,7 +1975,7 @@ def __init__(self, table_name: str) -> None:
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(uncache_table=proto.UncacheTable()))
plan = self._create_proto_relation()
plan.catalog.uncache_table.table_name = self._table_name
return plan

Expand All @@ -1989,7 +1985,9 @@ def __init__(self) -> None:
super().__init__(None)

def plan(self, session: "SparkConnectClient") -> proto.Relation:
return proto.Relation(catalog=proto.Catalog(clear_cache=proto.ClearCache()))
plan = self._create_proto_relation()
plan.catalog.clear_cache.SetInParent()
return plan


class RefreshTable(LogicalPlan):
Expand All @@ -1998,7 +1996,7 @@ def __init__(self, table_name: str) -> None:
self._table_name = table_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(refresh_table=proto.RefreshTable()))
plan = self._create_proto_relation()
plan.catalog.refresh_table.table_name = self._table_name
return plan

Expand All @@ -2009,7 +2007,7 @@ def __init__(self, path: str) -> None:
self._path = path

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(refresh_by_path=proto.RefreshByPath()))
plan = self._create_proto_relation()
plan.catalog.refresh_by_path.path = self._path
return plan

Expand All @@ -2019,7 +2017,9 @@ def __init__(self) -> None:
super().__init__(None)

def plan(self, session: "SparkConnectClient") -> proto.Relation:
return proto.Relation(catalog=proto.Catalog(current_catalog=proto.CurrentCatalog()))
plan = self._create_proto_relation()
plan.catalog.current_catalog.SetInParent()
return plan


class SetCurrentCatalog(LogicalPlan):
Expand All @@ -2028,7 +2028,7 @@ def __init__(self, catalog_name: str) -> None:
self._catalog_name = catalog_name

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(set_current_catalog=proto.SetCurrentCatalog()))
plan = self._create_proto_relation()
plan.catalog.set_current_catalog.catalog_name = self._catalog_name
return plan

Expand All @@ -2039,7 +2039,8 @@ def __init__(self, pattern: Optional[str] = None) -> None:
self._pattern = pattern

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation(catalog=proto.Catalog(list_catalogs=proto.ListCatalogs()))
plan = self._create_proto_relation()
plan.catalog.list_catalogs.SetInParent()
if self._pattern is not None:
plan.catalog.list_catalogs.pattern = self._pattern
return plan
Expand Down

0 comments on commit 609552e

Please sign in to comment.