Skip to content

Instantly share code, notes, and snippets.

@UmerHA
Last active January 18, 2025 23:55
Show Gist options
  • Save UmerHA/eb1c623fd71a49b0965079926750faaf to your computer and use it in GitHub Desktop.
Save UmerHA/eb1c623fd71a49b0965079926750faaf to your computer and use it in GitHub Desktop.
Memory safety in Triton
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "db52a164-16a0-43a2-a107-5dba7868fa32",
"metadata": {},
"source": [
"**Question** (edited for brevity)"
]
},
{
"cell_type": "markdown",
"id": "2ec37f77-0430-4279-8048-e9fefd77e7e8",
"metadata": {},
"source": [
"Dear Umer,\n",
"\n",
"Thank you for your awesome Triton tutorial! I learned a lot from watching your video. However, I have a question that I’ve been thinking about but haven’t found much discussion on. It’s regarding the naive_matmul_k function in the notebook. To make it easier for you to read, I’ve included the relevant code below:\n",
"\n",
"```python\n",
"@triton.jit\n",
"def naive_matmul_k(\n",
" a_ptr, b_ptr, c_ptr,\n",
" m, n, k,\n",
" stride_am, stride_ak, \n",
" stride_bk, stride_bn,\n",
" stride_cm, stride_cn,\n",
" bm: tl.constexpr, bn: tl.constexpr, bk: tl.constexpr\n",
"):\n",
" pid_m, pid_n = tl.program_id(0), tl.program_id(1)\n",
" # chunks along m/n/k dimensions\n",
" rm = get_1d_offset(size=bm, n_prev_chunks=pid_m)\n",
" rn = get_1d_offset(size=bn, n_prev_chunks=pid_n)\n",
" rk = get_1d_offset(size=bk, n_prev_chunks=0)\n",
" # relevant offsets of a, b\n",
" offs_a = a_ptr + get_2d_offset(rm, rk, stride_am, stride_ak)\n",
" offs_b = b_ptr + get_2d_offset(rk, rn, stride_bk, stride_bn)\n",
" # initialize and iteratively update accumulator\n",
" acc = tl.zeros((bm, bn), dtype=tl.float32)\n",
" for _ in range(0, k, bk):\n",
" # todo umer: don't we need mask when loading a & b?\n",
" a = tl.load(offs_a)\n",
" b = tl.load(offs_b)\n",
" acc += tl.dot(a, b, allow_tf32=False) # matmul in block; Weirdness: allow_tf32 must be set to False for older GPUs, otherwise won't compile\n",
" # increase offsets, so next iteration loads next chunks\n",
" offs_a += bk * stride_ak\n",
" offs_b += bk * stride_bk\n",
" c = c_ptr + get_2d_offset(rm, rn, stride_cm, stride_cn)\n",
" mask = get_2d_mask(rm, rn, m, n)\n",
" tl.store(c, acc, mask=mask)\n",
"```\n",
"\n",
"You raise the question: \"Don't we need a mask when loading a and b?\" I believe we do need a mask when loading a and b, because, we may go out of bounds when accessing memory beyond the matrix boundary.\n",
"\n",
"My understanding is that the mask used in the `tl.store` prevents writing results to out-of-bound memory. This would ensure that we don't write any unintended values to memory, right?\n",
"\n",
"However, Triton doesn't seem to throw an error when memory is accessed out of bounds, so I’m a bit confused.\n",
"\n",
"I’d really appreciate it if you could clarify this for me.\n",
"\n",
"Thank you so much!\n",
"\n",
"Best regards,\n",
"Jiaxing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca0a1c5f-05c0-4d02-bede-18c2876ef28c",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "cc727a4f-db0b-49f5-9e36-e76125734103",
"metadata": {},
"source": [
"**Answer**"
]
},
{
"cell_type": "markdown",
"id": "11dc5941-338b-4bee-b204-7ba09e13e0fd",
"metadata": {},
"source": [
"Hi Jiaxing, very good question!\n",
"\n",
"You're correct that the mask in `tl.store` prevents writing to out-of-bounds memory.\n",
"\n",
"Triton is a low-level language and, like other low-level languages (eg cuda, C,..) doesn't enforce memory safety.<br/>\n",
"Physically, computer memory is nothing else than a very long row of locations (\"addresses\") that contain values.<br/>\n",
"In memory-safe languages (eg Python) each time you use a memory location, the language first checks if that location actually belongs to your program. If not, it throws an error. This is super convenient, but also slower because of extra checks.<br/>\n",
"In non-memory-safe langueges there are no checks, and so you're faster, but it's your responsibility to only access locations belonging to your progam."
]
},
{
"cell_type": "markdown",
"id": "2d7e425e-182e-4e3b-835e-70bc03ba7342",
"metadata": {},
"source": [
"To give a toy example: Say we have a tiny memory with only 4 locations, where the first 2 belong to your program, and the last to another program.\n",
"```\n",
"Memory address [ 0, 1, 2, 3 ]\n",
"Memory value [ 27, 18, -9, 3.4]\n",
"Memory ownership [you, you, other, other]\n",
"```\n",
"And let's say you have a variable x whichs points to location `0`. Ie `x[0]` is `27` and `x[1]` is `18`.\n",
"\n",
"You _can_ read from address 2, ie do `x[2]`, and will get a value (`-9`), but you should never do this, as the value will be effectively random for you.\n",
"And you _can_ write to address 2, but you should never do this, as you'll destroy data of another program."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "81a08228-b980-4e71-bbdc-14c2eb41c912",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['TRITON_INTERPRET'] = '1' # so we can print\n",
"import torch, triton, triton.language as tl"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2a9151ab-c6a2-4368-87be-bb77723e6fb3",
"metadata": {},
"outputs": [],
"source": [
"@triton.jit\n",
"def show(x_ptr, n: tl.constexpr):\n",
" offs = tl.arange(0, n+1) # we're reading 1 value out of bounds!\n",
" print(tl.load(x_ptr + offs))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9013b706-2bf7-4c53-89bf-f371d19e1940",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([2, 2, 2], device='cuda:0', dtype=torch.int16)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = torch.tensor([2,2,2], device='cuda', dtype=torch.int16); x"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "324a329d-4846-4147-bcb9-318d574a5e33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[ 2 2 2 10095]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n",
"[2 2 2 0]\n"
]
}
],
"source": [
"for _ in range(20): show[(1,)](x,x.numel())"
]
},
{
"cell_type": "markdown",
"id": "35fee7cc-f035-476c-b46a-c1022c70faa7",
"metadata": {},
"source": [
"We see: The 4th value can be read, but is basically random."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b916fd3-4699-48e2-b353-22d51cd127ed",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@tomasruizt
Copy link

Nice demonstration! @UmerHA

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment