Skip to content

Instantly share code, notes, and snippets.

@DEKHTIARJonathan
Created February 12, 2026 06:58
Show Gist options
  • Select an option

  • Save DEKHTIARJonathan/dfd559e7966df7dbeb494abfc947dbd3 to your computer and use it in GitHub Desktop.

Select an option

Save DEKHTIARJonathan/dfd559e7966df7dbeb494abfc947dbd3 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from pprint import pprint
from variantlib.models.variant import VariantDescription
from variantlib.models.variant import VariantProperty
from variantlib.resolver.filtering import filter_variants_by_property
from variantlib.resolver.sorting import sort_variants_descriptions
if __name__ == "__main__":
NVIDIA_LOCAL_SM = 85 # CUDA 12.6 & 12.8
# NVIDIA_LOCAL_SM = 104 # Only CUDA 12.8
NVIDIA_SM_MAJOR = int(str(NVIDIA_LOCAL_SM)[:-1])
NVIDIA_SM_MINOR = int(str(NVIDIA_LOCAL_SM)[-1])
vprops_proprioty_list = [
# NS `nvidia`
*[
VariantProperty(
namespace="nvidia",
feature="cuda_version_lower_bound",
value=f"12.{minor}",
)
for minor in range(9, -1, -1)
],
*[
VariantProperty(
namespace="nvidia",
feature="sm_arch",
value=f"{NVIDIA_SM_MAJOR}{minor}_real",
)
for minor in range(NVIDIA_SM_MINOR, -1, -1)
],
VariantProperty(
namespace="nvidia",
feature="sm_arch",
value=f"{NVIDIA_SM_MAJOR}0_virtual",
),
# NS `priority`
VariantProperty(namespace="priority", feature="order", value="1"),
VariantProperty(namespace="priority", feature="order", value="2"),
VariantProperty(namespace="priority", feature="order", value="3"),
]
torch_12_6 = VariantDescription(
[
VariantProperty(
namespace="nvidia", feature="cuda_version_lower_bound", value="12.0"
),
VariantProperty(namespace="nvidia", feature="sm_arch", value="50_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="60_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="70_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="75_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="80_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="86_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="90_real"),
VariantProperty(namespace="priority", feature="order", value="3"),
]
)
torch_12_8 = VariantDescription(
[
VariantProperty(
namespace="nvidia", feature="cuda_version_lower_bound", value="12.0"
),
VariantProperty(namespace="nvidia", feature="sm_arch", value="70_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="75_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="80_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="86_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="90_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="100_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="120_real"),
VariantProperty(namespace="priority", feature="order", value="2"),
]
)
torch_13_0 = VariantDescription(
[
VariantProperty(
namespace="nvidia", feature="cuda_version_lower_bound", value="13.0"
),
VariantProperty(namespace="nvidia", feature="sm_arch", value="75_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="80_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="86_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="90_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="100_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="120_real"),
VariantProperty(namespace="nvidia", feature="sm_arch", value="120_virtual"),
VariantProperty(namespace="priority", feature="order", value="1"),
]
)
variant_label_map = {
torch_13_0.hexdigest: "cuda13.0",
torch_12_8.hexdigest: "cuda12.8",
torch_12_6.hexdigest: "cuda12.6",
}
print(("=" * 25) + " Compatible Variant Properties " + ("=" * 25))
pprint(vprops_proprioty_list)
print(("=" * 25) + " Filtered Variant Descriptions " + ("=" * 25))
vdescs = [torch_12_6, torch_12_8, torch_13_0]
vdescs_filtered = list(
filter_variants_by_property(
vdescs=vdescs,
allowed_properties=vprops_proprioty_list,
forbidden_properties=[],
)
)
# Should drop CUDA 13.0
print(
f"Filtered Variants: {[variant_label_map[vd.hexdigest] for vd in vdescs_filtered]}"
)
print(("=" * 25) + " Sorted Variant Descriptions " + ("=" * 25))
vdescs_sorted = sort_variants_descriptions(
vdescs=vdescs_filtered,
property_priorities=vprops_proprioty_list,
)
print(
f"Sorted & Filtered Variants: {[variant_label_map[vd.hexdigest] for vd in vdescs_sorted]}"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment