'Implementing recursive function with __iter__ method in a Python class
So I'm working on a problem in which I'm to create a Python Class to generate all Permutations of a list and I'm running across the following questions:
- I can complete this easily with a simple recursive function, but as a class it seems like I would want to use the iter method. My method calls a recursive function (list_all) that's almost identical to my iter, which is very unsettling. How do I modify my recursive function to be in compliance with best practices for iter?
- I wrote this code, saw that it worked, and I feel like I don't understand it! I try to trace the code line by line in a test case, but to me it looks like the first element in the list is frozen each time, and the rest of the list is shuffled. Instead the output comes out in an unexpected order. I'm not understanding something!
Thanks!
class permutations():
def __init__(self, ls):
self.list = ls
def __iter__(self):
ls = self.list
length = len(ls)
if length <= 1:
yield ls
else:
for p in self.list_all(ls[1:]):
for x in range(length):
yield p[:x] + ls[0:1] + p[x:]
def list_all(self, ls):
length = len(ls)
if length <= 1:
yield ls
else:
for p in self.list_all(ls[1:]):
for x in range(length):
yield p[:x] + ls[0:1] + p[x:]
Solution 1:[1]
Just call self.list_all from __iter__:
class permutations():
def __init__(self, ls):
self.list = ls
def __iter__(self):
for item in self.list_all(self.list):
yield item
def list_all(self, ls):
length = len(ls)
if length <= 1:
yield ls
else:
for p in self.list_all(ls[1:]):
for x in range(length):
yield p[:x] + ls[0:1] + p[x:]
Solution 2:[2]
Your list_all method is already a generator, so you can return that directly in __iter__:
class permutations():
def __init__(self, ls):
self.list = ls
def __iter__(self):
return self.list_all(self.list)
def list_all(self, ls):
length = len(ls)
if length <= 1:
yield ls
else:
for p in self.list_all(ls[1:]):
for x in range(length):
yield p[:x] + ls[0:1] + p[x:]
This is both cleaner to read and executes faster.
You also have option is to define list_all inside __iter__.
class permutations2():
def __init__(self, ls):
self.list = ls
def __iter__(self):
def list_all(ls):
length = len(ls)
if length <= 1:
yield ls
else:
for p in list_all(ls[1:]):
for x in range(length):
yield p[:x] + ls[0:1] + p[x:]
return list_all(self.list)
Timing permutations vs my permutations2 gives almost identical results.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | bruno desthuilliers |
| Solution 2 |
