'Understand the number of partitions that data saved with partitionBy gets read in

When I save a dataframe partitioned with partitionBy and then read it, how many partitions will it be read in? I used the following script to understand the behaviour but it is not making sense to me

I have an 8-core machine

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
import random
def f(num_key):
    data1 = [(i,random.randint(1,5),random.randint(1,5)) for t in range(2) for i in range(num_key)]
    df1=spark.createDataFrame(data1,schema = 'a int,b int,c int')
    df1.write.partitionBy("a").csv('df.csv',header=True,mode='overwrite')
    df1.write.partitionBy("a").parquet('df.parquet',mode='overwrite')
    print(f"partitions to be saved: {df1.select('a').distinct().count()}")
    print(f"csv_partitions:{spark.read.csv('df.csv',header=True).rdd.getNumPartitions()},\
    parque_partitions:{spark.read.parquet('df.parquet',header=True).rdd.getNumPartitions()}")

The num of partitions read in are not making sense to me. example:

f(1)
output:
partitions to be saved: 1
csv_partitions:2,    parque_partitions:2
f(2)
output:
partitions to be saved: 2
csv_partitions:4,    parque_partitions:4
f(3)
output:
partitions to be saved: 3
csv_partitions:6,    parque_partitions:6
f(4)
output:
partitions to be saved: 4
csv_partitions:8,    parque_partitions:8
f(5)
output:
partitions to be saved: 5
csv_partitions:5,    parque_partitions:5
f(6)
partitions to be saved: 6
csv_partitions:6,    parque_partitions:6
f(7)
output:
partitions to be saved: 7
csv_partitions:7,    parque_partitions:7
f(700)
output:
partitions to be saved: 700
csv_partitions:44,    parque_partitions:44

How is the number of partitions being decided here? what is the logic?



Solution 1:[1]

Your data in generated with range(2). So your data is partitioned on "a" with each time only 2 items. exemple :

f(10) # generate 20 tuples
[(0, 3, 1), (1, 1, 2), (2, 4, 5), (3, 1, 5), (4, 4, 4), (5, 5, 3), (6, 1, 1), (7, 4, 2), (8, 1, 1), (9, 3, 2), (0, 3, 4), (1, 3, 1), (2, 3, 2), (3, 3, 2), (4, 1, 3), (5, 1, 5), (6, 3, 4), (7, 1, 2), (8, 1, 4), (9, 4, 2)]

# check partitions content
for i, part in enumerate(df.rdd.glom().collect()):
    print({i: len(part)})

{0: 3}
{1: 3}
{2: 3}
{3: 3}
{4: 3}
{5: 3}
{6: 2}

So you're ending up with 20 divided by 8 (default parallelism - number of cores) equals to 2.5 which is rounded up to 3 except on last partition.

You can apply the same formula below around 200, then somewhere around that threshold data gets shuffled across execuctors (cores) so that there is not too many partitions or partitions doesn't end too sparse.

Spark smallest unit of parallelism is a partition, so it balance data between locality and repartition to have a good ratio of parallelism / shuffle.

if you change 2 by 4 (partition on "a"), you get the following:

{0: 4}
{1: 4}
{2: 4}
{3: 4}
{4: 4}
{5: 4}
{6: 4}
{7: 2}

So in both cases, group and join operations on "a" are optimized because data in on the same partition.

Shuffle occurs when data is bigger because at some point, it would be impossible to group / join everything on a single node / executor.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1