Source code for pyspark.sql.merge

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import Dict, Optional, TYPE_CHECKING

from pyspark.sql.column import Column
from pyspark.sql.utils import to_scala_map

if TYPE_CHECKING:
    from pyspark.sql.dataframe import DataFrame

__all__ = ["MergeIntoWriter"]


class MergeIntoWriter:
    """
    `MergeIntoWriter` provides methods to define and execute merge actions based
    on specified conditions.

    .. versionadded: 4.0.0
    """

    def __init__(self, df: "DataFrame", table: str, condition: Column):
        self._spark = df.sparkSession

        from pyspark.sql.classic.column import _to_java_column

        self._jwriter = df._jdf.mergeInto(table, _to_java_column(condition))

[docs] def whenMatched(self, condition: Optional[Column] = None) -> "MergeIntoWriter.WhenMatched": """ Initialize a `WhenMatched` action with a condition. This `WhenMatched` action will be executed when a source row matches a target table row based on the merge condition and the specified `condition` is satisfied. This `WhenMatched` can be followed by one of the following merge actions: - `updateAll`: Update all the matched target table rows with source dataset rows. - `update(Dict)`: Update all the matched target table rows while changing only a subset of columns based on the provided assignment. - `delete`: Delete all target rows that have a match in the source table. """ return self.WhenMatched(self, condition)
[docs] def whenNotMatched( self, condition: Optional[Column] = None ) -> "MergeIntoWriter.WhenNotMatched": """ Initialize a `WhenNotMatched` action with a condition. This `WhenNotMatched` action will be executed when a source row does not match any target row based on the merge condition and the specified `condition` is satisfied. This `WhenNotMatched` can be followed by one of the following merge actions: - `insertAll`: Insert all rows from the source that are not already in the target table. - `insert(Dict)`: Insert all rows from the source that are not already in the target table, with the specified columns based on the provided assignment. """ return self.WhenNotMatched(self, condition)
[docs] def whenNotMatchedBySource( self, condition: Optional[Column] = None ) -> "MergeIntoWriter.WhenNotMatchedBySource": """ Initialize a `WhenNotMatchedBySource` action with a condition. This `WhenNotMatchedBySource` action will be executed when a target row does not match any rows in the source table based on the merge condition and the specified `condition` is satisfied. This `WhenNotMatchedBySource` can be followed by one of the following merge actions: - `updateAll`: Update all the not matched target table rows with source dataset rows. - `update(Dict)`: Update all the not matched target table rows while changing only the specified columns based on the provided assignment. - `delete`: Delete all target rows that have no matches in the source table. """ return self.WhenNotMatchedBySource(self, condition)
[docs] def withSchemaEvolution(self) -> "MergeIntoWriter": """ Enable automatic schema evolution for this merge operation. """ self._jwriter = self._jwriter.withSchemaEvolution() return self
[docs] def merge(self) -> None: """ Execute the merge operation. """ self._jwriter.merge()
class WhenMatched: """ A class for defining actions to be taken when matching rows in a DataFrame during a merge operation.""" def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]): self.writer = writer if condition is None: self.when_matched = writer._jwriter.whenMatched() else: from pyspark.sql.classic.column import _to_java_column self.when_matched = writer._jwriter.whenMatched(_to_java_column(condition)) def updateAll(self) -> "MergeIntoWriter": """ Specifies an action to update all matched rows in the DataFrame. """ self.writer._jwriter = self.when_matched.updateAll() return self.writer def update(self, assignments: Dict[str, Column]) -> "MergeIntoWriter": """ Specifies an action to update matched rows in the DataFrame with the provided column assignments. """ jvm = self.writer._spark._jvm from pyspark.sql.classic.column import _to_java_column jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()}) self.writer._jwriter = self.when_matched.update(jmap) return self.writer def delete(self) -> "MergeIntoWriter": """ Specifies an action to delete matched rows from the DataFrame. """ self.writer._jwriter = self.when_matched.delete() return self.writer class WhenNotMatched: """ A class for defining actions to be taken when no matching rows are found in a DataFrame during a merge operation.""" def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]): self.writer = writer if condition is None: self.when_not_matched = writer._jwriter.whenNotMatched() else: from pyspark.sql.classic.column import _to_java_column self.when_not_matched = writer._jwriter.whenNotMatched(_to_java_column(condition)) def insertAll(self) -> "MergeIntoWriter": """ Specifies an action to insert all non-matched rows into the DataFrame. """ self.writer._jwriter = self.when_not_matched.insertAll() return self.writer def insert(self, assignments: Dict[str, Column]) -> "MergeIntoWriter": """ Specifies an action to insert non-matched rows into the DataFrame with the provided column assignments. """ jvm = self.writer._spark._jvm from pyspark.sql.classic.column import _to_java_column jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()}) self.writer._jwriter = self.when_not_matched.insert(jmap) return self.writer class WhenNotMatchedBySource: """ A class for defining actions to be performed when there is no match by source during a merge operation in a MergeIntoWriter. """ def __init__(self, writer: "MergeIntoWriter", condition: Optional[Column]): self.writer = writer if condition is None: self.when_not_matched_by_source = writer._jwriter.whenNotMatchedBySource() else: from pyspark.sql.classic.column import _to_java_column self.when_not_matched_by_source = writer._jwriter.whenNotMatchedBySource( _to_java_column(condition) ) def updateAll(self) -> "MergeIntoWriter": """ Specifies an action to update all non-matched rows in the target DataFrame when not matched by the source. """ self.writer._jwriter = self.when_not_matched_by_source.updateAll() return self.writer def update(self, assignments: Dict[str, Column]) -> "MergeIntoWriter": """ Specifies an action to update non-matched rows in the target DataFrame with the provided column assignments when not matched by the source. """ jvm = self.writer._spark._jvm from pyspark.sql.classic.column import _to_java_column jmap = to_scala_map(jvm, {k: _to_java_column(v) for k, v in assignments.items()}) self.writer._jwriter = self.when_not_matched_by_source.update(jmap) return self.writer def delete(self) -> "MergeIntoWriter": """ Specifies an action to delete matched rows from the DataFrame. """ self.writer._jwriter = self.when_not_matched_by_source.delete() return self.writer def _test() -> None: import doctest import os import py4j from pyspark.core.context import SparkContext from pyspark.sql import SparkSession import pyspark.sql.merge os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.merge.__dict__.copy() sc = SparkContext("local[4]", "PythonTest") try: spark = SparkSession._getActiveSessionOrCreate() except py4j.protocol.Py4JError: spark = SparkSession(sc) globs["spark"] = spark (failure_count, test_count) = doctest.testmod( pyspark.sql.merge, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, ) spark.stop() if failure_count: sys.exit(-1) if __name__ == "__main__": _test()