Created
February 12, 2026 06:58
-
-
Save DEKHTIARJonathan/dfd559e7966df7dbeb494abfc947dbd3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| from pprint import pprint | |
| from variantlib.models.variant import VariantDescription | |
| from variantlib.models.variant import VariantProperty | |
| from variantlib.resolver.filtering import filter_variants_by_property | |
| from variantlib.resolver.sorting import sort_variants_descriptions | |
| if __name__ == "__main__": | |
| NVIDIA_LOCAL_SM = 85 # CUDA 12.6 & 12.8 | |
| # NVIDIA_LOCAL_SM = 104 # Only CUDA 12.8 | |
| NVIDIA_SM_MAJOR = int(str(NVIDIA_LOCAL_SM)[:-1]) | |
| NVIDIA_SM_MINOR = int(str(NVIDIA_LOCAL_SM)[-1]) | |
| vprops_proprioty_list = [ | |
| # NS `nvidia` | |
| *[ | |
| VariantProperty( | |
| namespace="nvidia", | |
| feature="cuda_version_lower_bound", | |
| value=f"12.{minor}", | |
| ) | |
| for minor in range(9, -1, -1) | |
| ], | |
| *[ | |
| VariantProperty( | |
| namespace="nvidia", | |
| feature="sm_arch", | |
| value=f"{NVIDIA_SM_MAJOR}{minor}_real", | |
| ) | |
| for minor in range(NVIDIA_SM_MINOR, -1, -1) | |
| ], | |
| VariantProperty( | |
| namespace="nvidia", | |
| feature="sm_arch", | |
| value=f"{NVIDIA_SM_MAJOR}0_virtual", | |
| ), | |
| # NS `priority` | |
| VariantProperty(namespace="priority", feature="order", value="1"), | |
| VariantProperty(namespace="priority", feature="order", value="2"), | |
| VariantProperty(namespace="priority", feature="order", value="3"), | |
| ] | |
| torch_12_6 = VariantDescription( | |
| [ | |
| VariantProperty( | |
| namespace="nvidia", feature="cuda_version_lower_bound", value="12.0" | |
| ), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="50_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="60_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="70_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="75_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="80_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="86_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="90_real"), | |
| VariantProperty(namespace="priority", feature="order", value="3"), | |
| ] | |
| ) | |
| torch_12_8 = VariantDescription( | |
| [ | |
| VariantProperty( | |
| namespace="nvidia", feature="cuda_version_lower_bound", value="12.0" | |
| ), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="70_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="75_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="80_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="86_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="90_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="100_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="120_real"), | |
| VariantProperty(namespace="priority", feature="order", value="2"), | |
| ] | |
| ) | |
| torch_13_0 = VariantDescription( | |
| [ | |
| VariantProperty( | |
| namespace="nvidia", feature="cuda_version_lower_bound", value="13.0" | |
| ), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="75_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="80_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="86_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="90_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="100_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="120_real"), | |
| VariantProperty(namespace="nvidia", feature="sm_arch", value="120_virtual"), | |
| VariantProperty(namespace="priority", feature="order", value="1"), | |
| ] | |
| ) | |
| variant_label_map = { | |
| torch_13_0.hexdigest: "cuda13.0", | |
| torch_12_8.hexdigest: "cuda12.8", | |
| torch_12_6.hexdigest: "cuda12.6", | |
| } | |
| print(("=" * 25) + " Compatible Variant Properties " + ("=" * 25)) | |
| pprint(vprops_proprioty_list) | |
| print(("=" * 25) + " Filtered Variant Descriptions " + ("=" * 25)) | |
| vdescs = [torch_12_6, torch_12_8, torch_13_0] | |
| vdescs_filtered = list( | |
| filter_variants_by_property( | |
| vdescs=vdescs, | |
| allowed_properties=vprops_proprioty_list, | |
| forbidden_properties=[], | |
| ) | |
| ) | |
| # Should drop CUDA 13.0 | |
| print( | |
| f"Filtered Variants: {[variant_label_map[vd.hexdigest] for vd in vdescs_filtered]}" | |
| ) | |
| print(("=" * 25) + " Sorted Variant Descriptions " + ("=" * 25)) | |
| vdescs_sorted = sort_variants_descriptions( | |
| vdescs=vdescs_filtered, | |
| property_priorities=vprops_proprioty_list, | |
| ) | |
| print( | |
| f"Sorted & Filtered Variants: {[variant_label_map[vd.hexdigest] for vd in vdescs_sorted]}" | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment