Last active
August 13, 2024 21:08
-
-
Save moyix/5bac4b2e383a466b7d015b8c04db13b5 to your computer and use it in GitHub Desktop.
Some handy utils for messing with MXCSR (x86-64 SSE FPU control register)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import sys, os | |
import platform | |
import ctypes as ct | |
import mmap | |
from enum import Enum | |
import importlib | |
import functools | |
import errno | |
COLOR_RED = "\033[31m" | |
COLOR_GREEN = "\033[32m" | |
COLOR_RESET = "\033[0m" | |
# Ensure we're on x86_64 | |
if platform.machine() != 'x86_64' or sys.maxsize <= 2**32: | |
raise RuntimeError("This module only works on x86_64") | |
# Set up a RWX buffer so we can put some assembly into it | |
_code_buf = mmap.mmap(-1, mmap.PAGESIZE, prot=mmap.PROT_READ | mmap.PROT_WRITE) | |
_set_mxcsr_asm = ( | |
b"\x0F\xAE\x17" # ldmxcsr [rdi] | |
b"\xc3" # ret | |
b"\x90" * 4 # padding | |
) | |
_get_mxcsr_asm = ( | |
b"\x0F\xAE\x1F" # stmxcsr [rdi] | |
b"\xc3" # ret | |
b"\x90" * 4 # padding | |
) | |
# Copy the assembly into the buffer | |
_code_buf_addr = ct.addressof(ct.c_void_p.from_buffer(_code_buf)) | |
_code_buf.write(_set_mxcsr_asm) | |
_set_mxcsr_addr = _code_buf_addr | |
_code_buf.write(_get_mxcsr_asm) | |
_get_mxcsr_addr = _set_mxcsr_addr+len(_set_mxcsr_asm) | |
# Make our code buffer read-only after we're done with it. | |
mprotect = ct.CDLL(None, use_errno=True).mprotect | |
mprotect.argtypes = [ct.c_void_p, ct.c_size_t, ct.c_int] | |
mprotect.restype = ct.c_int | |
if mprotect(_code_buf_addr, mmap.PAGESIZE, mmap.PROT_READ | mmap.PROT_EXEC) != 0: | |
e = ct.get_errno() | |
raise OSError("mprotect: " + errno.errorcode[e] + f" ({os.strerror(e)})") | |
############################################################################################################################## | |
# Bits of the MXCSR register. Diagram was # +----+--------+----+----+----+----+----+----+----+----+----+----+----+----+----+ # | |
# converted to ASCII-art from Figure 10-3 # | 15 | 14 13 | 12 | 11 | 10 | 9 | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 | # | |
# in the Intel 64 and IA-32 Architectures # +----+--------+----+----+----+----+----+----+----+----+----+----+----+----+----+ # | |
# Software Developer's Manual, Volume 1. # FZ RC PM UM OM ZM DM IM DAZ PE UE OE ZE DE IE # | |
# ######################################### | | | | | | | | | | | | | | | # | |
# Flush to Zero -------------------------------' | | | | | | | | | | | | | | # | |
# Rounding Control -----------------------------------' | | | | | | | | | | | | | # | |
# Precision Mask --------------------------------------------' | | | | | | | | | | | | # | |
# Underflow Mask -------------------------------------------------' | | | | | | | | | | | # | |
# Overflow Mask -------------------------------------------------------' | | | | | | | | | | # | |
# Divide-by-Zero Mask ------------------------------------------------------' | | | | | | | | | # | |
# Denormal Operation Mask -------------------------------------------------------' | | | | | | | | # | |
# Invalid Operation Mask -------------------------------------------------------------' | | | | | | | # | |
# Denormals Are Zeros ---------------------------------------------------------------------' | | | | | | # | |
# Precision Flag -------------------------------------------------------------------------------' | | | | | # | |
# Underflow Flag ------------------------------------------------------------------------------------' | | | | # | |
# Overflow Flag ------------------------------------------------------------------------------------------' | | | # | |
# Divide-by-Zero Flag -----------------------------------------------------------------------------------------' | | # | |
# Denormal Flag ----------------------------------------------------------------------------------------------------' | # | |
# Invalid Operation Flag ------------------------------------------------------------------------------------------------' # | |
############################################################################################################################## | |
class MXCSR_bits(ct.LittleEndianStructure): | |
_fields_ = [ | |
("IE", ct.c_uint32, 1), | |
("DE", ct.c_uint32, 1), | |
("ZE", ct.c_uint32, 1), | |
("OE", ct.c_uint32, 1), | |
("UE", ct.c_uint32, 1), | |
("PE", ct.c_uint32, 1), | |
("DAZ", ct.c_uint32, 1), | |
("IM", ct.c_uint32, 1), | |
("DM", ct.c_uint32, 1), | |
("ZM", ct.c_uint32, 1), | |
("OM", ct.c_uint32, 1), | |
("UM", ct.c_uint32, 1), | |
("PM", ct.c_uint32, 1), | |
("RC", ct.c_uint32, 2), | |
("FZ", ct.c_uint32, 1), | |
("reserved", ct.c_uint32, 16), | |
] | |
class RoundingModes(Enum): | |
RoundToNearest = 0 | |
RoundDown = 1 | |
RoundUp = 2 | |
RoundTowardsZero = 3 | |
def short(self): | |
return ["RN", "RD", "RU", "RZ"][self.value] | |
full_names = [ | |
"Invalid Operation Flag", | |
"Denormal Flag", | |
"Divide-by-Zero Flag", | |
"Overflow Flag", | |
"Underflow Flag", | |
"Precision Flag", | |
"Denormals Are Zeros", | |
"Invalid Operation Mask", | |
"Denormal Operation Mask", | |
"Divide-by-Zero Mask", | |
"Overflow Mask", | |
"Underflow Mask", | |
"Precision Mask", | |
"Rounding Control", | |
"Flush to Zero", | |
] | |
# Draw FZ and DAZ in red if they are set and output is a terminal | |
@staticmethod | |
def _colorize(s, name, value): | |
if sys.stdout.isatty() and name in ["FZ", "DAZ"] and value: | |
return COLOR_RED + s + COLOR_RESET | |
return s | |
# Custom __str__ method to print MXCSR register | |
def __str__(self): | |
bits_set = [(nm if getattr(self, nm) else ' '*len(nm)) | |
for nm,_,_ in self._fields_ if nm not in ['RC', 'reserved']] | |
bits_set = [self._colorize(s, s, 1) for s in bits_set] | |
bits_set += [f"RC={self.RoundingModes(self.RC).short()}"] | |
bits_str = ",".join(bits_set) | |
return f"MXCSR({bits_str})" | |
__repr__ = __str__ | |
# Custom __setattr__ method to prevent setting reserved bits | |
def __setattr__(self, name, value): | |
if name == "reserved": | |
raise ValueError("Cannot set reserved bits") | |
super().__setattr__(name, value) | |
# A more verbose description | |
def describe(self): | |
MAX_NAME_LEN = max([len(n) for n in self.full_names])+1 | |
s = "MXCSR register:" | |
for field, full_name in zip(self._fields_, self.full_names): | |
name = field[0] | |
full_name_col = full_name.ljust(MAX_NAME_LEN) | |
value = getattr(self, name) | |
if ('Flag' in full_name or 'Mask' in full_name or | |
full_name == 'Flush to Zero' or full_name == 'Denormals Are Zeros'): | |
value_s = 'Set' if value else 'Clear' | |
elif 'Rounding Control' in full_name: | |
value_s = self.RoundingModes(value).name | |
else: | |
raise ValueError(f"You forgot one: {full_name}") | |
s += self._colorize(f"\n {full_name_col}: {value_s}", name, value) | |
return s | |
class MXCSR(ct.Union): | |
_fields_ = [ | |
("bits", MXCSR_bits), | |
("value", ct.c_uint32), | |
] | |
# Convenience function to get power-on MXCSR value | |
RESET_VALUE = 0x1f80 | |
@staticmethod | |
def initial(): | |
return MXCSR(value=MXCSR.RESET_VALUE) | |
# Convenience function to create MXCSR from a dictionary | |
@staticmethod | |
def from_dict(d): | |
mxcsr = MXCSR() | |
for k, v in d.items(): | |
setattr(mxcsr.bits, k, v) | |
return mxcsr | |
def __repr__(self): | |
return f"MXCSR({self.value:#x})" | |
def __str__(self): | |
return str(self.bits) | |
def describe(self): | |
return self.bits.describe() | |
_set_mxcsr = ct.CFUNCTYPE(None, ct.POINTER(MXCSR))(_set_mxcsr_addr) | |
def set_mxcsr(val: MXCSR): | |
_set_mxcsr(ct.byref(val)) | |
_get_mxcsr = ct.CFUNCTYPE(None, ct.POINTER(MXCSR))(_get_mxcsr_addr) | |
def get_mxcsr() -> MXCSR: | |
mxcsr = MXCSR() | |
_get_mxcsr(ct.byref(mxcsr)) | |
return mxcsr | |
def ensure_clean_fpu_state(function): | |
@functools.wraps(function) | |
def decorator(*args, **kwargs): | |
old_mxcsr = get_mxcsr() | |
set_mxcsr(MXCSR.initial()) | |
try: | |
return function(*args, **kwargs) | |
finally: | |
set_mxcsr(old_mxcsr) | |
return decorator | |
# Small demo. numpy's finfo will yell loudly if the FZ or DAZ bits are set. | |
def decorator_demo(): | |
import numpy as np | |
@ensure_clean_fpu_state | |
def tricky_numerical_operation_safe(): | |
np.finfo(np.float32) | |
def tricky_numerical_operation_unsafe(): | |
np.finfo(np.float32) | |
print(f"MXCSR at power on: {MXCSR.initial()}") | |
print(f"MXCSR now: {get_mxcsr()}") | |
print("Importing gevent, which uses ffast-math...") | |
import gevent | |
print(f"MXCSR after import: {get_mxcsr()}") | |
print("Running np.finfo(np.float32) without FPU wrapper (you should see warnings):") | |
print(COLOR_RED, end='') | |
tricky_numerical_operation_unsafe() | |
print(COLOR_RESET, end='') | |
print("Running np.finfo(np.float32) with FPU wrapper (no warnings):") | |
tricky_numerical_operation_safe() | |
print(COLOR_GREEN+"All done!"+COLOR_RESET) | |
if __name__ == "__main__": | |
decorator_demo() |
I just wanted to say—this is an absolutely brilliant write-up. Thanks so much for all your hard work
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Has a more stable or easier to install solution ever been created? Just today I came across a project that depends entirely on FP math throwing warnings wrt the smallest subnormal being zero, would like to fix that (at least in a fork) with either a proper dependency, or with credits and a link here