Last active
December 24, 2020 01:11
-
-
Save thvasilo/21af37da926f9f7d62bfc7eb5953ea00 to your computer and use it in GitHub Desktop.
A matrix multiplication benchmark.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Benchmark for measuring matrix multiplication speed, Martin Nilsson, Rise SICS | |
# relevant for certain Machine Learning tasks v1.0 2017-11-21 | |
# v1.1 Theodore Vasiloudis (PyTortch solution) | |
# ==================================================== | |
# Run by: | |
# | |
# python3 multiplytest.py 10000 | |
# | |
# to measure squaring a 10000 x 10000 random matrix. | |
# Weirdly enough K80 and Titan X get different results prolly something to do with numerical accuracy. | |
from numpy.random import seed, randn | |
from numpy import float32, float64 | |
from time import process_time, perf_counter | |
from sys import argv, maxsize | |
from platform import uname, python_version | |
import torch | |
def multiplytest(n): | |
seed(1) | |
# Choose precision here (standard is float64, NOT float32!) | |
a = randn(n, n).astype(float64) | |
print('Size of matrix ({0} x {0}): {1:.0f}MB'.format(n, a.nbytes/1e6)) | |
# Convert to PyTorch Tensor and copy to GPU | |
a = torch.Tensor(a).cuda() | |
# Perform measurement | |
t0 = process_time() | |
t1 = perf_counter() | |
b = torch.mm(a, a.t()) # Calls is async so this doesn't really measure much. | |
t0 = process_time() - t0 | |
t1 = perf_counter() - t1 | |
# Print timing results | |
# The time is cubic in n unless the system implements Strassen-type | |
# matrix multiplication (but BLAS implementations usually don't). | |
print('{0:.3f}s process time (total for all cores)'.format(t0)) | |
print('{0:.3f}s perf time (elapsed time)'.format(t1)) | |
# Print general computer info | |
u = uname() | |
print('Operating system:', u.system, 'release', u.release + ',', | |
'64-bit,' if (maxsize > 2**31 - 1) else '32-bit;', | |
'Python version', python_version(), | |
'\nCPU:', u.processor) | |
try: | |
from psutil import cpu_count, virtual_memory, cpu_freq | |
print('Cores:', | |
cpu_count(logical=False), 'physical,', | |
cpu_count(logical=True), 'logical;', | |
'RAM: {0:.3f}GB total'.format(virtual_memory().total/1e9)) | |
print('Current CPU frequency: {0:.3f}GHz'.format(cpu_freq().current/1e3)) | |
except: | |
print('(Install psutil to find more details!') | |
print(' You may have to do \'sudo apt-get install python-dev\'') | |
print(' or similarly before \'pip install psutil\'.)') | |
# System specific information | |
if u.system == 'Windows': | |
from subprocess import check_output | |
info = check_output('wmic cpu get name').decode().split()[1:] | |
print(' '.join(info)) | |
elif u.system == 'Linux': | |
from platform import linux_distribution | |
print('Distribution: {0} {1}'.format(*linux_distribution())) | |
from re import sub | |
from subprocess import check_output | |
info = check_output('cat /proc/cpuinfo', shell=True).decode().split('\n') | |
for line in info: | |
if 'model name' in line: | |
info = sub( '.*odel name.*:', '', line, 1) | |
break | |
for line in info: | |
if 'ardware' in line: | |
info += ', ' + sub( '.*ardware.*:', '', line, 1) | |
break | |
print(info) | |
elif u.system == 'Darwin': | |
from platform import mac_ver | |
m = mac_ver() | |
print('Release {0},'.format(m.release), | |
'version {1},'.format(m.versioninfo), | |
'machine {2}'.format(m.machine)) | |
# Sanity check - returned value should be approximately 1.0 | |
from numpy.linalg import norm | |
b = b.cpu().numpy() # Copy back to main memory and convert to Numpy | |
for j in range(n): | |
b[j,j] -= n | |
return norm(b)**2/n**3 | |
# On Titan X: .0.99979501050625 | |
# on K80: .9972694914682851 | |
# For calling from command line | |
if __name__=='__main__': | |
if len(argv) == 2: | |
print("Returned value: {}".format(multiplytest(int(argv[1])))) | |
else: | |
print("Returned value: {}".format(multiplytest(10000))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment