Created
December 28, 2024 08:32
-
-
Save esnya/f36acc27c5a7be96da331acab534eeca to your computer and use it in GitHub Desktop.
This script converts a MIDI file into a piano roll punch card image.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script converts a MIDI file into a piano roll punch card image. | |
Usage: | |
Run the script with the appropriate arguments to convert a MIDI file to a punch card image | |
or to create a blank template. Use the -h or --help flag for more information. | |
""" | |
from argparse import ArgumentParser | |
from collections.abc import Collection, Iterable, Sequence | |
from dataclasses import dataclass | |
from logging import getLogger | |
from pathlib import Path | |
from typing import NamedTuple | |
from mido import MidiFile, tempo2bpm | |
from PIL import Image, ImageDraw | |
# Note range for the piano roll | |
DEFAULT_NOTE_RANGE = range(60, 93 + 1) | |
# Default image size | |
DEFAULT_IMAGE_WIDTH = 4096 | |
DEFAULT_IMAGE_HEIGHT = 256 | |
# Default margins (top, bottom, left) | |
DEFAULT_MARGIN_TOP = 2.9 | |
DEFAULT_MARGIN_BOTTOM = 2.3 | |
DEFAULT_MARGIN_LEFT = 1.0 | |
# Default beat width | |
DEFAULT_BEAT_WIDTH = 31.68 | |
# Maximum image width | |
DEFAULT_MAX_IMAGE_WIDTH = 8192 | |
# Default colors | |
DEFAULT_BACKGROUND_COLOR = (230, 220, 180, 255) | |
DEFAULT_LINE_COLOR = (100, 100, 100, 255) | |
DEFAULT_BAR_LINE_COLOR = (50, 50, 50, 255) | |
DEFAULT_TEXT_COLOR = (50, 50, 50, 255) | |
DEFAULT_NOTE_COLOR = (0, 0, 0, 0) | |
logger = getLogger(__name__) | |
class Note(NamedTuple): | |
tick: int | |
note: int | |
@dataclass | |
class Melody: | |
""" | |
A class to represent a melody parsed from a MIDI file. | |
Attributes: | |
---------- | |
note_events : list[Note] | |
A collection of note events in the melody. | |
ticks_per_beat : int | |
The number of ticks per beat in the MIDI file. | |
beat_label : str | |
The time signature of the melody. | |
bpm : int | |
The tempo of the melody in beats per minute. | |
copyright : str | None | |
The copyright information of the MIDI file, if available. | |
""" | |
note_events: list[Note] | |
ticks_per_beat: int | |
beat_label: str | |
bpm: int | |
copyright: str | None | |
@classmethod | |
def parse( | |
cls, | |
midi: MidiFile, | |
note_range: Collection[int], | |
octave_shift: float = 0, | |
channels: Collection[int] = [], | |
all_channels: bool = False, | |
) -> "Melody": | |
""" | |
Parses a MIDI file and extracts melody information. | |
Args: | |
midi (MidiFile): The MIDI file to parse. | |
note_range (Collection[int]): The range of MIDI note numbers to include. | |
octave_shift (float, optional): The number of octaves to shift the notes. Defaults to 0. | |
channels (Collection[int], optional): The MIDI channels to include. Defaults to an empty list. | |
all_channels (bool, optional): Whether to include all MIDI channels. Defaults to False. | |
Returns: | |
Melody: An object containing the parsed melody information. | |
""" | |
_logger = logger.getChild(cls.__name__) | |
notes = {} | |
current_time = 0 | |
copyright = None | |
bpm = 120 | |
beat_label = "4/4" | |
for msg in midi.merged_track: | |
current_time += int(msg.time) | |
if msg.type == "time_signature": | |
beat_label = cls._process_time_signature(msg, _logger) | |
elif msg.type == "set_tempo": | |
bpm = cls._process_tempo(msg, _logger) | |
elif msg.type == "copyright": | |
copyright = cls._process_copyright(msg, _logger) | |
elif msg.type == "note_on": | |
cls._process_note_on( | |
msg, | |
notes, | |
current_time, | |
note_range, | |
octave_shift, | |
channels, | |
all_channels, | |
_logger, | |
) | |
else: | |
_logger.info(f"Skipping non-note message: {msg}") | |
return Melody( | |
note_events=[Note(tick, note) for tick, note in notes.items()], | |
ticks_per_beat=midi.ticks_per_beat, | |
beat_label=beat_label, | |
bpm=bpm, | |
copyright=copyright, | |
) | |
@staticmethod | |
def _process_time_signature(msg, logger) -> str: | |
logger.info(f"Time signature: {msg}") | |
return f"{msg.numerator}/{msg.denominator}" | |
@staticmethod | |
def _process_tempo(msg, logger) -> int: | |
logger.info(f"Tempo: {msg}") | |
return int(tempo2bpm(msg.tempo)) | |
@staticmethod | |
def _process_copyright(msg, logger) -> str: | |
logger.info(f"Copyright: {msg}") | |
return str(msg.text) | |
@staticmethod | |
def _process_note_on( | |
msg, | |
notes: dict[int, int], | |
current_time: int, | |
note_range: Collection[int], | |
octave_shift: float, | |
channels: Collection[int], | |
all_channels: bool, | |
logger, | |
) -> None: | |
if msg.velocity > 0 and (msg.channel in channels or all_channels): | |
shifted_note = msg.note + int(octave_shift * 12) | |
if shifted_note not in note_range: | |
logger.warning(f"Skipping out-of-range note: {msg}") | |
elif current_time in notes: | |
logger.warning(f"Skipping overlapping note: {msg}") | |
else: | |
notes[current_time] = shifted_note | |
@dataclass | |
class MelodyImageMetaData: | |
""" | |
A class to represent metadata for a melody image. | |
Attributes: | |
size (tuple[int, int]): The size of the image (width, height). | |
margin (tuple[float, float, float]): The margins (top, left, bottom). | |
beat_width (float): The width of a beat in the image. | |
num_ticks (int): The number of ticks in the melody. | |
num_beats (int): The number of beats in the melody. | |
num_notes (int): The number of notes in the melody. | |
note_size (float): The size of a note in the image. | |
tick_scale (float): The scale of a tick in the image. | |
note_range (Sequence[int]): The range of notes in the melody. | |
Properties: | |
margin_top (float): The top margin. | |
margin_left (float): The left margin. | |
margin_bottom (float): The bottom margin. | |
image_width (int): The width of the image. | |
image_height (int): The height of the image. | |
margin_bottom_pixels (int): The bottom margin in pixels. | |
margin_top_pixels (int): The top margin in pixels. | |
margin_left_pixels (int): The left margin in pixels. | |
note_size_pixels (int): The size of a note in pixels. | |
""" | |
size: tuple[int, int] = (0, 0) | |
margin: tuple[float, float, float] = (0, 0, 0) | |
beat_width: float = 0 | |
num_ticks: int = 0 | |
num_beats: int = 0 | |
num_notes: int = 0 | |
note_size: float = 0 | |
tick_scale: float = 0 | |
note_range: Sequence[int] = () | |
@property | |
def margin_top(self) -> float: | |
return self.margin[0] | |
@property | |
def margin_left(self) -> float: | |
return self.margin[1] | |
@property | |
def margin_bottom(self) -> float: | |
return self.margin[2] | |
@property | |
def image_width(self) -> int: | |
return self.size[0] | |
@property | |
def image_height(self) -> int: | |
return self.size[1] | |
@property | |
def margin_bottom_pixels(self) -> int: | |
return int(self.margin_bottom * self.note_size) | |
@property | |
def margin_top_pixels(self) -> int: | |
return int(self.margin_top * self.note_size) | |
@property | |
def margin_left_pixels(self) -> int: | |
return int(self.margin_left * self.beat_width) | |
@property | |
def note_size_pixels(self) -> int: | |
return int(self.note_size) | |
@staticmethod | |
def _calculate_meta_data( | |
note_range: Sequence[int], | |
size: Collection[int], | |
margin: Collection[float], | |
beat_width: float, | |
ticks_per_beat: int = 0, | |
num_ticks: int = 0, | |
bpm: int = 120, | |
) -> "MelodyImageMetaData": | |
[max_image_width, image_height] = size | |
[margin_top, margin_left, margin_bottom] = margin | |
num_notes = len(note_range) | |
note_size = image_height / (num_notes + margin_top + margin_bottom) | |
if ticks_per_beat > 0: | |
beat_width = beat_width * (image_height / 256) * (120 / bpm) | |
num_beats = int(num_ticks / ticks_per_beat / 4 + 0.5) * 4 | |
tick_scale = beat_width / ticks_per_beat | |
else: | |
beat_width = beat_width * image_height / 256 | |
num_beats = int((max_image_width - (margin_left * note_size)) / beat_width) | |
tick_scale = 0 | |
image_width = min( | |
max_image_width, | |
int(num_beats * beat_width + int(margin_left * beat_width) * 2), | |
) | |
logger.info(f"image_size: {image_width}x{image_height}") | |
return MelodyImageMetaData( | |
size=(image_width, image_height), | |
margin=(margin_top, margin_left, margin_bottom), | |
beat_width=beat_width, | |
num_ticks=num_beats * ticks_per_beat, | |
num_beats=num_beats, | |
num_notes=num_notes, | |
note_size=note_size, | |
tick_scale=tick_scale, | |
note_range=note_range, | |
) | |
@staticmethod | |
def from_melody( | |
melody: Melody, | |
note_range: Sequence[int], | |
size: Collection[int], | |
margin: Collection[float], | |
beat_width: float, | |
) -> "MelodyImageMetaData": | |
""" | |
Creates a MelodyImageMetaData object from a given melody. | |
Args: | |
melody (Melody): The melody object containing note events and other metadata. | |
note_range (Sequence[int]): The range of notes to be included in the metadata. | |
size (Collection[int]): The size dimensions for the metadata. | |
margin (Collection[float]): The margin values for the metadata. | |
beat_width (float): The width of each beat in the metadata. | |
Returns: | |
MelodyImageMetaData: The metadata object containing calculated information based on the melody. | |
""" | |
return MelodyImageMetaData._calculate_meta_data( | |
note_range=note_range, | |
size=size, | |
margin=margin, | |
beat_width=beat_width, | |
ticks_per_beat=melody.ticks_per_beat, | |
num_ticks=int(max(t for t, _ in melody.note_events)), | |
bpm=melody.bpm, | |
) | |
@staticmethod | |
def from_size( | |
note_range: Sequence[int], | |
size: Collection[int], | |
margin: Collection[float], | |
beat_width: float, | |
) -> "MelodyImageMetaData": | |
""" | |
Create a MelodyImageMetaData instance based on the given parameters. | |
Args: | |
note_range (Sequence[int]): The range of notes to be included. | |
size (Collection[int]): The size dimensions of the image. | |
margin (Collection[float]): The margins around the image. | |
beat_width (float): The width of each beat in the image. | |
Returns: | |
MelodyImageMetaData: An instance of MelodyImageMetaData with calculated metadata. | |
""" | |
return MelodyImageMetaData._calculate_meta_data( | |
note_range=note_range, | |
size=size, | |
margin=margin, | |
beat_width=beat_width, | |
) | |
class RollRenderer: | |
""" | |
A class to render a musical roll image with various elements such as lines, notes, and titles. | |
Attributes: | |
background_color (tuple): RGBA color for the background. | |
line_color (tuple): RGBA color for the lines. | |
bar_line_color (tuple): RGBA color for the bar lines. | |
text_color (tuple): RGBA color for the text. | |
note_color (tuple): RGBA color for the notes. | |
""" | |
background_color = DEFAULT_BACKGROUND_COLOR | |
line_color = DEFAULT_LINE_COLOR | |
bar_line_color = DEFAULT_BAR_LINE_COLOR | |
text_color = DEFAULT_TEXT_COLOR | |
note_color = DEFAULT_NOTE_COLOR | |
def __init__(self, meta: MelodyImageMetaData, image: Image.Image) -> None: | |
""" | |
Initializes the RollRenderer object. | |
Args: | |
meta (MelodyImageMetaData): Metadata for the melody image. | |
image (Image.Image): The image on which to draw. | |
""" | |
self.logger = logger.getChild(__class__.__name__) | |
self.meta = meta | |
self.draw = ImageDraw.Draw(image) | |
self._hlines() | |
self._vlines() | |
self._note_placeholders() | |
def _vlines(self) -> None: | |
for beat in range(self.meta.num_beats + 2): | |
x = int(beat * self.meta.beat_width) + self.meta.margin_left_pixels | |
if beat % 4 == 0: | |
self.draw.line( | |
[ | |
(x, self.meta.margin_top_pixels), | |
( | |
x, | |
self.meta.image_height - self.meta.margin_bottom_pixels, | |
), | |
], | |
fill=self.bar_line_color, | |
width=2, | |
) | |
else: | |
self.draw.line( | |
[ | |
(x, self.meta.margin_top_pixels), | |
( | |
x, | |
self.meta.image_height - self.meta.margin_bottom_pixels, | |
), | |
], | |
fill=self.line_color, | |
width=1, | |
) | |
def _hlines(self) -> None: | |
for i in range(self.meta.num_notes + 1): | |
y = int(i * self.meta.note_size) + self.meta.margin_top_pixels | |
self.draw.line( | |
[ | |
(self.meta.margin_left_pixels, y), | |
(self.meta.image_width - self.meta.margin_left_pixels, y), | |
], | |
fill=self.line_color, | |
width=1, | |
) | |
def _note_placeholders( | |
self, | |
) -> None: | |
for i in range(self.meta.num_notes + 1): | |
for j in range(self.meta.num_beats): | |
x = int((j + 0.5) * self.meta.beat_width) + self.meta.margin_left_pixels | |
y = int(i * self.meta.note_size) + self.meta.margin_top_pixels - 1 | |
self.draw.line(((x, y), (x, y + 2)), fill=self.line_color, width=1) | |
def notes( | |
self, | |
note_events: Iterable[Note], | |
) -> None: | |
""" | |
Draws notes on the image based on the provided note events. | |
Args: | |
note_events (Iterable[Note]): An iterable of note events, | |
where each event is a tuple containing a tick and a note. | |
Notes: | |
- The method calculates the x and y coordinates for each note based on the tick and note values. | |
- If a note's x coordinate exceeds the image width, a warning is logged and the note is skipped. | |
- The note is drawn as a circle on the image at the calculated coordinates. | |
""" | |
note_index_map = { | |
note: idx for idx, note in enumerate(self.meta.note_range[::-1]) | |
} | |
for tick, note in note_events: | |
x = int( | |
( | |
tick * self.meta.tick_scale | |
+ self.meta.margin_left_pixels | |
+ self.meta.note_size_pixels / 2 | |
) | |
) | |
if x >= self.meta.image_width: | |
self.logger.warning(f"Skipping note outside image: {note}") | |
continue | |
y = int( | |
(note_index_map[note] + 0.5) * self.meta.note_size_pixels | |
+ self.meta.margin_top_pixels | |
) | |
self.draw.circle( | |
(x, y), | |
int(self.meta.note_size_pixels / 2), | |
fill=self.note_color, | |
) | |
def title( | |
self, | |
text_list: Iterable[str], | |
) -> None: | |
""" | |
Draws a title on the image using the provided list of text strings. | |
Args: | |
text_list (Iterable[str]): A list of strings to be used as the title text. | |
Each string will be joined with spaces and newlines | |
will be replaced with spaces. | |
""" | |
self.draw.text( | |
(self.meta.margin_left_pixels, self.meta.margin_top_pixels - 1), | |
" ".join(text_list).replace("\n", " "), | |
anchor="lb", | |
fill=self.text_color, | |
) | |
def blank_punch_card( | |
note_range: Sequence[int], | |
margin_top: float, | |
margin_bottom: float, | |
margin_left: float, | |
image_height: int, | |
image_width: int, | |
beat_width: float, | |
) -> Image.Image: | |
""" | |
Creates a blank punch card image for a given note range and dimensions. | |
Args: | |
note_range (Sequence[int]): The range of notes to be represented on the punch card. | |
margin_top (float): The top margin of the punch card. | |
margin_bottom (float): The bottom margin of the punch card. | |
margin_left (float): The left margin of the punch card. | |
image_height (int): The height of the punch card image. | |
image_width (int): The width of the punch card image. | |
beat_width (float): The width of each beat on the punch card. | |
Returns: | |
Image.Image: A blank punch card image with the specified dimensions and note range. | |
""" | |
meta = MelodyImageMetaData.from_size( | |
note_range=note_range, | |
size=(image_width, image_height), | |
margin=(margin_top, margin_left, margin_bottom), | |
beat_width=beat_width, | |
) | |
image = Image.new( | |
"RGBA", | |
meta.size, | |
RollRenderer.background_color, | |
) | |
RollRenderer(meta, image) | |
return image | |
def midi_to_punch_card( | |
midi: MidiFile, | |
note_range: Sequence[int], | |
octave_shift: float, | |
margin_top: float, | |
margin_bottom: float, | |
margin_left: float, | |
max_image_width: int, | |
image_height: int, | |
beat_width: float, | |
channels: list[int], | |
all_channels: bool, | |
) -> Image.Image: | |
""" | |
Converts a MIDI file into a punch card image representation. | |
Args: | |
midi (MidiFile): The MIDI file to be converted. | |
note_range (Sequence[int]): The range of MIDI notes to be included in the punch card. | |
octave_shift (float): The number of octaves to shift the notes. | |
margin_top (float): The top margin of the image. | |
margin_bottom (float): The bottom margin of the image. | |
margin_left (float): The left margin of the image. | |
max_image_width (int): The maximum width of the generated image. | |
image_height (int): The height of the generated image. | |
beat_width (float): The width of each beat in the image. | |
channels (list[int]): The list of MIDI channels to include. | |
all_channels (bool): Whether to include all MIDI channels. | |
Returns: | |
Image.Image: The generated punch card image. | |
Raises: | |
Exception: If there is an error in parsing the melody, calculating image metadata, or drawing the image. | |
""" | |
try: | |
melody = Melody.parse( | |
midi, | |
note_range, | |
octave_shift=octave_shift, | |
channels=channels, | |
all_channels=all_channels, | |
) | |
except Exception: | |
logger.error("Failed to parse melody") | |
raise | |
try: | |
meta = MelodyImageMetaData.from_melody( | |
melody=melody, | |
note_range=note_range, | |
size=(max_image_width, image_height), | |
margin=(margin_top, margin_left, margin_bottom), | |
beat_width=beat_width, | |
) | |
except Exception: | |
logger.error("Failed to calculate image metadata") | |
raise | |
try: | |
image = Image.new( | |
"RGBA", | |
meta.size, | |
RollRenderer.background_color, | |
) | |
draw = RollRenderer(meta, image) | |
draw.notes(melody.note_events) | |
draw.title( | |
[ | |
Path(midi.filename).name if midi.filename else "untitled", | |
melody.beat_label, | |
f"BPM {melody.bpm}", | |
melody.copyright or "anonymous", | |
], | |
) | |
return image | |
except Exception: | |
logger.error("Failed to draw image") | |
raise | |
class PunchRollArgumentParser(ArgumentParser): | |
""" | |
PunchRollArgumentParser is a custom argument parser for converting MIDI files to piano roll punch card images. | |
Use the -h or --help flag for more information on the available arguments. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__( | |
*args, | |
description="Convert MIDI file to piano roll punch card image.", | |
**kwargs, | |
) | |
common_group = self.add_argument_group("Common") | |
common_group.add_argument("-O", "--output", default=None) | |
common_group.add_argument( | |
"-t", | |
"--margin_top", | |
type=float, | |
default=DEFAULT_MARGIN_TOP, | |
help="Top margin of the image. (Default: %(default)s)", | |
) | |
common_group.add_argument( | |
"-b", | |
"--margin_bottom", | |
type=float, | |
default=DEFAULT_MARGIN_BOTTOM, | |
help="Bottom margin of the image. (Default: %(default)s)", | |
) | |
common_group.add_argument( | |
"-l", | |
"--margin_left", | |
type=float, | |
default=DEFAULT_MARGIN_LEFT, | |
help="Left margin of the image. (Default: %(default)s)", | |
) | |
common_group.add_argument( | |
"-s", | |
"--image_height", | |
type=int, | |
default=DEFAULT_IMAGE_HEIGHT, | |
help="Height of the image. (Default: %(default)s)", | |
) | |
exclusive_group = self.add_mutually_exclusive_group(required=True) | |
exclusive_group.add_argument( | |
"input", | |
help="MIDI file to draw. Required if --template is not specified.", | |
nargs="?", | |
) | |
exclusive_group.add_argument( | |
"-T", "--template", action="store_true", help="Draw blank template." | |
) | |
midi_group = self.add_argument_group("Drawing MIDI File") | |
midi_group.add_argument( | |
"-o", | |
"--octave_shift", | |
type=float, | |
default=0, | |
help="Number of octaves to shift the notes. (Default: %(default)s)", | |
) | |
midi_group.add_argument( | |
"-B", | |
"--beat_width", | |
type=float, | |
default=DEFAULT_BEAT_WIDTH, | |
help="Width of each beat in the pixel. (Default: %(default)s)", | |
) | |
midi_group.add_argument( | |
"-C", | |
"--all_channels", | |
action="store_true", | |
default=False, | |
help="Include all MIDI channels regardless of the --channel option.", | |
) | |
midi_group.add_argument( | |
"-c", | |
"--channel", | |
type=int, | |
default=None, | |
action="append", | |
help="MIDI channel to include. (Default: %(default)s)", | |
) | |
midi_group.add_argument( | |
"--max_image_width", | |
type=int, | |
default=DEFAULT_MAX_IMAGE_WIDTH, | |
help="Maximum width of the image. Notes will be truncated if the image exceeds this width. (Default: %(default)s)", # noqa: E501 | |
) | |
template_group = self.add_argument_group("Drawing Blank Template") | |
template_group.add_argument( | |
"-w", | |
"--image_width", | |
type=int, | |
default=DEFAULT_IMAGE_WIDTH, | |
help="Width of the image for the blank template. (Default: %(default)s)", | |
) | |
if __name__ == "__main__": | |
from logging import _STYLES, DEBUG, INFO, WARNING, Formatter, StreamHandler | |
from sys import stderr, stdout | |
logger.setLevel(INFO) | |
formatter = Formatter(_STYLES["%"][1]) | |
stderr_handler = StreamHandler(stderr) | |
stderr_handler.setLevel(WARNING) | |
stderr_handler.setFormatter(formatter) | |
logger.addHandler(stderr_handler) | |
stdout_handler = StreamHandler(stdout) | |
stdout_handler.setLevel(DEBUG) | |
stdout_handler.setFormatter(formatter) | |
logger.addHandler(stdout_handler) | |
parser = PunchRollArgumentParser() | |
try: | |
args = parser.parse_args() | |
if args.template: | |
logger.info("Drawing blank template") | |
image = blank_punch_card( | |
note_range=DEFAULT_NOTE_RANGE, | |
margin_top=args.margin_top, | |
margin_bottom=args.margin_bottom, | |
margin_left=args.margin_left, | |
image_height=args.image_height, | |
image_width=args.image_width, | |
beat_width=args.beat_width, | |
) | |
else: | |
logger.info(f"Reading {args.input}") | |
midi = MidiFile(filename=args.input) | |
logger.info("Drawing midi to punch card") | |
image = midi_to_punch_card( | |
midi, | |
note_range=DEFAULT_NOTE_RANGE, | |
octave_shift=args.octave_shift, | |
margin_top=args.margin_top, | |
margin_bottom=args.margin_bottom, | |
margin_left=args.margin_left, | |
beat_width=args.beat_width, | |
image_height=args.image_height, | |
max_image_width=args.max_image_width, | |
all_channels=args.all_channels, | |
channels=args.channel or [0], | |
) | |
output = ( | |
args.output or "template.png" | |
if args.template | |
else f"{Path(args.input).stem}.png" | |
) | |
image.save(output) | |
logger.info(f"Saved {output}") | |
except Exception as e: | |
logger.error("Unexpected error") | |
logger.exception(e) | |
exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment