Skip to content

Instantly share code, notes, and snippets.

@esnya
Created December 28, 2024 08:32
Show Gist options
  • Save esnya/f36acc27c5a7be96da331acab534eeca to your computer and use it in GitHub Desktop.
Save esnya/f36acc27c5a7be96da331acab534eeca to your computer and use it in GitHub Desktop.
This script converts a MIDI file into a piano roll punch card image.
"""
This script converts a MIDI file into a piano roll punch card image.
Usage:
Run the script with the appropriate arguments to convert a MIDI file to a punch card image
or to create a blank template. Use the -h or --help flag for more information.
"""
from argparse import ArgumentParser
from collections.abc import Collection, Iterable, Sequence
from dataclasses import dataclass
from logging import getLogger
from pathlib import Path
from typing import NamedTuple
from mido import MidiFile, tempo2bpm
from PIL import Image, ImageDraw
# Note range for the piano roll
DEFAULT_NOTE_RANGE = range(60, 93 + 1)
# Default image size
DEFAULT_IMAGE_WIDTH = 4096
DEFAULT_IMAGE_HEIGHT = 256
# Default margins (top, bottom, left)
DEFAULT_MARGIN_TOP = 2.9
DEFAULT_MARGIN_BOTTOM = 2.3
DEFAULT_MARGIN_LEFT = 1.0
# Default beat width
DEFAULT_BEAT_WIDTH = 31.68
# Maximum image width
DEFAULT_MAX_IMAGE_WIDTH = 8192
# Default colors
DEFAULT_BACKGROUND_COLOR = (230, 220, 180, 255)
DEFAULT_LINE_COLOR = (100, 100, 100, 255)
DEFAULT_BAR_LINE_COLOR = (50, 50, 50, 255)
DEFAULT_TEXT_COLOR = (50, 50, 50, 255)
DEFAULT_NOTE_COLOR = (0, 0, 0, 0)
logger = getLogger(__name__)
class Note(NamedTuple):
tick: int
note: int
@dataclass
class Melody:
"""
A class to represent a melody parsed from a MIDI file.
Attributes:
----------
note_events : list[Note]
A collection of note events in the melody.
ticks_per_beat : int
The number of ticks per beat in the MIDI file.
beat_label : str
The time signature of the melody.
bpm : int
The tempo of the melody in beats per minute.
copyright : str | None
The copyright information of the MIDI file, if available.
"""
note_events: list[Note]
ticks_per_beat: int
beat_label: str
bpm: int
copyright: str | None
@classmethod
def parse(
cls,
midi: MidiFile,
note_range: Collection[int],
octave_shift: float = 0,
channels: Collection[int] = [],
all_channels: bool = False,
) -> "Melody":
"""
Parses a MIDI file and extracts melody information.
Args:
midi (MidiFile): The MIDI file to parse.
note_range (Collection[int]): The range of MIDI note numbers to include.
octave_shift (float, optional): The number of octaves to shift the notes. Defaults to 0.
channels (Collection[int], optional): The MIDI channels to include. Defaults to an empty list.
all_channels (bool, optional): Whether to include all MIDI channels. Defaults to False.
Returns:
Melody: An object containing the parsed melody information.
"""
_logger = logger.getChild(cls.__name__)
notes = {}
current_time = 0
copyright = None
bpm = 120
beat_label = "4/4"
for msg in midi.merged_track:
current_time += int(msg.time)
if msg.type == "time_signature":
beat_label = cls._process_time_signature(msg, _logger)
elif msg.type == "set_tempo":
bpm = cls._process_tempo(msg, _logger)
elif msg.type == "copyright":
copyright = cls._process_copyright(msg, _logger)
elif msg.type == "note_on":
cls._process_note_on(
msg,
notes,
current_time,
note_range,
octave_shift,
channels,
all_channels,
_logger,
)
else:
_logger.info(f"Skipping non-note message: {msg}")
return Melody(
note_events=[Note(tick, note) for tick, note in notes.items()],
ticks_per_beat=midi.ticks_per_beat,
beat_label=beat_label,
bpm=bpm,
copyright=copyright,
)
@staticmethod
def _process_time_signature(msg, logger) -> str:
logger.info(f"Time signature: {msg}")
return f"{msg.numerator}/{msg.denominator}"
@staticmethod
def _process_tempo(msg, logger) -> int:
logger.info(f"Tempo: {msg}")
return int(tempo2bpm(msg.tempo))
@staticmethod
def _process_copyright(msg, logger) -> str:
logger.info(f"Copyright: {msg}")
return str(msg.text)
@staticmethod
def _process_note_on(
msg,
notes: dict[int, int],
current_time: int,
note_range: Collection[int],
octave_shift: float,
channels: Collection[int],
all_channels: bool,
logger,
) -> None:
if msg.velocity > 0 and (msg.channel in channels or all_channels):
shifted_note = msg.note + int(octave_shift * 12)
if shifted_note not in note_range:
logger.warning(f"Skipping out-of-range note: {msg}")
elif current_time in notes:
logger.warning(f"Skipping overlapping note: {msg}")
else:
notes[current_time] = shifted_note
@dataclass
class MelodyImageMetaData:
"""
A class to represent metadata for a melody image.
Attributes:
size (tuple[int, int]): The size of the image (width, height).
margin (tuple[float, float, float]): The margins (top, left, bottom).
beat_width (float): The width of a beat in the image.
num_ticks (int): The number of ticks in the melody.
num_beats (int): The number of beats in the melody.
num_notes (int): The number of notes in the melody.
note_size (float): The size of a note in the image.
tick_scale (float): The scale of a tick in the image.
note_range (Sequence[int]): The range of notes in the melody.
Properties:
margin_top (float): The top margin.
margin_left (float): The left margin.
margin_bottom (float): The bottom margin.
image_width (int): The width of the image.
image_height (int): The height of the image.
margin_bottom_pixels (int): The bottom margin in pixels.
margin_top_pixels (int): The top margin in pixels.
margin_left_pixels (int): The left margin in pixels.
note_size_pixels (int): The size of a note in pixels.
"""
size: tuple[int, int] = (0, 0)
margin: tuple[float, float, float] = (0, 0, 0)
beat_width: float = 0
num_ticks: int = 0
num_beats: int = 0
num_notes: int = 0
note_size: float = 0
tick_scale: float = 0
note_range: Sequence[int] = ()
@property
def margin_top(self) -> float:
return self.margin[0]
@property
def margin_left(self) -> float:
return self.margin[1]
@property
def margin_bottom(self) -> float:
return self.margin[2]
@property
def image_width(self) -> int:
return self.size[0]
@property
def image_height(self) -> int:
return self.size[1]
@property
def margin_bottom_pixels(self) -> int:
return int(self.margin_bottom * self.note_size)
@property
def margin_top_pixels(self) -> int:
return int(self.margin_top * self.note_size)
@property
def margin_left_pixels(self) -> int:
return int(self.margin_left * self.beat_width)
@property
def note_size_pixels(self) -> int:
return int(self.note_size)
@staticmethod
def _calculate_meta_data(
note_range: Sequence[int],
size: Collection[int],
margin: Collection[float],
beat_width: float,
ticks_per_beat: int = 0,
num_ticks: int = 0,
bpm: int = 120,
) -> "MelodyImageMetaData":
[max_image_width, image_height] = size
[margin_top, margin_left, margin_bottom] = margin
num_notes = len(note_range)
note_size = image_height / (num_notes + margin_top + margin_bottom)
if ticks_per_beat > 0:
beat_width = beat_width * (image_height / 256) * (120 / bpm)
num_beats = int(num_ticks / ticks_per_beat / 4 + 0.5) * 4
tick_scale = beat_width / ticks_per_beat
else:
beat_width = beat_width * image_height / 256
num_beats = int((max_image_width - (margin_left * note_size)) / beat_width)
tick_scale = 0
image_width = min(
max_image_width,
int(num_beats * beat_width + int(margin_left * beat_width) * 2),
)
logger.info(f"image_size: {image_width}x{image_height}")
return MelodyImageMetaData(
size=(image_width, image_height),
margin=(margin_top, margin_left, margin_bottom),
beat_width=beat_width,
num_ticks=num_beats * ticks_per_beat,
num_beats=num_beats,
num_notes=num_notes,
note_size=note_size,
tick_scale=tick_scale,
note_range=note_range,
)
@staticmethod
def from_melody(
melody: Melody,
note_range: Sequence[int],
size: Collection[int],
margin: Collection[float],
beat_width: float,
) -> "MelodyImageMetaData":
"""
Creates a MelodyImageMetaData object from a given melody.
Args:
melody (Melody): The melody object containing note events and other metadata.
note_range (Sequence[int]): The range of notes to be included in the metadata.
size (Collection[int]): The size dimensions for the metadata.
margin (Collection[float]): The margin values for the metadata.
beat_width (float): The width of each beat in the metadata.
Returns:
MelodyImageMetaData: The metadata object containing calculated information based on the melody.
"""
return MelodyImageMetaData._calculate_meta_data(
note_range=note_range,
size=size,
margin=margin,
beat_width=beat_width,
ticks_per_beat=melody.ticks_per_beat,
num_ticks=int(max(t for t, _ in melody.note_events)),
bpm=melody.bpm,
)
@staticmethod
def from_size(
note_range: Sequence[int],
size: Collection[int],
margin: Collection[float],
beat_width: float,
) -> "MelodyImageMetaData":
"""
Create a MelodyImageMetaData instance based on the given parameters.
Args:
note_range (Sequence[int]): The range of notes to be included.
size (Collection[int]): The size dimensions of the image.
margin (Collection[float]): The margins around the image.
beat_width (float): The width of each beat in the image.
Returns:
MelodyImageMetaData: An instance of MelodyImageMetaData with calculated metadata.
"""
return MelodyImageMetaData._calculate_meta_data(
note_range=note_range,
size=size,
margin=margin,
beat_width=beat_width,
)
class RollRenderer:
"""
A class to render a musical roll image with various elements such as lines, notes, and titles.
Attributes:
background_color (tuple): RGBA color for the background.
line_color (tuple): RGBA color for the lines.
bar_line_color (tuple): RGBA color for the bar lines.
text_color (tuple): RGBA color for the text.
note_color (tuple): RGBA color for the notes.
"""
background_color = DEFAULT_BACKGROUND_COLOR
line_color = DEFAULT_LINE_COLOR
bar_line_color = DEFAULT_BAR_LINE_COLOR
text_color = DEFAULT_TEXT_COLOR
note_color = DEFAULT_NOTE_COLOR
def __init__(self, meta: MelodyImageMetaData, image: Image.Image) -> None:
"""
Initializes the RollRenderer object.
Args:
meta (MelodyImageMetaData): Metadata for the melody image.
image (Image.Image): The image on which to draw.
"""
self.logger = logger.getChild(__class__.__name__)
self.meta = meta
self.draw = ImageDraw.Draw(image)
self._hlines()
self._vlines()
self._note_placeholders()
def _vlines(self) -> None:
for beat in range(self.meta.num_beats + 2):
x = int(beat * self.meta.beat_width) + self.meta.margin_left_pixels
if beat % 4 == 0:
self.draw.line(
[
(x, self.meta.margin_top_pixels),
(
x,
self.meta.image_height - self.meta.margin_bottom_pixels,
),
],
fill=self.bar_line_color,
width=2,
)
else:
self.draw.line(
[
(x, self.meta.margin_top_pixels),
(
x,
self.meta.image_height - self.meta.margin_bottom_pixels,
),
],
fill=self.line_color,
width=1,
)
def _hlines(self) -> None:
for i in range(self.meta.num_notes + 1):
y = int(i * self.meta.note_size) + self.meta.margin_top_pixels
self.draw.line(
[
(self.meta.margin_left_pixels, y),
(self.meta.image_width - self.meta.margin_left_pixels, y),
],
fill=self.line_color,
width=1,
)
def _note_placeholders(
self,
) -> None:
for i in range(self.meta.num_notes + 1):
for j in range(self.meta.num_beats):
x = int((j + 0.5) * self.meta.beat_width) + self.meta.margin_left_pixels
y = int(i * self.meta.note_size) + self.meta.margin_top_pixels - 1
self.draw.line(((x, y), (x, y + 2)), fill=self.line_color, width=1)
def notes(
self,
note_events: Iterable[Note],
) -> None:
"""
Draws notes on the image based on the provided note events.
Args:
note_events (Iterable[Note]): An iterable of note events,
where each event is a tuple containing a tick and a note.
Notes:
- The method calculates the x and y coordinates for each note based on the tick and note values.
- If a note's x coordinate exceeds the image width, a warning is logged and the note is skipped.
- The note is drawn as a circle on the image at the calculated coordinates.
"""
note_index_map = {
note: idx for idx, note in enumerate(self.meta.note_range[::-1])
}
for tick, note in note_events:
x = int(
(
tick * self.meta.tick_scale
+ self.meta.margin_left_pixels
+ self.meta.note_size_pixels / 2
)
)
if x >= self.meta.image_width:
self.logger.warning(f"Skipping note outside image: {note}")
continue
y = int(
(note_index_map[note] + 0.5) * self.meta.note_size_pixels
+ self.meta.margin_top_pixels
)
self.draw.circle(
(x, y),
int(self.meta.note_size_pixels / 2),
fill=self.note_color,
)
def title(
self,
text_list: Iterable[str],
) -> None:
"""
Draws a title on the image using the provided list of text strings.
Args:
text_list (Iterable[str]): A list of strings to be used as the title text.
Each string will be joined with spaces and newlines
will be replaced with spaces.
"""
self.draw.text(
(self.meta.margin_left_pixels, self.meta.margin_top_pixels - 1),
" ".join(text_list).replace("\n", " "),
anchor="lb",
fill=self.text_color,
)
def blank_punch_card(
note_range: Sequence[int],
margin_top: float,
margin_bottom: float,
margin_left: float,
image_height: int,
image_width: int,
beat_width: float,
) -> Image.Image:
"""
Creates a blank punch card image for a given note range and dimensions.
Args:
note_range (Sequence[int]): The range of notes to be represented on the punch card.
margin_top (float): The top margin of the punch card.
margin_bottom (float): The bottom margin of the punch card.
margin_left (float): The left margin of the punch card.
image_height (int): The height of the punch card image.
image_width (int): The width of the punch card image.
beat_width (float): The width of each beat on the punch card.
Returns:
Image.Image: A blank punch card image with the specified dimensions and note range.
"""
meta = MelodyImageMetaData.from_size(
note_range=note_range,
size=(image_width, image_height),
margin=(margin_top, margin_left, margin_bottom),
beat_width=beat_width,
)
image = Image.new(
"RGBA",
meta.size,
RollRenderer.background_color,
)
RollRenderer(meta, image)
return image
def midi_to_punch_card(
midi: MidiFile,
note_range: Sequence[int],
octave_shift: float,
margin_top: float,
margin_bottom: float,
margin_left: float,
max_image_width: int,
image_height: int,
beat_width: float,
channels: list[int],
all_channels: bool,
) -> Image.Image:
"""
Converts a MIDI file into a punch card image representation.
Args:
midi (MidiFile): The MIDI file to be converted.
note_range (Sequence[int]): The range of MIDI notes to be included in the punch card.
octave_shift (float): The number of octaves to shift the notes.
margin_top (float): The top margin of the image.
margin_bottom (float): The bottom margin of the image.
margin_left (float): The left margin of the image.
max_image_width (int): The maximum width of the generated image.
image_height (int): The height of the generated image.
beat_width (float): The width of each beat in the image.
channels (list[int]): The list of MIDI channels to include.
all_channels (bool): Whether to include all MIDI channels.
Returns:
Image.Image: The generated punch card image.
Raises:
Exception: If there is an error in parsing the melody, calculating image metadata, or drawing the image.
"""
try:
melody = Melody.parse(
midi,
note_range,
octave_shift=octave_shift,
channels=channels,
all_channels=all_channels,
)
except Exception:
logger.error("Failed to parse melody")
raise
try:
meta = MelodyImageMetaData.from_melody(
melody=melody,
note_range=note_range,
size=(max_image_width, image_height),
margin=(margin_top, margin_left, margin_bottom),
beat_width=beat_width,
)
except Exception:
logger.error("Failed to calculate image metadata")
raise
try:
image = Image.new(
"RGBA",
meta.size,
RollRenderer.background_color,
)
draw = RollRenderer(meta, image)
draw.notes(melody.note_events)
draw.title(
[
Path(midi.filename).name if midi.filename else "untitled",
melody.beat_label,
f"BPM {melody.bpm}",
melody.copyright or "anonymous",
],
)
return image
except Exception:
logger.error("Failed to draw image")
raise
class PunchRollArgumentParser(ArgumentParser):
"""
PunchRollArgumentParser is a custom argument parser for converting MIDI files to piano roll punch card images.
Use the -h or --help flag for more information on the available arguments.
"""
def __init__(self, *args, **kwargs):
super().__init__(
*args,
description="Convert MIDI file to piano roll punch card image.",
**kwargs,
)
common_group = self.add_argument_group("Common")
common_group.add_argument("-O", "--output", default=None)
common_group.add_argument(
"-t",
"--margin_top",
type=float,
default=DEFAULT_MARGIN_TOP,
help="Top margin of the image. (Default: %(default)s)",
)
common_group.add_argument(
"-b",
"--margin_bottom",
type=float,
default=DEFAULT_MARGIN_BOTTOM,
help="Bottom margin of the image. (Default: %(default)s)",
)
common_group.add_argument(
"-l",
"--margin_left",
type=float,
default=DEFAULT_MARGIN_LEFT,
help="Left margin of the image. (Default: %(default)s)",
)
common_group.add_argument(
"-s",
"--image_height",
type=int,
default=DEFAULT_IMAGE_HEIGHT,
help="Height of the image. (Default: %(default)s)",
)
exclusive_group = self.add_mutually_exclusive_group(required=True)
exclusive_group.add_argument(
"input",
help="MIDI file to draw. Required if --template is not specified.",
nargs="?",
)
exclusive_group.add_argument(
"-T", "--template", action="store_true", help="Draw blank template."
)
midi_group = self.add_argument_group("Drawing MIDI File")
midi_group.add_argument(
"-o",
"--octave_shift",
type=float,
default=0,
help="Number of octaves to shift the notes. (Default: %(default)s)",
)
midi_group.add_argument(
"-B",
"--beat_width",
type=float,
default=DEFAULT_BEAT_WIDTH,
help="Width of each beat in the pixel. (Default: %(default)s)",
)
midi_group.add_argument(
"-C",
"--all_channels",
action="store_true",
default=False,
help="Include all MIDI channels regardless of the --channel option.",
)
midi_group.add_argument(
"-c",
"--channel",
type=int,
default=None,
action="append",
help="MIDI channel to include. (Default: %(default)s)",
)
midi_group.add_argument(
"--max_image_width",
type=int,
default=DEFAULT_MAX_IMAGE_WIDTH,
help="Maximum width of the image. Notes will be truncated if the image exceeds this width. (Default: %(default)s)", # noqa: E501
)
template_group = self.add_argument_group("Drawing Blank Template")
template_group.add_argument(
"-w",
"--image_width",
type=int,
default=DEFAULT_IMAGE_WIDTH,
help="Width of the image for the blank template. (Default: %(default)s)",
)
if __name__ == "__main__":
from logging import _STYLES, DEBUG, INFO, WARNING, Formatter, StreamHandler
from sys import stderr, stdout
logger.setLevel(INFO)
formatter = Formatter(_STYLES["%"][1])
stderr_handler = StreamHandler(stderr)
stderr_handler.setLevel(WARNING)
stderr_handler.setFormatter(formatter)
logger.addHandler(stderr_handler)
stdout_handler = StreamHandler(stdout)
stdout_handler.setLevel(DEBUG)
stdout_handler.setFormatter(formatter)
logger.addHandler(stdout_handler)
parser = PunchRollArgumentParser()
try:
args = parser.parse_args()
if args.template:
logger.info("Drawing blank template")
image = blank_punch_card(
note_range=DEFAULT_NOTE_RANGE,
margin_top=args.margin_top,
margin_bottom=args.margin_bottom,
margin_left=args.margin_left,
image_height=args.image_height,
image_width=args.image_width,
beat_width=args.beat_width,
)
else:
logger.info(f"Reading {args.input}")
midi = MidiFile(filename=args.input)
logger.info("Drawing midi to punch card")
image = midi_to_punch_card(
midi,
note_range=DEFAULT_NOTE_RANGE,
octave_shift=args.octave_shift,
margin_top=args.margin_top,
margin_bottom=args.margin_bottom,
margin_left=args.margin_left,
beat_width=args.beat_width,
image_height=args.image_height,
max_image_width=args.max_image_width,
all_channels=args.all_channels,
channels=args.channel or [0],
)
output = (
args.output or "template.png"
if args.template
else f"{Path(args.input).stem}.png"
)
image.save(output)
logger.info(f"Saved {output}")
except Exception as e:
logger.error("Unexpected error")
logger.exception(e)
exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment