Last active
April 10, 2023 14:22
-
-
Save mentha/7469681fe42e21b5049a726a1acdb868 to your computer and use it in GitHub Desktop.
run libvirt kvm guests with vfio pci passthrough
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from argparse import ArgumentParser | |
from contextlib import suppress | |
from dataclasses import dataclass | |
from fnmatch import fnmatch | |
from functools import cached_property | |
from select import select | |
from signal import SIGHUP, SIGINT, SIGTERM, SIG_IGN, signal | |
from traceback import format_exc | |
from xml.etree import ElementTree as ET | |
import hashlib | |
import libvirt | |
import logging | |
import os | |
import re | |
import subprocess as sp | |
import sys | |
import time | |
PROGRAM = 'virt-createvfio' | |
logger = logging | |
class Xml: | |
def __init__(self, n): | |
self.node = n | |
@classmethod | |
def wrap(cls, n): | |
if n is None: | |
return None | |
return cls(n) | |
@classmethod | |
def fromstring(cls, s): | |
return cls(ET.fromstring(s)) | |
def __getattr__(self, k): | |
return getattr(self.node, k) | |
def __setattr__(self, k, v): | |
if k in {'node'}: | |
return super().__setattr__(k, v) | |
return setattr(self.node, k, v) | |
def __repr__(self): | |
r = '<' + self.tag | |
for k, v in self.attrib.items(): | |
r += ' ' + k + '="' + v + '"' | |
if self.text or len(self.node) > 0: | |
r += '>...<' | |
r += '/>' | |
return r | |
def tostring(self): | |
return ET.tostring(self.node, encoding='unicode') | |
def set(self, k=None, v=None, **ka): | |
if k is None: | |
for n, val in ka.items(): | |
self.node.set(n, val) | |
else: | |
self.node.set(k, v) | |
def find(self, m): | |
return self.wrap(self.node.find(m)) | |
def findall(self, m): | |
return [self.wrap(r) for r in self.node.findall(m)] | |
def remove(self, e): | |
if not isinstance(e, ET.Element): | |
e = e.node | |
self.node.remove(e) | |
def _procattrs(self, a): | |
na = {} | |
for k, v in a.items(): | |
if k.startswith('attr_'): | |
k = k[5:] | |
na[k] = v | |
return na | |
def _subelem(self, tag, clear=False, match=True, **attrs): | |
attrs = self._procattrs(attrs) | |
if '[' in tag: | |
m = re.match(r'(\w+)(\[.*)$', tag) | |
tag, a = m[1], m[2] | |
while a: | |
m = re.match(r'''\[@([-\w]+)(=('[^']*'|"[^"]*"))?\]''', a) | |
a = a[len(m[0]):] | |
k = m[1] | |
v = '' | |
if m[3]: | |
v = m[3][1:-1] | |
attrs[k] = v | |
if not match: | |
return self.wrap(ET.SubElement(self.node, tag, **attrs)) | |
m = tag | |
if len(attrs) > 0: | |
m += '[' + ']['.join(f'@{n}={str(v)!r}' for n, v in attrs.items()) + ']' | |
r = self.find(m) | |
if r is None: | |
r = self.wrap(ET.SubElement(self.node, tag, **attrs)) | |
if clear: | |
r.clear() | |
r.set(**attrs) | |
return r | |
def subelem(self, tag, clear=False, match=True, **attrs): | |
e = self | |
while tag: | |
m = re.match(r'''[-\w]+(\[@[-\w]+(=('[^']*'|"[^"]*"))?\])*''', tag) | |
tag = tag[len(m[0]):].lstrip('/') | |
if tag: | |
e = e._subelem(m[0]) | |
else: | |
return e._subelem(m[0], clear, match, **attrs) | |
def newelem(self, tag, match=True, **newattrs): | |
r = self.subelem(tag, clear=True, match=match) | |
r.set(**newattrs) | |
return r | |
@dataclass(order=True, frozen=True) | |
class PciAddr: | |
domain: int | |
bus: int | |
slot: int | |
function: int | |
def fromstring(self, a): | |
slot, func = a.split('.') | |
dom = None | |
bus = None | |
if slot.count(':') == 2: | |
dom, bus, slot = slot.split(':') | |
else: | |
dom = '0' | |
bus, slot = slot.split(':') | |
super().__setattr__('domain', int(dom, 16)) | |
super().__setattr__('bus', int(bus, 16)) | |
super().__setattr__('slot', int(slot, 16)) | |
super().__setattr__('function', int(func, 16)) | |
def fromxml(self, x): | |
for a in ('domain', 'bus', 'slot', 'function'): | |
super().__setattr__(a, int(x.get(a), 0)) | |
def __init__(self, s=None, **ka): | |
if s is not None: | |
if isinstance(s, (ET.Element, Xml)): | |
self.fromxml(s) | |
else: | |
self.fromstring(s) | |
self.__dict__.update(ka) | |
def toxml(self, x): | |
for a, f in ( | |
('domain', '0x%04x'), | |
('bus', '0x%02x'), | |
('slot', '0x%02x'), | |
('function', '0x%x'), | |
): | |
x.set(a, f % getattr(self, a)) | |
def __repr__(self): | |
return '{domain:04x}:{bus:02x}:{slot:02x}.{function:x}'.format(**self.__dict__) | |
def parsePciAddrs(s): | |
r = [] | |
m = re.match(r'^([0-9a-fA-F]+):([0-9a-fA-F]+)$', s) | |
if m is not None: | |
vendor = int(m[1], 16) | |
device = int(m[2], 16) | |
for a in os.listdir('/sys/bus/pci/devices'): | |
with open(f'/sys/bus/pci/devices/{a}/vendor') as f: | |
if int(f.read(), 0) != vendor: | |
continue | |
with open(f'/sys/bus/pci/devices/{a}/device') as f: | |
if int(f.read(), 0) != device: | |
continue | |
r.append(PciAddr(a)) | |
else: | |
r.append(PciAddr(s)) | |
return r | |
class Opt: | |
def __init__(self, args, conn, cap, dom, xml): | |
self.args = args | |
self.conn = conn | |
self.cap = cap | |
self.dom = dom | |
self.xml = xml | |
class PciPassthruOpt(Opt): | |
@staticmethod | |
def config_parser(a): | |
a = a.add_argument_group('PCI passthrough') | |
a.add_argument('--pci', '-p', nargs='+', metavar='ADDRESS', type=parsePciAddrs, action='extend', default=[], help='passthrough additional pci devices') | |
a.add_argument('--nodisplay', action='store_true', help='remove virtualized display and audio devices') | |
def nodisplay(self): | |
devices = self.xml.subelem('devices') | |
for m in ('graphics', 'video', 'input', 'sound', 'audio'): | |
for e in devices.findall(m): | |
logger.debug('removing display device %r', e) | |
devices.remove(e) | |
def addpci(self, a): | |
d = self.xml.newelem('devices/hostdev', match=False, mode='subsystem', type='pci', managed='yes') | |
a.toxml(d.newelem('source/address')) | |
logger.debug('added host pci dev %r', d) | |
def setpm(self): | |
pm = self.xml.newelem('pm') | |
pm.newelem('suspend-to-mem', enabled='no') | |
pm.newelem('suspend-to-disk', enabled='no') | |
def start(self): | |
if self.args.nodisplay: | |
self.nodisplay() | |
pciaddrs = set() | |
for l in self.args.pci: | |
for a in l: | |
pciaddrs.add(a) | |
for a in sorted(pciaddrs): | |
self.addpci(a) | |
self.setpm() | |
class CpuOpt(Opt): | |
@staticmethod | |
def config_parser(a): | |
a = a.add_argument_group('CPU Options') | |
a.add_argument('--cpus', type=int, help='override guest cpu count') | |
a.add_argument('--pincpu', action='store_true', help='pin guest cpus') | |
a.add_argument('--pinrest', type=int, metavar='CPUCOUNT', nargs='?', const=0, help='pin rest threads, 1 cpu per iothread plus 1 for emulator by default') | |
a.add_argument('--ignore-queues', action='store_true', help='do not adjust disk queue size') | |
@cached_property | |
def cpucount(self): | |
return self.args.cpus or int(self.xml.find('vcpu').text) | |
@cached_property | |
def iothreadcount(self): | |
e = self.xml.find('iothreads') | |
if e is None: | |
return 0 | |
return int(e.text) | |
@cached_property | |
def cputhreadcount(self): | |
return int(self.cap.find('host/cpu/topology').get('threads', 1)) | |
@cached_property | |
def cpucorecount(self): | |
return self.cpucount // self.cputhreadcount | |
def resetcpu(self): | |
c = self.cpucount | |
for m in ('vcpu', 'vcpus', 'cputune'): | |
for e in self.xml.findall(m): | |
self.xml.remove(e) | |
self.xml.newelem('vcpu', placement='static').text = str(c) | |
cpu = self.xml.subelem('cpu') | |
cpu.set(mode='host-passthrough', check='partial', migratable='off') | |
cpu.newelem('topology', sockets='1', dies='1', cores=str(self.cpucorecount), threads=str(self.cputhreadcount)) | |
logger.debug('guest has %d cores %d threads', self.cpucorecount, self.cputhreadcount) | |
if self.cap.find('host/cpu/feature[@name="topoext"]') is not None: | |
logger.debug('enabling topoext on amd') | |
cpu.newelem('feature[@name="topoext"]', policy='require') | |
def pincpu(self, pincpu, pinrest): | |
cores = [] | |
for e in self.cap.findall(f'host/topology/cells/cell[@id="{self.args.cell}"]/cpus/cpu[@socket_id="0"][@die_id="0"]'): | |
core = int(e.get('core_id')) | |
if len(cores) <= core: | |
cores += [[]] * (core - len(cores) + 1) | |
cores[core].append(int(e.get('id'))) | |
if pincpu: | |
for core in range(self.cpucorecount): | |
hcore = cores.pop() | |
for thr in range(self.cputhreadcount): | |
cpu = core * self.cputhreadcount + thr | |
c = hcore.pop() | |
self.xml.newelem(f'cputune/vcpupin[@vcpu="{cpu}"]', cpuset=str(c)) | |
logger.debug('guest cpu %d pinned to host cpu %d', cpu, c) | |
if pinrest: | |
rest = cores.pop() | |
while len(rest) < pinrest: | |
rest.extend(cores.pop()) | |
rest = ','.join(str(x) for x in sorted(rest[:pinrest])) | |
for i in range(self.iothreadcount): | |
self.xml.newelem(f'cputune/iothreadpin[@iothread="{i + 1}"]', cpuset=rest) | |
self.xml.newelem('cputune/emulatorpin', cpuset=rest) | |
logger.debug('other threads pinned to host %s', rest) | |
def adjust_queue(self): | |
for drv in self.xml.findall('devices/controller[@type="scsi"]/driver[@iothread]') + self.xml.findall('devices/disk[@device="disk"]/driver[@iothread]'): | |
drv.set('queues', str(self.cpucount)) | |
def start(self): | |
self.resetcpu() | |
pinrest = self.args.pinrest | |
if pinrest is not None: | |
if pinrest == 0: | |
pinrest = self.iothreadcount + 1 | |
self.pincpu(self.args.pincpu, pinrest) | |
if not self.args.ignore_queues: | |
self.adjust_queue() | |
class MemOpt(Opt): | |
HP_DISABLED = 0x1 | |
HP_FORCE = 0x2 | |
HP_HARDER = 0x4 | |
hugepages_map = { | |
'disabled': HP_DISABLED, | |
'try': 0, | |
'tryharder': HP_HARDER, | |
'force': HP_HARDER | HP_FORCE, | |
} | |
@classmethod | |
def config_parser(cls, a): | |
a.add_argument('--mem', type=int, help='override guest memory size in GiB') | |
a.add_argument('--hugepages', default='force', choices=cls.hugepages_map.keys(), help='hugepages allocation mode, force by default') | |
def __init__(self, *a): | |
super().__init__(*a) | |
self.allocated = None | |
@property | |
def hugepages(self): | |
return self.hugepages_map[self.args.hugepages] | |
@staticmethod | |
def unit2k(u): | |
return { | |
'kib': 1, | |
'mib': 1024, | |
'gib': 1024**2, | |
}[u.lower()] | |
@cached_property | |
def mem_k(self): | |
if self.args.mem: | |
return self.args.mem * self.unit2k('GiB') | |
m = self.xml.find('memory') | |
return self.unit2k(m.get('unit', 'KiB')) * int(m.text) | |
def resetmem(self): | |
for m in ('currentMemory', 'maxMemory'): | |
for e in self.xml.findall(m): | |
self.xml.remove(e) | |
memk = self.mem_k | |
self.xml.newelem('memory', unit='KiB').text = str(memk) | |
self.xml.newelem('memoryBacking/nosharepages') | |
self.xml.newelem('memoryBacking/locked') | |
logger.debug('guest has %f MiB memory', memk / 1024) | |
def gethugepages(self, s): | |
return self.conn.getFreePages([s], self.args.cell, 1)[self.args.cell][s] | |
def sethugepages(self, s, n): | |
with suppress(libvirt.libvirtError): | |
return self.conn.allocPages({s: n}, self.args.cell, 1, libvirt.VIR_NODE_ALLOC_PAGES_SET) | |
def alloc_hugepage(self, ps, firstonly=False): | |
for s in ps: | |
if self.mem_k % s != 0: | |
continue | |
n = self.mem_k // s | |
logger.debug('trying to allocate %d pages of %d KiB each', n, s) | |
pre = self.gethugepages(s) | |
target = pre + n | |
self.sethugepages(s, target) | |
post = self.gethugepages(s) | |
if post == target: | |
self.allocated = (s, n) | |
return True | |
self.sethugepages(s, pre) | |
if firstonly: | |
break | |
return False | |
def prepare_hugepage(self, mode): | |
if mode & self.HP_DISABLED: | |
return | |
hpsz = [] | |
for e in self.cap.findall(f'host/topology/cells/cell[@id="{self.args.cell}"]/pages'): | |
s = self.unit2k(e.get('unit', 'KiB')) * int(e.get('size')) | |
hpsz.append(s) | |
logger.debug('host support pagesize %d KiB', s) | |
hpsz = sorted(hpsz, reverse=True)[:-1] | |
if self.alloc_hugepage(hpsz, (mode & self.HP_HARDER) != 0): | |
return | |
if mode & self.HP_HARDER: | |
sp.run(['sync']) | |
sp.run(['sysctl', 'vm.drop_caches=3', 'vm.compact_memory=1']) | |
if self.alloc_hugepage(hpsz): | |
return | |
if mode & self.HP_FORCE: | |
raise RuntimeError('cannot allocate hugepage') | |
logger.info('cannot allocate hugepage') | |
def enable_hugepage(self): | |
s, _ = self.allocated | |
self.xml.newelem('memoryBacking/hugepages').newelem('page', size=str(s), unit='KiB') | |
def start(self): | |
self.resetmem() | |
logger.debug('preparing hugepages') | |
if not self.args.dry_run: | |
self.prepare_hugepage(self.hugepages) | |
if self.allocated is not None: | |
self.enable_hugepage() | |
def stop(self): | |
logger.debug('stopping hugepages') | |
if self.allocated is not None: | |
s, n = self.allocated | |
pre = self.gethugepages(s) | |
self.sethugepages(s, max(0, pre - n)) | |
class IoSchedOpt(Opt): | |
@staticmethod | |
def config_parser(a): | |
a.add_argument('--iosched', default='*deadline', help='set host io scheduler, *deadline by default') | |
def __init__(self, *a): | |
super().__init__(*a) | |
self.saved = {} | |
def start(self): | |
for blk in os.listdir('/sys/block'): | |
spath = f'/sys/block/{blk}/queue/scheduler' | |
if not os.path.exists(spath): | |
continue | |
old = None | |
news = None | |
with open(spath, 'r') as f: | |
for s in f.read().strip().split(): | |
if s[0] == '[' and s[-1] == ']': | |
old = s[1:-1] | |
s = old | |
if fnmatch(s, self.args.iosched): | |
news = s | |
if old is None or news is None or old == news: | |
continue | |
self.saved[blk] = old | |
if not self.args.dry_run: | |
with open(spath, 'w') as f: | |
f.write(news + '\n') | |
logger.debug('set io scheduler of %s: %s => %s', blk, old, news) | |
def stop(self): | |
for blk, s in self.saved.items(): | |
try: | |
if not self.args.dry_run: | |
with open(f'/sys/block/{blk}/queue/scheduler', 'w') as f: | |
f.write(s + '\n') | |
logger.debug('restore io scheduler of %s: => %s', blk, s) | |
except OSError as e: | |
logger.error('error restoring io scheduler of %s to %s: %r', blk, s, e) | |
class DRand: | |
def __init__(self, seed): | |
if isinstance(seed, str): | |
seed = seed.encode('utf8') | |
self.h = hashlib.sha512(seed) | |
def randbytes(self, n): | |
r = b'' | |
while len(r) < n: | |
d = self.h.digest() | |
r += d | |
self.h.update(d) | |
return r[:n] | |
def randint(self, l, h=None): | |
if h is None: | |
h = l | |
l = 0 | |
s = h - l + 1 | |
r = int.from_bytes(self.randbytes(s.bit_length() + 7 // 8)) | |
return r % s + l | |
def choice(self, s): | |
return s[self.randint(len(s) - 1)] | |
def randword(self, digits=False, minlen=4, maxlen=10): | |
a = 'abcdefghijklmnopqrstuvwxyz' | |
if digits: | |
a += '0123456789' | |
return ''.join(self.choice(a) for _ in range(self.randint(minlen, maxlen))) | |
class GuestTuneOpt(Opt): | |
@staticmethod | |
def config_parser(a): | |
a = a.add_argument_group('Guest tuning') | |
a.add_argument('--nohyperv', action='store_true', help='do not add hyper-v enlightments') | |
a.add_argument('--noperf', action='store_true', help='do not add performance tuning') | |
a.add_argument('--hidevm', action='store_true', help='try to hide hypervisor info') | |
a.add_argument('--sysinfo', action='store_true', help='generate fake sysinfo') | |
def start(self): | |
if not self.args.nohyperv: | |
self.tune_hyperv() | |
if not self.args.noperf: | |
self.tune_perf() | |
if self.args.hidevm: | |
self.tune_hidevm() | |
if self.args.sysinfo: | |
self.tune_sysinfo() | |
def tune_hyperv(self): | |
hyperv = self.xml.newelem('features/hyperv', mode='passthrough') | |
for f in ('relaxed', 'vapic', 'vpindex', 'runtime', 'synic', 'reset', 'frequencies', 'reenlightenment', 'tlbflush', 'ipi'): | |
hyperv.newelem(f, state='on') | |
hyperv.newelem('spinlocks', state='on', retries='8191') | |
hyperv.newelem('stimer', state='on'). \ | |
newelem('direct', state='on') | |
self.xml.newelem('clock/timer[@name="hypervclock"]', present='yes') | |
def tune_perf(self): | |
self.xml.newelem('clock/timer[@name="rtc"]', present='yes', tickpolicy='catchup') | |
self.xml.newelem('clock/timer[@name="pit"]', present='no') | |
self.xml.newelem('clock/timer[@name="hpet"]', present='no') | |
self.xml.newelem('clock/timer[@name="tsc"]', present='yes', mode='native') | |
self.xml.newelem('cpu/feature[@name="invtsc"]', policy='require') | |
self.xml.newelem('clock/timer[@name="kvmclock"]', present='yes') | |
self.xml.newelem('features/msrs', unknown='ignore') | |
self.xml.newelem('features/pmu', state='off') | |
devices = self.xml.subelem('devices') | |
devices.newelem('memballoon', model='none') | |
if devices.find('graphics[@type="spice"]') is None: | |
for m in ('redirdev[@type="spicevmc"]', | |
'smartcard[@type="spicevmc"]', | |
'channel[@type="spicevmc"]', | |
'audio[@type="spice"]'): | |
for e in devices.findall(m): | |
devices.remove(e) | |
def tune_hidevm(self): | |
self.xml.newelem('cpu/feature[@name="hypervisor"]', policy='disable') | |
self.xml.newelem('features/kvm/hidden', state='on') | |
if not self.args.nohyperv: | |
self.xml.newelem('features/hyperv/vendor_id', state='on', value='libvirt') | |
def tune_sysinfo(self): | |
g = DRand(self.xml.find('uuid').text) | |
vendor = g.randword().capitalize() | |
product = g.randword(True).upper() | |
serial = g.randword(True, 10, 20).upper() | |
version = f'{g.randint(1, 9)}{g.randword(True, 3, 3)}{g.randint(0, 9)}' | |
date = f'{g.randint(1, 12)}/{g.randint(1, 28)}/{g.randint(2010, 2020)}' | |
self.xml.newelem('os/smbios', mode='sysinfo') | |
si = self.xml.newelem('sysinfo', type='smbios') | |
e = si.newelem('bios') | |
e.newelem('entry[@name="vendor"]').text = vendor | |
e.newelem('entry[@name="version"]').text = version | |
e.newelem('entry[@name="date"]').text = date | |
e = si.newelem('system') | |
e.newelem('entry[@name="manufacturer"]').text = vendor | |
e.newelem('entry[@name="product"]').text = product | |
e.newelem('entry[@name="version"]').text = version | |
e.newelem('entry[@name="serial"]').text = serial | |
e = si.newelem('baseBoard') | |
e.newelem('entry[@name="manufacturer"]').text = vendor | |
e.newelem('entry[@name="product"]').text = product | |
e.newelem('entry[@name="version"]').text = version | |
e.newelem('entry[@name="serial"]').text = serial | |
e = si.newelem('chassis') | |
e.newelem('entry[@name="manufacturer"]').text = vendor | |
e.newelem('entry[@name="version"]').text = version | |
e.newelem('entry[@name="serial"]').text = serial | |
class ConflictingDomainOpt(Opt): | |
C_SHUTDOWN = 0x1 | |
C_RESTART = 0x2 | |
onconflict_map = { | |
'terminate': 0, | |
'terminate-restart': C_RESTART, | |
'shutdown': C_SHUTDOWN, | |
'shutdown-restart': C_SHUTDOWN | C_RESTART, | |
} | |
@classmethod | |
def config_parser(cls, a): | |
a.add_argument('--onconflict', default='shutdown-restart', | |
choices=cls.onconflict_map.keys(), help='configure ways to resolve resource conflict, shutdown-restart by default') | |
a.add_argument('--onconflict-shutdown-timeout', type=int, default=60) | |
a.add_argument('--onconflict-start-interval', type=int, default=10) | |
def __init__(self, *a): | |
super().__init__(*a) | |
self.saved = [] | |
@property | |
def onconflict(self): | |
return self.onconflict_map[self.args.onconflict] | |
def findconflict(self): | |
usedpci = set() | |
for e in self.xml.findall('devices/hostdev[@type="pci"]/source/address'): | |
usedpci.add(PciAddr(e)) | |
logger.debug(f'guest use pci devices {sorted(usedpci)}') | |
for d in self.conn.listAllDomains(): | |
if not d.isActive(): | |
continue | |
logger.debug(f'found active domain {d.name()}') | |
dx = Xml.fromstring(d.XMLDesc()) | |
for e in dx.findall('devices/hostdev[@type="pci"]/source/address'): | |
a = PciAddr(e) | |
logger.debug(f'active domain {d.name()} use hostdev {a}') | |
if a in usedpci: | |
logger.debug('found conflicting domain %s on %r', d.name(), a) | |
yield d | |
break | |
def start(self): | |
for d in self.findconflict(): | |
self.saved.append(d) | |
if self.onconflict & self.C_SHUTDOWN: | |
logger.info('shutting down conflicting domain %s', d.name()) | |
if not self.args.dry_run: | |
with suppress(libvirt.libvirtError): | |
d.shutdown() | |
ddl = self.args.onconflict_shutdown_timeout + time.time() | |
while d.isActive() and time.time() < ddl: | |
time.sleep(1) | |
if d.isActive(): | |
logger.warning('destroying conflicting domain %s', d.name()) | |
if not self.args.dry_run: | |
with suppress(libvirt.libvirtError): | |
d.destroy() | |
def stop(self): | |
if not self.onconflict & self.C_RESTART: | |
return | |
for d in self.saved: | |
logger.info('starting stopped conflicting domain %s', d.name()) | |
if not self.args.dry_run: | |
d.create() | |
time.sleep(self.args.onconflict_start_interval) | |
class UtilityDomainOpt(Opt): | |
@classmethod | |
def config_parser(cls, a): | |
a.add_argument('--start-domain', action='append', default=[], help='start other kvm domains') | |
def __init__(self, *a): | |
super().__init__(*a) | |
self.saved = [] | |
def start(self): | |
for dn in self.args.start_domain: | |
onexit = 'ignore' | |
if ':' in dn: | |
dn, onexit = dn.split(':', 1) | |
if onexit not in {'ignore', 'shutdown', 'destroy'}: | |
raise RuntimeError(f'invalid destroy method {onexit} for domain {dn}') | |
d = self.conn.lookupByName(dn) | |
self.saved.append((d, onexit)) | |
logger.info('starting domain %s', dn) | |
if not self.args.dry_run: | |
d.create() | |
def stop(self): | |
for d, onexit in self.saved: | |
dn = d.name() | |
logger.info('destroy domain %s by %s', dn, onexit) | |
if self.args.dry_run or onexit == 'ignore': | |
continue | |
elif onexit == 'shutdown': | |
d.shutdown() | |
else: | |
d.destroy() | |
class RealtimeOpt(Opt): | |
@staticmethod | |
def config_parser(a): | |
a.add_argument('--realtime', action='store_true', help='grant realtime priority to the guest') | |
def start(self): | |
if not self.args.realtime: | |
return | |
vcpu = int(self.xml.find('vcpu').text) | |
iothr = self.xml.find('iothreads') | |
if iothr is not None: | |
iothr = int(iothr.text) | |
else: | |
iothr = 0 | |
for i in range(vcpu): | |
self.xml.newelem(f'cputune/vcpusched[@vcpus="{i}"]', scheduler='fifo', priority='1') | |
for i in range(iothr): | |
self.xml.newelem(f'cputune/iothreadsched[@iothreads="{i + 1}"]', scheduler='fifo', priority='1') | |
self.xml.newelem('cputune/emulatorsched', scheduler='fifo', priority='1') | |
class CpuFreqOpt(Opt): | |
@staticmethod | |
def config_parser(a): | |
a.add_argument('--cpufreq', default='ondemand', help='set host cpu frequency governor, ondemand by default') | |
a.add_argument('--noboost', action='store_true', help='disable host cpu frequency boost') | |
def __init__(self, *a): | |
super().__init__(*a) | |
self.savedfreq = {} | |
self.savedboost = None | |
def start_cpufreq(self): | |
for p in os.listdir('/sys/devices/system/cpu/cpufreq'): | |
d = '/sys/devices/system/cpu/cpufreq/' + p + '/' | |
if not (os.path.exists(d + 'scaling_available_governors') and os.path.exists(d + 'scaling_governor')): | |
continue | |
newg = None | |
with open(d + 'scaling_available_governors') as f: | |
for g in f.read().strip().split(): | |
if g == self.args.cpufreq: | |
newg = g | |
break | |
if newg is None: | |
continue | |
old = None | |
with open(d + 'scaling_governor') as f: | |
old = f.read().strip() | |
self.savedfreq[d] = old | |
if not self.args.dry_run: | |
with open(d + 'scaling_governor', 'w') as f: | |
f.write(newg + '\n') | |
def stop_cpufreq(self): | |
for d, g in self.savedfreq.items(): | |
try: | |
if not self.args.dry_run: | |
with open(d + 'scaling_governor', 'w') as f: | |
f.write(g + '\n') | |
except OSError as e: | |
logger.error('error restoring cpu governor of %s to %s: %r', d, g, e) | |
boost_path = '/sys/devices/system/cpu/cpufreq/boost' | |
def start_noboost(self): | |
try: | |
with open(self.boost_path) as f: | |
self.savedboost = int(f.read().strip()) | |
except FileNotFoundError: | |
return | |
with open(self.boost_path, 'w') as f: | |
f.write('0\n') | |
logger.debug('cpufreq boost disabled') | |
def stop_noboost(self): | |
if self.savedboost is None: | |
return | |
try: | |
with open(self.boost_path, 'w') as f: | |
f.write(f'{self.savedboost}\n') | |
except OSError as e: | |
logger.error('error restoring cpufreq boost state: %r', e) | |
def start(self): | |
self.start_cpufreq() | |
if self.args.noboost: | |
self.start_noboost() | |
def stop(self): | |
self.stop_cpufreq() | |
self.stop_noboost() | |
class CreateVfioSvc: | |
OPTS = [ | |
PciPassthruOpt, | |
CpuOpt, | |
IoSchedOpt, | |
GuestTuneOpt, | |
ConflictingDomainOpt, | |
MemOpt, | |
UtilityDomainOpt, | |
RealtimeOpt, | |
CpuFreqOpt, | |
] | |
@classmethod | |
def config_parser(cls, a): | |
for o in cls.OPTS: | |
o.config_parser(a) | |
def __init__(self, a): | |
self.args = a | |
self.conn = None | |
self.dom = None | |
self.activeopts = [] | |
def start(self): | |
if self.args.dry_run: | |
self.conn = libvirt.openReadOnly('qemu:///system') | |
else: | |
self.conn = libvirt.open('qemu:///system') | |
cap = Xml.fromstring(self.conn.getCapabilities()) | |
self.dom = self.conn.lookupByName(self.args.domain) | |
xml = Xml.fromstring(self.dom.XMLDesc()) | |
for o in self.OPTS: | |
o = o(self.args, self.conn, cap, self.dom, xml) | |
self.activeopts.append(o) | |
o.start() | |
xmltext = xml.tostring() | |
logger.info('starting domain with optimized xml') | |
if self.args.dry_run: | |
sp.run('xmllint -format - || cat', shell=True, universal_newlines=True, input=xmltext) | |
else: | |
self.conn.createXML(xmltext) | |
def wait(self): | |
logger.info('waiting until guest domain power off') | |
if self.args.dry_run: | |
return | |
while self.dom.isActive(): | |
time.sleep(1) | |
def stop(self): | |
with suppress(libvirt.libvirtError): | |
if not self.args.dry_run and self.dom is not None: | |
self.dom.destroy() | |
while len(self.activeopts) > 0: | |
o = self.activeopts.pop() | |
try: | |
if hasattr(o, 'stop'): | |
o.stop() | |
except: | |
logger.error('error stopping %r: %s', o, format_exc()) | |
def setdevnull(self): | |
with open('/dev/null', 'rb') as f: | |
os.dup2(f.fileno(), sys.stdin.fileno()) | |
with open('/dev/null', 'wb') as f: | |
os.dup2(f.fileno(), sys.stdout.fileno()) | |
os.dup2(f.fileno(), sys.stderr.fileno()) | |
def main(self): | |
try: | |
self.start() | |
if not self.args.dry_run: | |
os.close(int(os.environ['startfd'])) | |
if self.args.background: | |
self.setdevnull() | |
self.wait() | |
finally: | |
self.stop() | |
def main_svc(a): | |
CreateVfioSvc(a).main() | |
def main_cmd(a): | |
if a.dry_run: | |
main_svc(a) | |
return | |
if os.getuid() != 0: | |
logger.debug('using sudo to run as root') | |
os.execvp('sudo', ['sudo', sys.executable] + sys.argv) | |
if sp.run(['systemctl', '--quiet', 'is-active', f'{PROGRAM}.service']).returncode == 0: | |
raise RuntimeError(f'already running as {PROGRAM}.service') | |
sp.run(['systemctl', '--quiet', '--wait', 'reset-failed', f'{PROGRAM}.service']) | |
sp.run(['systemctl', '--quiet', '--wait', 'clean', f'{PROGRAM}.service']) | |
piper, pipew = os.pipe() | |
cmd = ['systemd-run', f'--unit={PROGRAM}.service', '--pipe', '--wait', '--', | |
'systemd-inhibit', f'--what={a.inhibit}', f'--who={PROGRAM}.service', f'--why=vfio guest {a.domain} running', '--mode=block', '--', | |
'env', f'startfd={pipew}', | |
sys.executable, sys.argv[0], '--action=svc'] + sys.argv[1:] | |
logger.debug('spawning start process with %r', cmd) | |
p = sp.Popen(cmd, stdin=sp.DEVNULL, pass_fds=(pipew,)) | |
os.close(pipew) | |
select([piper], [], []) | |
os.close(piper) | |
if a.background: | |
p.terminate() | |
return | |
p.wait() | |
def handlesig(*_): | |
sys.exit(0) | |
def main(): | |
signal(SIGHUP, SIG_IGN) | |
signal(SIGINT, handlesig) | |
signal(SIGTERM, handlesig) | |
args = sys.argv[1:] | |
action = 'cmd' | |
if len(args) > 0 and args[0].startswith('--action='): | |
action = args.pop(0).split('=', 1)[1] | |
a = ArgumentParser(description='Create libvirt KVM guest with vfio pci passthrough and optimizations') | |
a.add_argument('--background', '-b', action='store_true', help='background after guest started') | |
a.add_argument('--dry-run', '-n', action='store_true', help='show actions that would be taken and exit') | |
a.add_argument('--verbose', '-v', action='count', default=0, help='be verbose') | |
a.add_argument('--inhibit', default='sleep:idle', help='inhibition lock, sleep:idle by default') | |
a.add_argument('--cell', type=int, default=0, help='run guest in specified libvirt cell') | |
CreateVfioSvc.config_parser(a) | |
a.add_argument('domain', help='KVM guest domain on qemu:///system') | |
a = a.parse_args(args) | |
logging.basicConfig(level={ | |
0: logging.WARNING, | |
1: logging.INFO, | |
2: logging.DEBUG | |
}[min(a.verbose, 2)]) | |
logger.debug('action=%s', action) | |
logger.debug('arguments=%r', a) | |
{ | |
'cmd': main_cmd, | |
'svc': main_svc, | |
}[action](a) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment