'Python fuctional style iterative algoritm?

In Haskell there is a simple list function available

iterate :: (a -> a) -> a -> [a]
iterate f x =  x : iterate f (f x)

In python it could be implemented as following:

def iterate(f, init):
  while True:
    yield init
    init = f(init)

I was kinda surprised that something basic like this is not part of the functools/itertools modules. Could it be simply costructed in functional style (i.e. without the loop) using the tools provided in these libraries? (Mostly code golf, trying to learn about functional style in Python.)



Solution 1:[1]

You can do it using some of the functions in itertools:

from itertools import accumulate, repeat

def iterate(func, initial):
    return accumulate(repeat(None), func=lambda tot, _: func(tot), initial=initial)

Although it's clearly not very clean. Itertools is missing some fundamental functions for constructing streams, like unfoldr. Most of the itertools functions could be defined in terms of unfoldr, as it happens, but functional programming is a little uncomfortable in Python anyways so that might not be much of a benefit.

Solution 2:[2]

There is a 3rd-party "extension" to the itertools module, more-iterools, that includes (among many other things) an iterate function defined exactly like you observed:

# Exact definition, minus the doc string...
def iterate(func, start):
    while True:
        yield start
        start = func(start)

Python lacks the optimization necessary to make a recursive definition like

def iterate(func, start):
    yield from chain([start], iterate(func, func(start))

feasible.


If you are curious, Coconut is a superset of Python that does do things like tail-call optimization. Try the following code at https://cs121-team-panda.github.io/coconut-interpreter/:

@recursive_iterator
def iterate(f, s) = (s,) :: iterate(f, f(s))

for x in iterate(x -> x + 1, 0)$[1000:1010]:
    print(x)

(I'm not entirely sure the recursive_iterator decorator is necessary. The iteration slice demonstrates, I think, that this avoids the recursion-depth error similar code in Python would produce.)

Solution 3:[3]

You could use the walrus operator* in a generator expression to get desired output.

from itertools import chain, repeat

def f(x):
    return x + 10

x = 0
it = chain([x], (x:=f(x) for _ in repeat(None)))

>>> next(it)
0
>>> next(it)
10
>>> next(it)
20

* Walrus operator is available from python3.8 or above

Solution 4:[4]

Yes, we can construct it "loop-less" with map and itertools, and it's faster than the others:

from itertools import tee, chain, islice

def iterate(f, init):
    parts = [[init]]
    values1, values2 = tee(chain.from_iterable(parts))
    parts.append(map(f, values2))
    return values1

def f(x):
    return 3 * x

print(*islice(iterate(f, 1), 10))

Output (Try it online!):

1 3 9 27 81 243 729 2187 6561 19683

The first problem is that we need the values both for the outside user and fed back into itself to compute further values. We can use tee to duplicate the values.

Next we have a chicken-and-egg problem: We want to use map(f, values2) to get the function values, where values2 comes from the map iterator that we're only about to create! Fortunately, chain.from_iterable takes an iterable that we can extend after creating the chain.

Alternatively, we can make parts a generator, as that accesses values2 only after it's created:

def iterate(f, init):
    def parts():
        yield init,
        yield map(f, values2)
    values1, values2 = tee(chain.from_iterable(parts()))
    return values1

Benchmark for computing 100,000 values with f = abs and init = 0 (just something fast to minimize the dilution of the speed differences of the solutions):

CPython 3.8.0b4 on tio.run:

    mean  stdev  (from best 5 of 20 attempts)
 2.75 ms  0.03 ms  with_itertools1  (my first solution)
 2.76 ms  0.02 ms  with_itertools2  (my second solution)
 5.29 ms  0.02 ms  with_generator   (the question's solution)
 5.73 ms  0.04 ms  with_walrus
 9.00 ms  0.09 ms  with_accumulate

CPython 3.10.4 on my Windows laptop:

    mean  stdev  (from best 5 of 20 attempts)
 8.37 ms  0.02 ms  with_itertools2
 8.37 ms  0.00 ms  with_itertools1
17.86 ms  0.00 ms  with_generator
20.73 ms  0.01 ms  with_walrus
26.03 ms  0.24 ms  with_accumulate

CPython 3.10.4 on a Debian Google Compute Engine instance:

    mean  stdev  (from best 5 of 20 attempts)
 2.25 ms  0.00 ms  with_itertools1
 2.26 ms  0.00 ms  with_itertools2
 3.91 ms  0.00 ms  with_generator
 4.43 ms  0.00 ms  with_walrus
 7.14 ms  0.01 ms  with_accumulate

Benchmark code (Try it online!):

from itertools import accumulate, tee, chain, islice, repeat
import timeit
from bisect import insort
from random import shuffle
from statistics import mean, stdev
import sys

def with_accumulate(f, init):
    return accumulate(repeat(None), func=lambda tot, _: f(tot), initial=init)

def with_generator(f, init):
    while True:
        yield init
        init = f(init)

def with_walrus(f, init):
    return chain([x:=init], (x:=f(x) for _ in repeat(None)))

def with_itertools1(f, init):
    parts = [[init]]
    values1, values2 = tee(chain.from_iterable(parts))
    parts.append(map(f, values2))
    return values1

def with_itertools2(f, init):
    def parts():
        yield init,
        yield map(f, values2)
    values1, values2 = tee(chain.from_iterable(parts()))
    return values1

solutions = [
    with_accumulate,
    with_generator,
    with_walrus,
    with_itertools1,
    with_itertools2,
]

for solution in solutions:
    iterator = solution(lambda x: 3 * x, 1)
    print(*islice(iterator, 10), solution.__name__)

def consume(iterator, n):
    next(islice(iterator, n, n), None)

attempts, best = 20, 5
times = {solution: [] for solution in solutions}
for _ in range(attempts):
    shuffle(solutions)
    for solution in solutions:
        time = min(timeit.repeat(lambda: consume(solution(abs, 0), 10**5), number=1))
        insort(times[solution], time)
print(f'    mean  stdev  (from best {best} of {attempts} attempts)')
for solution in sorted(solutions, key=times.get):
    ts = times[solution][:best]
    print('%5.2f ms ' * 2 % (mean(ts) * 1e3, stdev(ts) * 1e3), solution.__name__)

print()
print(sys.implementation.name, sys.version)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 oisdk
Solution 2
Solution 3
Solution 4