Last active
March 4, 2026 12:33
-
-
Save sueszli/73a059f213546acf46b385fb85b18a42 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ================================================================================ | |
| memref explained through C | |
| ================================================================================ | |
| Docs: https://mlir.llvm.org/docs/Dialects/MemRef/ | |
| SUMMARY | |
| ------- | |
| memref = fat pointer (data pointer + shape + strides) | |
| memref.cast = erase/assert static size. Same pointer, same data. | |
| memref.reinterpret = reinterpret shape/strides. Same pointer, same data. | |
| memref.subview = slice into a sub-region. Pointer moves, new shape/strides. | |
| +------------------+----------------------------------------------+------+ | |
| | operation | what changes | cost | | |
| +------------------+----------------------------------------------+------+ | |
| | memref.cast | .sizes (static <-> dynamic) | zero | | |
| | memref.reinterp | .sizes + .strides (anything goes) | zero | | |
| | memref.subview | .data (moves pointer) + .sizes + .strides | zero | | |
| +------------------+----------------------------------------------+------+ | |
| After lowering to pointers, each op becomes: | |
| memref.cast -> erased (just use the source directly) | |
| memref.reinterp -> to_ptr + unrealized_conversion_cast (change type) | |
| memref.subview -> to_ptr + ptradd + from_ptr (move pointer) | |
| ================================================================================ | |
| 1. What is a memref? | |
| ================================================================================ | |
| A memref is a fat pointer: a raw data pointer bundled with shape and stride | |
| metadata. We'll use this C struct throughout the entire document: | |
| typedef struct { | |
| void *data; // pointer to the first element (includes any offset) | |
| int64_t sizes[N]; // size of each dimension | |
| int64_t strides[N]; // elements to skip per step in each dimension | |
| } memref_t; | |
| Note: real MLIR memref descriptors also carry an offset and separate | |
| allocated/aligned pointers. The convert-memref-to-ptr pass simplifies | |
| this by baking the offset into the data pointer, so our struct omits it. | |
| Examples: | |
| // memref<10xi32> | |
| memref_t src = { | |
| .data = malloc(10 * sizeof(int32_t)), | |
| .sizes = {10}, | |
| .strides = {1}, // contiguous | |
| }; | |
| // memref<3x4xf64> | |
| memref_t mat = { | |
| .data = malloc(3 * 4 * sizeof(double)), | |
| .sizes = {3, 4}, | |
| .strides = {4, 1}, // row-major: skip 4 elements to go to next row | |
| }; | |
| ================================================================================ | |
| 2. Static vs dynamic dimensions | |
| ================================================================================ | |
| // memref<10xi32> | |
| memref_t a = { .data=buf, .sizes={10}, .strides={1} }; | |
| // memref<?xi32> | |
| memref_t b = { .data=buf, .sizes={n}, .strides={1} }; | |
| // ^ | |
| // runtime variable instead of constant | |
| The '?' erases the static size from the type. The .data pointer is identical. | |
| You just lose the compile-time guarantee about .sizes. | |
| ================================================================================ | |
| 3. Strides | |
| ================================================================================ | |
| Strides = "how many elements to skip when advancing one step along each axis." | |
| // memref<3x4xi32> -- default strides (row-major) | |
| // element(r,c) = data[r*4 + c*1] | |
| memref_t a = { .data=buf, .sizes={3,4}, .strides={4,1} }; | |
| // memref<3x4xi32, strided<[1,4]>> -- column-major strides | |
| // element(r,c) = data[r*1 + c*4] | |
| memref_t b = { .data=buf, .sizes={3,4}, .strides={1,4} }; | |
| General formula: address = data + sum(index_i * stride_i) * sizeof(element) | |
| ================================================================================ | |
| 4. memref.cast | |
| ================================================================================ | |
| %dst = memref.cast %src : memref<10xi32> -> memref<?xi32> | |
| The memref_t struct analogy doesn't work here. Nothing in the struct changes. | |
| Cast is purely a type annotation: it toggles static <-> dynamic sizes. | |
| The closest C equivalent is int[10] vs int*: | |
| int32_t buf[10]; // memref<10xi32> -- compiler knows size | |
| int32_t *p = buf; // memref<?xi32> -- same pointer, size unknown | |
| Why does memref.cast exist if we just drop it? It satisfies MLIR's type checker | |
| (can't pass memref<10xi32> where memref<?xi32> is expected). Once we lower to | |
| raw pointers there are no types, so the cast is pointless. | |
| Lowering: | |
| // Python: | |
| rewriter.replace_matched_op((), (op.source,)) | |
| // BEFORE: // C: | |
| %src = ... : memref<10xi32> int32_t buf[10]; | |
| %dst = memref.cast %src int32_t *p = buf; | |
| use(%dst) use(p); | |
| // AFTER: // C: | |
| %src = ... : memref<10xi32> int32_t buf[10]; | |
| use(%src) use(buf); // cast gone, use buf directly | |
| ================================================================================ | |
| 5. memref.reinterpret_cast | |
| ================================================================================ | |
| %dst = memref.reinterpret_cast %src | |
| offsets: [0], sizes: [3, 4], strides: [4, 1] | |
| : memref<12xindex> -> memref<3x4xindex, strided<[4, 1]>> | |
| What it does: takes the SAME .data pointer and builds a new struct with | |
| completely different .sizes and .strides. No data moves. | |
| C equivalent: | |
| memref_t src = { .data=buf, .sizes={12}, .strides={1} }; | |
| // reinterpret: treat 12 flat elements as a 3x4 row-major matrix | |
| memref_t dst = { | |
| .data = src.data, // same pointer | |
| .sizes = {3, 4}, // new shape | |
| .strides = {4, 1}, // new strides | |
| }; | |
| Lowering: | |
| // Python: | |
| rewriter.replace_matched_op(( | |
| ptr_cast := ptr.ToPtrOp(op.source), | |
| builtin.UnrealizedConversionCastOp.get([ptr_cast.res], [op.result.type]), | |
| )) | |
| // BEFORE: // C: | |
| %src = ... : memref<12xindex> memref_t src = { .data=buf, .sizes={12}, .strides={1} }; | |
| %dst = memref.reinterpret_cast %src ... memref_t dst = { .data=src.data, .sizes={3,4}, .strides={4,1} }; | |
| use(%dst) use(dst); | |
| // AFTER: // C: | |
| %src = ... : memref<12xindex> memref_t src = { .data=buf, .sizes={12}, .strides={1} }; | |
| %ptr = to_ptr %src void *raw = src.data; // extract .data, discard .sizes/.strides | |
| %dst = unrealized_conversion_cast %ptr memref_t dst = { .data=raw, .sizes={3,4}, .strides={4,1} }; | |
| -> memref<3x4xindex, strided<[4,1]>> | |
| use(%dst) use(dst); | |
| ================================================================================ | |
| 6. memref.subview | |
| ================================================================================ | |
| %sub = memref.subview %buf[5][5][1] : memref<10xi32> to memref<5xi32> | |
| The only view op that moves .data and shrinks .sizes. | |
| (reinterpret_cast can also move .data via a non-zero offset, but our example used offset=0.) | |
| C equivalent: | |
| // buf.data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | |
| // ^-----------^ | |
| // sub sees this window | |
| memref_t buf = { .data=malloc(10*4), .sizes={10}, .strides={1} }; | |
| memref_t sub = { | |
| .data = (int32_t*)buf.data + 5, // .data moved forward by 5 elements | |
| .sizes = {5}, | |
| .strides = {1}, | |
| }; | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment