'finding the middle values with the min distance in pyspark

i need some help please i have this dataframe with an even number of values for the column 'b'

df1 = spark.createDataFrame([
    ('c',1),
    ('c',2),
    ('c',4),
    ('c',6),
    ('c',7),
    ('c',8),
], ['a', 'b'])
df1.show()
+---+---+
|  a|  b|
+---+---+
|  c|  1|
|  c|  2|
|  c|  4|
|  c|  6|
|  c|  7|
|  c|  8|

i want to take the middle values that are mid[4,6]

calculate the distance between those numbers and the numbers adjacent : dis[2,4]=2 , dis[6,7]=1

take the min distance min_dis[6,7]=1

and finally display the value that is in both min_dis[6,7] and mid[4,6]

result = spark.createDataFrame([
    ('c',6),
], ['a', 'b'])
result.show()
+---+---+
|  a|  b|
+---+---+
|  c|  6|
+---+---+

is there a way to do this in pyspark ?



Solution 1:[1]

Try this:

Assuming df1 is as follows:

df1 = spark.createDataFrame(
    [
        ("c", 1),
        ("c", 2),
        ("c", 4),
        ("c", 6),
        ("c", 7),
        ("c", 8),
        ("d", 1),
        ("d", 2),
        ("d", 4),
        ("d", 6),
        ("d", 7),
        ("d", 8),
        ("d", 9),
        ("d", 10),
    ],
    ["a", "b"],
)

Let us first define our window function

from pyspark.sql import Window
import pyspark.sql.functions as F

w = (
    Window.partitionBy("a")
    .orderBy(F.asc("b"))
    .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)

Then let us look at step-by-step parts:

df2 = df1.withColumn("list", F.collect_list("b").over(w)).\
select("a", "list").distinct()
df2.show(2, False)
+---+-------------------------+
|a  |list                     |
+---+-------------------------+
|c  |[1, 2, 4, 6, 7, 8]       |
|d  |[1, 2, 4, 6, 7, 8, 9, 10]|
+---+-------------------------+

Then find distances:

df3 = df2.withColumn("mid_ind", (F.size("list") / 2).cast("int")).\
withColumn('lower_mid_diff', F.abs(F.col('list')[F.col('mid_ind')-1]-F.col('list')[F.col('mid_ind')-2])).\
withColumn('upper_mid_diff', F.abs(F.col('list')[F.col('mid_ind')]-F.col('list')[F.col('mid_ind')+1]))
df3.show(2, False)
+---+-------------------------+-------+--------------+--------------+
|a  |list                     |mid_ind|lower_mid_diff|upper_mid_diff|
+---+-------------------------+-------+--------------+--------------+
|c  |[1, 2, 4, 6, 7, 8]       |3      |2             |1             |
|d  |[1, 2, 4, 6, 7, 8, 9, 10]|4      |2             |1             |
+---+-------------------------+-------+--------------+--------------+

and apply final when statement:

df3.withColumn('b', F.when(F.col('lower_mid_diff')<=F.col('upper_mid_diff'), F.col('list')[F.col('mid_ind')-1]).
               otherwise(F.col('list')[F.col('mid_ind')])).select('a', 'b').show()
+---+---+
|  a|  b|
+---+---+
|  c|  6|
|  d|  7|
+---+---+

Hope this helps. But also note that you can combine many steps and make it shorter :)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1