Skip to content

Commit

Permalink
new pad argument, sos_eos_tokens arg renamed sos_eos, tokenizers voca…
Browse files Browse the repository at this point in the history
…b creation fixed
  • Loading branch information
Natooz committed Nov 3, 2022
1 parent b9218bf commit f9cb109
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 120 deletions.
1 change: 1 addition & 0 deletions miditok/bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def save_params(self, out_dir: Union[str, Path, PurePath]):
'beat_res': {f'{k1}_{k2}': v for (k1, k2), v in self.beat_res.items()},
'nb_velocities': len(self.velocities),
'additional_tokens': self.additional_tokens,
'_pad': self._pad,
'_sos_eos': self._sos_eos,
'_mask': self._mask,
'encoding': f'{getmro(self.__class__)[1].__name__}_bpe',
Expand Down
1 change: 0 additions & 1 deletion miditok/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
'tempo_range': (40, 250), # (min_tempo, max_tempo)
# time signature params
'time_signature_range': (8, 2)} # (max_beat_res, max_bar_length_in_NOTE)
SPECIAL_TOKENS = {'pad': True, 'mask': False, 'sos_eos': False}

# Defaults values when writing new MIDI files
TIME_DIVISION = 384 # 384 and 480 are convenient as divisible by 4, 8, 12, 16, 24, 32
Expand Down
23 changes: 13 additions & 10 deletions miditok/cp_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@ class CPWord(MIDITokenizer):
The values are the resolution, in samples per beat, of the given range, ex 8
:param nb_velocities: number of velocity bins
:param additional_tokens: specifies additional tokens (chords, time signature, rests, tempo...)
:param sos_eos_tokens: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary
:param pad: will include a PAD token, used when training a model with batch of sequences of
unequal lengths, and usually at index 0 of the vocabulary. (default: True)
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
:param params: can be a path to the parameter (json encoded) file or a dictionary. (default: None)
"""

def __init__(self, pitch_range: range = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES,
nb_velocities: int = NB_VELOCITIES, additional_tokens: Dict[str, bool] = ADDITIONAL_TOKENS,
sos_eos_tokens: bool = False, mask: bool = False, params=None):
pad: bool = True, sos_eos: bool = False, mask: bool = False, params=None):
# Indexes of additional token types within a compound token
add_idx = 5
additional_tokens['TimeSignature'] = False # not compatible
Expand All @@ -65,7 +68,7 @@ def __init__(self, pitch_range: range = PITCH_RANGE, beat_res: Dict[Tuple[int, i
if additional_tokens['Tempo']:
self.tempo_idx = add_idx

super().__init__(pitch_range, beat_res, nb_velocities, additional_tokens, sos_eos_tokens, mask, params)
super().__init__(pitch_range, beat_res, nb_velocities, additional_tokens, pad, sos_eos, mask, params)

def track_to_tokens(self, track: Instrument) -> List[List[int]]:
r"""Converts a track (miditoolkit.Instrument object) into a sequence of tokens
Expand Down Expand Up @@ -197,7 +200,7 @@ def create_cp_token(self, time: int, bar: bool = False, pos: int = None, pitch:
:param desc: an optional argument for debug and used to spot position tokens in track_to_tokens
:return: The compound token as a list of integers
"""
cp_token_template = [Event(type_='Family', time=time, value='Metric', desc=desc),
cp_token_template = [Event(type_='Family', value='Metric', time=time, desc=desc),
self.vocab[1].event_to_token['Position_Ignore'],
self.vocab[2].event_to_token['Pitch_Ignore'],
self.vocab[3].event_to_token['Velocity_Ignore'],
Expand Down Expand Up @@ -299,7 +302,7 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> List[Vocabulary]:
print('\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, '
'_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m')

vocab = [Vocabulary({'PAD_None': 0}, sos_eos=self._sos_eos, mask=self._mask) for _ in range(5)]
vocab = [Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask) for _ in range(5)]

vocab[0].add_event('Family_Metric')
vocab[0].add_event('Family_Note')
Expand All @@ -324,26 +327,26 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> List[Vocabulary]:

# PROGRAM
if self.additional_tokens['Program']:
vocab.append(Vocabulary({'PAD_None': 0}, sos_eos=self._sos_eos, mask=self._mask))
vocab.append(Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask))
vocab[-1].add_event('Program_Ignore')
vocab[-1].add_event(f'Program_{program}' for program in range(-1, 128))

# CHORD
if self.additional_tokens['Chord']:
vocab.append(Vocabulary({'PAD_None': 0}, sos_eos=self._sos_eos, mask=self._mask))
vocab.append(Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask))
vocab[-1].add_event('Chord_Ignore')
vocab[-1].add_event(f'Chord_{i}' for i in range(3, 6)) # non recognized chords (between 3 and 5 notes)
vocab[-1].add_event(f'Chord_{chord_quality}' for chord_quality in CHORD_MAPS)

# REST
if self.additional_tokens['Rest']:
vocab.append(Vocabulary({'PAD_None': 0}, sos_eos=self._sos_eos, mask=self._mask))
vocab.append(Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask))
vocab[-1].add_event('Rest_Ignore')
vocab[-1].add_event(f'Rest_{".".join(map(str, rest))}' for rest in self.rests)

# TEMPO
if self.additional_tokens['Tempo']:
vocab.append(Vocabulary({'PAD_None': 0}, sos_eos=self._sos_eos, mask=self._mask))
vocab.append(Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask))
vocab[-1].add_event('Tempo_Ignore')
vocab[-1].add_event(f'Tempo_{i}' for i in self.tempos)

Expand Down
35 changes: 19 additions & 16 deletions miditok/midi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ class MIDILike(MIDITokenizer):
The values are the resolution, in samples per beat, of the given range, ex 8
:param nb_velocities: number of velocity bins
:param additional_tokens: specifies additional tokens (chords, time signature, rests, tempo...)
:param sos_eos_tokens: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary
:param pad: will include a PAD token, used when training a model with batch of sequences of
unequal lengths, and usually at index 0 of the vocabulary. (default: True)
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
"""

def __init__(self, pitch_range: range = PITCH_RANGE, beat_res: Dict[Tuple[int, int], int] = BEAT_RES,
nb_velocities: int = NB_VELOCITIES, additional_tokens: Dict[str, bool] = ADDITIONAL_TOKENS,
sos_eos_tokens: bool = False, mask: bool = False, params=None):
pad: bool = True, sos_eos: bool = False, mask: bool = False, params=None):
additional_tokens['TimeSignature'] = False # not compatible
super().__init__(pitch_range, beat_res, nb_velocities, additional_tokens, sos_eos_tokens, mask, params)
super().__init__(pitch_range, beat_res, nb_velocities, additional_tokens, pad, sos_eos, mask, params)

def track_to_tokens(self, track: Instrument) -> List[int]:
r"""Converts a track (miditoolkit.Instrument object) into a sequence of tokens
Expand All @@ -63,15 +66,15 @@ def track_to_tokens(self, track: Instrument) -> List[int]:
# Creates the Note On, Note Off and Velocity events
for n, note in enumerate(track.notes):
# Note On
events.append(Event(type_='Note-On', time=note.start, value=note.pitch, desc=note.end))
events.append(Event(type_='Note-On', value=note.pitch, time=note.start, desc=note.end))
# Velocity
events.append(Event(type_='Velocity', time=note.start, value=note.velocity, desc=f'{note.velocity}'))
events.append(Event(type_='Velocity', value=note.velocity, time=note.start, desc=f'{note.velocity}'))
# Note Off
events.append(Event(type_='Note-Off', time=note.end, value=note.pitch, desc=note.end))
events.append(Event(type_='Note-Off', value=note.pitch, time=note.end, desc=note.end))
# Adds tempo events if specified
if self.additional_tokens['Tempo']:
for tempo_change in self.current_midi_metadata['tempo_changes']:
events.append(Event(type_='Tempo', time=tempo_change.time, value=tempo_change.tempo,
events.append(Event(type_='Tempo', value=tempo_change.tempo, time=tempo_change.time,
desc=tempo_change.tempo))

# Sorts events
Expand All @@ -95,30 +98,30 @@ def track_to_tokens(self, track: Instrument) -> List[int]:
rest_tick = previous_tick # untouched tick value to the order is not messed after sorting

if rest_beat > 0:
events.append(Event(type_='Rest', time=rest_tick, value=f'{rest_beat}.0',
events.append(Event(type_='Rest', value=f'{rest_beat}.0', time=rest_tick,
desc=f'{rest_beat}.0'))
previous_tick += rest_beat * self.current_midi_metadata['time_division']

while rest_pos >= self.rests[0][1]:
rest_pos_temp = min([r[1] for r in self.rests], key=lambda x: abs(x - rest_pos))
events.append(Event(type_='Rest', time=rest_tick, value=f'0.{rest_pos_temp}',
events.append(Event(type_='Rest', value=f'0.{rest_pos_temp}', time=rest_tick,
desc=f'0.{rest_pos_temp}'))
previous_tick += round(rest_pos_temp * ticks_per_sample)
rest_pos -= rest_pos_temp

# Adds an additional time shift if needed
# Adds a time-shift if needed
if rest_pos > 0:
time_shift = round(rest_pos * ticks_per_sample)
index = np.argmin(np.abs(dur_bins - time_shift))
events.append(Event(type_='Time-Shift', time=previous_tick,
value='.'.join(map(str, self.durations[index])), desc=f'{time_shift} ticks'))
events.append(Event(type_='Time-Shift', value='.'.join(map(str, self.durations[index])),
time=previous_tick, desc=f'{time_shift} ticks'))

# Time shift
# Time-shift
else:
time_shift = event.time - previous_tick
index = np.argmin(np.abs(dur_bins - time_shift))
events.append(Event(type_='Time-Shift', time=previous_tick,
value='.'.join(map(str, self.durations[index])), desc=f'{time_shift} ticks'))
events.append(Event(type_='Time-Shift', value='.'.join(map(str, self.durations[index])),
time=previous_tick, desc=f'{time_shift} ticks'))

if event.type == 'Note-On':
previous_note_end = max(previous_note_end, event.desc)
Expand Down Expand Up @@ -210,7 +213,7 @@ def _create_vocabulary(self, sos_eos_tokens: bool = None) -> Vocabulary:
if sos_eos_tokens is not None:
print('\033[93msos_eos_tokens argument is depreciated and will be removed in a future update, '
'_create_vocabulary now uses self._sos_eos attribute set a class init \033[0m')
vocab = Vocabulary({'PAD_None': 0}, sos_eos=self._sos_eos, mask=self._mask)
vocab = Vocabulary(pad=self._pad, sos_eos=self._sos_eos, mask=self._mask)

# NOTE ON
vocab.add_event(f'Note-On_{i}' for i in self.pitch_range)
Expand Down
47 changes: 23 additions & 24 deletions miditok/midi_tokenizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@ class MIDITokenizer(ABC):
The values are the resolution, in samples per beat, of the given range, ex 8
:param nb_velocities: number of velocity bins
:param additional_tokens: specifies additional tokens (chords, rests, tempo, time signature...)
:param sos_eos_tokens: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary
:param pad: will include a PAD token, used when training a model with batch of sequences of
unequal lengths, and usually at index 0 of the vocabulary. (default: True)
:param sos_eos: adds Start Of Sequence (SOS) and End Of Sequence (EOS) tokens to the vocabulary.
(default: False)
:param mask: will add a MASK token to the vocabulary (default: False)
:param params: can be a path to the parameter (json encoded) file or a dictionary
"""

def __init__(self, pitch_range: range, beat_res: Dict[Tuple[int, int], int], nb_velocities: int,
additional_tokens: Dict[str, Union[bool, int, Tuple[int, int]]], sos_eos_tokens: bool = False,
mask: bool = False, params: Union[str, Path, PurePath, Dict[str, Any]] = None):
additional_tokens: Dict[str, Union[bool, int, Tuple[int, int]]], pad: bool = True,
sos_eos: bool = False, mask: bool = False, params: Union[str, Path, PurePath, Dict[str, Any]] = None):
# Initialize params
self.vocab = None
if params is None:
self.pitch_range = pitch_range
self.beat_res = beat_res
self.additional_tokens = additional_tokens
self.nb_velocities = nb_velocities
self._sos_eos = sos_eos_tokens
self._pad = pad
self._sos_eos = sos_eos
self._mask = mask
else:
self.load_params(params)
Expand Down Expand Up @@ -77,7 +81,7 @@ def __init__(self, pitch_range: range, beat_res: Dict[Tuple[int, int], int], nb_
self.time_signatures = self.__create_time_signatures()

# Vocabulary and token types graph
if self.vocab is None: # in case it was already loaded by an overriding load_params method, such as with BPE
if self.vocab is None: # in case it was already loaded by an overridden load_params method, such as with BPE
self.vocab = self._create_vocabulary()
self.tokens_types_graph = self._create_token_types_graph()

Expand Down Expand Up @@ -349,22 +353,15 @@ def _create_token_types_graph(self) -> Dict[str, List[str]]:
See other classes (REMI, MIDILike ...) for examples of how to implement it."""
raise NotImplementedError

def _add_pad_type_to_graph(self, dic: Dict[str, List[str]]):
r"""DEPRECIATED: has been replaced by _add_special_tokens_to_types_graph.
This method will call _add_special_tokens_to_types_graph, see below.
:param dic: token types graph to add PAD type
"""
self._add_special_tokens_to_types_graph(dic)

def _add_special_tokens_to_types_graph(self, dic: Dict[str, List[str]]):
r"""Inplace adds special tokens (PAD, EOS, SOS, MASK) types to the token types graph dictionary.
:param dic: token types graph to add PAD type
"""
for value in dic.values():
value.append('PAD')
dic['PAD'] = ['PAD']
if self._pad:
for value in dic.values():
value.append('PAD')
dic['PAD'] = ['PAD']
if self._sos_eos:
dic['SOS'] = list(dic.keys())
dic['EOS'] = []
Expand Down Expand Up @@ -600,19 +597,19 @@ def save_params(self, out_dir: Union[str, Path, PurePath]):
'beat_res': {f'{k1}_{k2}': v for (k1, k2), v in self.beat_res.items()},
'nb_velocities': len(self.velocities),
'additional_tokens': self.additional_tokens,
'_pad': self._pad,
'_sos_eos': self._sos_eos,
'_mask': self._mask,
'encoding': self.__class__.__name__,
'miditok_version': CURRENT_PACKAGE_VERSION}, outfile, indent=4)

def load_params(self, params: Union[str, Path, PurePath, Dict[str, Any]]):
def load_params(self, params: Union[str, Path, PurePath]):
r"""Load parameters and set the encoder attributes
:param params: can be a path to the parameter (json encoded) file or a dictionary
:param params: can be a path to the parameter (json encoded) file
"""
if isinstance(params, (str, Path, PurePath)):
with open(params) as param_file:
params = json.load(param_file)
with open(params) as param_file:
params = json.load(param_file)

if not isinstance(params['pitch_range'], range):
params['pitch_range'] = range(*params['pitch_range'])
Expand All @@ -626,10 +623,12 @@ def load_params(self, params: Union[str, Path, PurePath, Dict[str, Any]]):
value['TimeSignature'] = value.get('TimeSignature', False)
setattr(self, key, value)

# when loading from params of miditok < v1.2.0
if '_sos_eos' not in params:
# when loading from params of miditok of previous versions
if '_pad' not in params: # miditok < v1.3.0
self._pad = False
if '_sos_eos' not in params: # miditok < v1.2.0
self._sos_eos = False
if '_mask' not in params:
if '_mask' not in params: # miditok < v1.2.0
self._mask = False

def __call__(self, midi: MidiFile, *args, **kwargs):
Expand Down
Loading

0 comments on commit f9cb109

Please sign in to comment.