Skip to content

Instantly share code, notes, and snippets.

@pujianto
Last active December 26, 2021 21:34
Show Gist options
  • Save pujianto/c5487bb0b97139f133ffdffa36cc4a3e to your computer and use it in GitHub Desktop.
Save pujianto/c5487bb0b97139f133ffdffa36cc4a3e to your computer and use it in GitHub Desktop.
Splits an XML file into multiple files
#!/bin/env python
import logging
import os
from lxml.etree import iterparse, tostring
def split_xml(xml_source, rows_per_file, output_dir):
"""
Splits an XML file into multiple files.
Non built in dependency modules: lxml
"""
def xml_generator(iter_xml):
start_tag = None
for event, element in iter_xml:
if event == 'start' and start_tag is None:
start_tag = element.tag
if event == 'end' and element.tag == start_tag:
yield element
start_tag = None
element.getparent().remove(element)
# Check file exists
if not os.path.isfile(xml_source):
raise Exception("File does not exist: {}".format(xml_source))
if int(rows_per_file) < 1:
raise Exception("Rows per file must be greater than 0")
# Create output directory if it doesn't exist
if not os.path.isdir(output_dir):
logging.log(
logging.INFO, "Creating output directory: {}".format(output_dir))
os.makedirs(output_dir, exist_ok=True)
basename = os.path.basename(xml_source).rsplit('.')[0]
# Split file template
split_file_template = str(os.path.join(output_dir, "{}-split-{:05d}.xml"))
# Open source file
doc = iterparse(xml_source, events=('start', 'end'))
_, root = next(doc)
wrapper = root.tag
current_line = 0
split_file_index = 1
split_file = split_file_template.format(basename, split_file_index)
f = open(split_file, 'w')
f.write('<?xml version="1.0" encoding="utf-8"?>\n')
f.write('<{}>\n'.format(wrapper))
for element in xml_generator(doc):
if current_line >= rows_per_file:
logging.log(
logging.INFO, "Reaching rows per file limit. create new file")
split_file_index += 1
current_line = 0
f.write('</{}>\n'.format(wrapper))
f.close()
f = open(split_file_template.format(
basename, split_file_index), 'w')
f.write('<?xml version="1.0" encoding="utf-8"?>\n')
f.write('<{}>\n'.format(wrapper))
f.write(tostring(element, encoding='unicode'))
current_line += 1
f.write('</{}>\n'.format(wrapper))
f.close()
logging.log(logging.INFO, "splitting xml completed")
if __name__ == "__main__":
import argparse
if os.environ.get('DEBUG') in ('True', 'true', '1', 'TRUE'):
import sys
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
parser = argparse.ArgumentParser(
description="Split an XML file into multiple files.")
parser.add_argument("xml", help="The XML file to split")
parser.add_argument("rows", help="The number of rows per file", type=int)
parser.add_argument("output_dir", help="The output directory")
args = parser.parse_args()
split_xml(args.xml, args.rows, args.output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment