'Spark Streaming code to read from Kafka broker and calculate average of numbers by mapWithState

I'm receiving my value from Kafka in the form of character,number in a random manner which is generated by another program. An example of values I receive:

a,4
b,3
d,7
f,5
b,2
...

Here is the program that generates these values and sends it over Kafka topic:

package generator

import java.util.{Date, Properties}
import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, ProducerConfig}
import scala.util.Random
import kafka.producer.KeyedMessage

object ScalaProducerExample extends App {
    def getRandomVal: String = {
        val i = Random.nextInt(alphabet.size)
        val key = alphabet(i)
        val value = Random.nextInt(alphabet.size)
        key + "," + value
    }

    val alphabet = 'a' to 'z'
    val events = 10000
    val topic = "avg"
    val brokers = "localhost:9092"
    val rnd = new Random()

    val props = new Properties()
    props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers)
    props.put(ProducerConfig.CLIENT_ID_CONFIG, "ScalaProducerExample")
    props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer")
    props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer")
    val producer = new KafkaProducer[String, String](props)

    while (true) {
        val data = new ProducerRecord[String, String](topic, null, getRandomVal)
        producer.send(data)
        print(data + "\n")
    }

    producer.close()
}

My task is to show the up-to-date average for every character based on the sum and number of received values for it from the beginning until now. I wrote this code for this task and I am receiving from Kafka successfully:

package DirectKafkaWordCount

import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.common.serialization.StringDeserializer

import org.apache.spark.SparkConf
import org.apache.spark.streaming._
import org.apache.spark.streaming.kafka010._

case class Data(key: String, count: Int)

object DirectKafkaWordCount {
  def main(args: Array[String]): Unit = {
    val Array(brokers, topics) = args
    val sparkConf = new SparkConf().setMaster("local[4]").setAppName("DirectKafkaWordCount")
    val ssc = new StreamingContext(sparkConf, Seconds(2))
    ssc.checkpoint("_checkpoint")
    val topicsSet = topics.split(",").toSet
    val kafkaParams = Map[String, Object](
      ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG -> brokers,
      ConsumerConfig.GROUP_ID_CONFIG -> "1",
      ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer],
      ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[StringDeserializer])
    val messages = KafkaUtils.createDirectStream[String, String](
      ssc,
      LocationStrategies.PreferConsistent,
      ConsumerStrategies.Subscribe[String, String](topicsSet, kafkaParams))

    val pairs = messages.map(_.value).map(x => (x.split(",")(0), x.split(",")(1).toInt))
    val wc = pairs.mapWithState(StateSpec.function((key: String, value: Option[Int], state: State[String]) => {
      val newNum = value.getOrElse(0)
      val sData = state.getOption.getOrElse("a,0,0")
      var count = sData.split(",")(1).toInt
      var sum = sData.split(",")(2).toInt
      sum = sum + newNum
      count = count + 1
      val output = key + "," + count.toString + "," + sum.toString
      state.update(output)
      output
    }))
    wc.map(process _).print()
    ssc.start()
    ssc.awaitTermination()
  }

  def process(s: String): String = {
    var count = s.split(",")(1).toInt
    var sum = s.split(",")(2).toInt
    s.split(",")(0) + "," + (sum / count).toString
  }
}

My problem is that the average for every character becomes the constant number 12. Is there sth wrong with the mapWithState function? How can I fix it? Something that makes me suspicious is that there isn't only one entry per character in the output, there may be 3 or 4 entries per character. A sample output:

-------------------------------------------
Time: 1651560488000 ms
-------------------------------------------
d,12
t,12
h,12
t,12
h,12
x,12
d,12
h,12
p,12
p,12
...


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source