Skip to content

Instantly share code, notes, and snippets.

@antoinebrl
Last active March 6, 2025 12:01
Show Gist options
  • Save antoinebrl/36e6297d08444e425f54f55e6968ac01 to your computer and use it in GitHub Desktop.
Save antoinebrl/36e6297d08444e425f54f55e6968ac01 to your computer and use it in GitHub Desktop.
Benchmark attribute access - `dict` vs `dataclass` vs `namedtuple` vs `tensordict` vs `tensorclass`

Introduction

Assessesing the efficiency of various data structures in Python. The script introduces several ways to represent 3D points, utilizing both built-in structures like namedtuples, dataclasses (both with and without slots), standard classes, enums, and dictionaries, as well as custom structures like TensorDict and TensorClass. For TensorDict and TensorClass, certain tests separate the object creation from the timed segments.

Each data structure undergoes evaluation through functions made to access the coordinates within a loop. These functions are executed multiple times, and the entire test suite is repeated multiple times too.

Inspired by https://gist.github.com/wolph/02fae0b20b914354734aaac01c06d23b

Python 3.10

$ uv run --no-project --python 3.10 --with tensordict==0.7.2 benchmark-attribute-access.py
slots                    0.267s
dataclass_slots          0.269s
namedtuple_index         0.316s
dict                     0.317s
dict_tensor              0.320s
dataclass                0.321s
namedtuple_attr          0.518s
namedtuple_unpack        0.798s
enum_attr                2.014s
enum_item                2.892s
enum_call                7.421s
tensordict_int_noinit    21.804s
tensordict_int           21.975s
tensorclass_int_noinit   22.072s
tensorclass              22.228s
tensorclass_int          22.290s
tensorclass_autocast     23.070s

Python 3.12

$ uv run --no-project --python 3.12 --with tensordict==0.7.2 benchmark-attribute-access.py
slots                    0.164s
dataclass_slots          0.165s
dataclass                0.176s
dict                     0.239s
dict_tensor              0.244s
enum_attr                0.262s
namedtuple_index         0.289s
namedtuple_attr          0.293s
namedtuple_unpack        0.686s
enum_item                0.737s
enum_call                3.570s
tensorclass_int_noinit   7.878s
tensorclass_autocast     8.091s
tensorclass_int          8.118s
tensorclass              8.118s
tensordict_int_noinit    11.428s
tensordict_int           11.587s
import enum
import random
import timeit
import typing
import dataclasses
import collections
from tensordict import TensorDict, TensorClass
from torch import Tensor
repeat = 5
calls = 1000
iterations = 5000
###
# Data structures
###
class PointTuple(typing.NamedTuple):
x: int
y: int
z: int
@dataclasses.dataclass
class PointDataclass:
x: int
y: int
z: int
@dataclasses.dataclass(slots=True)
class PointDataclassSlots:
x: int
y: int
z: int
class PointObject:
__slots__ = "x", "y", "z"
x: int
y: int
z: int
class PointEnum(enum.Enum):
x = 1
y = 2
z = 3
class PointTensorClass(TensorClass):
x: Tensor
y: Tensor
z: Tensor
class PointTensorClassAutocast(TensorClass, autocast=True):
x: Tensor
y: Tensor
z: Tensor
class PointTensorClassInt(TensorClass):
x: int
y: int
z: int
###
# Test functions
###
def test_namedtuple_attr():
point = PointTuple(1234, 5678, 9012)
for i in range(iterations):
x, y, z = point.x, point.y, point.z
def test_namedtuple_index():
point = PointTuple(1234, 5678, 9012)
for i in range(iterations):
x, y, z = point
def test_namedtuple_unpack():
point = PointTuple(1234, 5678, 9012)
for i in range(iterations):
x, *y = point
def test_dataclass():
point = PointDataclass(1234, 5678, 9012)
for i in range(iterations):
x, y, z = point.x, point.y, point.z
def test_dataclass_slots():
point = PointDataclassSlots(1234, 5678, 9012)
for i in range(iterations):
x, y, z = point.x, point.y, point.z
def test_dict():
point = dict(x=1234, y=5678, z=9012)
for i in range(iterations):
x, y, z = point["x"], point["y"], point["z"]
def test_dict_tensor():
point = dict(x=Tensor([1234]), y=Tensor([5678]), z=Tensor([9012]))
for i in range(iterations):
x, y, z = point["x"], point["y"], point["z"]
def test_slots():
point = PointObject()
point.x = 1234
point.y = 5678
point.z = 9012
for i in range(iterations):
x, y, z = point.x, point.y, point.z
def test_enum_attr():
point = PointEnum
for i in range(iterations):
x, y, z = point.x, point.y, point.z
def test_enum_call():
point = PointEnum
for i in range(iterations):
x, y, z = point(1), point(2), point(3)
def test_enum_item():
point = PointEnum
for i in range(iterations):
x, y, z = point["x"], point["y"], point["z"]
def test_tensordict_int():
point = TensorDict({"x": 1234, "y": 5678, "z": 9012})
for i in range(iterations):
x, y, z = point["x"], point["y"], point["z"]
def test_tensorclass():
point = PointTensorClass(x=1234, y=5678, z=9012)
for i in range(iterations):
x, y, z = point.x, point.y, point.z
def test_tensorclass_autocast():
point = PointTensorClassAutocast(x=1234, y=5678, z=9012)
for i in range(iterations):
x, y, z = point.x, point.y, point.z
def test_tensorclass_int():
point = PointTensorClassInt(x=1234, y=5678, z=9012)
for i in range(iterations):
x, y, z = point.x, point.y, point.z
point1 = PointTensorClassInt(x=1234, y=5678, z=9012)
def test_tensorclass_int_noinit():
for i in range(iterations):
x, y, z = point1.x, point1.y, point1.z
point2 = TensorDict({"x": 1234, "y": 5678, "z": 9012})
def test_tensordict_int_noinit():
for i in range(iterations):
x, y, z = point2["x"], point2["y"], point2["z"]
if __name__ == "__main__":
tests = [
test_namedtuple_attr,
test_namedtuple_index,
test_namedtuple_unpack,
test_dataclass,
test_dataclass_slots,
test_dict,
test_dict_tensor,
test_slots,
test_enum_attr,
test_enum_call,
test_enum_item,
test_tensordict_int,
test_tensordict_int_noinit,
test_tensorclass,
test_tensorclass_autocast,
test_tensorclass_int,
test_tensorclass_int_noinit,
]
print(f"Running tests {repeat} times with {calls} calls, using {iterations} iterations in the loop")
results = collections.defaultdict(lambda: 0)
for i in range(repeat):
# Shuffling tests to prevent skewed results due to CPU boosting or throttling
random.shuffle(tests)
print(f"Run {i}", flush=True)
for t in tests:
name = t.__name__
timer = timeit.Timer(f"{name}()", f"from __main__ import {name}")
results[name] = results[name] + timer.timeit(calls)
for name, result in sorted(results.items(), key=lambda x: x[::-1]):
print(f"{name.replace('test_', ''):24} {result / repeat:.3f}s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment