Skip to content

Instantly share code, notes, and snippets.

@sueszli
Last active March 4, 2026 12:33
Show Gist options
  • Select an option

  • Save sueszli/73a059f213546acf46b385fb85b18a42 to your computer and use it in GitHub Desktop.

Select an option

Save sueszli/73a059f213546acf46b385fb85b18a42 to your computer and use it in GitHub Desktop.
================================================================================
memref explained through C
================================================================================
Docs: https://mlir.llvm.org/docs/Dialects/MemRef/
SUMMARY
-------
memref = fat pointer (data pointer + shape + strides)
memref.cast = erase/assert static size. Same pointer, same data.
memref.reinterpret = reinterpret shape/strides. Same pointer, same data.
memref.subview = slice into a sub-region. Pointer moves, new shape/strides.
+------------------+----------------------------------------------+------+
| operation | what changes | cost |
+------------------+----------------------------------------------+------+
| memref.cast | .sizes (static <-> dynamic) | zero |
| memref.reinterp | .sizes + .strides (anything goes) | zero |
| memref.subview | .data (moves pointer) + .sizes + .strides | zero |
+------------------+----------------------------------------------+------+
After lowering to pointers, each op becomes:
memref.cast -> erased (just use the source directly)
memref.reinterp -> to_ptr + unrealized_conversion_cast (change type)
memref.subview -> to_ptr + ptradd + from_ptr (move pointer)
================================================================================
1. What is a memref?
================================================================================
A memref is a fat pointer: a raw data pointer bundled with shape and stride
metadata. We'll use this C struct throughout the entire document:
typedef struct {
void *data; // pointer to the first element (includes any offset)
int64_t sizes[N]; // size of each dimension
int64_t strides[N]; // elements to skip per step in each dimension
} memref_t;
Note: real MLIR memref descriptors also carry an offset and separate
allocated/aligned pointers. The convert-memref-to-ptr pass simplifies
this by baking the offset into the data pointer, so our struct omits it.
Examples:
// memref<10xi32>
memref_t src = {
.data = malloc(10 * sizeof(int32_t)),
.sizes = {10},
.strides = {1}, // contiguous
};
// memref<3x4xf64>
memref_t mat = {
.data = malloc(3 * 4 * sizeof(double)),
.sizes = {3, 4},
.strides = {4, 1}, // row-major: skip 4 elements to go to next row
};
================================================================================
2. Static vs dynamic dimensions
================================================================================
// memref<10xi32>
memref_t a = { .data=buf, .sizes={10}, .strides={1} };
// memref<?xi32>
memref_t b = { .data=buf, .sizes={n}, .strides={1} };
// ^
// runtime variable instead of constant
The '?' erases the static size from the type. The .data pointer is identical.
You just lose the compile-time guarantee about .sizes.
================================================================================
3. Strides
================================================================================
Strides = "how many elements to skip when advancing one step along each axis."
// memref<3x4xi32> -- default strides (row-major)
// element(r,c) = data[r*4 + c*1]
memref_t a = { .data=buf, .sizes={3,4}, .strides={4,1} };
// memref<3x4xi32, strided<[1,4]>> -- column-major strides
// element(r,c) = data[r*1 + c*4]
memref_t b = { .data=buf, .sizes={3,4}, .strides={1,4} };
General formula: address = data + sum(index_i * stride_i) * sizeof(element)
================================================================================
4. memref.cast
================================================================================
%dst = memref.cast %src : memref<10xi32> -> memref<?xi32>
The memref_t struct analogy doesn't work here. Nothing in the struct changes.
Cast is purely a type annotation: it toggles static <-> dynamic sizes.
The closest C equivalent is int[10] vs int*:
int32_t buf[10]; // memref<10xi32> -- compiler knows size
int32_t *p = buf; // memref<?xi32> -- same pointer, size unknown
Why does memref.cast exist if we just drop it? It satisfies MLIR's type checker
(can't pass memref<10xi32> where memref<?xi32> is expected). Once we lower to
raw pointers there are no types, so the cast is pointless.
Lowering:
// Python:
rewriter.replace_matched_op((), (op.source,))
// BEFORE: // C:
%src = ... : memref<10xi32> int32_t buf[10];
%dst = memref.cast %src int32_t *p = buf;
use(%dst) use(p);
// AFTER: // C:
%src = ... : memref<10xi32> int32_t buf[10];
use(%src) use(buf); // cast gone, use buf directly
================================================================================
5. memref.reinterpret_cast
================================================================================
%dst = memref.reinterpret_cast %src
offsets: [0], sizes: [3, 4], strides: [4, 1]
: memref<12xindex> -> memref<3x4xindex, strided<[4, 1]>>
What it does: takes the SAME .data pointer and builds a new struct with
completely different .sizes and .strides. No data moves.
C equivalent:
memref_t src = { .data=buf, .sizes={12}, .strides={1} };
// reinterpret: treat 12 flat elements as a 3x4 row-major matrix
memref_t dst = {
.data = src.data, // same pointer
.sizes = {3, 4}, // new shape
.strides = {4, 1}, // new strides
};
Lowering:
// Python:
rewriter.replace_matched_op((
ptr_cast := ptr.ToPtrOp(op.source),
builtin.UnrealizedConversionCastOp.get([ptr_cast.res], [op.result.type]),
))
// BEFORE: // C:
%src = ... : memref<12xindex> memref_t src = { .data=buf, .sizes={12}, .strides={1} };
%dst = memref.reinterpret_cast %src ... memref_t dst = { .data=src.data, .sizes={3,4}, .strides={4,1} };
use(%dst) use(dst);
// AFTER: // C:
%src = ... : memref<12xindex> memref_t src = { .data=buf, .sizes={12}, .strides={1} };
%ptr = to_ptr %src void *raw = src.data; // extract .data, discard .sizes/.strides
%dst = unrealized_conversion_cast %ptr memref_t dst = { .data=raw, .sizes={3,4}, .strides={4,1} };
-> memref<3x4xindex, strided<[4,1]>>
use(%dst) use(dst);
================================================================================
6. memref.subview
================================================================================
%sub = memref.subview %buf[5][5][1] : memref<10xi32> to memref<5xi32>
The only view op that moves .data and shrinks .sizes.
(reinterpret_cast can also move .data via a non-zero offset, but our example used offset=0.)
C equivalent:
// buf.data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
// ^-----------^
// sub sees this window
memref_t buf = { .data=malloc(10*4), .sizes={10}, .strides={1} };
memref_t sub = {
.data = (int32_t*)buf.data + 5, // .data moved forward by 5 elements
.sizes = {5},
.strides = {1},
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment