'Complex aggregation

I have data in a topic that needs to be counted at multiple levels and all code and articles only mention the word count example.

An example of the data would be:

serial: 123 country: us date: 01/05/2018 state: new york city: nyc visitors: 5

serial: 123 country: us date: 01/06/2018 state: new york city: queens visitors: 10

serial: 456 date: 01/06/2018 country: us state: new york city: queens visitors: 27

serial: 123 date: 01/06/2018 country: us state: new york city: nyc visitors: 867

I have done the filter, groupBy but the aggregate ? Sorry for the Java 8 and & mix , i prefer 8 but learning it at the same time

KTable<String, CountryVisitorModel> countryStream1 = inStream
    .filter((key, value) -> value.status.equalsIgnoreCase("TEST_DATA"))
    .groupBy((key, value) -> value.serial)
            new Initializer<CountryVisitorModel>() {

            public CountryVisitorModelapply() {
                return new CountryVisitorModel();
        new Aggregator<String, InputModel, CountryVisitorModel>() {

            public CountryVisitorModelapply(String key, InputModel value, CountryVisitorModel aggregate) {

    aggregate.serial = value.serial;
    aggregate.country_name = value.country_name;
    aggregate.city_name = value.city_name;


    return aggregate;
Materialized.with(stringSerde, visitorSerde));

For all equal serial_id (this would be the group by) count total number of visitors per this:

serial country state city total_num_visitors

Solution 1:[1]

If each record contributes to exactly one count, I would recommend to branch() the stream and count per sub-stream:

KStream stream = builder.stream(...)
KStream[] subStreams = stream.branch(...);

// each record of `stream` will be contained in exactly _one_ `substream`
subStream[0].grouByKey().count(); // or aggregate() instead of count()
// ...

If branching does not work, because a single record needs to go into multiple counts, you can "broadcast" and filter:

KStream stream = builder.stream(...)

// each record in `stream` will be "duplicated" and sent to all `filters`
stream.filter(...).grouByKey().count(); // or aggregate() instead of count()
// ...

Using the same KStream object multiple time and apply multiple operators (in our case filter(), each record will be "broadcasted" to all operators). Note, that record (ie, objects) are not physically copied for this case, but the same input record object is used to call each filter().

Solution 2:[2]

You could keep the field values in a set and get the count per key with the Set#size method.

import com.google.gson.Gson;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.streams.*;
import org.apache.kafka.streams.kstream.Consumed;
import org.apache.kafka.streams.kstream.Grouped;
import org.apache.kafka.streams.kstream.Materialized;
import org.apache.kafka.streams.state.KeyValueStore;
import org.junit.jupiter.api.Test;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.HashSet;
import java.util.Properties;
import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;

public class SomeTest {

    public static class VisitorDetails {
        public long serial;
        public String country;
        public long date;
        public String state;
        public String city;
        public long visitors;

    public static class Aggregate {
        public Set <String> countrySet = new HashSet <>();
        public long countryCounter;

        public Set <String> citySet = new HashSet <>();
        public long cityCounter;

        public long totalVisitorCounter = 0;

    public static class CustomSerializer<T> implements Serializer <T> {
        private static final Charset CHARSET = StandardCharsets.UTF_8;
        static private final Gson gson = new Gson();

        public byte[] serialize(String topic, T data) {
            String line = gson.toJson(data);
            return line.getBytes(CHARSET);

    public static class CustomDeserializer<T> implements Deserializer <T> {
        private static final Charset CHARSET = StandardCharsets.UTF_8;
        static private final Gson gson = new Gson();

        private final Class <T> tClass;

        public CustomDeserializer(Class <T> tClass) {
            this.tClass = tClass;

        public T deserialize(String topic, byte[] data) {
            try {
                String person = new String(data, CHARSET);
                return gson.fromJson(person, tClass);
            } catch (Exception e) {
                throw new IllegalArgumentException("Deserialization failed:", e);

    public static class AggregateSerde implements Serde <Aggregate> {

        public Serializer <Aggregate> serializer() {
            return new CustomSerializer <Aggregate>();

        public Deserializer <Aggregate> deserializer() {
            return new CustomDeserializer <Aggregate>(Aggregate.class);

    public static class VisitorDetailsSerde implements Serde <VisitorDetails> {

        public Serializer <VisitorDetails> serializer() {
            return new CustomSerializer <VisitorDetails>();

        public Deserializer <VisitorDetails> deserializer() {
            return new CustomDeserializer <VisitorDetails>(VisitorDetails.class);

    void test() {

        StreamsBuilder builder = new StreamsBuilder();
        builder.stream("input", Consumed.with(Serdes.Long(), new VisitorDetailsSerde()))
                .groupByKey(Grouped.with(Serdes.Long(), new VisitorDetailsSerde()))
                        (key, value, agg) -> {

                            agg.countryCounter = agg.countrySet.size();

                            agg.cityCounter = agg.citySet.size();

                            agg.totalVisitorCounter += value.visitors;
                            return agg;

                        Materialized. <Long, Aggregate, KeyValueStore <Bytes, byte[]>>as("store-name-2")
                                .withValueSerde(new AggregateSerde())
                                .withLoggingDisabled() // only for testing,
                                // recommended to not disable on prod as it provides fault tolerance


        Topology topology = builder.build();
        Properties properties = new Properties();
        properties.put(StreamsConfig.APPLICATION_ID_CONFIG, "test");
        properties.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234");

        TopologyTestDriver testDriver = new TopologyTestDriver(topology, properties);

        TestInputTopic <Long, VisitorDetails> inputTopic = testDriver.createInputTopic("input",
                Serdes.Long().serializer(), new CustomSerializer <VisitorDetails>());

        inputTopic.pipeInput(123L, visitorDetail(123L, "usa", "ny", 10L));
        inputTopic.pipeInput(123L, visitorDetail(123L, "usa", "la", 20L));
        inputTopic.pipeInput(123L, visitorDetail(123L, "pl", "krk", 30L));
        inputTopic.pipeInput(123L, visitorDetail(123L, "pl", "wrs", 40L));
        inputTopic.pipeInput(123L, visitorDetail(123L, "pl", "krk", 50L));

        KeyValueStore <Long, Aggregate> keyValueStore = testDriver. <Long, Aggregate>getKeyValueStore("store-name-2");




    private VisitorDetails visitorDetail(long serial, String country, String city, long visitors) {
        VisitorDetails visitorDetails = new VisitorDetails();
        visitorDetails.serial = serial;
        visitorDetails.country = country;
        visitorDetails.city = city;
        visitorDetails.visitors = visitors;
        return visitorDetails;


