-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain_single.py
130 lines (102 loc) · 3.91 KB
/
train_single.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from music21 import *
import glob
import pickle
import numpy
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.layers import Activation
from keras.utils import np_utils
from keras.callbacks import ModelCheckpoint
def train_network():
notes = get_notes()
# get amount of pitch names
n_vocab = len(set(notes))
network_input, network_output = prepare_sequences(notes, n_vocab)
model = create_network(network_input, n_vocab)
train(model, network_input, network_output)
def get_notes():
notes = [] # contains elements only
rest = True
for file in glob.glob("midi/*.mid"):
# file = "midi/Wtcii01a.mid"
midi = converter.parse(file)
print("Parsing %s" % file)
notes_to_parse = None
try: # file has instrument parts
inst = instrument.partitionByInstrument(midi)
print("Number of instrument parts: " + str(len(inst.parts)))
notes_to_parse = inst.parts[0].recurse()
except: # file has notes in a flat structure
notes_to_parse = midi.flat.notes
for element in notes_to_parse:
if isinstance(element, note.Note):
notes.append(str(element.pitch))
elif isinstance(element, chord.Chord):
notes.append('.'.join(str(n) for n in element.normalOrder))
elif isinstance(element, note.Rest) and rest:
notes.append("rest")
with open('data/notes', 'wb') as filepath:
pickle.dump(notes, filepath)
# for note in notes:
# print(note)
return notes
def prepare_sequences(notes, n_vocab):
pitchnames = sorted(set(item for item in notes))
note_to_int = dict((notes, number) for number, notes in enumerate(pitchnames))
print("Dictionary size: %f" % len(note_to_int))
sequence_length = 100
network_input = []
network_output = []
# create input sequences and the corresponding outputs
print("Create input sequences and the corresponding outputs")
for i in range(0, len(notes) - sequence_length, 1):
sequence_in = notes[i:i + sequence_length]
sequence_out = notes[i + sequence_length]
network_input.append([note_to_int[char] for char in sequence_in])
network_output.append(note_to_int[sequence_out])
# print("outside of for loop", i)
n_patterns = len(network_input)
# reshape the input into a format compatible with LSTM layers
print("Reshape the input into a format compatible with LSTM layers")
network_input = numpy.reshape(network_input, (n_patterns, sequence_length, 1))
# normalize input
print("Normalize input")
network_input = network_input / float(n_vocab)
network_output = np_utils.to_categorical(network_output)
return (network_input, network_output)
def create_network(network_input, n_vocab):
# Creating model
print("Creating model")
model = Sequential()
model.add(LSTM(
512,
input_shape=(network_input.shape[1], network_input.shape[2]),
return_sequences=True
))
model.add(Dropout(0.3))
model.add(LSTM(512, return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(512))
model.add(Dense(256))
model.add(Dropout(0.3))
model.add(Dense(n_vocab))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
return model
def train(model, network_input, network_output):
# Training model
print("Training model")
filepath = "weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5"
checkpoint = ModelCheckpoint(
filepath,
monitor='loss',
verbose=0,
save_best_only=True,
mode='min'
)
callbacks_list = [checkpoint]
model.fit(network_input, network_output, epochs=200, batch_size=64, callbacks=callbacks_list)
if __name__ == '__main__':
train_network()