Skip to content

Instantly share code, notes, and snippets.

@cowmix
Forked from damico/test-rocm.py
Created November 22, 2024 15:35
Show Gist options
  • Save cowmix/96ebe94f72897bb968155aa36d2ff4e5 to your computer and use it in GitHub Desktop.
Save cowmix/96ebe94f72897bb968155aa36d2ff4e5 to your computer and use it in GitHub Desktop.
Script for testing PyTorch support with AMD GPUs using ROCM
import torch
import grp
import pwd
import os
import subprocess
devices = []
try:
print("\n\nChecking ROCM support...")
result = subprocess.run(['rocminfo'], stdout=subprocess.PIPE)
cmd_str = result.stdout.decode('utf-8')
cmd_split = cmd_str.split('Agent ')
for part in cmd_split:
item_single = part[0:1]
item_double = part[0:2]
if item_single.isnumeric() or item_double.isnumeric():
new_split = cmd_str.split('Agent '+item_double)
device = new_split[1].split('Marketing Name:')[0].replace(' Name: ', '').replace('\n','').replace(' ','').split('Uuid:')[0].split('*******')[1]
devices.append(device)
if len(devices) > 0:
print('GOOD: ROCM devices found: ', len(devices))
else:
print('BAD: No ROCM devices found.')
print("Checking PyTorch...")
x = torch.rand(5, 3)
has_torch = False
len_x = len(x)
if len_x == 5:
has_torch = True
for i in x:
if len(i) == 3:
has_torch = True
else:
has_torch = False
if has_torch:
print('GOOD: PyTorch is working fine.')
else:
print('BAD: PyTorch is NOT working.')
print("Checking user groups...")
# More reliable way to get current user in containers
try:
user = os.environ.get('USER')
if not user:
user = pwd.getpwuid(os.getuid()).pw_name
except:
user = pwd.getpwuid(os.getuid()).pw_name
# Get groups for the current user
groups = []
try:
gid = os.getgid()
groups = [g.gr_name for g in grp.getgrall() if (user in g.gr_mem) or (g.gr_gid == gid)]
except:
print("WARNING: Unable to get complete group information")
# Fallback to checking just primary group
try:
gid = os.getgid()
groups = [grp.getgrgid(gid).gr_name]
except:
groups = []
if 'render' in groups and 'video' in groups:
print('GOOD: The user', user, 'is in RENDER and VIDEO groups.')
else:
print('BAD: The user', user, 'is NOT in RENDER and VIDEO groups. This is necessary in order to PyTorch use HIP resources')
if torch.cuda.is_available():
print("GOOD: PyTorch ROCM support found.")
t = torch.tensor([5, 5, 5], dtype=torch.int64, device='cuda')
print('Testing PyTorch ROCM support...')
if str(t) == "tensor([5, 5, 5], device='cuda:0')":
print('Everything fine! You can run PyTorch code inside of: ')
for device in devices:
print('---> ', device)
else:
print("BAD: PyTorch ROCM support NOT found.")
except Exception as e:
print('Cannot find rocminfo command information. Unable to determine if AMDGPU drivers with ROCM support were installed.')
print(f'Error details: {str(e)}')
@cowmix
Copy link
Author

cowmix commented Nov 22, 2024

my fix: improve user detection for Docker/headless environments

Replaced os.getlogin() with more resilient user/group detection methods to support
containerized environments where standard login info might not be available. Now falls back
to env vars and UID-based lookups when traditional methods fail.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment