Main inspiration comes from here.
“”” Here is what a deep learning system stack would look like in nowdays.
- Build operator level graph description language: name whatever dl frameworks you care about, and ONNX
- Tensor primitive level graph description languages: NNVM, HLO/XLA, NGraph. It is close enough to the first one that you can also build graph optimization on first layer and bypass this layer.
- DSL for description and codegen: TVM, image processing languages like halide, darkroom.
- Hardcoded optimized kernel library: nnpack, cudnn, libdnn
- Device dependent library: maxas(assembler for NVIDIA Maxwell architecture) “””
Now let's consider convoluting an average kernel over an image, AKA blurring. And the following shows what the code looks like on each level.
Level 1 and 2: operator/tensor primitive level, we already have the conv
operator.
image = load_image()
average_kernel = np.full((3,3), 1.0/9)
blurred = conv(image, average_kernel)
Level 3: DSL for description and codegen. Take halide for example, a user needs to write both
- the definition of the algorithm
- the scheduling of storage(tile, vectorize) and computation order(parallel)
Func halide_blur(Func in) {
Func tmp, blurred;
Var x, y, xi, yi;
// The algorithm
tmp(x, y) = (in(x-1, y) + in(x, y) + in(x+1, y))/3;
blurred(x, y) = (tmp(x, y-1) + tmp(x, y) + tmp(x, y+1))/3;
// The schedule
blurred.tile(x, y, xi, yi, 256, 32)
.vectorize(xi, 8).parallel(y);
tmp.chunk(x).vectorize(x, 8);
return blurred;
}
Level 4: Hard coded optimized kernel. A user need to hardcode vectorization, multithreading, tiling and fusion.
void fast_blur(const Image &in, Image &blurred) {
m128i one_third = _mm_set1_epi16(21846);
#pragma omp parallel for
for (int yTile = 0; yTile < in.height(); yTile += 32) {
m128i a, b, c, sum, avg;
m128i tmp[(256/8)*(32+2)];
for (int xTile = 0; xTile < in.width(); xTile += 256) {
m128i *tmpPtr = tmp;
for (int y = -1; y < 32+1; y++) {
const uint16_t *inPtr = &(in(xTile, yTile+y));
for (int x = 0; x < 256; x += 8) {
a = _mm_loadu_si128(( m128i*)(inPtr-1));
b = _mm_loadu_si128(( m128i*)(inPtr+1));
c = _mm_load_si128(( m128i*)(inPtr));
sum = _mm_add_epi16(_mm_add_epi16(a, b), c);
avg = _mm_mulhi_epi16(sum, one_third);
_mm_store_si128(tmpPtr++, avg);
inPtr += 8;
}}
tmpPtr = tmp;
for (int y = 0; y < 32; y++) {
m128i *outPtr = ( m128i *)(&(blurred(xTile, yTile+y)));
for (int x = 0; x < 256; x += 8) {
a = _mm_load_si128(tmpPtr+(2*256)/8);
b = _mm_load_si128(tmpPtr+256/8);
c = _mm_load_si128(tmpPtr++);
sum = _mm_add_epi16(_mm_add_epi16(a, b), c);
avg = _mm_mulhi_epi16(sum, one_third);
_mm_store_si128(outPtr++, avg);
}
}
}
}
}
Level 5: Device dependent library. Usual coded in assembly language. One example here.