'Pyspark regexp_replace and replace it with captured group

I have a pyspark table like this.

|col                                   |
|a function createdf and function roll |
|ground function power                 |

I am trying to capture the pattern function <word> and create a new column.
My expected output will be

|col                                   | new_col
|a function createdf and function roll | [function createdf, function roll]
|ground function power                 | [function power]

Code I tried:

pat_function_definition = '((function)\s+(\w+))'
df.withColumn('temp', f.split(f.regexp_replace("col", "(((function)\s+(\w+)))" , "$1"), ",")).show()

Spark version is 2.4.



Solution 1:[1]

For Spark 3.1+, you can use regexp_extract_all function with a lookbehinds regex (?<=\bfunction\s)(\w+):

from pyspark.sql import functions as F

df = spark.createDataFrame([("a function createdf and function roll",), ("ground function power",)], ["col"])

df.withColumn(
    "new_col",
    F.expr(r"regexp_extract_all(col, '(?<=\\bfunction\\s)(\\w+)')")
).show(truncate=False)

#+-------------------------------------+----------------+
#|col                                  |new_col         |
#+-------------------------------------+----------------+
#|a function createdf and function roll|[createdf, roll]|
#|ground function power                |[power]         |
#+-------------------------------------+----------------+

For older versions, you can replace all words that are not preceded by the word function by a comma , then split the resulting string and remove empty values:

df.withColumn(
    "new_col",
    F.split(F.regexp_replace("col", r"(?<!\bfunction\s)\b\w+", ","), ",")
).withColumn(
    "new_col",
    F.expr("filter(new_col, x -> trim(x) != '')")
).show(truncate=False)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1