Skip to content

Instantly share code, notes, and snippets.

@RavuAlHemio
Created May 9, 2014 07:46
Show Gist options
  • Save RavuAlHemio/2e53d3b174c69653ca29 to your computer and use it in GitHub Desktop.
Save RavuAlHemio/2e53d3b174c69653ca29 to your computer and use it in GitHub Desktop.
dissect a RIFF file and extract all WAVE components into .wav files
import io
import struct
container_chunks = {b"RIFF", b"LIST"}
class RiffChunk:
def __init__(self, tag, data):
self.tag = tag
self.data = data
def __repr__(self):
return "RiffChunk({0}, {1})".format(
repr(self.tag), repr(self.data)
)
def serialize(self):
head = struct.pack("<4sI", self.tag, len(self.data))
pad = b""
if len(self.data) % 2 != 0:
pad = b"\0"
return head + self.data + pad
class RiffContainerChunk:
def __init__(self, container_tag, content_tag, subchunks):
self.container_tag = container_tag
self.content_tag = content_tag
self.subchunks = subchunks
def __repr__(self):
return "RiffContainerChunk({0}, {1}, {2})".format(
repr(self.container_tag), repr(self.content_tag), repr(self.subchunks)
)
def serialize(self):
serialized_subchunks = b"".join([sub.serialize() for sub in self.subchunks])
my_data_length = len(self.content_tag) + len(serialized_subchunks)
head = struct.pack("<4sI4s", self.container_tag, my_data_length, self.content_tag)
return head + serialized_subchunks
def deserialize_riff(stream):
tag = stream.read(4)
if len(tag) == 0:
return None
elif len(tag) != 4:
raise ValueError("early end!")
length_bytes = stream.read(4)
if len(length_bytes) != 4:
raise ValueError("early end!")
(length,) = struct.unpack("<I", length_bytes)
data = stream.read(length)
if length % 2 != 0:
# padding...
stream.read(1)
if tag in container_chunks:
content_tag = data[0:4]
data_io = io.BytesIO(data[4:])
subchunks = []
while True:
subchunk = deserialize_riff(data_io)
if subchunk is None:
break
subchunks.append(subchunk)
return RiffContainerChunk(tag, content_tag, subchunks)
else:
return RiffChunk(tag, data)
def extract_waves(riff, prefix, counter):
if type(riff) is RiffContainerChunk:
if riff.content_tag == b"WAVE":
out_name = "{0}{1:04}.wav".format(prefix, counter.value)
counter.increment()
with open(out_name, "wb") as f:
f.write(riff.serialize())
for subchunk in riff.subchunks:
extract_waves(subchunk, prefix, counter)
if __name__ == '__main__':
class Counter:
def __init__(self, init_value=0):
self.value = init_value
def increment(self):
self.value += 1
import sys
for arg in sys.argv[1:]:
with open(arg, "rb") as f:
riff = deserialize_riff(f)
#print(repr(riff))
counter = Counter()
extract_waves(riff, arg, counter)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment