Skip to content

Instantly share code, notes, and snippets.

@mentha
Last active April 10, 2023 14:22
Show Gist options
  • Save mentha/7469681fe42e21b5049a726a1acdb868 to your computer and use it in GitHub Desktop.
Save mentha/7469681fe42e21b5049a726a1acdb868 to your computer and use it in GitHub Desktop.
run libvirt kvm guests with vfio pci passthrough
#!/usr/bin/env python3
from argparse import ArgumentParser
from contextlib import suppress
from dataclasses import dataclass
from fnmatch import fnmatch
from functools import cached_property
from select import select
from signal import SIGHUP, SIGINT, SIGTERM, SIG_IGN, signal
from traceback import format_exc
from xml.etree import ElementTree as ET
import hashlib
import libvirt
import logging
import os
import re
import subprocess as sp
import sys
import time
PROGRAM = 'virt-createvfio'
logger = logging
class Xml:
def __init__(self, n):
self.node = n
@classmethod
def wrap(cls, n):
if n is None:
return None
return cls(n)
@classmethod
def fromstring(cls, s):
return cls(ET.fromstring(s))
def __getattr__(self, k):
return getattr(self.node, k)
def __setattr__(self, k, v):
if k in {'node'}:
return super().__setattr__(k, v)
return setattr(self.node, k, v)
def __repr__(self):
r = '<' + self.tag
for k, v in self.attrib.items():
r += ' ' + k + '="' + v + '"'
if self.text or len(self.node) > 0:
r += '>...<'
r += '/>'
return r
def tostring(self):
return ET.tostring(self.node, encoding='unicode')
def set(self, k=None, v=None, **ka):
if k is None:
for n, val in ka.items():
self.node.set(n, val)
else:
self.node.set(k, v)
def find(self, m):
return self.wrap(self.node.find(m))
def findall(self, m):
return [self.wrap(r) for r in self.node.findall(m)]
def remove(self, e):
if not isinstance(e, ET.Element):
e = e.node
self.node.remove(e)
def _procattrs(self, a):
na = {}
for k, v in a.items():
if k.startswith('attr_'):
k = k[5:]
na[k] = v
return na
def _subelem(self, tag, clear=False, match=True, **attrs):
attrs = self._procattrs(attrs)
if '[' in tag:
m = re.match(r'(\w+)(\[.*)$', tag)
tag, a = m[1], m[2]
while a:
m = re.match(r'''\[@([-\w]+)(=('[^']*'|"[^"]*"))?\]''', a)
a = a[len(m[0]):]
k = m[1]
v = ''
if m[3]:
v = m[3][1:-1]
attrs[k] = v
if not match:
return self.wrap(ET.SubElement(self.node, tag, **attrs))
m = tag
if len(attrs) > 0:
m += '[' + ']['.join(f'@{n}={str(v)!r}' for n, v in attrs.items()) + ']'
r = self.find(m)
if r is None:
r = self.wrap(ET.SubElement(self.node, tag, **attrs))
if clear:
r.clear()
r.set(**attrs)
return r
def subelem(self, tag, clear=False, match=True, **attrs):
e = self
while tag:
m = re.match(r'''[-\w]+(\[@[-\w]+(=('[^']*'|"[^"]*"))?\])*''', tag)
tag = tag[len(m[0]):].lstrip('/')
if tag:
e = e._subelem(m[0])
else:
return e._subelem(m[0], clear, match, **attrs)
def newelem(self, tag, match=True, **newattrs):
r = self.subelem(tag, clear=True, match=match)
r.set(**newattrs)
return r
@dataclass(order=True, frozen=True)
class PciAddr:
domain: int
bus: int
slot: int
function: int
def fromstring(self, a):
slot, func = a.split('.')
dom = None
bus = None
if slot.count(':') == 2:
dom, bus, slot = slot.split(':')
else:
dom = '0'
bus, slot = slot.split(':')
super().__setattr__('domain', int(dom, 16))
super().__setattr__('bus', int(bus, 16))
super().__setattr__('slot', int(slot, 16))
super().__setattr__('function', int(func, 16))
def fromxml(self, x):
for a in ('domain', 'bus', 'slot', 'function'):
super().__setattr__(a, int(x.get(a), 0))
def __init__(self, s=None, **ka):
if s is not None:
if isinstance(s, (ET.Element, Xml)):
self.fromxml(s)
else:
self.fromstring(s)
self.__dict__.update(ka)
def toxml(self, x):
for a, f in (
('domain', '0x%04x'),
('bus', '0x%02x'),
('slot', '0x%02x'),
('function', '0x%x'),
):
x.set(a, f % getattr(self, a))
def __repr__(self):
return '{domain:04x}:{bus:02x}:{slot:02x}.{function:x}'.format(**self.__dict__)
def parsePciAddrs(s):
r = []
m = re.match(r'^([0-9a-fA-F]+):([0-9a-fA-F]+)$', s)
if m is not None:
vendor = int(m[1], 16)
device = int(m[2], 16)
for a in os.listdir('/sys/bus/pci/devices'):
with open(f'/sys/bus/pci/devices/{a}/vendor') as f:
if int(f.read(), 0) != vendor:
continue
with open(f'/sys/bus/pci/devices/{a}/device') as f:
if int(f.read(), 0) != device:
continue
r.append(PciAddr(a))
else:
r.append(PciAddr(s))
return r
class Opt:
def __init__(self, args, conn, cap, dom, xml):
self.args = args
self.conn = conn
self.cap = cap
self.dom = dom
self.xml = xml
class PciPassthruOpt(Opt):
@staticmethod
def config_parser(a):
a = a.add_argument_group('PCI passthrough')
a.add_argument('--pci', '-p', nargs='+', metavar='ADDRESS', type=parsePciAddrs, action='extend', default=[], help='passthrough additional pci devices')
a.add_argument('--nodisplay', action='store_true', help='remove virtualized display and audio devices')
def nodisplay(self):
devices = self.xml.subelem('devices')
for m in ('graphics', 'video', 'input', 'sound', 'audio'):
for e in devices.findall(m):
logger.debug('removing display device %r', e)
devices.remove(e)
def addpci(self, a):
d = self.xml.newelem('devices/hostdev', match=False, mode='subsystem', type='pci', managed='yes')
a.toxml(d.newelem('source/address'))
logger.debug('added host pci dev %r', d)
def setpm(self):
pm = self.xml.newelem('pm')
pm.newelem('suspend-to-mem', enabled='no')
pm.newelem('suspend-to-disk', enabled='no')
def start(self):
if self.args.nodisplay:
self.nodisplay()
pciaddrs = set()
for l in self.args.pci:
for a in l:
pciaddrs.add(a)
for a in sorted(pciaddrs):
self.addpci(a)
self.setpm()
class CpuOpt(Opt):
@staticmethod
def config_parser(a):
a = a.add_argument_group('CPU Options')
a.add_argument('--cpus', type=int, help='override guest cpu count')
a.add_argument('--pincpu', action='store_true', help='pin guest cpus')
a.add_argument('--pinrest', type=int, metavar='CPUCOUNT', nargs='?', const=0, help='pin rest threads, 1 cpu per iothread plus 1 for emulator by default')
a.add_argument('--ignore-queues', action='store_true', help='do not adjust disk queue size')
@cached_property
def cpucount(self):
return self.args.cpus or int(self.xml.find('vcpu').text)
@cached_property
def iothreadcount(self):
e = self.xml.find('iothreads')
if e is None:
return 0
return int(e.text)
@cached_property
def cputhreadcount(self):
return int(self.cap.find('host/cpu/topology').get('threads', 1))
@cached_property
def cpucorecount(self):
return self.cpucount // self.cputhreadcount
def resetcpu(self):
c = self.cpucount
for m in ('vcpu', 'vcpus', 'cputune'):
for e in self.xml.findall(m):
self.xml.remove(e)
self.xml.newelem('vcpu', placement='static').text = str(c)
cpu = self.xml.subelem('cpu')
cpu.set(mode='host-passthrough', check='partial', migratable='off')
cpu.newelem('topology', sockets='1', dies='1', cores=str(self.cpucorecount), threads=str(self.cputhreadcount))
logger.debug('guest has %d cores %d threads', self.cpucorecount, self.cputhreadcount)
if self.cap.find('host/cpu/feature[@name="topoext"]') is not None:
logger.debug('enabling topoext on amd')
cpu.newelem('feature[@name="topoext"]', policy='require')
def pincpu(self, pincpu, pinrest):
cores = []
for e in self.cap.findall(f'host/topology/cells/cell[@id="{self.args.cell}"]/cpus/cpu[@socket_id="0"][@die_id="0"]'):
core = int(e.get('core_id'))
if len(cores) <= core:
cores += [[]] * (core - len(cores) + 1)
cores[core].append(int(e.get('id')))
if pincpu:
for core in range(self.cpucorecount):
hcore = cores.pop()
for thr in range(self.cputhreadcount):
cpu = core * self.cputhreadcount + thr
c = hcore.pop()
self.xml.newelem(f'cputune/vcpupin[@vcpu="{cpu}"]', cpuset=str(c))
logger.debug('guest cpu %d pinned to host cpu %d', cpu, c)
if pinrest:
rest = cores.pop()
while len(rest) < pinrest:
rest.extend(cores.pop())
rest = ','.join(str(x) for x in sorted(rest[:pinrest]))
for i in range(self.iothreadcount):
self.xml.newelem(f'cputune/iothreadpin[@iothread="{i + 1}"]', cpuset=rest)
self.xml.newelem('cputune/emulatorpin', cpuset=rest)
logger.debug('other threads pinned to host %s', rest)
def adjust_queue(self):
for drv in self.xml.findall('devices/controller[@type="scsi"]/driver[@iothread]') + self.xml.findall('devices/disk[@device="disk"]/driver[@iothread]'):
drv.set('queues', str(self.cpucount))
def start(self):
self.resetcpu()
pinrest = self.args.pinrest
if pinrest is not None:
if pinrest == 0:
pinrest = self.iothreadcount + 1
self.pincpu(self.args.pincpu, pinrest)
if not self.args.ignore_queues:
self.adjust_queue()
class MemOpt(Opt):
HP_DISABLED = 0x1
HP_FORCE = 0x2
HP_HARDER = 0x4
hugepages_map = {
'disabled': HP_DISABLED,
'try': 0,
'tryharder': HP_HARDER,
'force': HP_HARDER | HP_FORCE,
}
@classmethod
def config_parser(cls, a):
a.add_argument('--mem', type=int, help='override guest memory size in GiB')
a.add_argument('--hugepages', default='force', choices=cls.hugepages_map.keys(), help='hugepages allocation mode, force by default')
def __init__(self, *a):
super().__init__(*a)
self.allocated = None
@property
def hugepages(self):
return self.hugepages_map[self.args.hugepages]
@staticmethod
def unit2k(u):
return {
'kib': 1,
'mib': 1024,
'gib': 1024**2,
}[u.lower()]
@cached_property
def mem_k(self):
if self.args.mem:
return self.args.mem * self.unit2k('GiB')
m = self.xml.find('memory')
return self.unit2k(m.get('unit', 'KiB')) * int(m.text)
def resetmem(self):
for m in ('currentMemory', 'maxMemory'):
for e in self.xml.findall(m):
self.xml.remove(e)
memk = self.mem_k
self.xml.newelem('memory', unit='KiB').text = str(memk)
self.xml.newelem('memoryBacking/nosharepages')
self.xml.newelem('memoryBacking/locked')
logger.debug('guest has %f MiB memory', memk / 1024)
def gethugepages(self, s):
return self.conn.getFreePages([s], self.args.cell, 1)[self.args.cell][s]
def sethugepages(self, s, n):
with suppress(libvirt.libvirtError):
return self.conn.allocPages({s: n}, self.args.cell, 1, libvirt.VIR_NODE_ALLOC_PAGES_SET)
def alloc_hugepage(self, ps, firstonly=False):
for s in ps:
if self.mem_k % s != 0:
continue
n = self.mem_k // s
logger.debug('trying to allocate %d pages of %d KiB each', n, s)
pre = self.gethugepages(s)
target = pre + n
self.sethugepages(s, target)
post = self.gethugepages(s)
if post == target:
self.allocated = (s, n)
return True
self.sethugepages(s, pre)
if firstonly:
break
return False
def prepare_hugepage(self, mode):
if mode & self.HP_DISABLED:
return
hpsz = []
for e in self.cap.findall(f'host/topology/cells/cell[@id="{self.args.cell}"]/pages'):
s = self.unit2k(e.get('unit', 'KiB')) * int(e.get('size'))
hpsz.append(s)
logger.debug('host support pagesize %d KiB', s)
hpsz = sorted(hpsz, reverse=True)[:-1]
if self.alloc_hugepage(hpsz, (mode & self.HP_HARDER) != 0):
return
if mode & self.HP_HARDER:
sp.run(['sync'])
sp.run(['sysctl', 'vm.drop_caches=3', 'vm.compact_memory=1'])
if self.alloc_hugepage(hpsz):
return
if mode & self.HP_FORCE:
raise RuntimeError('cannot allocate hugepage')
logger.info('cannot allocate hugepage')
def enable_hugepage(self):
s, _ = self.allocated
self.xml.newelem('memoryBacking/hugepages').newelem('page', size=str(s), unit='KiB')
def start(self):
self.resetmem()
logger.debug('preparing hugepages')
if not self.args.dry_run:
self.prepare_hugepage(self.hugepages)
if self.allocated is not None:
self.enable_hugepage()
def stop(self):
logger.debug('stopping hugepages')
if self.allocated is not None:
s, n = self.allocated
pre = self.gethugepages(s)
self.sethugepages(s, max(0, pre - n))
class IoSchedOpt(Opt):
@staticmethod
def config_parser(a):
a.add_argument('--iosched', default='*deadline', help='set host io scheduler, *deadline by default')
def __init__(self, *a):
super().__init__(*a)
self.saved = {}
def start(self):
for blk in os.listdir('/sys/block'):
spath = f'/sys/block/{blk}/queue/scheduler'
if not os.path.exists(spath):
continue
old = None
news = None
with open(spath, 'r') as f:
for s in f.read().strip().split():
if s[0] == '[' and s[-1] == ']':
old = s[1:-1]
s = old
if fnmatch(s, self.args.iosched):
news = s
if old is None or news is None or old == news:
continue
self.saved[blk] = old
if not self.args.dry_run:
with open(spath, 'w') as f:
f.write(news + '\n')
logger.debug('set io scheduler of %s: %s => %s', blk, old, news)
def stop(self):
for blk, s in self.saved.items():
try:
if not self.args.dry_run:
with open(f'/sys/block/{blk}/queue/scheduler', 'w') as f:
f.write(s + '\n')
logger.debug('restore io scheduler of %s: => %s', blk, s)
except OSError as e:
logger.error('error restoring io scheduler of %s to %s: %r', blk, s, e)
class DRand:
def __init__(self, seed):
if isinstance(seed, str):
seed = seed.encode('utf8')
self.h = hashlib.sha512(seed)
def randbytes(self, n):
r = b''
while len(r) < n:
d = self.h.digest()
r += d
self.h.update(d)
return r[:n]
def randint(self, l, h=None):
if h is None:
h = l
l = 0
s = h - l + 1
r = int.from_bytes(self.randbytes(s.bit_length() + 7 // 8))
return r % s + l
def choice(self, s):
return s[self.randint(len(s) - 1)]
def randword(self, digits=False, minlen=4, maxlen=10):
a = 'abcdefghijklmnopqrstuvwxyz'
if digits:
a += '0123456789'
return ''.join(self.choice(a) for _ in range(self.randint(minlen, maxlen)))
class GuestTuneOpt(Opt):
@staticmethod
def config_parser(a):
a = a.add_argument_group('Guest tuning')
a.add_argument('--nohyperv', action='store_true', help='do not add hyper-v enlightments')
a.add_argument('--noperf', action='store_true', help='do not add performance tuning')
a.add_argument('--hidevm', action='store_true', help='try to hide hypervisor info')
a.add_argument('--sysinfo', action='store_true', help='generate fake sysinfo')
def start(self):
if not self.args.nohyperv:
self.tune_hyperv()
if not self.args.noperf:
self.tune_perf()
if self.args.hidevm:
self.tune_hidevm()
if self.args.sysinfo:
self.tune_sysinfo()
def tune_hyperv(self):
hyperv = self.xml.newelem('features/hyperv', mode='passthrough')
for f in ('relaxed', 'vapic', 'vpindex', 'runtime', 'synic', 'reset', 'frequencies', 'reenlightenment', 'tlbflush', 'ipi'):
hyperv.newelem(f, state='on')
hyperv.newelem('spinlocks', state='on', retries='8191')
hyperv.newelem('stimer', state='on'). \
newelem('direct', state='on')
self.xml.newelem('clock/timer[@name="hypervclock"]', present='yes')
def tune_perf(self):
self.xml.newelem('clock/timer[@name="rtc"]', present='yes', tickpolicy='catchup')
self.xml.newelem('clock/timer[@name="pit"]', present='no')
self.xml.newelem('clock/timer[@name="hpet"]', present='no')
self.xml.newelem('clock/timer[@name="tsc"]', present='yes', mode='native')
self.xml.newelem('cpu/feature[@name="invtsc"]', policy='require')
self.xml.newelem('clock/timer[@name="kvmclock"]', present='yes')
self.xml.newelem('features/msrs', unknown='ignore')
self.xml.newelem('features/pmu', state='off')
devices = self.xml.subelem('devices')
devices.newelem('memballoon', model='none')
if devices.find('graphics[@type="spice"]') is None:
for m in ('redirdev[@type="spicevmc"]',
'smartcard[@type="spicevmc"]',
'channel[@type="spicevmc"]',
'audio[@type="spice"]'):
for e in devices.findall(m):
devices.remove(e)
def tune_hidevm(self):
self.xml.newelem('cpu/feature[@name="hypervisor"]', policy='disable')
self.xml.newelem('features/kvm/hidden', state='on')
if not self.args.nohyperv:
self.xml.newelem('features/hyperv/vendor_id', state='on', value='libvirt')
def tune_sysinfo(self):
g = DRand(self.xml.find('uuid').text)
vendor = g.randword().capitalize()
product = g.randword(True).upper()
serial = g.randword(True, 10, 20).upper()
version = f'{g.randint(1, 9)}{g.randword(True, 3, 3)}{g.randint(0, 9)}'
date = f'{g.randint(1, 12)}/{g.randint(1, 28)}/{g.randint(2010, 2020)}'
self.xml.newelem('os/smbios', mode='sysinfo')
si = self.xml.newelem('sysinfo', type='smbios')
e = si.newelem('bios')
e.newelem('entry[@name="vendor"]').text = vendor
e.newelem('entry[@name="version"]').text = version
e.newelem('entry[@name="date"]').text = date
e = si.newelem('system')
e.newelem('entry[@name="manufacturer"]').text = vendor
e.newelem('entry[@name="product"]').text = product
e.newelem('entry[@name="version"]').text = version
e.newelem('entry[@name="serial"]').text = serial
e = si.newelem('baseBoard')
e.newelem('entry[@name="manufacturer"]').text = vendor
e.newelem('entry[@name="product"]').text = product
e.newelem('entry[@name="version"]').text = version
e.newelem('entry[@name="serial"]').text = serial
e = si.newelem('chassis')
e.newelem('entry[@name="manufacturer"]').text = vendor
e.newelem('entry[@name="version"]').text = version
e.newelem('entry[@name="serial"]').text = serial
class ConflictingDomainOpt(Opt):
C_SHUTDOWN = 0x1
C_RESTART = 0x2
onconflict_map = {
'terminate': 0,
'terminate-restart': C_RESTART,
'shutdown': C_SHUTDOWN,
'shutdown-restart': C_SHUTDOWN | C_RESTART,
}
@classmethod
def config_parser(cls, a):
a.add_argument('--onconflict', default='shutdown-restart',
choices=cls.onconflict_map.keys(), help='configure ways to resolve resource conflict, shutdown-restart by default')
a.add_argument('--onconflict-shutdown-timeout', type=int, default=60)
a.add_argument('--onconflict-start-interval', type=int, default=10)
def __init__(self, *a):
super().__init__(*a)
self.saved = []
@property
def onconflict(self):
return self.onconflict_map[self.args.onconflict]
def findconflict(self):
usedpci = set()
for e in self.xml.findall('devices/hostdev[@type="pci"]/source/address'):
usedpci.add(PciAddr(e))
logger.debug(f'guest use pci devices {sorted(usedpci)}')
for d in self.conn.listAllDomains():
if not d.isActive():
continue
logger.debug(f'found active domain {d.name()}')
dx = Xml.fromstring(d.XMLDesc())
for e in dx.findall('devices/hostdev[@type="pci"]/source/address'):
a = PciAddr(e)
logger.debug(f'active domain {d.name()} use hostdev {a}')
if a in usedpci:
logger.debug('found conflicting domain %s on %r', d.name(), a)
yield d
break
def start(self):
for d in self.findconflict():
self.saved.append(d)
if self.onconflict & self.C_SHUTDOWN:
logger.info('shutting down conflicting domain %s', d.name())
if not self.args.dry_run:
with suppress(libvirt.libvirtError):
d.shutdown()
ddl = self.args.onconflict_shutdown_timeout + time.time()
while d.isActive() and time.time() < ddl:
time.sleep(1)
if d.isActive():
logger.warning('destroying conflicting domain %s', d.name())
if not self.args.dry_run:
with suppress(libvirt.libvirtError):
d.destroy()
def stop(self):
if not self.onconflict & self.C_RESTART:
return
for d in self.saved:
logger.info('starting stopped conflicting domain %s', d.name())
if not self.args.dry_run:
d.create()
time.sleep(self.args.onconflict_start_interval)
class UtilityDomainOpt(Opt):
@classmethod
def config_parser(cls, a):
a.add_argument('--start-domain', action='append', default=[], help='start other kvm domains')
def __init__(self, *a):
super().__init__(*a)
self.saved = []
def start(self):
for dn in self.args.start_domain:
onexit = 'ignore'
if ':' in dn:
dn, onexit = dn.split(':', 1)
if onexit not in {'ignore', 'shutdown', 'destroy'}:
raise RuntimeError(f'invalid destroy method {onexit} for domain {dn}')
d = self.conn.lookupByName(dn)
self.saved.append((d, onexit))
logger.info('starting domain %s', dn)
if not self.args.dry_run:
d.create()
def stop(self):
for d, onexit in self.saved:
dn = d.name()
logger.info('destroy domain %s by %s', dn, onexit)
if self.args.dry_run or onexit == 'ignore':
continue
elif onexit == 'shutdown':
d.shutdown()
else:
d.destroy()
class RealtimeOpt(Opt):
@staticmethod
def config_parser(a):
a.add_argument('--realtime', action='store_true', help='grant realtime priority to the guest')
def start(self):
if not self.args.realtime:
return
vcpu = int(self.xml.find('vcpu').text)
iothr = self.xml.find('iothreads')
if iothr is not None:
iothr = int(iothr.text)
else:
iothr = 0
for i in range(vcpu):
self.xml.newelem(f'cputune/vcpusched[@vcpus="{i}"]', scheduler='fifo', priority='1')
for i in range(iothr):
self.xml.newelem(f'cputune/iothreadsched[@iothreads="{i + 1}"]', scheduler='fifo', priority='1')
self.xml.newelem('cputune/emulatorsched', scheduler='fifo', priority='1')
class CpuFreqOpt(Opt):
@staticmethod
def config_parser(a):
a.add_argument('--cpufreq', default='ondemand', help='set host cpu frequency governor, ondemand by default')
a.add_argument('--noboost', action='store_true', help='disable host cpu frequency boost')
def __init__(self, *a):
super().__init__(*a)
self.savedfreq = {}
self.savedboost = None
def start_cpufreq(self):
for p in os.listdir('/sys/devices/system/cpu/cpufreq'):
d = '/sys/devices/system/cpu/cpufreq/' + p + '/'
if not (os.path.exists(d + 'scaling_available_governors') and os.path.exists(d + 'scaling_governor')):
continue
newg = None
with open(d + 'scaling_available_governors') as f:
for g in f.read().strip().split():
if g == self.args.cpufreq:
newg = g
break
if newg is None:
continue
old = None
with open(d + 'scaling_governor') as f:
old = f.read().strip()
self.savedfreq[d] = old
if not self.args.dry_run:
with open(d + 'scaling_governor', 'w') as f:
f.write(newg + '\n')
def stop_cpufreq(self):
for d, g in self.savedfreq.items():
try:
if not self.args.dry_run:
with open(d + 'scaling_governor', 'w') as f:
f.write(g + '\n')
except OSError as e:
logger.error('error restoring cpu governor of %s to %s: %r', d, g, e)
boost_path = '/sys/devices/system/cpu/cpufreq/boost'
def start_noboost(self):
try:
with open(self.boost_path) as f:
self.savedboost = int(f.read().strip())
except FileNotFoundError:
return
with open(self.boost_path, 'w') as f:
f.write('0\n')
logger.debug('cpufreq boost disabled')
def stop_noboost(self):
if self.savedboost is None:
return
try:
with open(self.boost_path, 'w') as f:
f.write(f'{self.savedboost}\n')
except OSError as e:
logger.error('error restoring cpufreq boost state: %r', e)
def start(self):
self.start_cpufreq()
if self.args.noboost:
self.start_noboost()
def stop(self):
self.stop_cpufreq()
self.stop_noboost()
class CreateVfioSvc:
OPTS = [
PciPassthruOpt,
CpuOpt,
IoSchedOpt,
GuestTuneOpt,
ConflictingDomainOpt,
MemOpt,
UtilityDomainOpt,
RealtimeOpt,
CpuFreqOpt,
]
@classmethod
def config_parser(cls, a):
for o in cls.OPTS:
o.config_parser(a)
def __init__(self, a):
self.args = a
self.conn = None
self.dom = None
self.activeopts = []
def start(self):
if self.args.dry_run:
self.conn = libvirt.openReadOnly('qemu:///system')
else:
self.conn = libvirt.open('qemu:///system')
cap = Xml.fromstring(self.conn.getCapabilities())
self.dom = self.conn.lookupByName(self.args.domain)
xml = Xml.fromstring(self.dom.XMLDesc())
for o in self.OPTS:
o = o(self.args, self.conn, cap, self.dom, xml)
self.activeopts.append(o)
o.start()
xmltext = xml.tostring()
logger.info('starting domain with optimized xml')
if self.args.dry_run:
sp.run('xmllint -format - || cat', shell=True, universal_newlines=True, input=xmltext)
else:
self.conn.createXML(xmltext)
def wait(self):
logger.info('waiting until guest domain power off')
if self.args.dry_run:
return
while self.dom.isActive():
time.sleep(1)
def stop(self):
with suppress(libvirt.libvirtError):
if not self.args.dry_run and self.dom is not None:
self.dom.destroy()
while len(self.activeopts) > 0:
o = self.activeopts.pop()
try:
if hasattr(o, 'stop'):
o.stop()
except:
logger.error('error stopping %r: %s', o, format_exc())
def setdevnull(self):
with open('/dev/null', 'rb') as f:
os.dup2(f.fileno(), sys.stdin.fileno())
with open('/dev/null', 'wb') as f:
os.dup2(f.fileno(), sys.stdout.fileno())
os.dup2(f.fileno(), sys.stderr.fileno())
def main(self):
try:
self.start()
if not self.args.dry_run:
os.close(int(os.environ['startfd']))
if self.args.background:
self.setdevnull()
self.wait()
finally:
self.stop()
def main_svc(a):
CreateVfioSvc(a).main()
def main_cmd(a):
if a.dry_run:
main_svc(a)
return
if os.getuid() != 0:
logger.debug('using sudo to run as root')
os.execvp('sudo', ['sudo', sys.executable] + sys.argv)
if sp.run(['systemctl', '--quiet', 'is-active', f'{PROGRAM}.service']).returncode == 0:
raise RuntimeError(f'already running as {PROGRAM}.service')
sp.run(['systemctl', '--quiet', '--wait', 'reset-failed', f'{PROGRAM}.service'])
sp.run(['systemctl', '--quiet', '--wait', 'clean', f'{PROGRAM}.service'])
piper, pipew = os.pipe()
cmd = ['systemd-run', f'--unit={PROGRAM}.service', '--pipe', '--wait', '--',
'systemd-inhibit', f'--what={a.inhibit}', f'--who={PROGRAM}.service', f'--why=vfio guest {a.domain} running', '--mode=block', '--',
'env', f'startfd={pipew}',
sys.executable, sys.argv[0], '--action=svc'] + sys.argv[1:]
logger.debug('spawning start process with %r', cmd)
p = sp.Popen(cmd, stdin=sp.DEVNULL, pass_fds=(pipew,))
os.close(pipew)
select([piper], [], [])
os.close(piper)
if a.background:
p.terminate()
return
p.wait()
def handlesig(*_):
sys.exit(0)
def main():
signal(SIGHUP, SIG_IGN)
signal(SIGINT, handlesig)
signal(SIGTERM, handlesig)
args = sys.argv[1:]
action = 'cmd'
if len(args) > 0 and args[0].startswith('--action='):
action = args.pop(0).split('=', 1)[1]
a = ArgumentParser(description='Create libvirt KVM guest with vfio pci passthrough and optimizations')
a.add_argument('--background', '-b', action='store_true', help='background after guest started')
a.add_argument('--dry-run', '-n', action='store_true', help='show actions that would be taken and exit')
a.add_argument('--verbose', '-v', action='count', default=0, help='be verbose')
a.add_argument('--inhibit', default='sleep:idle', help='inhibition lock, sleep:idle by default')
a.add_argument('--cell', type=int, default=0, help='run guest in specified libvirt cell')
CreateVfioSvc.config_parser(a)
a.add_argument('domain', help='KVM guest domain on qemu:///system')
a = a.parse_args(args)
logging.basicConfig(level={
0: logging.WARNING,
1: logging.INFO,
2: logging.DEBUG
}[min(a.verbose, 2)])
logger.debug('action=%s', action)
logger.debug('arguments=%r', a)
{
'cmd': main_cmd,
'svc': main_svc,
}[action](a)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment