'How can I implement a numba jitted priority queue?

I am failing to implement a numba jitted priority queue.

Heavily plagiarized from the python docs, I am fairly happy with this class.

import itertools

import numba as nb
from numba.experimental import jitclass
from typing import List, Tuple, Dict
from heapq import heappush, heappop


class PurePythonPriorityQueue:
    def __init__(self):
        self.pq = [] # list of entries arranged in a heap
        self.entry_finder = {}  # mapping of indices to entries
        self.REMOVED = -1 # placeholder for a removed item
        self.counter = itertools.count() # unique sequence count

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            self.remove_item(item)
        count = next(self.counter)
        entry = [priority, count, item]
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Tuple[int, int]):
        """Mark an existing item as REMOVED.  Raise KeyError if not found."""
        entry = self.entry_finder.pop(item)
        entry[-1] = self.REMOVED

    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item = heappop(self.pq)
            if item is not self.REMOVED:
                del self.entry_finder[item]
                return item
        raise KeyError("pop from an empty priority queue")

Now I would like to call this from a numba jitted function doing heavy numerical work, so I tried to make this a numba jitclass. Since entries are heterogeneous list in the vanilla python implementation, I figured I should implement other jitclasses as well. However, I am getting a Failed in nopython mode pipeline (step: nopython frontend) (full trace below).

Here is my attempt:

@jitclass
class Item:
    i: int
    j: int

    def __init__(self, i, j):
        self.i = i
        self.j = j


@jitclass
class Entry:
    priority: float
    count: int
    item: Item
    removed: bool

    def __init__(self, p: float, c: int, i: Item):
        self.priority = p
        self.count = c
        self.item = i
        self.removed = False


@jitclass
class PriorityQueue:
    pq: List[Entry]
    entry_finder: Dict[Item, Entry]
    counter: int

    def __init__(self):
        self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
        self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
        self.counter = 0

    def put(self, item: Item, priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            self.remove_item(item)
        self.counter += 1
        entry = Entry(priority, self.counter, item)
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Item):
        """Mark an existing item as REMOVED.  Raise KeyError if not found."""
        entry = self.entry_finder.pop(item)
        entry.removed = True

    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item = heappop(self.pq)
            entry = heappop(self.pq)
            if not entry.removed:
                del self.entry_finder[entry.item]
                return item
        raise KeyError("pop from an empty priority queue")


if __name__ == "__main__":
    queue1 = PurePythonPriorityQueue()
    queue1.put((4, 5), 5.4)
    queue1.put((5, 6), 1.0)
    print(queue1.pop())  # Yay this works!

    queue2 = PriorityQueue()  # Nope
    queue2.put(Item(4, 5), 5.4)
    queue2.put(Item(5, 6), 1.0)
    print(queue2.pop())

Is this type of data structure implementable with numba? What is wrong with my current implementation?

Full trace:

(5, 6)
Traceback (most recent call last):
  File "/home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py", line 106, in <module>
    queue2 = PriorityQueue()  # Nope
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/experimental/jitclass/base.py", line 122, in __call__
    return cls._ctor(*bind.args[1:], **bind.kwargs)
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function typeddict_empty at 0x7fead8c3f8b0>) found for signature:

 >>> typeddict_empty(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'typeddict_empty': File: numba/typed/typeddict.py: Line 213.
    With argument(s): '(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   No implementation of function Function(<function new_dict at 0x7fead9002a60>) found for signature:

    >>> new_dict(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)

   There are 2 candidate implementations:
         - Of which 2 did not match due to:
         Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
           With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
          Rejected as the implementation raised a specific error:
            TypingError: Failed in nopython mode pipeline (step: nopython mode backend)
          No implementation of function Function(<built-in function eq>) found for signature:

           >>> eq(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)

          There are 30 candidate implementations:
                - Of which 28 did not match due to:
                Overload of function 'eq': File: <numerous>: Line N/A.
                  With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
                 No match.
                - Of which 2 did not match due to:
                Operator Overload in function 'eq': File: unknown: Line unknown.
                  With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
                 No match for registered cases:
                  * (bool, bool) -> bool
                  * (int8, int8) -> bool
                  * (int16, int16) -> bool
                  * (int32, int32) -> bool
                  * (int64, int64) -> bool
                  * (uint8, uint8) -> bool
                  * (uint16, uint16) -> bool
                  * (uint32, uint32) -> bool
                  * (uint64, uint64) -> bool
                  * (float32, float32) -> bool
                  * (float64, float64) -> bool
                  * (complex64, complex64) -> bool
                  * (complex128, complex128) -> bool

          During: lowering "$20call_function.8 = call $12load_global.4(dp, $16load_deref.6, $18load_deref.7, func=$12load_global.4, args=[Var(dp, dictobject.py:653), Var($16load_deref.6, dictobject.py:654), Var($18load_deref.7, dictobject.py:654)], kws=(), vararg=None)" at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
     raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/types/functions.py:229

   During: resolving callee type: Function(<function new_dict at 0x7fead9002a60>)
   During: typing of call at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py (219)


   File "../../../../../.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py", line 219:
       def impl(cls, key_type, value_type):
           return dictobject.new_dict(key_type, value_type)
           ^

  raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/typeinfer.py:1071

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.abstract.TypeRef'>, 'empty') for typeref[<class 'numba.core.types.containers.DictType'>])
During: typing of call at /home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py (72)


File "priorityqueue.py", line 72:
    def __init__(self):
        <source elided>
        self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
        self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
        ^

During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>


Process finished with exit code 1


Solution 1:[1]

I was having a similar issue related to the custom class Entry. Basically Numba is unable to use __lt__(self, other) to compare entries, and gave me an No implementation of function Function(< built-in function lt >) error.

So I came up with the following. It works on Numba 0.55.1 on Python 3.8 on Ubuntu 18.04. The trick is to avoid using any custom class object as part of your priority queue item to avoid the aforementioned error.

from typing import List, Dict, Tuple 
from heapq import heappush, heappop
import numba as nb
from numba.experimental import jitclass

# priority, counter, item, removed
entry_def = (0.0, 0, (0,0), nb.typed.List([False]))
entry_type = nb.typeof(entry_def)

@jitclass
class PriorityQueue:
    # The following helps numba infer type of variable
    pq: List[entry_type]
    entry_finder: Dict[Tuple[int, int], entry_type]
    counter: int
    entry: entry_type

    def __init__(self):
        # Must declare types here see https://numba.pydata.org/numba-doc/dev/reference/pysupported.html
        self.pq = nb.typed.List.empty_list((0.0, 0, (0,0), nb.typed.List([False])))
        self.entry_finder = nb.typed.Dict.empty( (0, 0), (0.0, 0, (0,0), nb.typed.List([False])))
        self.counter = 0

    def put(self, item: Tuple[int, int], priority: float = 0.0):
        """Add a new item or update the priority of an existing item"""
        if item in self.entry_finder:
            # Mark duplicate item for deletion
            self.remove_item(item)
    
        self.counter += 1
        entry = (priority, self.counter, item, nb.typed.List([False]))
        self.entry_finder[item] = entry
        heappush(self.pq, entry)

    def remove_item(self, item: Tuple[int, int]):
        """Mark an existing item as REMOVED via True.  Raise KeyError if not found."""
        self.entry = self.entry_finder.pop(item)
        self.entry[3][0] = True
    
    def pop(self):
        """Remove and return the lowest priority item. Raise KeyError if empty."""
        while self.pq:
            priority, count, item, removed = heappop(self.pq)
            if not removed[0]:
                del self.entry_finder[item]
                return priority, item
        raise KeyError("pop from an empty priority queue")

First define a global variable called entry_def which will serve as entries in our priority queue pq. The "removed" sentinel is now replaced with an numba.typed.List([False]) to serve as a way of keeping track of which item to delete in case of priority key changes (lazy deletion). The annoying part is having to type out the definition of pq and entry_finder; I couldn't reuse the entry_def variable.

I can confirm the PriorityQueue works as follows:

    q = PriorityQueue()
    q.put((1,1), 5.0)
    q.put((1,1), 4.0)
    q.put((1,1), 3.0)
    q.put((1,1), 6.0)
    print(q.pq)
    >>  [(3.0, 3, (1, 1), ListType[bool]([True])), (5.0, 1, (1, 1), ListType[bool]([True])), (4.0, 2, (1, 1), ListType[bool]([True])), (6.0, 4, (1, 1), ListType[bool]([False]))]
    print(q.pop())
    >> (6.0, (1, 1))
    print(len(q.entry_finder))
    >> 0

Hopefully someone will find this useful or can provide a better alternative.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1