'Counting the number of leading zero bits in a sha256 encrpytion

I'm having trouble trying to count the number of leading zero bits after an sha256 hash function as I don't have a lot of experience on 'low level' stuff in python

hex = hashlib.sha256((some_data_from_file).encode('ascii')).hexdigest()
# hex = 0000094e7cc7303a3e33aaeaba76ad937603d4d040064f473a12f10ab30a879f
# this has 20 leading zero bits
hex_digits = int.from_bytes(bytes(hex.encode('ascii')), 'big') #convert str to int

#count num of leading zeroes 
def countZeros(x):
   total_bits = 256
   res = 0
   while ((x & (1 << (total_bits - 1))) == 0):
       x = (x << 1)
       res += 1
   return res
print(countZeroes(hex_digits)) # returns 2

I've also tried converting it using bin() however that didn't provide me with any leading zeros.



Solution 1:[1]

Instead of getting the hex digest and analyzing that hex string, you could just get the digest, interpret it as an int, ask for its bit-length, and subtract that from 256:

digest = hashlib.sha256(some_data_from_file.encode('ascii')).digest()
print(256 - int.from_bytes(digest, 'big').bit_length())

Demo (Try it online!):

import hashlib

some_data_from_file = '665782'

# Show hex digest for clarity
hex = hashlib.sha256(some_data_from_file.encode('ascii')).hexdigest()
print(hex)

# Show number of leading zero bits
digest = hashlib.sha256(some_data_from_file.encode('ascii')).digest()
print(256 - int.from_bytes(digest, 'big').bit_length())

Output:

0000000399c6aea5ad0c709a9bc331a3ed6494702bd1d129d8c817a0257a1462
30

Benchmark along with Pranav's (not sure how to handle mtraceur's) starting with sha256-values (i.e., before calling hexdigest() or digest()):

 462 ns   464 ns   471 ns  just_digest
 510 ns   518 ns   519 ns  just_hexdigest
 566 ns   568 ns   574 ns  Kelly3
 608 ns   608 ns   611 ns  Kelly2
 688 ns   688 ns   692 ns  Kelly
1139 ns  1139 ns  1140 ns  Pranav

Benchmark code (Try it online!):

def Kelly(sha256):
    return 256 - int.from_bytes(sha256.digest(), 'big').bit_length()

def Kelly2(sha256):
    zeros = 0
    for byte in sha256.digest():
        if byte:
            return zeros + 8 - byte.bit_length()
        zeros += 8
    return zeros

def Kelly3(sha256):
    digest = sha256.digest()
    if byte := digest[0]:
        return 8 - byte.bit_length()
    zeros = 0
    for byte in digest:
        if byte:
            return zeros + 8 - byte.bit_length()
        zeros += 8
    return zeros

def Pranav(sha256):
    nzeros = 0
    for c in sha256.hexdigest():
        if c == "0": nzeros += 4
        else: 
            digit = int(c, base=16)
            nzeros += 4 - (math.floor(math.log2(digit)) + 1)
            break
    return nzeros

def just_digest(sha256):
    return sha256.digest()

def just_hexdigest(sha256):
    return sha256.hexdigest()

funcs = just_digest, just_hexdigest, Kelly3, Kelly2, Kelly, Pranav

from timeit import repeat
import hashlib, math
from collections import deque

sha256s = [hashlib.sha256(str(i).encode('ascii'))
           for i in range(10_000)]

expect = list(map(Kelly, sha256s))
for func in funcs:
    result = list(map(func, sha256s))
    print(result == expect, func.__name__)

tss = [[] for _ in funcs]
for _ in range(10):
    print()
    for func, ts in zip(funcs, tss):
        t = min(repeat(lambda: deque(map(func, sha256s), 0), number=1))
        ts.append(t)
    for func, ts in zip(funcs, tss):
        print(*('%4d ns ' % (t / len(sha256s) * 1e9) for t in sorted(ts)[:3]), func.__name__)

Solution 2:[2]

.hexdigest() returns a string, so your hex variable is a string.

I'm going to call it h instead, because hex is a builtin python function.

So you have:

h = "0000094e7cc7303a3e33aaeaba76ad937603d4d040064f473a12f10ab30a879f"

Now this is a hexadecimal string. Each digit in a hexadecimal number gives you four bits in binary. Since this has five leading zeros, you already have 5 * 4 = 20 leading zeros.

nzeros = 0
for c in h:
    if c == "0": nzeros += 4
    else: break

Then, you need to count the leading zeros in the binary representation of the first non-zero hexadecimal digit. This is easy to get: A number has math.floor(math.log2(number)) + 1 binary digits, i.e. 4 - (math.floor(math.log2(number)) + 1) leading zeros if it's a hexadecimal digit, since they can only have a max of 4 bits. In this case, the digit is a 9 (1001 in binary), so there are zero additional leading zeros.

So, modify the previous loop:

nzeros = 0
for c in h:
    if c == "0": nzeros += 4
    else: 
        digit = int(c, base=16)
        nzeros += 4 - (math.floor(math.log2(digit)) + 1)
        break

print(nzeros) # 20

Solution 3:[3]

Danger!!!

Is this security-sensitive code? Can this hash ever be the result of hashing secret/private data?

If so, then you should probably implement something in C or similar, while taking care to protect against leaking information about the hash through side-channels.

Otherwise, I suggest picking the version (from any of these answers) that you and the people working on your code find the most intuitive, clear, and so on, unless performance matters more than readability, in which case pick the fastest one.


If your hashes are never of security-sensitive inputs, then:

If you just want a good balance of simplicity and low-effort:

def count_leading_zeroes(value, max_bits=256):
    value &= (1 << max_bits) - 1  # truncate; treat negatives as 2's compliment
    if value == 0:
        return max_bits
    significant_bits = len(bin(value)) - 2  # has "0b" prefix
    return max_bits - significant_bits

If you want to really embrace the bit twiddling you were trying in your question:

def count_leading_zeroes(value, max_bits=256):
   top_bit = 1 << (max_bits - 1)
   count = 0
   value &= (1 << max_bits) - 1
   while not value & top_bit:
       count += 1
       value <<= 1
   return count

If you're doing manual bit twiddling, I think in this case a loop which counts from the top is the most justified option, because hashes are rather evenly randomly distributed and so each bit has about even chance of not being zero.

So you have a good chance of exiting the loop early and thus executing more efficiently if you start for from the top (if you start from the bottom you have to check every bit).

You could alternatively do a bit twiddling thing that's inspired by binary search. That way instead of O(n) steps you do O(log(n)) steps. However, this arguably isn't an optimization worth doing in CPython, and for a JIT implementation like PyPy this manual optimization can actually make it harder for automatic optimization to realize that you can just use a raw "count leading zeroes" CPU instruction. If I ever get the time I'll edit this answer with an example of that later.


Now about those side-channel attacks: basically any time you have code that works on secret data (or any results of secret data which you can't prove (like a cryptographer would) have fully irretrievably lost all information about the secret data) , you should make sure your code takes does exactly the same amount of operations and takes the same branches regardless of the data!

Explaining why you should do this is outside the scope of this answer, but if you don't, your code could be harming users by leaking their secret information in ways that hackers could access.

So!

You might be tempted to modify the simple version that uses bin, but bin is inherently hash-dependent: it produces a string whose length is conditional on the leading zeroes, and as far as I know it doesn't (and logically can't! at least not in the general case) guarantee that it does so in constant-time without data-dependent branches. So we should assume merely running bin on an integer leaks information about the integer through side-channels like runtime and branch predictor state and amount of memory allocated and so on.

For illustrative purposes, if we did have a side-channel-safe bin, which I'll call "bin_", we could do:

def count_leading_zeroes(value, max_bits=256):
    value &= (1 << max_bits) - 1  # truncate; treat negatives as 2's compliment
    value <<= 1  # securely compensate for 0
    significant_bits = len(bin_(value)) - 3  # has "0b" prefix and "0" suffix
    return max_bits - significant_bits

In principle, a bit-twiddling loop could do leading zero bit count in constant-time and free of input-dependent branches.

The challenge is writing this neatly in Python, because Python is so abstracted from the underlying machine.

The core problem is this: at the CPU level, it's really easy to take a 1 or 0 bit and turn it, branchlessly, into something more useful (like a mask with all bits 1s or all bits 2s, which then lets you conditionally but branchlessly select one of two numbers or clear a number, which you can then use to implement something like "if the lowest bit is set, reset the counter to zero"). At the Python level, implementing stuff like this is a struggle through the fog of a lot of uncertainty of how the Python runtime is implemented - there are many places where it might be reasonable to have data-dependent branches under the covers. So really we want to reduce the amount of Python steps and conversions between the digest that hashlib gives us and our leading zeroes answer.

So the best option is actually to never even reach for human-readable stuff like hex or integer forms of the digest at all! Just stick to the raw digest. Something like this, conceptually:

def count_leading_zeroes_in_bytes(data):
   count = 0
   # branchless "latch" mask to stop counting:
   still_in_leading_zeroes = 1
   for byte in data:
       for index in reversed(range(8)):
           bit = (byte >> index) & 1
           # branchless "conditional" if bit is zero:
           is_zero = bit ^ 1
           # branchlessly increment count if needed:
           count += is_zero & still_in_leading_zeroes
           # branchlessly latch count on first 1 bit:
           still_in_leading_zeroes &= is_zero
   return count

This is the best I was able to think of in pure Python. And it still failed.

But some quick testing by both @KellyBundy and me (see comments and Kelly's answer for some examples) shows this version is both extremely slow, and does not actually achieve input-independent execution times (because there's yet another relevant data-dependent optimization inside Python, and possibly for other reasons we're missing).

So if you're going to try to implement anything in Python, test it thoroughly before relying on it to be actually be secure, or just taking the general gist and implementing a C or assembly version. Something like this:

/* _clz.c */

#include <limits.h>  /* CHAR_BIT */
#include <stddef.h>  /* size_t */

int count_leading_zeroes_bytes(char * bytes, size_t length)
{                                                           
    int still_in_leading_zeroes = 1;
    int count = 0;
    while(length--)
    {
        char byte = *bytes++;
        int bits = CHAR_BIT;
        while(bits--)
        {
            int bit = (byte >> bits) & 1;
            int is_zero = bit ^ 1;
            count += is_zero & still_in_leading_zeroes;
            still_in_leading_zeroes &= is_zero;
        }
    }
    return count;
}
# clz.py

import ctypes


# This is just a quick example. A mature version
# would load the library as appropriate for each
# platform.
_library = ctypes.CDLL('./_clz.so')
_count_leading_zeroes_bytes = _library.count_leading_zeroes_bytes


def count_leading_zeroes_bytes(data):
    return _count_leading_zeroes_bytes(
        ctypes.c_char_p(data),
        ctypes.c_size_t(len(data)),
    )

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Pranav Hosangadi
Solution 3