'How to Python sorting array with multiple keys efficiently?

Does Python sort arrays with multiple keys with or without executing the second key? (It does execute the second key) If so is there a way to stop it from evaluating the second key when it is unnecessary? Is there a module that would be able to do this easily without having to add extra code?

import random
import itertools
alist=[random.randint(0,10000000) for i in range(10000)]
def cheap(x):
    return x%100000
    
def expensive(x):
    def primes():
        D = {}
        yield 2
        for q in itertools.count(3, 2):
            p = D.pop(q, None)
            if p is None:
                yield q
                D[q*q] = q
            else:
                x = p + q
                while x in D or x % 2 == 0:
                    x += p
                D[x] = p
    
    def nth_prime(n):
        if n < 1:
            raise ValueError("n must be >= 1 for nth_prime")
        for i, p in enumerate(primes(), 1):
            if i == n:
                return p
    return nth_prime(x%99999+1)

alist.sort(key=lambda x: (cheap(x),expensive(x)))
print(alist)


Solution 1:[1]

As you've noticed, putting the expensive call in the lambda function you pass as the key function for your sort eagerly calls the expensive calculation for every value. If that's undesirable, you might need to write your own object to be returned by the key function, which lazily computes values if they're needed. Most of the values won't need the expensive key value, since their cheap value will be unique. As long as you cache the results of each call, the performance shouldn't suffer too badly (probably a lot less than just running the expensive computation a lot of times).

Here's how I'd do it. Note that the top-level interface is a class-factory function.

def lazy_keys(*keyfuncs):
    class LazyKeyList:
        def __init__(self, value):
            self.value = value
            self.cache = {}           # maps from keyfunc to keyfunc(value)

        def __iter__(self):           # lazily produces values as needed
            for keyfunc in keyfuncs:
                if keyfunc not in self.cache:
                   self.cache[keyfunc] = keyfunc(self.value)
                yield self.cache[keyfunc]

        def __eq__(self, other):
            for x, y in zip(self, other):
                if x != y:
                    return False
            return True

        def __lt__(self, other):
            for x, y in zip(self, other):
                if x < y:
                    return True
                if x > y:
                    return False
            return False

    return LazyKeyList

Now your sort would be:

alist.sort(key=lazy_keys(cheap, expensive))
print(alist)

Here's a smaller and simpler example of a fast and slow key function that shows that it only runs the slower one when necessary, for values that have matching fast key results:

from time import sleep

def fast(value):
    return value % 10

def slow(value):
    print("slow", value)
    sleep(1)
    return value

x = [random.randrange(20) for _ in range(20)]
print(x)
print(sorted(x, key=lazy_keys(fast, slow)))

The output is:

[6, 3, 7, 3, 2, 11, 6, 8, 15, 10, 12, 16, 2, 7, 19, 4, 5, 7, 2, 17]
slow 3
slow 3
slow 6
slow 6
slow 12
slow 2
slow 16
slow 2
slow 7
slow 7
slow 5
slow 15
slow 7
slow 2
slow 17
[10, 11, 2, 2, 2, 12, 3, 3, 4, 5, 15, 6, 6, 16, 7, 7, 7, 17, 8, 19]

Solution 2:[2]

Solution 1: Separate sorts

You could sort and group by cheap, then sort each group of more than one element by expensive:

alist.sort(key=cheap)
result = []
for _, [*g] in itertools.groupby(alist, cheap):
    if len(g) > 1:
        g.sort(key=expensive)
    result += g
print(result)

Solution 2: Decorator

I like my above solution best, it's simple and I think fast and uses little memory. But here's another: a decorator that can be used on the expensive/slow function to make it lazy and caching. Instead of computing the key value right away, the decorated key function returns a proxy object. Which only computes the real key value if it ever gets compared, and it stores the computed value for potential further comparisons. Full demo with parts from Blckknght:

from time import sleep
import random

def lazy(keyfunc):
    def lazied(x):
        class Lazy:
            def __lt__(self, other):
                return self.y() < other.y()
            def y(self):
                y = keyfunc(x)
                self.y = lambda: y
                return y
        return Lazy()
    return lazied

def fast(value):
    return value

@lazy
def slow(value):
    print("slow", value)
    sleep(1)
    return value

random.seed(42)
x = [random.randrange(50) for _ in range(20)]
print(x)
print(sorted(x, key=lambda x: (fast(x), slow(x))))

Output (Try it online!):

[40, 7, 1, 47, 17, 15, 14, 8, 47, 6, 43, 47, 34, 5, 37, 27, 2, 1, 5, 13]
slow 47
slow 47
slow 47
slow 1
slow 1
slow 5
slow 5
[1, 1, 2, 5, 5, 6, 7, 8, 13, 14, 15, 17, 27, 34, 37, 40, 43, 47, 47, 47]

Note that 47 appears thrice in the input, so those three each cause an expensive calculation when they get compared for the first time. Similarly 1 and 5. The other values appear only once and thus never cause an expensive calculation.

Solution 3:[3]

It does run the second function, one way around this is to sort it by the first key, and then the second

values = set(map(lambda x:x[1], alist)) newlist = [[y[0] for y in alist if y[1]==x] for x in values]

uhh, IDK past this point. I really just wanted to open a discussion,

Solution 4:[4]

You can inherit int and implement a new comparison method:

class Comparer(int):
    def __lt__(self, other):
        if not isinstance(other, Comparer):
            return NotImplemented

        diff = cheap(self) - cheap(other)
        if diff < 0:
            return True
        elif diff > 0:
            return False
        else:
            return expensive(self) < expensive(other)

Test:

>>> lst = [random.randint(0, 10000000) for i in range(100)]
>>> timeit(lambda: sorted(lst, key=lambda x: (cheap(x), expensive(x))), number=1)
13.85503659999813
>>> timeit(lambda: sorted(lst, key=Comparer), number=10000)
1.5208626000094227

More general approach:

def chain_key(*keys):
    class Comparer(int):
        def __lt__(self, other):
            for key in keys:
                k1, k2 = key(self), key(other)
                if k1 < k2:
                    return True
                elif k1 > k2:
                    return False
            return False
    return Comparer

Test:

>>> timeit(lambda: sorted(lst, key=chain_key(cheap, expensive)), number=10000)
1.583277800003998

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Blckknght
Solution 2
Solution 3 Emmanuel Lopez
Solution 4