Skip to content

Instantly share code, notes, and snippets.

@panther03
Created August 13, 2024 17:28
Show Gist options
  • Save panther03/d909d5cf83ad33cd7ff7e8358092b1c5 to your computer and use it in GitHub Desktop.
Save panther03/d909d5cf83ad33cd7ff7e8358092b1c5 to your computer and use it in GitHub Desktop.
conv1d in exo
from __future__ import annotations
import os
import sys
from exo import proc
from exo.libs.memories import *
from exo.platforms.x86 import *
from exo.stdlib.scheduling import *
class RVM_TILE(StaticMemory):
NUM_RVM_TILES = 8
StaticMemory.init_state(NUM_RVM_TILES)
tile_dict = {}
# TODO: have a better way of doing this rather than manually
# calling this after each test that fails to compile.
@classmethod
def reset_allocations(cls):
cls.init_state(cls.NUM_RVM_TILES)
cls.tile_dict = {}
@classmethod
def can_read(cls):
return False
@classmethod
def alloc(cls, new_name, prim_type, shape, srcinfo):
if not (shape[0].isdecimal() and int(shape[0]) == 4):
raise MemGenError("Number of tile rows must be 4.")
if not (shape[1].isdecimal() and int(shape[1]) == 4):
raise MemGenError("Number of tile columns must be 4.")
tile_num = cls.find_free_chunk()
cls.mark(tile_num)
cls.tile_dict[new_name] = tile_num
return f"#define {new_name} \"m{7-tile_num}\""
@classmethod
def free(cls, new_name, prim_type, shape, srcinfo):
tile_num = cls.tile_dict[new_name]
del cls.tile_dict[new_name]
cls.unmark(tile_num)
return f"#undef {new_name}"
class DRAM_INTERLEAVED(DRAM):
@classmethod
def alloc(cls, new_name, prim_type, shape, srcinfo):
# Error checking only
for extent in shape:
try:
int(extent)
except ValueError as e:
raise MemGenError(
f"DRAM_STATIC requires constant shapes. Saw: {extent}"
) from e
return f'static {prim_type} __attribute__((section(".xheep_data_interleaved"))) {new_name}[{" * ".join(shape)}];'
@classmethod
def free(cls, new_name, prim_type, shape, srcinfo):
return ""
@instr('asm volatile("mld.w "{dst_int}", (%1), %0" :: "r"(4*({src}.strides[0])), "r"(&{src_data}));')
def rvm_mld(dst: [i32][4,4] @ RVM_TILE, src: [i32][4,4] @ DRAM):
assert stride(src, 1) == 1
assert stride(dst, 1) == 1
for i in seq(0, 4):
for j in seq(0,4):
dst[i,j] = src[i,j]
@instr('asm volatile("mzero "{dst_int});')
def rvm_mzero(dst: [i32][4,4] @ RVM_TILE):
assert stride(dst, 1) == 1
for i in seq(0, 4):
for j in seq(0,4):
dst[i,j] = 0
@instr('asm volatile("mst.w "{src_int}", (%1), %0" :: "r"(4*({dst}.strides[0])), "r"(&{dst_data}));')
def rvm_mst(src: [i32][4,4] @ RVM_TILE, dst: [i32][4,4] @ DRAM):
assert stride(src, 1) == 1
assert stride(dst, 1) == 1
for i in seq(0, 4):
for j in seq(0,4):
dst[i,j] = src[i,j]
@instr('asm volatile("mmasa.w "{md_int}", "{ms1_int}", "{ms2_int});')
def rvm_mmasa(md: [i32][4,4] @ RVM_TILE, ms1: [i32][4,4] @ RVM_TILE, ms2: [i32][4,4] @ RVM_TILE):
assert stride(md, 1) == 1
assert stride(ms1, 1) == 1
assert stride(ms2, 1) == 1
for i in seq(0,4):
for j in seq(0,4):
for k in seq(0,4):
md[i,j] += ms2[i,k] * ms1[j,k]
# convert if else to bitwise using these instructions
# Look at _Select class in builtins and try to extend
# there are some examples
IW = 16
IC = 4
KW = 4
ICKW = IC*KW
OC = 16
TILE = 4
def gen_conv1d():
K = OC
W = IW
C = IC
R = KW
@proc
def generic_conv1d(
data: i32[C, W],
kernels: i32[K, C, R],
out: i32[K, W],
):
# zero out the result memory
for k_init in seq(0, K):
for i_init in seq(0, W):
out[k_init, i_init] = 0.0
# do the convolution
for k in seq(0, K):
for c in seq(0, C):
for i in seq(0, W):
for r in seq(0, R):
y: i32
y = 0
if i + r < W:
y = data[c, i + r]
out[k, i] += kernels[k, c, r] * y
return generic_conv1d
#def make_less_generic(p):
# p = bind_expr(p, "kernels[_]", "y")
# print(p)
# exit(0)
def make_im2col_from_generic(p):
# Let's start applying scheduling
p = rename(p, "im2col_conv")
p = reorder_loops(p, "c i")
p = fuse(p, "for k_init in _:_", "for k in _:_")
p = fuse(p, "for i_init in _:_", "for i in _:_")
# add the tiles corresponding to the size of our systolic array
p = divide_loop(p, "for k_init in _:_", TILE, new_iters=['tile_i', 'i'], perfect=True)
p = divide_loop(p, "for i_init in _:_", TILE, new_iters=['tile_j', 'j'], perfect=True)
p = reorder_loops(p, "i tile_j")
# tile once again since we have 4 registers, we want to do 4 computes on different output channels at once
p = divide_loop(p, "for tile_i in _:_", 4, new_iters=['hi', 'lo'], perfect=True)
p = reorder_loops(p, "lo tile_j")
# channels should be on the outside of each of these inner loops; these will all be part of the RVM instructions themselves
p = autofission(p, p.find("for c in _:_").before(), 3)
p = reorder_loops(p, "j c")
p = reorder_loops(p, "i c")
p = reorder_loops(p, "lo c")
# next, we can start to lift that allocation
# up and out of the loop
p = autolift_alloc(p, "y:i32", 4, keep_dims=True)
p = set_memory(p, "y: _", DRAM_INTERLEAVED)
# Then, we can fission the loop correspondingly,
# separating what is now a data-marshalling statement from
# the actual compute statement in two subsequent
# loop nests via fissioning
p = autofission(p, p.find("out[_] += _").before(), 4)
p = simplify(p)
return p
def rvm_optimize(p):
# Setting up data tile load
p = stage_mem(p, 'for lo in _:_ #1', f'y[0:{TILE}, 0:{TILE}]', "data_tile")
p = set_memory(p, "data_tile", RVM_TILE)
p = replace(p, "for i0 in _:_", rvm_mld)
# Setting up kernel tile load
p = stage_mem(p, 'for i in _:_ #1', f'kernels[hi*{TILE}*4 + 4*lo:hi*{TILE}*4 + 4*lo+{TILE}, c, 0:{TILE}]', "kernel_tile")
p = set_memory(p, "kernel_tile", RVM_TILE)
p = replace(p, "for i0 in _:_", rvm_mld)
# Setting up output tile
# Here we are specifiying all 4 output registers at once.
p = stage_mem(p, 'for c in _:_', f'out[hi*{TILE}*4: hi*{TILE}*4+{TILE}*4, tile_j*{TILE}:tile_j*{TILE}+{TILE}]', "output_tile")
# That means we need to divide the buffers further.
p = set_memory(p, "output_tile", RVM_TILE)
p = divide_dim(p, "output_tile:_", 0, 4)
# In addition, the loops which do the clearing and storing of this buffer should be tiled,
# that way the inner nest can be replaced with the appropriate matrix instructions.
p = divide_loop(p, "for i0 in _:_", 4, ["d_lo", "d_i"], perfect=True)
p = divide_loop(p, "for i0 in _:_", 4, ["s_lo", "s_i"], perfect=True)
# Fuse output zeroing loop with generated stage_mem loop,
# that way we can replace to mzero
p = simplify(p)
# loops need to be adjacent, get this alloc out of the way
# need to simplify first otherwise it's dependent on the index
p = lift_alloc(p, "output_tile: _")
p = fuse(p, "for lo in _:_ #0", "for d_lo in _:_ #0")
p = fuse(p, "for i in _:_ #0", "for d_i in _:_ #0")
p = fuse(p, "for j in _:_ #0", "for i1 in _:_")
p = sink_alloc(p, "output_tile: _")
# remove the assignment to out[]
# TODO: Not a bug, but it's interesting that this is correct?
# Exo must be able to prove that I don't read from out[] later on in the program..
p = inline_assign(p, "out[i + 4 * lo + 16 * hi, j + 4 * tile_j] = 0")
p = replace(p, "for i in _:_ #0", rvm_mzero)
# Setting up output tile store
p = replace(p, "for s_i in _:_ ", rvm_mst)
# Replace with matmul instruction
p = replace(p, "for i in _:_", rvm_mmasa)
# unroll everything!
p = unroll_loop(p, "for lo in _:_")
p = unroll_loop(p, "for s_lo in _:_")
p = simplify(p)
# some gymnastics for reusing the output of the load so we don't exceed the 8 available registers
p = autolift_alloc(p, "kernel_tile: _", keep_dims=True)
p = unroll_loop(p, "for lo in _:_")
p = unroll_buffer(p, "kernel_tile: _", 0)
p = reorder_stmts(p, "kernel_tile_3: _; rvm_mld(_)")
p = reuse_buffer(p, "kernel_tile_0: _", "kernel_tile_3: _")
p = unroll_buffer(p, "output_tile: _", 0)
p = simplify(p)
# Done!
p = simplify(p)
p = rename(p, "exo_conv1d_tile_lt_kw")
return p
def make_routine():
generic_conv1d = gen_conv1d()
#less_generic = make_less_generic(generic_conv1d)
im2col_cpu = make_im2col_from_generic(generic_conv1d)
rvm_optimized = rvm_optimize(im2col_cpu)
return rvm_optimized
exo_conv1d_tile_lt_kw = make_routine()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment