Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix empty track issue #25

Merged
merged 3 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions midigen/midigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def __init__(
:param key_signature: The key signature as a string, e.g., 'C' for C major.
"""
self._midi_file = MidiFile()
self._track = MidiTrack()
self._midi_file.tracks.append(self._track)
self._midi_file.add_track()
self.set_tempo(tempo)
self.set_time_signature(*time_signature)
if key_signature is None:
Expand All @@ -40,7 +39,7 @@ def __str__(self):
:return: A string with the track, tempo, time signature, and key signature of the MidiGen object.
"""
return (
f"Track: {self._track}\nTempo: {self.tempo}\n \
f"Track: {self.track}\nTempo: {self.tempo}\n \
Time Signature: {self.time_signature}\nKey Signature: {self.key_signature}"
)

Expand Down Expand Up @@ -74,9 +73,9 @@ def set_tempo(self, tempo: int) -> None:
raise ValueError("Invalid tempo value: tempo must be a positive integer")

# Remove existing 'set_tempo' messages
self._track = [msg for msg in self._track if msg.type != "set_tempo"]
self.midi_file.tracks[0] = [msg for msg in self.track if msg.type != "set_tempo"]
self.tempo = bpm2tempo(tempo)
self._track.append(MetaMessage("set_tempo", tempo=self.tempo))
self.track.append(MetaMessage("set_tempo", tempo=self.tempo))

def set_time_signature(self, numerator: int, denominator: int) -> None:
"""
Expand All @@ -100,12 +99,12 @@ def set_time_signature(self, numerator: int, denominator: int) -> None:
)

# Remove existing 'time_signature' messages
self._track = [msg for msg in self._track if msg.type != "time_signature"]
self.midi_file.tracks[0] = [msg for msg in self.track if msg.type != "time_signature"]

self.time_signature = MetaMessage(
"time_signature", numerator=numerator, denominator=denominator
)
self._track.append(self.time_signature)
self.track.append(self.time_signature)

def set_key_signature(self, key: Key) -> None:
"""
Expand All @@ -118,7 +117,7 @@ def set_key_signature(self, key: Key) -> None:
ValueError: If key is not a valid key signature string.
"""
self.key_signature = key
self._track.append(MetaMessage("key_signature", key=str(key)))
self.track.append(MetaMessage("key_signature", key=str(key)))

def add_program_change(self, channel: int, program: int) -> None:
"""
Expand All @@ -137,7 +136,7 @@ def add_program_change(self, channel: int, program: int) -> None:
if not isinstance(program, int) or program < 0 or program > 127:
raise ValueError(f"Invalid program value: {program}. Program must be an integer between 0 and 127")

self._track.append(Message("program_change", channel=channel, program=program))
self.track.append(Message("program_change", channel=channel, program=program))

def add_control_change(self, channel: int, control: int, value: int, time: int = 0) -> None:
"""
Expand All @@ -159,7 +158,7 @@ def add_control_change(self, channel: int, control: int, value: int, time: int =
if not isinstance(value, int) or not 0 <= value <= 127:
raise ValueError(f"Invalid value: {value}. Value must be between 0 and 127.")

self._track.append(
self.track.append(
Message(
"control_change",
channel=channel,
Expand All @@ -185,7 +184,7 @@ def add_pitch_bend(self, channel: int, value: int, time: int = 0) -> None:
if not isinstance(time, int) or time < 0:
raise ValueError("Invalid time value: time must be a non-negative integer")

self._track.append(
self.track.append(
Message("pitchwheel", channel=channel, pitch=value, time=time)
)

Expand All @@ -198,8 +197,8 @@ def add_note(self, note: Note) -> None:
Raises:
ValueError: If note, velocity, duration or time is not an integer or outside valid range.
"""
self._track.append(Message("note_on", note=note.pitch, velocity=note.velocity, time=note.time))
self._track.append(
self.track.append(Message("note_on", note=note.pitch, velocity=note.velocity, time=note.time))
self.track.append(
Message("note_off", note=note.pitch, velocity=note.velocity, time=(note.time+note.duration))
)

Expand Down Expand Up @@ -281,7 +280,6 @@ def load_midi_file(self, filename: str) -> MidiFile:
raise FileNotFoundError(f"No such file or directory: '{filename}'")

self._midi_file = MidiFile(filename)
self._track = self._midi_file.tracks[0]
return self._midi_file

@property
Expand All @@ -291,7 +289,7 @@ def track(self) -> MidiTrack:
:return: The track of the MIDI file.
"""
return self._track
return self.midi_file.tracks[0]

@property
def midi_file(self) -> MidiFile:
Expand Down
11 changes: 9 additions & 2 deletions tests/test_midigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,16 @@ def create_note_off_message(self, note, time):
def test_midi_gen_creation(self):
self.assertIsNotNone(self.midi_gen)

def test_tracks(self):
self.assertEqual(len(self.midi_gen.midi_file.tracks), 1)
midigen_track = self.midi_gen.track
midi_file_track = self.midi_gen.midi_file.tracks[0]
self.assertEqual(len(midigen_track), len(midi_file_track))
self.assertEqual(str(midigen_track), str(midi_file_track))

def test_set_tempo(self):
self.midi_gen.set_tempo(90)
tempo_msgs = [msg for msg in self.midi_gen._track if msg.type == "set_tempo"]
tempo_msgs = [msg for msg in self.midi_gen.track if msg.type == "set_tempo"]
self.assertEqual(len(tempo_msgs), 1)
tempo_msg = tempo_msgs[0]
self.assertEqual(tempo_msg.type, "set_tempo")
Expand All @@ -33,7 +40,7 @@ def test_set_tempo(self):
def test_set_time_signature(self):
self.midi_gen.set_time_signature(3, 4)
time_sig_msgs = [
msg for msg in self.midi_gen._track if msg.type == "time_signature"
msg for msg in self.midi_gen.track if msg.type == "time_signature"
]
self.assertEqual(len(time_sig_msgs), 1)
time_sig_msg = time_sig_msgs[0]
Expand Down