'Scala: How Does flatMap Work In This Case?

//this class holds method run that let us give a state and obtain from it an 
//output value and next state    
case class State[S, +A](run: S => (A, S)) {
      //returns a state for which the run function will return (b,s) instead of (a,s)
      def map[B](f: A => B): State[S, B] =
        flatMap(a => unit(f(a)))
    
      //returns a state for which the run function will return (f(a,b),s)
      def map2[B, C](sb: State[S, B])(f: (A, B) => C): State[S, C] =
        flatMap(a => sb.map(b => f(a, b)))

      def flatMap[B](f: A => State[S, B]): State[S, B] =
        State(s => {
          val (a, s1) = run(s)
          val sB = f(a)
          sB.run(s1)
        })
    }

I am new to functional programming and I am dizzied by this code, I tried to understand what does flatMap really does but it looks very abstract to me that I couldn't understand it. For example here we defined map in terms of flatMap, but I can't understand how can we explain the implementation I mean how does it work?



Solution 1:[1]

This chapter is about pure functional state, and describes how the state is being updated in a sequence of operations in a functional way, that is to say, by means of functional composition. So you don“t have to create a global state variable that is mutated by a sequence of method executions, as you probably would do in a OOP fashion.

Having said that, this wonderful book explains each concept declaring an abstraction to represent an effect, and some primitives which are going to be used to build more complex functions. In this chapter, the effect is the State. It is represented as a function that takes a previous state and yields a value and a new state, as you can see with the run function declaration. So if you want to compose functions that propagate some state in a pure functional fashion, all of them need to return a new functional State.

With flatMap, which is a primitive of the State type, you can compose functions and, at the same time, propagate some internal state.

For example, if you have two functions:

case class Intermediate(v: Double = 0)

def execute1(a: Int): State[Int, Intermediate] = {
  State { s =>
    /** Do whatever */
    (Intermediate(a * a), s + 1)
  }
}

def execute2(a: Int): State[Int, Intermediate] = {
  State { s =>
    /** Do whatever */
    (Intermediate(a * a), s + 1)
  }
}

As you can see each function consist in a State function that takes a previous state, executes some logic, and returns a tuple of Int, which is going to be used to track the number of functions that have been executed, and a Intermediate value that we use to yield the execution of every function.

Now, if you apply flat map to execute both functions sequentially:

val result: State[Int, Intermediate] =
      execute1(2) flatMap (r => execute2(r.v.toInt))

the result val is another State that you can see as the point of entry to run the pipeline:

/** 0 would be the initial state **/
result.run(0)

the result is:

(Intermediate(16.0),2)

How this works under the hood:

  1. First you have the first function execute1. This function returns a State[Intermediate, Int]. I will use Intermediate to store the result of each execution, that would be the yielded value, and an Int to count the number of functions that are going to be executed in the flatMap sequential composition.

  2. Then I apply flatMap after the function call, which is the flatMap of the State type.

Remember the definition:

def flatMap[B](f: A => State[S, B]): State[S, B] =
  State(s => {
    /* This function is executed in a State context */
    val (a, s1) = run(s)
    val sB = f(a)
    sB.run(s1)
})

The execution is inside a state context(note that the flatMap function is inside the State type) which is be the first state function of the execute1 function. Maybe this representation is easier:

val st1 = execute1(2)
st1.flatMap(r => execute2(r.v.toInt))

So when run(s) is executed, it actually executes the function state returned by the execute1(2) function. It gives the first (a, s1) result.

Then you have to execute the second function, execute2. You can visualize this function as the input parameter of the flatMap function. This function is executed with the yielded value of the execute1(2) function (the Intermediate type, left side of the tuple). This execution returns another state function that uses the state(right value of the tuple) of the execution1, to create the final tuple with the resulting yielded value of the execute2 function and the int corresponding to the number of executions. Note, that in each function we increment the internal state in one unit.

Note that the result is another State. This has a run method that waits for the initial state.

I have tried to explain the workflow of function calls... it is a little bit hard if you are not used to think in a functional style, but it is only functional composition, it becomes familiar after a while.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Emiliano Martinez