'Iterating over RelationalGroupedDataset to find average and count of each key in Java

I have a Dataset<Row> which is built by reading a CSV file. I want to do the group by on one of the fields in CSV and then merge all the records with the same name and do some other computation over the merged Dataset.

My input CSV file looks like this

name,math_marks,science_marks
Ajay,10,20
Ram,15,25
Sita,18,30
Ajay,20,30
Sita,12,10
Sita,20,20
Ram,25,45

I want the final output to be something like this

name,math_avg,science_avg,count_of_records
Ajay,15,25,2
Ram,20,35,2
Sita,25,20,3

My initial code in Java is below:

import lombok.extern.slf4j.Slf4j;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RelationalGroupedDataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import java.util.List;
import java.util.stream.Collectors;
@Slf4j
public class ReadCSVFiles {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName(ReadCSVFiles.class.getName()).setMaster("local");
        // create Spark Context
        SparkContext context = new SparkContext(conf);
        // create spark Session
        SparkSession sparkSession = new SparkSession(context);
        context.setLogLevel("INFO");

        Dataset<Row> df = sparkSession.read()
                .format("csv")
                .option("header", true)
                .option("inferSchema", true)
                .load("/Users/ajaychoudhary/Downloads/marksInputFile.csv");

        System.out.println("========== Print Schema ============");
        df.printSchema();
        System.out.println("========== Print Data ==============");
        df.show();
        System.out.println("========== Print name of dataframe ==============");
        df.select("name").show();
        RelationalGroupedDataset relationalGroupedDataset = df.groupBy("name");
        List<String> relationalGroupedDatasetRows = relationalGroupedDataset.count().collectAsList().stream()
                .map(a -> a.mkString("::")).collect(Collectors.toList());
        log.info("relationalGroupedDatasetRows is = {} ", relationalGroupedDatasetRows);

    }
}

I am receiving this output as of now which is able to find the count of unique users. I am unable to find the average of the marks.

relationalGroupedDatasetRows is = [Ram::2, Ajay::2, Sita::3]

Also, I need to understand whether the above approach of using groupBy is fine or we can use some other alternate to achieve this.



Solution 1:[1]

I don't know much about this but you are using the 'count' method which "counts the number of rows for each group". Instead try using 'avg' method which "returns average for each group".

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Arun Kumar M