Skip to content

Instantly share code, notes, and snippets.

@sklam
Forked from rossant/raytracing.py
Last active May 25, 2021 05:14
Show Gist options
  • Save sklam/362c883eff73d297134c to your computer and use it in GitHub Desktop.
Save sklam/362c883eff73d297134c to your computer and use it in GitHub Desktop.
Numba optimized raytracing. From 23seconds to 9seconds.
from __future__ import print_function, division
from timeit import default_timer as timer
import numpy as np
import matplotlib.pyplot as plt
from numba import njit
w = 400
h = 300
TYPE_PLANE = 1
TYPE_SPHERE = 2
# def normalize(x):
# x /= np.linalg.norm(x)
# return x
@njit
def normalize(x):
n = 0
for i in range(x.size):
xi = x[i]
n += xi * xi
n = np.sqrt(n)
for i in range(x.size):
x[i] /= n
return x
@njit
def dot(a, b):
return np.sum(a * b)
@njit
def intersect_plane(O, D, P, N):
# Return the distance from O to the intersection of the ray (O, D) with the
# plane (P, N), or +inf if there is no intersection.
# O and P are 3D points, D and N (normal) are normalized vectors.
denom = dot(D, N)
if np.abs(denom) < 1e-6:
return np.inf
d = dot(P - O, N) / denom
if d < 0:
return np.inf
return d
@njit
def intersect_sphere(O, D, S, R):
# Return the distance from O to the intersection of the ray (O, D) with the
# sphere (S, R), or +inf if there is no intersection.
# O and S are 3D points, D (direction) is a normalized vector, R is a scalar.
a = dot(D, D)
OS = O - S
b = 2 * dot(D, OS)
c = dot(OS, OS) - R * R
disc = b * b - 4 * a * c
if disc > 0:
distSqrt = np.sqrt(disc)
q = (-b - distSqrt) / 2.0 if b < 0 else (-b + distSqrt) / 2.0
t0 = q / a
t1 = c / q
t0, t1 = min(t0, t1), max(t0, t1)
if t1 >= 0:
return t1 if t0 < 0 else t0
return np.inf
@njit
def intersect(O, D, obj):
if obj.type == TYPE_PLANE:
return intersect_plane(O, D, obj.position, obj.normal)
else:
assert obj.type == TYPE_SPHERE
return intersect_sphere(O, D, obj.position, obj.radius)
@njit
def get_normal(obj, M):
# Find normal.
N = np.zeros_like(obj.position)
if obj.type == TYPE_SPHERE:
diff = M - obj.position
for i in range(diff.size):
N[i] = diff[i]
normalize(N)
elif obj.type == TYPE_PLANE:
for i in range(N.size):
N[i] = obj.normal[i]
else:
assert False
return N
def get_color(obj, M):
color = obj['color']
if not hasattr(color, '__len__'):
color = color(M)
return color
@njit
def find_intersect_object(rayO, rayD):
t = np.inf
obj_idx = -1
for i, obj in enumerate(scene_array):
t_obj = intersect(rayO, rayD, obj)
if t_obj < t:
t, obj_idx = t_obj, i
return t, obj_idx
@njit
def find_shadow(M, N, toL, obj_idx):
ls = np.zeros(scene_array.size, dtype=np.float64)
ct = 0
for k, obj_sh in enumerate(scene_array):
if k != obj_idx:
ls[k] = intersect(M + N * .0001, toL, obj_sh)
ct += 1
else:
ls[k] = np.inf
return ls, ct
class NoIntersection(Exception):
pass
@njit(nogil=True)
def fast_trace_ray(rayO, rayD):
# # Find first point of intersection with the scene.
t, obj_idx = find_intersect_object(rayO, rayD)
if t == np.inf:
raise NoIntersection
# Find the point of intersection on the object.
M = rayO + rayD * t
# Find properties of the object.
N = get_normal(scene_array[obj_idx], M)
toL = normalize(L - M)
toO = normalize(rayO - M)
# Shadow: find if the point is shadowed or not.
l, ct = find_shadow(M, N, toL, obj_idx)
if ct and np.min(l) < np.inf:
raise NoIntersection
return t, obj_idx, M, N, toL, toO, l, ct
def trace_ray(rayO, rayD):
try:
t, obj_idx, M, N, toL, toO, l, ct = fast_trace_ray(rayO, rayD)
except NoIntersection:
return
# Start computing the color.
col_ray = ambient
# Find the object.
obj = scene[obj_idx]
color = get_color(obj, M)
# Lambert shading (diffuse).
col_ray += obj.get('diffuse_c', diffuse_c) * max(dot(N, toL), 0) * color
# Blinn-Phong shading (specular).
col_ray += obj.get('specular_c', specular_c) * max(
dot(N, normalize(toL + toO)), 0) ** specular_k * color_light
return obj, M, N, col_ray
def add_sphere(position, radius, color):
return dict(type=TYPE_SPHERE, position=np.array(position),
radius=np.array(radius), color=np.array(color), reflection=.5)
def add_plane(position, normal):
return dict(type=TYPE_PLANE, position=np.array(position),
normal=np.array(normal),
color=lambda M: (color_plane0
if (int(M[0] * 2) % 2) == (
int(M[2] * 2) % 2) else color_plane1),
diffuse_c=.75, specular_c=.5, reflection=.25)
object_struct = np.dtype([
('type', np.int8),
('position', np.float64, 3),
('normal', np.float64, 3),
('radius', np.float64),
# ('reflection', np.float64),
# ('diffuse_c', np.float64),
# ('specular_c', np.float64),
], align=True)
# List of objects.
color_plane0 = 1. * np.ones(3)
color_plane1 = 0. * np.ones(3)
scene = [add_sphere([.75, .1, 1.], .6, [0., 0., 1.]),
add_sphere([-.75, .1, 2.25], .6, [.5, .223, .5]),
add_sphere([-2.75, .1, 3.5], .6, [1., .572, .184]),
add_plane([0., -.5, 0.], [0., 1., 0.]),
]
# Fill array
scene_array = np.recarray(len(scene), dtype=object_struct)
for idx, obj in enumerate(scene):
item = scene_array[idx]
item.type = obj['type']
item.position[:] = obj['position']
if item.type == TYPE_PLANE:
item.normal[:] = obj['normal']
# item.reflection = obj['reflection']
# item.diffuse_c = obj['diffuse_c']
# item.specular_c = obj['specular_c']
else:
item.radius = obj['radius']
# Light position and color.
L = np.array([5., 5., -10.])
color_light = np.ones(3)
# Default light and material parameters.
ambient = .05
diffuse_c = 1.
specular_c = 1.
specular_k = 50
depth_max = 5 # Maximum number of light reflections.
r = float(w) / h
# Screen coordinates: x0, y0, x1, y1.
S = (-1., -1. / r + .25, 1., 1. / r + .25)
def inner_loop(img, col, O, x, y, i, j):
Q = np.array([0., 0., 0.]) # Camera pointing to.
col[:] = 0
Q[:2] = (x, y)
D = normalize(Q - O)
depth = 0
rayO, rayD = O, D
reflection = 1.
# Loop through initial and secondary rays.
while depth < depth_max:
traced = trace_ray(rayO, rayD)
if not traced:
break
obj, M, N, col_ray = traced
# Reflection: create a new ray.
rayO, rayD = M + N * .0001, normalize(
rayD - 2 * dot(rayD, N) * N)
depth += 1
col += reflection * col_ray
reflection *= obj.get('reflection', 1.)
img[h - j - 1, i, :] = np.clip(col, 0, 1)
def main():
img = np.zeros((h, w, 3))
col = np.zeros(3) # Current color.
O = np.array([0., 0.35, -1.]) # Camera.
ts = timer()
for i, x in enumerate(np.linspace(S[0], S[2], w)):
# Print every 5% increment
if i % (w // 100 * 5) == 0:
print("{0:.1f}%".format(i / w * 100))
for j, y in enumerate(np.linspace(S[1], S[3], h)):
inner_loop(img, col.copy(), O, x, y, i, j)
te = timer()
print("Time", te - ts)
print("Write to fig.png")
plt.imsave('fig.png', img)
# For 400x300
# Original speed is 23 second
# Numba optimized is 9 second
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment