Skip to main content

Better Numba calculation of inter-particle distance matrices

Recently, I've been looking for efficient ways to compute a distance matrix in Python. I'm deliberately trying to implement a naive n-body simulation so as to find optimized ways of calculating those, as practice. Let's do that using Numba.

As usual, we're going to be using the standard Python scientific stack... and we'll also use Numba, transitioning onto the GPU next week. Let's get those imports prepped:

In [1]:
import numpy as np
import scipy, scipy.spatial
import numba
import sys
np.__version__, scipy.__version__, numba.__version__, sys.version
from numpy.testing import assert_allclose

Let's get ourselves some sample 3D position data, for twenty thousand particles:

In [2]:
N = int(1e4)
np.random.seed(743)
r = np.random.random(size=(N, 3))
r
Out[2]:
array([[0.83244056, 0.94442527, 0.57451672],
       [0.09049263, 0.08428888, 0.43300003],
       [0.29973189, 0.11463598, 0.27817412],
       ...,
       [0.49628111, 0.1462252 , 0.18381982],
       [0.80535628, 0.07900376, 0.19831322],
       [0.75236151, 0.02655101, 0.54791037]])

Part I: CPU distance matrix calculations

Let's start out by following up on the 2013 results of Jake Vanderplas:

Direct numpy summation

This is the classic approach, but with a major flaw - it allocates a lot of temporary arrays in the meantime, and that takes a while.

In [3]:
def pairwise_numpy(X):
    """
    Reproduced from https://jakevdp.github.io/blog/2013/06/15/numba-vs-cython-take-2/
    """
    return np.sqrt(((X[:, None, :] - X) ** 2).sum(-1))
pairwise_numpy_timing = %timeit -o pairwise_numpy(r)
pairwise_numpy_result = pairwise_numpy(r)
5.02 s ± 43.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It's nice to have it for comparison, though.

Direct (slow) Python loop

We'll now switch over to doing things Numba-style. This means that we'll use math instead of numpy, so that the $\sqrt{x}$ we'll doing is explicitly a scalar operation.

In [4]:
import math
def scalar_distance(r, output):
    N, M = r.shape
    for i in range(N):
        for j in range(N):
            tmp = 0.0
            for k in range(M):
                tmp += (r[i, k] - r[j, k])**2
            output[i,j] = math.sqrt(tmp)
output = np.zeros((N, N), dtype=float)
In [5]:
# warning: LONG
direct_summation_timeit = %timeit -o -n1 -r1 scalar_distance(r, output)

# sanity check!
assert_allclose(pairwise_numpy_result, output)

print(f"The direct summation implementation is {direct_summation_timeit.average / pairwise_numpy_timing.average:.2f} slower than NumPy.")
4min 6s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
The direct summation implementation is 49.11 slower than NumPy.

And now, let's simply wrap this in numba.njit.

Note that the below is equivalent to

@numba.njit
def scalar_distance(...):
    ...
In [6]:
numba_jit_scalar_distance = numba.njit(scalar_distance)
numba_jit_timing = %timeit -o numba_jit_scalar_distance(r, output)

assert_allclose(pairwise_numpy_result, output)

print(f"Our Numba implementation is {pairwise_numpy_timing.average/numba_jit_timing.average:.2f} times faster than NumPy!")
408 ms ± 16.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Our Numba implementation is 12.31 times faster than NumPy!

Not bad! But we can still get speedups by replacing range with numba.prange, which tells Numba that "yes, this loop is trivially parallelizable". To do so we use the parallel=True flag to njit:

Optimal numba solution

In [7]:
@numba.njit(parallel=True)
def numba_jit_scalar_distance_parallel(r, output):
    N, M = r.shape
    for i in numba.prange(N):
        for j in numba.prange(N):
            tmp = 0.0
            for k in range(M):
                tmp += (r[i, k] - r[j, k])**2
            output[i,j] = math.sqrt(tmp)

numba_jit_parallel_timing = %timeit -o numba_jit_scalar_distance_parallel(r, output)

assert_allclose(pairwise_numpy_result, output)

print(f"Using `parallel=True` grants us a further {numba_jit_timing.average/numba_jit_parallel_timing.average:.2f}x speedup.")
105 ms ± 5.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Using `parallel=True` grants us a further 3.90x speedup.

Note that I've got four cores on this laptop, so this problem is truly trivially parallelilzable. This is nice because numba.prange is actually a no-op when not using it from within numba:

In [8]:
def scalar_distance_prange(r, output):
    N, M = r.shape
    for i in numba.prange(N):
        for j in numba.prange(N):
            tmp = 0.0
            for k in range(M):
                tmp += (r[i, k] - r[j, k])**2
            output[i,j] = math.sqrt(tmp)

direct_summation_prange_timeit = %timeit -o -n1 -r1 scalar_distance_prange(r, output)
assert_allclose(pairwise_numpy_result, output)
print(f"{direct_summation_prange_timeit.average:.5f}s vs {direct_summation_timeit.average:.5f}s.")
4min 2s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
242.71444s vs 246.70353s.

It's something you can just throw in "for free", lets you debug stuff just as easily, and once you end up turning on parallel = True, it lets speed ups kick in.

However, suppose we wanted to have this run really fast. What we then could do is turn to the GPU. And this is exactly what we'll be doing next week!

Comments