Last active
May 4, 2025 19:38
-
-
Save a1678991/ba26f3c9a601502e58ff52308f833bd8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import subprocess | |
import os | |
import re | |
import xml.etree.ElementTree as ET | |
import yaml # Requires PyYAML installation | |
from collections import defaultdict | |
import sys | |
# import shutil # No longer needed for backup when using API | |
import libvirt # Import the libvirt library | |
import difflib # For showing differences | |
from xml.dom import minidom # For pretty-printing XML for diff | |
import argparse # For command-line arguments | |
def pretty_print_xml(xml_string): | |
"""Pretty-prints an XML string using minidom and removes the XML declaration.""" | |
try: | |
# Parse the XML string | |
dom = minidom.parseString(xml_string) | |
# Pretty print with indentation | |
pretty_xml = dom.toprettyxml(indent=" ") | |
# Remove the XML declaration line (<?xml ...?>) | |
lines = pretty_xml.splitlines() | |
if lines and lines[0].startswith('<?xml'): | |
return '\n'.join(lines[1:]) | |
else: | |
return pretty_xml | |
except Exception as e: | |
print(f"Warning: Could not pretty-print XML for diff: {e}", file=sys.stderr) | |
# Fallback to the original string if pretty-printing fails | |
return xml_string | |
def get_pci_details(bdf): | |
"""Get detailed information for a given PCI BDF.""" | |
details = { | |
'bdf': bdf, | |
'vendor_id': None, | |
'device_id': None, | |
'vendor_device_id': None, | |
'driver': None, | |
'description': None # Potentially add later if needed | |
} | |
try: | |
# Get vendor/device IDs using lspci -nmm | |
lspci_output = subprocess.check_output(['lspci', '-nmm', '-s', bdf], text=True, stderr=subprocess.DEVNULL).strip() | |
match_vd = re.search(r'\[([0-9a-fA-F]{4}):([0-9a-fA-F]{4})\]', lspci_output) | |
if match_vd: | |
details['vendor_id'] = match_vd.group(1) | |
details['device_id'] = match_vd.group(2) | |
details['vendor_device_id'] = f"{match_vd.group(1)}:{match_vd.group(2)}" | |
else: | |
# Fallback using lspci -n | |
lspci_output_n = subprocess.check_output(['lspci', '-n', '-s', bdf], text=True, stderr=subprocess.DEVNULL).strip() | |
match_n_vd = re.search(r':\s*([0-9a-fA-F]{4}):([0-9a-fA-F]{4})', lspci_output_n) | |
if match_n_vd: | |
details['vendor_id'] = match_n_vd.group(1) | |
details['device_id'] = match_n_vd.group(2) | |
details['vendor_device_id'] = f"{match_n_vd.group(1)}:{match_n_vd.group(2)}" | |
else: | |
print(f"Warning: Could not extract vendor/device ID for {bdf}", file=sys.stderr) | |
except subprocess.CalledProcessError: | |
print(f"Warning: Failed to execute lspci for BDF {bdf}", file=sys.stderr) | |
except Exception as e: | |
print(f"Warning: Error parsing lspci output for BDF {bdf}: {e}", file=sys.stderr) | |
# Get driver info from sysfs | |
driver_link = f"/sys/bus/pci/devices/{bdf}/driver" | |
try: | |
if os.path.islink(driver_link): | |
driver_path = os.readlink(driver_link) | |
details['driver'] = os.path.basename(driver_path) | |
except OSError as e: | |
# Ignore if path doesn't exist or permission error, means no driver or inaccessible | |
if e.errno != 2 and e.errno != 13: | |
print(f"Warning: Error checking driver for {bdf}: {e}", file=sys.stderr) | |
pass # No driver bound or error reading link | |
return details | |
def get_iommu_groups(): | |
""" | |
Parses /sys/kernel/iommu_groups to map groups to devices and get device details. | |
Returns: | |
tuple: (groups, device_to_group, all_device_details) | |
groups (dict): group_id -> list of device BDFs | |
device_to_group (dict): device BDF -> group_id | |
all_device_details (dict): device BDF -> dict of device details (from get_pci_details) | |
""" | |
iommu_base = "/sys/kernel/iommu_groups" | |
groups = defaultdict(list) | |
device_to_group = {} | |
all_device_details = {} | |
if not os.path.isdir(iommu_base): | |
print(f"Error: IOMMU directory not found at {iommu_base}. Is IOMMU enabled?", file=sys.stderr) | |
return {}, {}, {} | |
try: | |
group_ids = sorted([d for d in os.listdir(iommu_base) if os.path.isdir(os.path.join(iommu_base, d)) and d.isdigit()], key=int) | |
except OSError as e: | |
print(f"Error reading IOMMU groups directory {iommu_base}: {e}", file=sys.stderr) | |
return {}, {}, {} | |
except PermissionError: | |
print(f"Error: Permission denied when trying to read {iommu_base}. Run as root?", file=sys.stderr) | |
return {}, {}, {} | |
print(f"Found {len(group_ids)} IOMMU groups. Parsing devices...") | |
for group_id_str in group_ids: | |
group_path = os.path.join(iommu_base, group_id_str, "devices") | |
if not os.path.isdir(group_path): | |
continue | |
try: | |
device_bdfs = os.listdir(group_path) | |
for bdf in device_bdfs: | |
if re.match(r'^[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$', bdf): | |
groups[group_id_str].append(bdf) | |
device_to_group[bdf] = group_id_str | |
# Get detailed info for each device | |
details = get_pci_details(bdf) | |
all_device_details[bdf] = details | |
except OSError as e: | |
print(f"Warning: Could not read devices for IOMMU group {group_id_str}: {e}", file=sys.stderr) | |
except PermissionError: | |
print(f"Warning: Permission denied reading devices for IOMMU group {group_id_str}", file=sys.stderr) | |
print("Finished parsing IOMMU groups.") | |
return groups, device_to_group, all_device_details | |
def parse_pci_address(bdf): | |
"""Parses a BDF string '0000:bb:ss.f' into components needed for XML.""" | |
# Correct regex: Need to escape the dot. | |
match = re.match(r'([0-9a-fA-F]{4}):([0-9a-fA-F]{2}):([0-9a-fA-F]{2})\.([0-9a-fA-F])', bdf) | |
if match: | |
return { | |
'domain': f'0x{match.group(1)}', | |
'bus': f'0x{match.group(2)}', | |
'slot': f'0x{match.group(3)}', | |
'function': f'0x{match.group(4)}' | |
} | |
return None | |
def update_vm_definition(vm_name, vm_config, groups, device_to_group, all_device_details, debug_xml_file=None): | |
"""Updates the VM definition using the libvirt API based on the configuration, showing diff and asking for confirmation.""" | |
# passthrough_devices should now be a list of dictionaries (match criteria sets) | |
passthrough_requests = vm_config.get('passthrough_devices', []) or [] | |
# --- Validate passthrough_requests format --- | |
if not isinstance(passthrough_requests, list): | |
print(f"Error: 'passthrough_devices' for VM '{vm_name}' must be a list of match criteria.", file=sys.stderr) | |
return False | |
for i, req in enumerate(passthrough_requests): | |
if not isinstance(req, dict) or 'match' not in req or not isinstance(req['match'], dict): | |
print(f"Error: Entry {i} in 'passthrough_devices' for VM '{vm_name}' is invalid. Expected format: {{ 'match': {{ key: value, ... }} }}", file=sys.stderr) | |
return False | |
print(f"\n--- Processing VM: {vm_name} ---") | |
# --- Determine the full set of devices needed based on matching criteria and IOMMU groups --- | |
final_passthrough_bdfs = set() | |
processed_groups = set() | |
resolution_warnings = 0 | |
matched_request_indices = set() | |
# Iterate through all known devices found on the host | |
for bdf, device_info in all_device_details.items(): | |
# Check if this device matches any of the requested criteria sets | |
for idx, request_set in enumerate(passthrough_requests): | |
match_criteria = request_set['match'] | |
matches_all = True | |
if not match_criteria: # Skip empty match blocks | |
matches_all = False | |
continue | |
for key, value in match_criteria.items(): | |
# Compare criteria key with the corresponding key in device_info | |
if key not in device_info: | |
print(f"Warning: Unknown match key '{key}' in request for VM '{vm_name}'. Skipping criteria.", file=sys.stderr) | |
matches_all = False | |
break # Stop checking this criteria set for this device | |
device_value = device_info[key] | |
# Simple string comparison for now | |
if isinstance(device_value, str) and isinstance(value, str): | |
if device_value.lower() != value.lower(): | |
matches_all = False | |
break # Stop checking this criteria set for this device | |
elif device_value != value: # Allow comparison for non-string types if added later | |
matches_all = False | |
break | |
# If this device matched ALL criteria in the current request_set | |
if matches_all: | |
print(f" Device {bdf} ({device_info.get('vendor_device_id', 'N/A')}, driver: {device_info.get('driver', 'None')}) matches criteria set {idx+1}.") | |
matched_request_indices.add(idx) | |
# Find the IOMMU group for this matched device | |
group_id = device_to_group.get(bdf) | |
if not group_id: | |
print(f"Warning: Could not find IOMMU group for matched device {bdf}. Cannot add its group.", file=sys.stderr) | |
resolution_warnings += 1 | |
continue # Try next request set for this BDF | |
# If this group hasn't been processed yet, add all its devices | |
if group_id not in processed_groups: | |
print(f" Device {bdf} is in IOMMU Group {group_id}.") | |
group_devices_bdfs = groups.get(group_id, []) | |
if not group_devices_bdfs: | |
print(f"Warning: IOMMU group {group_id} associated with {bdf} appears empty in parsed data!", file=sys.stderr) | |
resolution_warnings += 1 | |
else: | |
print(f" Adding all devices from IOMMU Group {group_id}: {', '.join(group_devices_bdfs)}") | |
final_passthrough_bdfs.update(group_devices_bdfs) | |
processed_groups.add(group_id) | |
# else: Group already added by a previous match | |
# Since this device satisfied one request, move to the next device | |
# (A single device can satisfy multiple requests, but its group is only added once) | |
# Let's continue checking other criteria sets for the *same* device in case | |
# it matches another set that belongs to an *unprocessed* group. | |
# If we break here, we might miss adding a different group this device belongs to if | |
# it also matches another criteria set tied to that different group (highly unlikely scenario). | |
# break # Optional: If one match is enough per device? safer not to break. | |
# Check if any requested criteria sets were not matched by any device | |
for idx, req in enumerate(passthrough_requests): | |
if idx not in matched_request_indices and req['match']: # Check non-empty match blocks | |
print(f"Warning: No host device found matching criteria set {idx+1} for VM '{vm_name}': {req['match']}", file=sys.stderr) | |
resolution_warnings += 1 | |
# --- Connect to libvirt and Prepare Potential Changes --- | |
conn = None | |
original_xml_string = None | |
modified_xml_string = None # This will have double quotes from ET | |
made_changes_in_tree = False | |
try: | |
conn = libvirt.open(None) | |
if conn is None: | |
print('Error: Failed to open connection to the hypervisor. Run as root?', file=sys.stderr) | |
return False | |
try: | |
dom = conn.lookupByName(vm_name) | |
except libvirt.libvirtError as e: | |
print(f"Error: Failed to find domain '{vm_name}': {e}", file=sys.stderr) | |
return False | |
original_xml_string = dom.XMLDesc(0) | |
if not original_xml_string: | |
print(f"Error: Failed to get XML description for domain '{vm_name}'.", file=sys.stderr) | |
return False | |
root = ET.fromstring(original_xml_string) | |
devices_element = root.find('./devices') | |
if devices_element is None: | |
print(f"Error: Cannot find <devices> element in XML for VM '{vm_name}'", file=sys.stderr) | |
return False | |
# --- Apply potential changes to the parsed XML Tree --- | |
removed_count = 0 | |
existing_hostdevs = devices_element.findall("./hostdev[@type='pci'][@mode='subsystem']") | |
for hostdev in existing_hostdevs: | |
devices_element.remove(hostdev) | |
removed_count += 1 | |
if removed_count > 0: | |
print(f" Will remove {removed_count} existing PCI subsystem hostdev element(s)...") | |
made_changes_in_tree = True | |
added_count = 0 | |
if final_passthrough_bdfs: | |
print(f" Will add {len(final_passthrough_bdfs)} devices for passthrough:") | |
# Sort BDFs for consistent ordering in the XML | |
for bdf in sorted(list(final_passthrough_bdfs)): | |
pci_addr = parse_pci_address(bdf) | |
if not pci_addr: | |
print(f"Warning: Could not parse BDF '{bdf}' into PCI address components. Skipping.", file=sys.stderr) | |
resolution_warnings += 1 # Count this as a warning | |
continue | |
hostdev_attrib = { | |
'mode': 'subsystem', 'type': 'pci', 'managed': 'yes' | |
} | |
hostdev = ET.Element('hostdev', attrib=hostdev_attrib) | |
source = ET.SubElement(hostdev, 'source') | |
address_attrib = { | |
'type': 'pci', | |
'domain': pci_addr['domain'], 'bus': pci_addr['bus'], | |
'slot': pci_addr['slot'], 'function': pci_addr['function'] | |
} | |
ET.SubElement(source, 'address', attrib=address_attrib) | |
devices_element.append(hostdev) | |
added_count += 1 | |
made_changes_in_tree = True | |
# Get device info for print statement | |
dev_info_print = all_device_details.get(bdf, {}) | |
print(f" + Proposing: {bdf} (Vendor:Device {dev_info_print.get('vendor_device_id', 'N/A')}, Driver: {dev_info_print.get('driver', 'None')})") | |
elif removed_count == 0: | |
print(f" No changes to PCI passthrough devices needed for VM '{vm_name}'.") | |
return True | |
if made_changes_in_tree: | |
# Generate the potentially compact XML string from the modified tree | |
modified_xml_string = ET.tostring(root, encoding='unicode') | |
else: | |
modified_xml_string = original_xml_string | |
# --- Save debug XML if requested --- | |
if debug_xml_file: | |
try: | |
print(f" Saving proposed XML for {vm_name} to '{debug_xml_file}'...") | |
# Use pretty_print_xml to make the saved file readable | |
pretty_debug_xml = pretty_print_xml(modified_xml_string) | |
with open(debug_xml_file, 'w') as f: | |
f.write(pretty_debug_xml) | |
print(f" Successfully saved debug XML.") | |
except Exception as e: | |
print(f"Warning: Failed to save debug XML to '{debug_xml_file}': {e}", file=sys.stderr) | |
# --- Show Diff and Ask for Confirmation --- | |
# Pretty-print both versions *for diffing only* | |
pretty_original_xml = pretty_print_xml(original_xml_string) | |
pretty_modified_xml = pretty_print_xml(modified_xml_string) | |
# Normalize quotes and spacing on the pretty-printed versions for diffing | |
modified_xml_for_diff = re.sub(r'=(["\'])(.*?)\1', r"='\2'", pretty_modified_xml) | |
original_xml_for_diff = re.sub(r'=(["\'])(.*?)\1', r"='\2'", pretty_original_xml) | |
modified_xml_for_diff = re.sub(r'\s+/>', '/>', modified_xml_for_diff) | |
original_xml_for_diff = re.sub(r'\s+/>', '/>', original_xml_for_diff) | |
# Compare the normalized, pretty-printed versions | |
if original_xml_for_diff != modified_xml_for_diff: | |
print("\n" + "-" * 15 + f" Proposed changes for {vm_name} " + "-" * 15) | |
# Note: Diff formatting includes normalization effects | |
print("(Note: Diff ignores attribute quotes and self-closing tag spacing)") | |
diff = difflib.unified_diff( | |
original_xml_for_diff.splitlines(keepends=True), | |
modified_xml_for_diff.splitlines(keepends=True), | |
fromfile='current definition', | |
tofile='proposed definition', | |
lineterm='\n' | |
) | |
sys.stdout.writelines(diff) | |
print("-" * (30 + len(f" Proposed changes for {vm_name} "))) # Match header length | |
try: | |
confirm = input("Apply these changes? [y/N]: ").strip().lower() | |
except EOFError: | |
confirm = 'n' | |
print("\nNo input detected, assuming No.", file=sys.stderr) | |
if confirm == 'y': | |
print(f"\nApplying updated definition to libvirt for VM '{vm_name}'...") | |
# IMPORTANT: Apply the *original* modified string (from ET.tostring) | |
# not the pretty-printed one, to avoid potential format conflicts. | |
conn.defineXML(modified_xml_string) | |
print(f"Successfully applied changes to '{vm_name}'.") | |
else: | |
print(f"Changes for VM '{vm_name}' aborted by user.") | |
else: | |
print(f" No effective changes detected for VM '{vm_name}' after processing (ignoring quote style and tag spacing).") | |
if resolution_warnings > 0: | |
print(f"NOTE: There were {resolution_warnings} warnings during device resolution for this VM.") | |
return True | |
except libvirt.libvirtError as e: | |
print(f"Error: Libvirt API error processing VM '{vm_name}': {e}", file=sys.stderr) | |
return False | |
except ET.ParseError as e: | |
print(f"Error: Failed to parse XML received from libvirt for VM '{vm_name}': {e}", file=sys.stderr) | |
return False | |
except Exception as e: | |
print(f"Error: An unexpected error occurred while processing VM '{vm_name}': {e}", file=sys.stderr) | |
return False | |
finally: | |
if conn: | |
try: | |
conn.close() | |
except libvirt.libvirtError: | |
pass | |
def main(): | |
parser = argparse.ArgumentParser(description="Manage Libvirt PCI passthrough using the libvirt API.") | |
parser.add_argument('--config', default='passthrough_config.yaml', | |
help="Path to the YAML configuration file (default: passthrough_config.yaml)") | |
parser.add_argument('--debug-xml', metavar='FILENAME', | |
help="Save the proposed XML for each VM to the specified file before applying.") | |
# Add other arguments here later (e.g., --yes, --dry-run, --vm) | |
args = parser.parse_args() | |
config_file = args.config | |
print("Libvirt Passthrough Manager (API Mode with Diff/Confirm)") | |
print("=" * 55) | |
if os.geteuid() != 0: | |
print("Error: This script needs root privileges to connect to the system libvirt daemon and modify domain definitions.", file=sys.stderr) | |
return 1 | |
print("Gathering IOMMU group information...") | |
# Get the enhanced device details | |
groups, device_to_group, all_device_details = get_iommu_groups() | |
if not device_to_group: # Check if the core mapping was populated | |
print("Error: Failed to get IOMMU group or device information. Exiting.", file=sys.stderr) | |
return 1 | |
print(f"\nLoading configuration from '{config_file}'...") | |
try: | |
with open(config_file, 'r') as f: | |
config = yaml.safe_load(f) | |
except FileNotFoundError: | |
print(f"Error: Configuration file '{config_file}' not found.", file=sys.stderr) | |
return 1 | |
except yaml.YAMLError as e: | |
print(f"Error: Failed to parse configuration file '{config_file}': {e}", file=sys.stderr) | |
return 1 | |
except PermissionError: | |
print(f"Error: Permission denied reading config file '{config_file}'.", file=sys.stderr) | |
return 1 | |
except Exception as e: | |
print(f"Error: Could not read config file '{config_file}': {e}", file=sys.stderr) | |
return 1 | |
if not config or 'vms' not in config or not isinstance(config['vms'], dict): | |
print(f"Error: Configuration file '{config_file}' is invalid. It must contain a top-level 'vms' dictionary.", file=sys.stderr) | |
return 1 | |
if not config['vms']: | |
print("Info: No VMs defined in the 'vms' section of the configuration file. Nothing to do.") | |
return 0 | |
print("\nProcessing VM configurations...") | |
processed_count = 0 | |
fail_count = 0 | |
total_vms = len(config['vms']) | |
for vm_name, vm_config_data in config['vms'].items(): | |
# Construct a specific debug filename for each VM if the flag is set | |
vm_debug_xml_file = None | |
if args.debug_xml: | |
# Insert VM name before the extension, or append if no extension | |
base, ext = os.path.splitext(args.debug_xml) | |
vm_debug_xml_file = f"{base}_{vm_name}{ext}" | |
if not isinstance(vm_config_data, dict): | |
print(f"Warning: Invalid configuration format for VM '{vm_name}' (expected a dictionary). Skipping.", file=sys.stderr) | |
fail_count += 1 | |
continue | |
# Pass the debug filename to the update function | |
if update_vm_definition(vm_name, vm_config_data, groups, device_to_group, all_device_details, vm_debug_xml_file): | |
processed_count += 1 | |
else: | |
fail_count += 1 | |
# --- Summary --- (remains the same) | |
print("\n" + "=" * 55) | |
print("Processing Complete.") | |
print(f" Total VMs in config: {total_vms}") | |
print(f" Successfully processed: {processed_count}") | |
print(f" Failed/Skipped: {fail_count}") | |
print("=" * 55) | |
if fail_count > 0: | |
print("\nCheck warnings/errors above for details on failures.", file=sys.stderr) | |
return 1 | |
elif processed_count == 0 and total_vms > 0: | |
print("\nNo VMs were successfully processed (check config and warnings).") | |
return 1 | |
return 0 | |
if __name__ == "__main__": | |
sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment