-
Notifications
You must be signed in to change notification settings - Fork 3
/
polish_nets.py
233 lines (193 loc) · 10.3 KB
/
polish_nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# ------------------------------------------------
# TensorFlow import
# ------------------------------------------------
# TensorFlow setting: Which GPU to use and not to consume the whole GPU:
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Which GPU to use.
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Filters TensorFlow warnings.
#os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # Prevents TensorFlow from consuming the whole GPU.
# Import TensorFlow:
import tensorflow as tf
#physical_devices = tf.config.list_physical_devices('GPU')
#tf.config.experimental.set_memory_growth(physical_devices[0], True)
# ------------------------------------------------
# Other imports
# ------------------------------------------------
import logging
import pandas as pd
import numpy as np
from dbitnet import make_model as make_dbitnet
from gohrnet import make_model as make_gohrnet
# ------------------------------------------------
# Configuration and constants
# ------------------------------------------------
ABORT_TRAINING_BELOW_ACC = 0.5050 # if the validation accuracy reaches or falls below this limit, abort further training.
EPOCHS = 1 # train for 10 epochs
NUM_SAMPLES = 10**7 # create 10 million training samples
NUM_VAL_SAMPLES = 10**7 # create 1 million validation samples
BATCHSIZE = 10_000 # training batch size
def train_one_round(model,
X, Y, X_val, Y_val,
round_number: int,
epochs=40,
model_name = 'model',
load_weight_file=False,
log_prefix = '',
weight_file=None):
"""Train the `model` on the training data (X,Y) for one round.
:param model: TensorFlow neural network
:param X, Y: training data
:param X_val, Y_val: validation data
:param epochs: number of epochs to train
:param load_weight_file: Boolean (if True: load weights from previous round.)
:return: best validation accuracy
"""
#------------------------------------------------
# Handle model weight checkpoints
#------------------------------------------------
from tensorflow.keras.callbacks import ModelCheckpoint
# load weight checkpoint from previous round?
if load_weight_file:
if weight_file is None:
print("loading weights from previous round...")
model.load_weights(f'{log_prefix}_{model_name}_round{round_number-1}.h5')
else:
print(f"loading weights from file {weight_file}...")
model.load_weights(weight_file)
# create model checkpoint callback for this round
checkpoint = ModelCheckpoint(f'{log_prefix}_{model_name}_round{round_number}.h5', monitor='val_loss', save_best_only = True)
#------------------------------------------------
# Train the model
#------------------------------------------------
history = model.fit(X, Y, epochs=epochs, batch_size=BATCHSIZE,
validation_data=(X_val, Y_val), callbacks=[checkpoint],
verbose=True)
print("Best validation accuracy: ", np.max(history.history['val_acc']))
# save the training history
pd.to_pickle(history.history, f'{log_prefix}_{model_name}_training_history_round{round_number}.pkl')
return np.max(history.history['val_acc'])
def polish_neural_distinguisher(starting_round,
data_generator,
model_name,
input_size,
word_size,
model_weights,
log_prefix = './'):
"""Staged training of model_name starting in `starting_round` for a cipher with data generated by `data_generator`.
:param starting_round: Integer in which round to start the neural network training.
:param data_generator: Data_generator(number_of_samples, current_round) returns X, Y.
:return: best_round, best_val_acc
"""
#------------------------------------------------
# Set parameters
#------------------------------------------------
current_round = starting_round
load_weight_file = True
best_val_acc = None
best_round = None
# Train 100 times:
# create NUM_SAMPLES = 10**7 one-hundred times to train overall on 10**9 samples
# ... (attempts to create 10**9 samples at once will take a very long time, and cause
# ... memory issues)
nTimes = 100
# ------------------------------------------------
# Create validation data
# ------------------------------------------------
# create validation data
print(f"CREATE CIPHER DATA for round {current_round} (validation samples={NUM_VAL_SAMPLES:.0e})...")
X_val, Y_val = data_generator(NUM_VAL_SAMPLES, current_round)
# ------------------------------------------------
# Polish in three steps with constant, and drecreasing learning rates
# ------------------------------------------------
for polishStep, LR in enumerate([1e-4, 1e-5, 1e-6]):
tf.keras.backend.clear_session()
#------------------------------------------------
# Create the neural network model
#------------------------------------------------
print(f'CREATE NEURAL NETWORK MODEL {model_name}')
if model_name == 'dbitnet':
model = make_dbitnet(2*input_size)
elif model_name == 'gohr':
model = make_gohrnet(2*input_size, word_size=word_size)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, amsgrad=False) # update of LR below
model.compile(optimizer=optimizer, loss='mse', metrics=['acc'])
local_best_val_acc = 0.50 # save the best model for every polishing run
optimizer.learning_rate.assign(LR) # update of LR
print(f"\n\n Set new learning rate to {optimizer.learning_rate} (LR={LR})")
# load the following weight files:
# polishing step 0: manual input weight file
# polishing step 1 and 2: best weight file from previous polishing step
load_weight_file = True
if polishStep == 0:
model_weights = model_weights
else:
model_weights = f'{log_prefix}_{model_name}_round{current_round}_best_polish{polishStep-1}.h5'
print(f"\t Load weight file = {model_weights}")
#------------------------------------------------
# Train / polish
#------------------------------------------------
# polish the model
for nCounter in range(nTimes):
print(f"Learning rate {LR} polishing {nCounter}/{nTimes}")
print(f"\t CREATE CIPHER DATA for round {current_round} (training samples={NUM_SAMPLES:.0e}...")
X, Y = data_generator(NUM_SAMPLES, current_round)
# train model for the current round
print(f"\t POLISH neural network for round {current_round}...")
val_acc = train_one_round(model,
X, Y, X_val, Y_val,
current_round,
epochs = EPOCHS,
load_weight_file = load_weight_file,
log_prefix = log_prefix,
model_name = model_name,
weight_file = model_weights)
if val_acc > local_best_val_acc:
local_best_val_acc = val_acc
model.save_weights(f'{log_prefix}_{model_name}_round{current_round}_best_polish{polishStep}.h5')
load_weight_file = False
model_weights = None
print('current best val_acc = ', local_best_val_acc)
#------------------------------------------------
# Free the memory
#------------------------------------------------
del X
del Y
best_round = current_round
best_val_acc = local_best_val_acc
return best_round, best_val_acc
if __name__=='__main__':
# TODO: Implement command line interface for the polishing step
print("NOTE: Unfortunately, this script is not yet executable using command line arguments. Please adjust the main-routine in polish_nets.py as necessary.")
model_weights = 'YOUR/MODEL.h5'
cipher_name = 'speck3264'
starting_round = 8
manualInputDiffInt = np.uint16([0x40, 0x0000]).reshape(-1, 1)
logname = '0x400000'
import main
import importlib
from ciphers.speck3264 import convert_to_binary
cipher = importlib.import_module('ciphers.' + cipher_name, package='ciphers')
scenario = 'single-key'
s = cipher.__name__[8:] + "_" + scenario
output_dir = 'results'
plain_bits = cipher.plain_bits
key_bits = cipher.key_bits
word_size = cipher.word_size
encryption_function = cipher.encrypt
delta_key = 0
delta = convert_to_binary(manualInputDiffInt)
delta_plain = delta[:plain_bits]
data_generator = lambda num_samples, nr: main.make_train_data(encryption_function,
plain_bits,
key_bits,
num_samples,
nr,
delta_plain,
delta_key)
best_round_dbitnet, best_val_acc_dbitnet = polish_neural_distinguisher(starting_round = starting_round,
data_generator = data_generator,
model_name = 'dbitnet',
input_size = plain_bits,
word_size = word_size,
log_prefix = f'{output_dir}/{s}_{logname}',
model_weights = model_weights)