Skip to content

Instantly share code, notes, and snippets.

@shabbir-hasan
Forked from eugeneyan/mandelbrot-mojo.md
Created May 8, 2023 09:46
Show Gist options
  • Save shabbir-hasan/568f2235b56e413eeca821aa0e829fa7 to your computer and use it in GitHub Desktop.
Save shabbir-hasan/568f2235b56e413eeca821aa0e829fa7 to your computer and use it in GitHub Desktop.
Benchmarking Mojo vs. Python on Mandelbrot sets

Mandelbrot in Mojo with Python plots

Not only Mojo is great for writing high-performance code, but it also allows us to leverage huge Python ecosystem of libraries and tools. With seamless Python interoperability, Mojo can use Python for what it's good at, especially GUIs, without sacrificing performance in critical code. Let's take the classic Mandelbrot set algorithm and implement it in Mojo.

We'll introduce a Complex type and use it in our implementation.

Mandelbrot in python

%%python
import numpy as np
import numba
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import time
%%python
# Constants
xmin = -2.25
xmax = 0.75
xn = 450
ymin = -1.25
ymax = 1.25
yn = 375
max_iter = 200

# Compute the number of steps to escape
def mandelbrot_kernel(c):
    z = c
    for i in range(max_iter):
        z = z * z + c
        if abs(z) > 2:
            return i
    return max_iter

def mandelbrot():
    # Create a matrix. Each element of the matrix corresponds to a pixel
    result = np.zeros((yn, xn), dtype=np.uint32)

    dx = (xmax - xmin) / xn
    dy = (ymax - ymin) / yn

    y = ymin
    for j in range(yn):
        x = xmin
        for i in range(xn):
            result[j, i] = mandelbrot_kernel(complex(x, y))
            x += dx
        y += dy
    return result

def make_plot_python(m):
    dpi = 32
    width = 5
    height = 5 * yn // xn

    fig = plt.figure(1, [width, height], dpi=dpi)
    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frame_on=False, aspect=1)

    light = colors.LightSource(315, 10, 0, 1, 1, 0)

    image = light.shade(m, plt.cm.hot, colors.PowerNorm(0.3), blend_mode='hsv', vert_exag=1.5)
    plt.imshow(image)
    plt.axis("off")
    plt.show()
%%python
start_time = time.time()
mandelbrot_set = mandelbrot()
end_time = time.time()
execution_time = (end_time - start_time) * 1000  # Make it milliseconds

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot: {execution_time:.0f} ms")

output_4_0

Execution time for Python Mandelbrot: 1269 ms

Python numba JIT compiler

%%python

# Run with Numba JIT compiler
@numba.jit(nopython=True)
def mandelbrot_kernel_numba(c):
    z = c
    for i in range(max_iter):
        z = z * z + c
        if abs(z) > 2:
            return i
    return max_iter

@numba.jit(nopython=True)
def mandelbrot_numba():
    # Create a matrix. Each element of the matrix corresponds to a pixel
    result = np.zeros((yn, xn), dtype=np.uint32)

    dx = (xmax - xmin) / xn
    dy = (ymax - ymin) / yn

    y = ymin
    for j in range(yn):
        x = xmin
        for i in range(xn):
            result[j, i] = mandelbrot_kernel_numba(complex(x, y))
            x += dx
        y += dy
    return result
%%python
start_time = time.time()
mandelbrot_set = mandelbrot_numba()
end_time = time.time()
execution_time = (end_time - start_time) * 1000  # Make it milliseconds

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (numba): {execution_time:.0f} ms")

output_7_0

Execution time for Python Mandelbrot (numba): 1026 ms

Python vectorized

%%python
def mandelbrot_vectorized(xn, yn, max_iter=200):
    # Define the boundaries of the complex plane
    xmin = -2.25
    xmax = 0.75
    ymin = -1.25
    ymax = 1.25

    # Create the grid of complex numbers
    x = np.linspace(xmin, xmax, xn)
    y = np.linspace(ymin, ymax, yn)
    c = np.array([[complex(re, im) for re in x] for im in y])

    # Initialize the Mandelbrot set and iteration count array
    mandelbrot_set = np.zeros((yn, xn), dtype=np.uint32)
    iter_count = np.zeros_like(mandelbrot_set)

    # Initialize the z values with the complex grid
    z = c.copy()

    # Iterate over each point using vectorized operations
    for i in range(max_iter):
        # Update z values based on the Mandelbrot equation
        z = z**2 + c
        # Update the iteration count for points that have not escaped
        iter_count[(np.abs(z) < 2) & (mandelbrot_set == 0)] = i
        # Mark points that have escaped
        mandelbrot_set[np.abs(z) >= 2] = 1

    # Replace points that never escaped with the maximum iteration count
    iter_count[mandelbrot_set == 0] = max_iter

    return iter_count
%%python
start_time = time.time()
mandelbrot_set = mandelbrot_vectorized(xn, yn)
end_time = time.time()
execution_time = (end_time - start_time) * 1000

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (vectorized): {execution_time:.0f} ms")
<string>:47: RuntimeWarning: overflow encountered in square
<string>:47: RuntimeWarning: invalid value encountered in square

output_10_1

Execution time for Python Mandelbrot (vectorized): 353 ms

Python vectorized numba JIT

%%python

@numba.vectorize([numba.uint32(numba.complex128, numba.uint32)], nopython=True)
def mandelbrot_element(c, max_iter):
    z = c
    for i in range(max_iter):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return max_iter

def mandelbrot_vectorized_numba(xn, yn, max_iter=200):
    # Define the boundaries of the complex plane
    xmin = -2.25
    xmax = 0.75
    ymin = -1.25
    ymax = 1.25

    # Create the grid of complex numbers
    x = np.linspace(xmin, xmax, xn)
    y = np.linspace(ymin, ymax, yn)
    c = np.array([[complex(re, im) for re in x] for im in y])

    # Compute the Mandelbrot set element-wise using the vectorized function
    iter_count = mandelbrot_element(c, max_iter)

    return iter_count
%%python
start_time = time.time()
mandelbrot_set = mandelbrot_vectorized_numba(xn, yn)
end_time = time.time()
execution_time = (end_time - start_time) * 1000

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (vectorized-numba): {execution_time:.0f} ms")

output_13_0

Execution time for Python Mandelbrot (vectorized-numba): 93 ms

Cython (can't load extension)

# %%python
# import os
# os.system('pip install cython')
%load_ext cython
error: �[0;1;31m�[1mExpression [11]:39:5: �[0m�[1munexpected token in expression
�[0m    %load_ext cython
�[0;1;32m    ^
�[0m�[0m
%%cython

## Try with cython
import numpy as np
cimport numpy as np

# Constants
cdef double xmin = -2.25
cdef double xmax = 0.75
cdef int xn = 450
cdef double ymin = -1.25
cdef double ymax = 1.25
cdef int yn = 375
cdef int max_iter = 200

# Mandelbrot computation in Cython
cpdef np.ndarray[np.uint32_t, ndim=2] mandelbrot_cython():
    cdef double dx = (xmax - xmin) / xn
    cdef double dy = (ymax - ymin) / yn
    cdef np.ndarray[np.uint32_t, ndim=2] result = np.zeros((yn, xn), dtype=np.uint32)
    cdef double x, y, real, imag, abs_val
    cdef int i, j, k
    y = ymin
    for j in range(yn):
        x = xmin
        for i in range(xn):
            real = x
            imag = y
            for k in range(max_iter):
                abs_val = real * real + imag * imag
                if abs_val > 4:
                    break
                real, imag = real * real - imag * imag + x, 2 * real * imag + y
            result[j, i] = k
            x += dx
        y += dy
    return result
error: �[0;1;31m�[1mExpression [12]:28:8: �[0m�[1munable to locate module 'numpy'
�[0mimport numpy as np
�[0;1;32m       ^
�[0m�[0m
error: �[0;1;31m�[1mExpression [12]:40:5: �[0m�[1munexpected token in expression
�[0m    %%cython
�[0;1;32m    ^
�[0m�[0m
start_time = time.time()
mandelbrot_set = mandelbrot_cython()
end_time = time.time()
execution_time = (end_time - start_time) * 1000  # Make it milliseconds

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (cython): {execution_time:.0f} ms")
error: �[0;1;31m�[1mExpression [13]:40:22: �[0m�[1muse of unknown declaration 'mandelbrot_cython'
�[0m    mandelbrot_set = mandelbrot_cython()
�[0;1;32m                     ^~~~~~~~~~~~~~~~~
�[0m�[0m
error: �[0;1;31m�[1mExpression [13]:42:32: �[0m�[1m'PythonObject' does not implement the '__sub__' method
�[0m    execution_time = (end_time - start_time) * 1000  # Make it milliseconds
�[0;1;32m                      ~~~~~~~~ ^
�[0m�[0m
error: �[0;1;31m�[1mExpression [13]:45:12: �[0m�[1mexpected ')' in call argument list
�[0m    print(f"Execution time for Python Mandelbrot (cython): {execution_time:.0f} ms")
�[0;1;32m           ^
�[0m�[0m

Mandelbrot in Mojo

from Benchmark import Benchmark
from DType import DType
from Memory import memset_zero
from Object import object, Attr
from Pointer import DTypePointer, Pointer
from Random import rand
from Range import range
from TargetInfo import dtype_sizeof
from Time import now
from Complex import ComplexSIMD as ComplexGenericSIMD
struct Matrix:
    var data: DTypePointer[DType.si64]
    var rows: Int
    var cols: Int
    var rc: Pointer[Int]

    fn __init__(self&, cols: Int, rows: Int):
        self.data = DTypePointer[DType.si64].alloc(rows * cols)
        self.rows = rows
        self.cols = cols
        self.rc = Pointer[Int].alloc(1)
        self.rc.store(1)

    fn __copyinit__(self&, other: Self):
        other._inc_rc()
        self.data = other.data
        self.rc   = other.rc
        self.rows = other.rows
        self.cols = other.cols

    fn __del__(owned self):
        self._dec_rc()

    fn _get_rc(self) -> Int:
        return self.rc.load()

    fn _dec_rc(self):
        let rc = self._get_rc()
        if rc > 1:
            self.rc.store(rc - 1)
            return
        self._free()

    fn _inc_rc(self):
        let rc = self._get_rc()
        self.rc.store(rc + 1)

    fn _free(self):
        self.data.free()
        self.rc.free()

    @always_inline
    fn __getitem__(self, col: Int, row: Int) -> SI64:
        return self.load[1](col, row)

    @always_inline
    fn load[nelts:Int](self, col: Int, row: Int) -> SIMD[DType.si64, nelts]:
        return self.data.simd_load[nelts](row * self.cols + col)

    @always_inline
    fn __setitem__(self, col: Int, row: Int, val: SI64):
        return self.store[1](col, row, val)

    @always_inline
    fn store[nelts:Int](self, col: Int, row: Int, val: SIMD[DType.si64, nelts]):
        self.data.simd_store[nelts](row * self.cols + col, val)

    def to_numpy(self) -> PythonObject:
        let np = Python.import_module("numpy")
        let numpy_array = np.zeros((self.rows, self.cols), np.uint32)
        for col in range(self.cols):
            for row in range(self.rows):
                numpy_array.itemset((row, col), self[col, row].cast[DType.f32]())
        return numpy_array
@register_passable("trivial")
struct Complex:
    var real: F32
    var imag: F32

    fn __init__(real: F32, imag: F32) -> Self:
        return Self {real: real, imag: imag}

    fn __add__(lhs, rhs: Self) -> Self:
        return Self(lhs.real + rhs.real, lhs.imag + rhs.imag)

    fn __mul__(lhs, rhs: Self) -> Self:
        return Self(
            lhs.real * rhs.real - lhs.imag * rhs.imag,
            lhs.real * rhs.imag + lhs.imag * rhs.real,
        )

    fn norm(self) -> F32:
        return self.real * self.real + self.imag * self.imag

Then we can write the core Mandelbrot algorithm, which involves computing an iterative complex function for each pixel until it "escapes" the complex circle of radius 2, counting the number of iterations to escape.

$$z_{i+1} = z_i^2 + c$$

alias xmin: F32 = -2.25
alias xmax: F32 = 0.75
alias xn = 450
alias ymin: F32 = -1.25
alias ymax: F32 = 1.25
alias yn = 375

# Compute the number of steps to escape.
def mandelbrot_kernel(c: Complex) -> Int:
    max_iter = 200
    z = c
    for i in range(max_iter):
        z = z * z + c
        if z.norm() > 4:
            return i
    return max_iter


def compute_mandelbrot() -> Matrix:
    # create a matrix. Each element of the matrix corresponds to a pixel
    result = Matrix(xn, yn)

    dx = (xmax - xmin) / xn
    dy = (ymax - ymin) / yn

    y = ymin
    for j in range(yn):
        x = xmin
        for i in range(xn):
            result[i, j] = mandelbrot_kernel(Complex(x, y))
            x += dx
        y += dy
    return result

Plotting the number of iterations to escape with some color gives us the canonical Mandelbrot set plot. To render it we can directly leverage Python's matplotlib right from Mojo!

def make_plot(m: Matrix):
    np = Python.import_module("numpy")
    plt = Python.import_module("matplotlib.pyplot")
    colors = Python.import_module("matplotlib.colors")
    dpi = 32
    width = 5
    height = 5 * yn // xn

    fig = plt.figure(1, [width, height], dpi)
    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], False, 1)

    light = colors.LightSource(315, 10, 0, 1, 1, 0)

    image = light.shade(m.to_numpy(), plt.cm.hot, colors.PowerNorm(0.3), "hsv", 0, 0, 1.5)
    plt.imshow(image)
    plt.axis("off")
    plt.show()
let eval_begin: Int = now()  # This is in nanoseconds
let mandelbrot_set = compute_mandelbrot()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000

make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot: ', execution_time, 'ms')

output_27_0

Execution time for Mojo Mandelbrot:  27 ms

Vectorizing Mandelbrot

We showed a naive implementation of the Mandelbrot algorithm, but there are two things we can do to speed it up. We can early-stop the loop iteration when a pixel is known to have escaped, and we can leverage Mojo's access to hardware by vectorizing the loop, computing multiple pixels simultaneously. To do that we will use the vectorize higher order generator.

We start by defining our main iteration loop in a vectorized fashion

fn mandelbrot_kernel_simd[simd_width:Int](c: ComplexGenericSIMD[DType.f32, simd_width]) -> SIMD[DType.si64, simd_width]:
    var z = c
    var nv = SIMD[DType.si64, simd_width](0)
    var escape_mask = SIMD[DType.bool, simd_width](0)

    var i = 200
    while i != 0 and not escape_mask:
        z = z*z + c
        # Only update elements that haven't escaped yet
        escape_mask = escape_mask.select(escape_mask, z.norm() > 4)
        nv = escape_mask.select(nv, nv + 1) 
        i -= 1
    
    return nv

The above function is parameterized on the simd_width and processes simd_width pixels. It only escapes once all pixels within the vector lane are done. We can use the same iteration loop as above, but this time we vectorize within each row instead. We use the vectorize generator to make this a simple function call.

from Functional import vectorize
from Math import iota
from TargetInfo import dtype_simd_width


def compute_mandelbrot_simd() -> Matrix:
    # create a matrix. Each element of the matrix corresponds to a pixel
    var result = Matrix(xn, yn)

    let dx = (xmax - xmin) / xn
    let dy = (ymax - ymin) / yn

    var y = ymin
    alias simd_width = dtype_simd_width[DType.f32]()

    for row in range(yn):
        var x = xmin
        @parameter
        fn _process_simd_element[simd_width:Int](col: Int):
            let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, 
                                                              SIMD[DType.f32, simd_width](y))
            result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c))
            x += simd_width*dx

        vectorize[simd_width, _process_simd_element](xn)
        y += dy
    return result
let eval_begin: Int = now()
let mandelbrot_set = compute_mandelbrot_simd()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000

make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot (vectorized): ', execution_time, 'ms')

output_34_0

Execution time for Mojo Mandelbrot (vectorized):  2 ms

Parallelizing Mandelbrot

While the vectorized implementation above is efficient, we can get better performance by parallelizing on the rows. This again is simple in Mojo using the parallelize higher order function. Only the function that performs the invocation needs to change.

from Functional import parallelize 

def compute_mandelbrot_simd_parallel() -> Matrix:
    # create a matrix. Each element of the matrix corresponds to a pixel
    var result = Matrix(xn, yn)

    let dx = (xmax - xmin) / xn
    let dy = (ymax - ymin) / yn

    alias simd_width = dtype_simd_width[DType.f32]()

    @parameter
    fn _process_row(row:Int):
        var y = ymin + dy*row
        var x = xmin
        @parameter
        fn _process_simd_element[simd_width:Int](col: Int):
            let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, 
                                                              SIMD[DType.f32, simd_width](y))
            result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c))
            x += simd_width*dx
            
        vectorize[simd_width, _process_simd_element](xn)

    parallelize[_process_row](yn)
    return result
let eval_begin: Int = now()
let mandelbrot_set = compute_mandelbrot_simd_parallel()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000

make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot (vectorized-parallelized): ', execution_time, 'ms')

output_38_0

Execution time for Mojo Mandelbrot (vectorized-parallelized):  4 ms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment