'heapq with custom compare predicate
I am trying to build a heap with a custom sort predicate. Since the values going into it are of 'user-defined' type, I cannot modify their built-in comparison predicate.
Is there a way to do something like:
h = heapq.heapify([...], key=my_lt_pred)
h = heapq.heappush(h, key=my_lt_pred)
Or even better, I could wrap the heapq functions in my own container so I don't need to keep passing the predicate.
Solution 1:[1]
Define a class, in which override the __lt__() function. See example below (works in Python 3.7):
import heapq
class Node(object):
def __init__(self, val: int):
self.val = val
def __repr__(self):
return f'Node value: {self.val}'
def __lt__(self, other):
return self.val < other.val
heap = [Node(2), Node(0), Node(1), Node(4), Node(2)]
heapq.heapify(heap)
print(heap) # output: [Node value: 0, Node value: 2, Node value: 1, Node value: 4, Node value: 2]
heapq.heappop(heap)
print(heap) # output: [Node value: 1, Node value: 2, Node value: 2, Node value: 4]
Solution 2:[2]
The heapq documentation suggests that heap elements could be tuples in which the first element is the priority and defines the sort order.
More pertinent to your question, however, is that the documentation includes a discussion with sample code of how one could implement their own heapq wrapper functions to deal with the problems of sort stability and elements with equal priority (among other issues).
In a nutshell, their solution is to have each element in the heapq be a triple with the priority, an entry count and the element to be inserted. The entry count ensures that elements with the same priority a sorted in the order they were added to the heapq.
Solution 3:[3]
setattr(ListNode, "__lt__", lambda self, other: self.val <= other.val)
Use this for comparing values of objects in heapq
Solution 4:[4]
The limitation with both answers is that they don't allow ties to be treated as ties. In the first, ties are broken by comparing items, in the second by comparing input order. It is faster to just let ties be ties, and if there are a lot of them it could make a big difference. Based on the above and on the docs, it is not clear if this can be achieved in heapq. It does seem strange that heapq does not accept a key, while functions derived from it in the same module do.
P.S.:
If you follow the link in the first comment ("possible duplicate...") there is another suggestion of defining le which seems like a solution.
Solution 5:[5]
In python3, you can use cmp_to_key from functools module. cpython source code.
Suppose you need a priority queue of triplets and specify the priority use the last attribute.
from heapq import *
from functools import cmp_to_key
def mycmp(triplet_left, triplet_right):
key_l, key_r = triplet_left[2], triplet_right[2]
if key_l > key_r:
return -1 # larger first
elif key_l == key_r:
return 0 # equal
else:
return 1
WrapperCls = cmp_to_key(mycmp)
pq = []
myobj = tuple(1, 2, "anystring")
# to push an object myobj into pq
heappush(pq, WrapperCls(myobj))
# to get the heap top use the `obj` attribute
inner = pq[0].obj
Performance Test:
Environment
python 3.10.2
Code
from functools import cmp_to_key
from timeit import default_timer as time
from random import randint
from heapq import *
class WrapperCls1:
__slots__ = 'obj'
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
kl, kr = self.obj[2], other.obj[2]
return True if kl > kr else False
def cmp_class2(obj1, obj2):
kl, kr = obj1[2], obj2[2]
return -1 if kl > kr else 0 if kl == kr else 1
WrapperCls2 = cmp_to_key(cmp_class2)
triplets = [[randint(-1000000, 1000000) for _ in range(3)] for _ in range(100000)]
# tuple_triplets = [tuple(randint(-1000000, 1000000) for _ in range(3)) for _ in range(100000)]
def test_cls1():
pq = []
for triplet in triplets:
heappush(pq, WrapperCls1(triplet))
def test_cls2():
pq = []
for triplet in triplets:
heappush(pq, WrapperCls2(triplet))
def test_cls3():
pq = []
for triplet in triplets:
heappush(pq, (-triplet[2], triplet))
start = time()
for _ in range(10):
test_cls1()
# test_cls2()
# test_cls3()
print("total running time (seconds): ", -start+(start:=time()))
Results
use list instead of tuple, per function:
- WrapperCls1: 16.2ms
- WrapperCls1 with
__slots__: 9.8ms - WrapperCls2: 8.6ms
- move the priority attribute into the first position ( don't support custom predicate ): 6.0ms.
Therefore, this method is slightly faster than using a custom class with an overridden __lt__() function and the __slots__ attribute.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Fanchen Bao |
| Solution 2 | srgerg |
| Solution 3 | Ritika Gupta |
| Solution 4 | |
| Solution 5 |
