diff --git a/Readme.md b/Readme.md index 4869aa27..1de18707 100644 --- a/Readme.md +++ b/Readme.md @@ -37,7 +37,8 @@ compiles, the performance is on par or slightly worse than the original NumPy. ## Mandelbrot fractal Results strongly depend on an implementation: a straighforward NumPy implementation -uses a data-dependent loop, which does not compile. +uses complex-valued arrays which are not supported by triton. +Working around this and several other dynamo issues, leads to speedups of about x3 to x5. The implementation based on the [Mojo benchmark](https://shashankprasanna.com/benchmarking-modular-mojo-and-pytorch-torch.compile-on-mandelbrot-function/index.html#benchmarking-pytorch-cpu-with-torchcompile) allows to compile the inner loop. The performance increase relative to numpy is substantial and strongly data size and machine diff --git a/e2e/kmeans/kmeans.py b/e2e/kmeans/kmeans.py index 204d9ce2..a3fbec9f 100644 --- a/e2e/kmeans/kmeans.py +++ b/e2e/kmeans/kmeans.py @@ -7,15 +7,9 @@ cfg.numpy_ndarray_as_tensor = True -# np.linalg.norm replacement (2-norm only), https://github.com/pytorch/pytorch/issues/105269 -def norm(a, axis): - s = (a.conj() * a).real - return np.sqrt(s.sum(axis=axis)) - - -#@torch.compile +# this will be compiled def get_labels(X, centroids) -> np.ndarray: - return np.argmin(norm(X - centroids[:, None], axis=2), + return np.argmin(np.linalg.norm(X - centroids[:, None, :], ord=2, axis=2), axis=0) @@ -31,7 +25,7 @@ def init(npts): import time # ### numpy ### -npts = int(2e7) +npts = int(1e8) X, centroids = init(npts) start_time = time.time() @@ -53,6 +47,8 @@ def init(npts): start_time = time.time() labels = get_labels_c(X, centroids) end_time = time.time() +torch.cuda.synchronize() compiled_time = end_time - start_time print("compiled: elapsed=", compiled_time, ' speedup = ', numpy_time / compiled_time) + diff --git a/e2e/mandelbrot/mandelbrot.png b/e2e/mandelbrot/mandelbrot.png index ef654186..0a5aa621 100644 Binary files a/e2e/mandelbrot/mandelbrot.png and b/e2e/mandelbrot/mandelbrot.png differ diff --git a/e2e/mandelbrot/mandelbrot.py b/e2e/mandelbrot/mandelbrot.py index 9a58aa4f..6264ba2e 100644 --- a/e2e/mandelbrot/mandelbrot.py +++ b/e2e/mandelbrot/mandelbrot.py @@ -3,16 +3,21 @@ # Copyright (2017) Nicolas P. Rougier - BSD license # More information at https://github.com/rougier/numpy-book # ----------------------------------------------------------------------------- -#import numpy as np -import torch_np as np +import math +import numpy as np +import time +# need to import before torch +from matplotlib import colors +import matplotlib.pyplot as plt -# To run on CUDA, change "cpu" to "cuda" below. import torch torch.set_default_device("cpu") +import torch._dynamo.config as cfg +cfg.numpy_ndarray_as_tensor = True -# from mandelbrot_numpy_1 import mandelbrot # copy-paste below +# ### Original NumPy version. ### def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0): # Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/... @@ -30,41 +35,75 @@ def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0): return Z, N -if __name__ == '__main__': - from matplotlib import colors - import matplotlib.pyplot as plt - ## from timeit import timeit - # Benchmark - xmin, xmax, xn = -2.25, +0.75, int(3000/3) - ymin, ymax, yn = -1.25, +1.25, int(2500/3) - maxiter = 200 - ## timeit("mandelbrot_1(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals()) - ## timeit("mandelbrot_2(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals()) - ## timeit("mandelbrot_3(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals()) +# ### Compiled analog. ### - # Visualization - xmin, xmax, xn = -2.25, +0.75, int(3000/2) - ymin, ymax, yn = -1.25, +1.25, int(2500/2) - maxiter = 200 - horizon = 2.0 ** 40 - log_horizon = np.log(np.log(horizon))/np.log(2) - Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon) +# For torch.Dynamo, need to work around +# 1. Complex numbers: add a trailing length-2 dimension for Re and Im parts. +# 2. Avoid fancy indexing: use with np.where instead to avoid data dependency +# +# Also: +# 1. Only compile the inner loop, to keep compile time and memory consumption +# under control (otherwise, can run into OOM while compiling) + +def abs2(a): + r"""abs(a) replacement.""" + return a[..., 0]**2 + a[..., 1]**2 + + +def sq2(a): + """a**2 replacement.""" + z = np.empty_like(a) + z[..., 0] = a[..., 0]**2 - a[..., 1]**2 + z[..., 1] = 2 * a[..., 0] * a[..., 1] + return z + + +@torch.compile(dynamic=True) +def step(n0, c, Z, N, horizon, chunksize): + for j in range(chunksize): + n = n0 + j + I = abs2(Z) < horizon**2 + N = np.where(I, n, N) # N[I] = n + Z = np.where(I[..., None], sq2(Z) + c, Z) # Z[I] = Z[I]**2 + C[I] + return Z, N + + +def mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=2**10, maxiter=5): + x = np.linspace(xmin, xmax, xn, dtype='float32') + y = np.linspace(ymin, ymax, yn, dtype='float32') + c = np.stack(np.broadcast_arrays(x[None, :], y[:, None]), axis=-1) + + N = np.zeros(c.shape[:-1], dtype='int') + Z = np.zeros_like(c, dtype='float32') + + chunksize=10 + n_chunks = maxiter // chunksize + + for i_chunk in range(n_chunks): + n0 = i_chunk*chunksize + Z, N = step(n0, c, Z, N, horizon, chunksize) + + N = np.where(N == maxiter-1, 0, N) # N[N == maxiter-1] = 0 + return Z, N - # Normalized recount as explained in: - # http://linas.org/art-gallery/escape/smooth.html + + +# plot a nice figure +def visualize(Z, N, horizon, xn, yn): + log_horizon = math.log(horizon, 2) M = np.nan_to_num(N + 1 - np.log(np.log(abs(Z)))/np.log(2) + log_horizon) dpi = 72 width = 10 height = 10*yn/xn - + fig = plt.figure(figsize=(width, height), dpi=dpi) ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False, aspect=1) light = colors.LightSource(azdeg=315, altdeg=10) - plt.imshow(light.shade(M.tensor.cpu().numpy(), cmap=plt.cm.hot, vert_exag=1.5, + plt.imshow(light.shade(M, cmap=plt.cm.hot, vert_exag=1.5, norm = colors.PowerNorm(0.3), blend_mode='hsv'), extent=[xmin, xmax, ymin, ymax], interpolation="bicubic") ax.set_xticks([]) @@ -72,3 +111,37 @@ def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0): plt.savefig("mandelbrot.png") plt.show() + + +if __name__ == '__main__': + # start up + xmax, xmin, xn = -2.25, 0.75, 3000 // 2 + ymax, ymin, yn = -1.25, 1.25, 2500 // 2 + + maxiter = 200 + horizon = 2**10 + + # time numpy + start_time = time.time() + Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter) + end_time = time.time() + numpy_time = end_time - start_time + print("\n\nnumpy: elapsed=", numpy_time) + + + # compile, warm up, time + for _ in range(3): + mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter) + + # measure + start_time = time.time() + Z, N = mandelbrot_c(xmin, xmax, ymin, ymax, xn, yn, horizon=horizon, maxiter=maxiter) + end_time = time.time() + compiled_time = end_time - start_time + print("compiled: elapsed=", compiled_time, ' speedup = ', numpy_time / compiled_time) + + # Visualization + Z = Z[..., 0] + 1j*Z[..., 1] + visualize(Z, N, horizon, xn, yn) + + diff --git a/e2e/mandelbrot/mandelbrot_eager.py b/e2e/mandelbrot/mandelbrot_eager.py new file mode 100644 index 00000000..9a58aa4f --- /dev/null +++ b/e2e/mandelbrot/mandelbrot_eager.py @@ -0,0 +1,74 @@ +# ----------------------------------------------------------------------------- +# From Numpy to Python +# Copyright (2017) Nicolas P. Rougier - BSD license +# More information at https://github.com/rougier/numpy-book +# ----------------------------------------------------------------------------- +#import numpy as np +import torch_np as np + + +# To run on CUDA, change "cpu" to "cuda" below. +import torch +torch.set_default_device("cpu") + + +# from mandelbrot_numpy_1 import mandelbrot # copy-paste below + +def mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon=2.0): + # Adapted from https://www.ibm.com/developerworks/community/blogs/jfp/... + # .../entry/How_To_Compute_Mandelbrodt_Set_Quickly?lang=en + X = np.linspace(xmin, xmax, xn, dtype=np.float32) + Y = np.linspace(ymin, ymax, yn, dtype=np.float32) + C = X + Y[:,None]*1j + N = np.zeros(C.shape, dtype=int) + Z = np.zeros(C.shape, np.complex64) + for n in range(maxiter): + I = np.less(abs(Z), horizon) + N[I] = n + Z[I] = Z[I]**2 + C[I] + N[N == maxiter-1] = 0 + return Z, N + + +if __name__ == '__main__': + from matplotlib import colors + import matplotlib.pyplot as plt + ## from timeit import timeit + + # Benchmark + xmin, xmax, xn = -2.25, +0.75, int(3000/3) + ymin, ymax, yn = -1.25, +1.25, int(2500/3) + maxiter = 200 + ## timeit("mandelbrot_1(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals()) + ## timeit("mandelbrot_2(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals()) + ## timeit("mandelbrot_3(xmin, xmax, ymin, ymax, xn, yn, maxiter)", globals()) + + # Visualization + xmin, xmax, xn = -2.25, +0.75, int(3000/2) + ymin, ymax, yn = -1.25, +1.25, int(2500/2) + maxiter = 200 + horizon = 2.0 ** 40 + log_horizon = np.log(np.log(horizon))/np.log(2) + Z, N = mandelbrot(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon) + + # Normalized recount as explained in: + # http://linas.org/art-gallery/escape/smooth.html + M = np.nan_to_num(N + 1 - np.log(np.log(abs(Z)))/np.log(2) + log_horizon) + + dpi = 72 + width = 10 + height = 10*yn/xn + + fig = plt.figure(figsize=(width, height), dpi=dpi) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False, aspect=1) + + light = colors.LightSource(azdeg=315, altdeg=10) + + plt.imshow(light.shade(M.tensor.cpu().numpy(), cmap=plt.cm.hot, vert_exag=1.5, + norm = colors.PowerNorm(0.3), blend_mode='hsv'), + extent=[xmin, xmax, ymin, ymax], interpolation="bicubic") + ax.set_xticks([]) + ax.set_yticks([]) + plt.savefig("mandelbrot.png") + plt.show() +