Better Numba calculation of inter-particle distance matrices
Recently, I've been looking for efficient ways to compute a distance matrix in Python. I'm deliberately trying to implement a naive n-body simulation so as to find optimized ways of calculating those, as practice. Let's do that using Numba.
As usual, we're going to be using the standard Python scientific stack... and we'll also use Numba, transitioning onto the GPU next week. Let's get those imports prepped:
import numpy as np
import scipy, scipy.spatial
import numba
import sys
np.__version__, scipy.__version__, numba.__version__, sys.version
from numpy.testing import assert_allclose
Let's get ourselves some sample 3D position data, for twenty thousand particles:
N = int(1e4)
np.random.seed(743)
r = np.random.random(size=(N, 3))
r
Part I: CPU distance matrix calculations¶
Let's start out by following up on the 2013 results of Jake Vanderplas:
Direct numpy summation¶
This is the classic approach, but with a major flaw - it allocates a lot of temporary arrays in the meantime, and that takes a while.
def pairwise_numpy(X):
"""
Reproduced from https://jakevdp.github.io/blog/2013/06/15/numba-vs-cython-take-2/
"""
return np.sqrt(((X[:, None, :] - X) ** 2).sum(-1))
pairwise_numpy_timing = %timeit -o pairwise_numpy(r)
pairwise_numpy_result = pairwise_numpy(r)
It's nice to have it for comparison, though.
Direct (slow) Python loop¶
We'll now switch over to doing things Numba-style. This means that we'll use math
instead of numpy
, so that the $\sqrt{x}$ we'll doing is explicitly a scalar operation.
import math
def scalar_distance(r, output):
N, M = r.shape
for i in range(N):
for j in range(N):
tmp = 0.0
for k in range(M):
tmp += (r[i, k] - r[j, k])**2
output[i,j] = math.sqrt(tmp)
output = np.zeros((N, N), dtype=float)
# warning: LONG
direct_summation_timeit = %timeit -o -n1 -r1 scalar_distance(r, output)
# sanity check!
assert_allclose(pairwise_numpy_result, output)
print(f"The direct summation implementation is {direct_summation_timeit.average / pairwise_numpy_timing.average:.2f} slower than NumPy.")
And now, let's simply wrap this in numba.njit
.
Note that the below is equivalent to
@numba.njit
def scalar_distance(...):
...
numba_jit_scalar_distance = numba.njit(scalar_distance)
numba_jit_timing = %timeit -o numba_jit_scalar_distance(r, output)
assert_allclose(pairwise_numpy_result, output)
print(f"Our Numba implementation is {pairwise_numpy_timing.average/numba_jit_timing.average:.2f} times faster than NumPy!")
Not bad! But we can still get speedups by replacing range
with numba.prange
, which tells Numba that "yes, this loop is trivially parallelizable". To do so we use the parallel=True
flag to njit
:
Optimal numba solution¶
@numba.njit(parallel=True)
def numba_jit_scalar_distance_parallel(r, output):
N, M = r.shape
for i in numba.prange(N):
for j in numba.prange(N):
tmp = 0.0
for k in range(M):
tmp += (r[i, k] - r[j, k])**2
output[i,j] = math.sqrt(tmp)
numba_jit_parallel_timing = %timeit -o numba_jit_scalar_distance_parallel(r, output)
assert_allclose(pairwise_numpy_result, output)
print(f"Using `parallel=True` grants us a further {numba_jit_timing.average/numba_jit_parallel_timing.average:.2f}x speedup.")
Note that I've got four cores on this laptop, so this problem is truly trivially parallelilzable. This is nice because numba.prange
is actually a no-op when not using it from within numba
:
def scalar_distance_prange(r, output):
N, M = r.shape
for i in numba.prange(N):
for j in numba.prange(N):
tmp = 0.0
for k in range(M):
tmp += (r[i, k] - r[j, k])**2
output[i,j] = math.sqrt(tmp)
direct_summation_prange_timeit = %timeit -o -n1 -r1 scalar_distance_prange(r, output)
assert_allclose(pairwise_numpy_result, output)
print(f"{direct_summation_prange_timeit.average:.5f}s vs {direct_summation_timeit.average:.5f}s.")
It's something you can just throw in "for free", lets you debug stuff just as easily, and once you end up turning on parallel = True
, it lets speed ups kick in.
However, suppose we wanted to have this run really fast. What we then could do is turn to the GPU. And this is exactly what we'll be doing next week!
Comments