Last active
October 28, 2019 15:29
-
-
Save rejunity/c29123220b00675ab760c350bf9b1e9d to your computer and use it in GitHub Desktop.
Python script to detect Thread Group Shared Memory bank conflicts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Script to detect bank conflicts in Thread Group Shared Memory | |
# Supports optional Multicast (default: On), variable wave (default: 32) and bank count (default: 32) | |
# bank_conflicts(ptrs, wavesize=32, banks=32, multicast=True): | |
# Use: | |
# np.sum(bank_conflicts([thread.x*8 for thread in thread_group(8,8)])) | |
# np.sum(bank_conflicts([thread.id*2 for thread in thread_group(64)])) | |
# Conflicts per wave by bank number: | |
# bank_conflicts([thread.x*8 for thread in thread_group(8,8)]) | |
# bank_conflicts([thread.id*2 for thread in thread_group(64)]) | |
# How it works: | |
# 1) thread_group() returns array of objects with `x`, `y`, `z` and `id` props mimicking SV_GroupThreadID and SV_GroupIndex. | |
# 2) From that array one can simulate access pointers, for example [thread.id*2 for thread ... ] will stride memory accessing every 2nd element. | |
# 3) bank_conflicts() takes array of access pointers, splits into waves and returns list of conflicts per bank ID. | |
# 4) Finally, use np.sum() to get the total number of conflicts | |
import numpy as np | |
def bank_conflicts(ptrs, wavesize=32, banks=32, multicast=True): | |
def chunk(l, n): # chunk() to split group workload into waves | |
for i in range(0, len(l), n): | |
yield l[i:i + n] | |
def wave_bank_conflicts(ptrs): | |
C = np.zeros(banks) # initialize conflict counters to 0 | |
P = C - 1 # initialize banked pointer access table to -1 | |
for p in ptrs: | |
bank = p % banks | |
if P[bank] >= 0 and (P[bank] != p or not multicast): | |
C[bank] += 1 # found confict! | |
P[bank] = p | |
return list(C) | |
if len(ptrs) <= wavesize: | |
return wave_bank_conflicts(ptrs) | |
return [wave_bank_conflicts(wave) for wave in chunk(ptrs, wavesize)] | |
def thread_group(x,y=1,z=1): | |
class TID(object): | |
def __init__(self, x,y,z,id): | |
self.x = x; self.y = y; self.z = z; self.id = id | |
def __str__(self): | |
return str([self.x,self.y,self.z]) | |
def __repr__(self): | |
return str([self.x,self.y,self.z]) | |
all = [] | |
for k in range(z): | |
for j in range(y): | |
for i in range(x): | |
ti = i + j*x + k*x*y | |
all.append(TID(i,j,k,ti)) | |
return all | |
np.sum(bank_conflicts([thread.x*8 for thread in thread_group(8,8)])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment