Skip to content

Instantly share code, notes, and snippets.

@mentha
Last active November 6, 2022 01:37
Show Gist options
  • Save mentha/9eb6656de2053c4d369b25027098597d to your computer and use it in GitHub Desktop.
Save mentha/9eb6656de2053c4d369b25027098597d to your computer and use it in GitHub Desktop.
hot unplug pci devices on linux
#!/usr/bin/env python3
from argparse import ArgumentParser
from contextlib import suppress
import logging
import os
import re
logger = logging.getLogger()
a = ArgumentParser(description='hot unplug pci device safely')
a.add_argument('device', nargs='+', help='device address')
a.add_argument('--port', '-p', action='store_true', help='remove all devices connected to a port')
a.add_argument('--devid', action='store_true', help='match devices using vendor:device')
a.add_argument('--remove', '-r', action='store_true', help='remove devices from bus')
a.add_argument('--nowait', '-n', action='store_true', help='do not wait for device unplug')
a.add_argument('--verbose', '-v', action='count', default=0, help='show more info')
a.add_argument('--dry', action='store_true', help='do not actually do anything')
a = a.parse_args()
logging.basicConfig(level={
0: logging.WARN,
1: logging.INFO,
2: logging.DEBUG
}[min(a.verbose, 2)])
devs = set()
def parsehex(h):
if h.lower().startswith('0x'):
return int(h, base=0)
return int(h, base=16)
def readhex(p):
with open(p) as f:
return parsehex(f.read().strip())
if a.devid:
logger.debug('finding devices using id')
ids = set()
for d in a.device:
v, p = d.split(':', 1)
v = parsehex(v)
p = parsehex(p)
logger.debug('match device id %04x:%04x', v, p)
ids.add((v, p))
for d in os.listdir('/sys/bus/pci/devices'):
dp = os.path.join('/sys/bus/pci/devices', d)
v = readhex(os.path.join(dp, 'vendor'))
p = readhex(os.path.join(dp, 'device'))
if (v, p) in ids:
logger.debug('matched device at %s', d)
devs.add(d)
else:
logger.debug('finding devices using address')
l = os.listdir('/sys/bus/pci/devices')
for d in a.device:
if d.count(':') < 2:
d = '0000:' + d
if d in l:
devs.add(d)
else:
raise RuntimeError(f'device at {d} not found')
re_pciaddr = re.compile(r'^[0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}.[0-9a-f]$')
def listsubdev(d):
for e in os.listdir(os.path.join('/sys/bus/pci/devices', d)):
if re_pciaddr.match(e):
yield e
if a.port:
logger.debug('finding devices connected to ports')
pdevs = set()
for d in devs:
for sd in listsubdev(d):
logger.debug('found device at %s connected to port %s', sd, d)
pdevs.add(sd)
devs = pdevs
logger.debug('finding downstream devices for %r', devs)
ldevs = set()
downmap = {}
while devs:
nd = set()
for d in devs:
down = set()
for e in listsubdev(d):
down.add(e)
if e not in ldevs and e not in devs:
nd.add(e)
logger.debug('found downstream device at %s', e)
ldevs.add(d)
if down:
downmap[d] = down
devs = nd
logger.debug('unbinding devices')
devs = ldevs.copy()
def unbinddev(d):
if d in downmap:
for dd in downmap[d]:
if dd in devs:
unbinddev(dd)
with suppress(FileNotFoundError):
logger.info('unbinding device at %s', d)
if not a.dry:
with open(os.path.join('/sys/bus/pci/devices', d, 'driver/unbind'), 'w') as f:
f.write(d + '\n')
if a.remove:
with open(os.path.join('/sys/bus/pci/devices', d, 'remove'), 'w') as f:
f.write('1\n')
with suppress(KeyError):
devs.remove(d)
while devs:
d = devs.pop()
unbinddev(d)
if not a.dry and not a.nowait:
from time import sleep
found = True
while found:
logger.info('unplug devices now')
sleep(1)
found = False
for e in os.listdir('/sys/bus/pci/devices'):
if e in ldevs:
found = True
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment