Skip to content

Instantly share code, notes, and snippets.

@sueszli
Created January 3, 2026 16:58
Show Gist options
  • Select an option

  • Save sueszli/e3aced3ecd0392608ac6e2d47742b5eb to your computer and use it in GitHub Desktop.

Select an option

Save sueszli/e3aced3ecd0392608ac6e2d47742b5eb to your computer and use it in GitHub Desktop.
I think it might make sense to `@cache`.
retrieval benchmark results:
- mixed types: 1.57x speedup
- repeated types: 2.31x speedup
- unique types: 0.68x speedup
- avg speed: 1.52x
❯ uv run perf_convert_type.py
scenario: mixed types (realistic)
items: 9
uncached: mean 5.88us, max 6.54us
cached: mean 3.71us, max 4.04us
speedup: 1.58x
info: CacheInfo(hits=218, misses=9, maxsize=None, currsize=9)
scenario: repeated types (high cache hit rate)
items: 30
uncached: mean 12.60us, max 12.96us
cached: mean 5.13us, max 5.33us
speedup: 2.46x
info: CacheInfo(hits=747, misses=3, maxsize=None, currsize=3)
scenario: unique types (low cache hit rate)
items: 30
uncached: mean 5.44us, max 5.58us
cached: mean 8.37us, max 8.75us
speedup: 0.65x
info: CacheInfo(hits=720, misses=30, maxsize=None, currsize=30)
mixed types (realistic) - speedup: 1.58x (+36.8%)
repeated types (high cache hit rate) - speedup: 2.46x (+59.3%)
unique types (low cache hit rate) - speedup: 0.65x (-54.0%)
conclusion:
avg speedup: 1.56x
import time
from functools import cache
import llvmlite.ir as ir
from xdsl.dialects.builtin import (
Float32Type,
Float64Type,
IndexType,
IntegerType,
VectorType,
)
from xdsl.dialects.llvm import LLVMArrayType, LLVMPointerType, LLVMStructType
from xdsl.ir import Attribute
from xdsl.utils.exceptions import LLVMTranslationException
#
# uncached version
#
def convert_type_uncached(type_attr: Attribute) -> ir.Type:
match type_attr:
case IntegerType():
return ir.IntType(type_attr.bitwidth)
case IndexType():
return ir.IntType(64)
case Float32Type():
return ir.FloatType()
case Float64Type():
return ir.DoubleType()
case LLVMPointerType():
return ir.PointerType()
case VectorType():
if type_attr.get_num_scalable_dims() > 0:
raise LLVMTranslationException("Scalable vectors not supported")
if type_attr.get_num_dims() != 1:
raise LLVMTranslationException(
"Multi-dimensional vectors not supported"
)
return ir.VectorType(
convert_type_uncached(type_attr.element_type),
type_attr.get_shape()[0],
)
case LLVMArrayType():
return ir.ArrayType(
convert_type_uncached(type_attr.type), type_attr.size.data
)
case LLVMStructType():
return ir.LiteralStructType(
[convert_type_uncached(t) for t in type_attr.types.data]
)
case _:
raise LLVMTranslationException(f"Type not supported: {type_attr}")
#
# cached version
#
@cache
def convert_type_cached(type_attr: Attribute) -> ir.Type:
match type_attr:
case IntegerType():
return ir.IntType(type_attr.bitwidth)
case IndexType():
return ir.IntType(64)
case Float32Type():
return ir.FloatType()
case Float64Type():
return ir.DoubleType()
case LLVMPointerType():
return ir.PointerType()
case VectorType():
if type_attr.get_num_scalable_dims() > 0:
raise LLVMTranslationException("Scalable vectors not supported")
if type_attr.get_num_dims() != 1:
raise LLVMTranslationException(
"Multi-dimensional vectors not supported"
)
return ir.VectorType(
convert_type_cached(type_attr.element_type),
type_attr.get_shape()[0],
)
case LLVMArrayType():
return ir.ArrayType(convert_type_cached(type_attr.type), type_attr.size.data)
case LLVMStructType():
return ir.LiteralStructType(
[convert_type_cached(t) for t in type_attr.types.data]
)
case _:
raise LLVMTranslationException(f"Type not supported: {type_attr}")
def create_test_types():
return [
IntegerType(32),
IntegerType(64),
IntegerType(1),
IndexType(),
Float32Type(),
Float64Type(),
LLVMPointerType(),
VectorType(Float32Type(), [4]),
VectorType(IntegerType(32), [8]),
]
def create_repeated_types():
base_types = [
IntegerType(32),
Float32Type(),
LLVMPointerType(),
]
# repeat each type multiple times
return base_types * 10
def create_unique_types():
return [IntegerType(i) for i in range(1, 31)]
def benchmark(convert_fn, test_types, warmup_rounds=5, iterations=20):
# warmup
for _ in range(warmup_rounds):
for t in test_types:
convert_fn(t)
# actual benchmark
times = []
for _ in range(iterations):
start = time.perf_counter()
for t in test_types:
convert_fn(t)
end = time.perf_counter()
times.append(end - start)
return times
def run_scenario(name, test_types):
print(f"scenario: {name}")
print(f"items: {len(test_types)}")
# benchmark uncached version
uncached_times = benchmark(convert_type_uncached, test_types)
uncached_mean = sum(uncached_times) / len(uncached_times)
uncached_max = max(uncached_times)
# clear cache and benchmark cached version
convert_type_cached.cache_clear()
cached_times = benchmark(convert_type_cached, test_types)
cached_mean = sum(cached_times) / len(cached_times)
cached_max = max(cached_times)
# results
print(f"uncached: mean {uncached_mean*1e6:.2f}us, max {uncached_max*1e6:.2f}us")
print(f"cached: mean {cached_mean*1e6:.2f}us, max {cached_max*1e6:.2f}us")
speedup = uncached_mean / cached_mean if cached_mean > 0 else 0
diff_pct = (
((uncached_mean - cached_mean) / uncached_mean * 100)
if uncached_mean > 0
else 0
)
print(f"speedup: {speedup:.2f}x")
print(f"info: {convert_type_cached.cache_info()}")
print()
return {
"name": name,
"uncached_mean": uncached_mean,
"cached_mean": cached_mean,
"speedup": speedup,
"diff_pct": diff_pct,
}
def main():
scenarios = [
("mixed types (realistic)", create_test_types()),
("repeated types (high cache hit rate)", create_repeated_types()),
("unique types (low cache hit rate)", create_unique_types()),
]
results = []
for name, test_types in scenarios:
result = run_scenario(name, test_types)
results.append(result)
for r in results:
print(
f"{r['name']:40s} - speedup: {r['speedup']:.2f}x ({r['diff_pct']:+.1f}%)"
)
print()
print("conclusion:")
avg_speedup = sum(r["speedup"] for r in results) / len(results)
print(f"avg speedup: {avg_speedup:.2f}x")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment