import numpy as np
import scipy, scipy.spatial
import numba
import sys
np.__version__, scipy.__version__, numba.__version__, sys.versionfrom numpy.testing import assert_allclose
Recently, I’ve been looking for efficient ways to compute a distance matrix in Python. I’m deliberately trying to implement a naive n-body simulation so as to find optimized ways of calculating those, as practice. Let’s do that using Numba.
As usual, we’re going to be using the standard Python scientific stack… and we’ll also use Numba, transitioning onto the GPU next week. Let’s get those imports prepped:
Let’s get ourselves some sample 3D position data, for twenty thousand particles:
= int(1e4)
N 743)
np.random.seed(= np.random.random(size=(N, 3))
r r
array([[0.83244056, 0.94442527, 0.57451672],
[0.09049263, 0.08428888, 0.43300003],
[0.29973189, 0.11463598, 0.27817412],
...,
[0.49628111, 0.1462252 , 0.18381982],
[0.80535628, 0.07900376, 0.19831322],
[0.75236151, 0.02655101, 0.54791037]])
Part I: CPU distance matrix calculations
Let’s start out by following up on the 2013 results of Jake Vanderplas:
Direct numpy summation
This is the classic approach, but with a major flaw - it allocates a lot of temporary arrays in the meantime, and that takes a while.
def pairwise_numpy(X):
"""
Reproduced from https://jakevdp.github.io/blog/2013/06/15/numba-vs-cython-take-2/
"""
return np.sqrt(((X[:, None, :] - X) ** 2).sum(-1))
= %timeit -o pairwise_numpy(r)
pairwise_numpy_timing = pairwise_numpy(r) pairwise_numpy_result
5.02 s ± 43.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
It’s nice to have it for comparison, though.
Direct (slow) Python loop
We’ll now switch over to doing things Numba-style. This means that we’ll use math
instead of numpy
, so that the \(\sqrt{x}\) we’ll doing is explicitly a scalar operation.
import math
def scalar_distance(r, output):
= r.shape
N, M for i in range(N):
for j in range(N):
= 0.0
tmp for k in range(M):
+= (r[i, k] - r[j, k])**2
tmp = math.sqrt(tmp)
output[i,j] = np.zeros((N, N), dtype=float) output
# warning: LONG
= %timeit -o -n1 -r1 scalar_distance(r, output)
direct_summation_timeit
# sanity check!
assert_allclose(pairwise_numpy_result, output)
print(f"The direct summation implementation is {direct_summation_timeit.average / pairwise_numpy_timing.average:.2f} slower than NumPy.")
4min 6s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
The direct summation implementation is 49.11 slower than NumPy.
And now, let’s simply wrap this in numba.njit
.
Note that the below is equivalent to
@numba.njit
def scalar_distance(...):
...
= numba.njit(scalar_distance)
numba_jit_scalar_distance = %timeit -o numba_jit_scalar_distance(r, output)
numba_jit_timing
assert_allclose(pairwise_numpy_result, output)
print(f"Our Numba implementation is {pairwise_numpy_timing.average/numba_jit_timing.average:.2f} times faster than NumPy!")
408 ms ± 16.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Our Numba implementation is 12.31 times faster than NumPy!
Not bad! But we can still get speedups by replacing range
with numba.prange
, which tells Numba that “yes, this loop is trivially parallelizable”. To do so we use the parallel=True
flag to njit
:
Optimal numba solution
@numba.njit(parallel=True)
def numba_jit_scalar_distance_parallel(r, output):
= r.shape
N, M for i in numba.prange(N):
for j in numba.prange(N):
= 0.0
tmp for k in range(M):
+= (r[i, k] - r[j, k])**2
tmp = math.sqrt(tmp)
output[i,j]
= %timeit -o numba_jit_scalar_distance_parallel(r, output)
numba_jit_parallel_timing
assert_allclose(pairwise_numpy_result, output)
print(f"Using `parallel=True` grants us a further {numba_jit_timing.average/numba_jit_parallel_timing.average:.2f}x speedup.")
105 ms ± 5.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Using `parallel=True` grants us a further 3.90x speedup.
Note that I’ve got four cores on this laptop, so this problem is truly trivially parallelilzable. This is nice because numba.prange
is actually a no-op when not using it from within numba
:
def scalar_distance_prange(r, output):
= r.shape
N, M for i in numba.prange(N):
for j in numba.prange(N):
= 0.0
tmp for k in range(M):
+= (r[i, k] - r[j, k])**2
tmp = math.sqrt(tmp)
output[i,j]
= %timeit -o -n1 -r1 scalar_distance_prange(r, output)
direct_summation_prange_timeit
assert_allclose(pairwise_numpy_result, output)print(f"{direct_summation_prange_timeit.average:.5f}s vs {direct_summation_timeit.average:.5f}s.")
4min 2s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
242.71444s vs 246.70353s.
It’s something you can just throw in “for free”, lets you debug stuff just as easily, and once you end up turning on parallel = True
, it lets speed ups kick in.
However, suppose we wanted to have this run really fast. What we then could do is turn to the GPU. And this is exactly what we’ll be doing next week!