forked from DeepRec-AI/DeepRec
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
1032 lines (939 loc) · 44.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import time
import argparse
import numbers
import tensorflow as tf
import os
import sys
import math
import collections
from tensorflow.python.client import timeline
import json
from tensorflow.python.ops import partitioned_variables
# Set to INFO for tracking training, default is WARN. ERROR for least messages
tf.logging.set_verbosity(tf.logging.INFO)
print("Using TensorFlow version %s" % (tf.__version__))
'''
INPUT CONFIG SPECIFICS
'''
TRAIN_DATA_NAME="taobao_train_data"
TEST_DATA_NAME="taobao_test_data"
LABEL_COLUMNS = ["clk", "buy"]
HASH_INPUTS = [
"pid",
"adgroup_id",
"cate_id",
"campaign_id",
"customer",
"brand",
"user_id",
"cms_segid",
"cms_group_id",
"final_gender_code",
"age_level",
"pvalue_level",
"shopping_level",
"occupation",
"new_user_class_level",
"tag_category_list",
"tag_brand_list",
"price"
]
ALL_FEATURE_COLUMNS = HASH_INPUTS
ALL_INPUT = LABEL_COLUMNS + HASH_INPUTS
NOT_USED_CATEGORY = ["final_gender_code"]
HASH_BUCKET_SIZES = {
'pid': 10,
'adgroup_id': 100000,
'cate_id': 10000,
'campaign_id': 100000,
'customer': 100000,
'brand': 100000,
'user_id': 100000,
'cms_segid': 100,
'cms_group_id': 100,
'final_gender_code': 10,
'age_level': 10,
'pvalue_level': 10,
'shopping_level': 10,
'occupation': 10,
'new_user_class_level': 10,
'tag_category_list': 100000,
'tag_brand_list': 100000,
'price': 50
}
'''
END OF INPUT CONFIG SPECIFICS
'''
'''
MODEL CONFIG SPECIFICS
'''
DNN_ACTIVATION = tf.nn.relu
L2_REGULARIZATION = 1e-06
EMBEDDING_REGULARIZATION = 5e-05
EXPERTS_COUNT = 4
EXPERT_HIDDEN_UNITS = [256, 128, 64]
EMBEDDING_DIM = 16
#Tower tuple structure (tower name, label name, hidden units)
TOWERS = [
("ctr", "clk", [256, 128, 64]),
("cvr", "buy", [256, 128, 64])
]
'''
MODEL CONFIG SPECIFICS
'''
def l2_regularizer(scale, scope=None):
if isinstance(scale, numbers.Integral):
raise ValueError(f'Scale cannot be an integer: {scale}')
if isinstance(scale, numbers.Real):
if scale < 0.:
raise ValueError(f'Setting a scale less than 0 on a regularizer: {scale}.')
if scale == 0.:
return lambda _: None
def l2(weights):
with tf.name_scope(scope, 'l2_regularizer', [weights]) as name:
my_scale = tf.convert_to_tensor(scale, dtype=weights.dtype.base_dtype, name='scale')
return tf.math.multiply(my_scale, tf.nn.l2_loss(weights), name=name)
return l2
class PLE():
def __init__(self,
input,
feature_column,
num_experts,
expert_hidden_units,
towers,
dnn_activation=DNN_ACTIVATION,
num_layers=3,
expert_dnn_hidden_units=(256, 128, 64),
gate_dnn_hidden_units=(256, 128, 64),
shared_expert_num=1,
specific_expert_num=2,
optimizer_type='adam',
learning_rate=0.001,
use_bn=True,
bf16=False,
stock_tf=None,
adaptive_emb=False,
input_layer_partitioner=None,
dense_layer_partitioner=None):
if not input:
raise ValueError("Dataset is not defined.")
self._feature = input[0]
self._label = input[1]
self._feature_column = feature_column
self._num_experts = num_experts
self._expert_hidden_units = expert_hidden_units
self._towers = towers
self._num_tasks = len(towers)
self._num_layers = num_layers
self._dnn_activation = dnn_activation
self._shared_expert_num = shared_expert_num
self._specific_expert_num = specific_expert_num
self._expert_dnn_hidden_units = expert_dnn_hidden_units
self._gate_dnn_hidden_units = gate_dnn_hidden_units
self._learning_rate = learning_rate
self.tf = stock_tf
self._bf16 = False if self.tf else bf16
self.use_bn = use_bn
self.is_training = True
self._adaptive_emb = adaptive_emb
self._optimizer_type = optimizer_type
self._input_layer_partitioner = input_layer_partitioner
self._dense_layer_partitioner = dense_layer_partitioner
self.model = self._create_model()
with tf.name_scope('head'):
self._create_loss()
self._create_optimizer()
self._create_metrics()
# used to add summary in tensorboard
def _add_layer_summary(self, value, tag):
tf.summary.scalar('%s/fraction_of_zero_values' % tag,
tf.nn.zero_fraction(value))
tf.summary.histogram('%s/activation' % tag, value)
def _make_scope(self, name, bf16, part):
if(bf16):
return tf.variable_scope(name, partitioner=part, reuse=tf.AUTO_REUSE).keep_weights(dtype=tf.float32)
else:
return tf.variable_scope(name, partitioner=part, reuse=tf.AUTO_REUSE)
def _dnn(self, dnn_input, dnn_hidden_units=None, layer_name=''):
for layer_id, num_hidden_units in enumerate(dnn_hidden_units):
with tf.variable_scope(layer_name + '_%d' % layer_id,
partitioner=self._dense_layer_partitioner,
reuse=tf.AUTO_REUSE) as dnn_layer_scope:
dnn_input = tf.layers.dense(
dnn_input,
units=num_hidden_units,
activation=self._dnn_activation,
name=dnn_layer_scope)
if self.use_bn:
dnn_input = tf.layers.batch_normalization(
dnn_input, training=self.is_training, trainable=True)
self._add_layer_summary(dnn_input, dnn_layer_scope.name)
return dnn_input
# single Extraction Layer
def cgc_model(self, inputs, level_name, is_last=False):
specific_expert_outputs = []
# build task-specific expert layer
for i in range(self._num_tasks):
for j in range(self._specific_expert_num):
expert_network = self._dnn(inputs[i], dnn_hidden_units=self._expert_dnn_hidden_units, layer_name=level_name + 'task_' + self._towers[i][0] + '_expert_specific_' + str(j))
specific_expert_outputs.append(expert_network)
# build task-shared expert layer
shared_expert_outputs = []
for k in range(self._shared_expert_num):
expert_network = self._dnn(inputs[-1], dnn_hidden_units=self._expert_dnn_hidden_units, layer_name=level_name + 'expert_shared_' + str(k))
shared_expert_outputs.append(expert_network)
# task_specific gate (count = num_tasks)
cgc_outs = []
for i in range(self._num_tasks):
# concat task-specific expert and task-shared expert
cur_expert_num = self._specific_expert_num + self._shared_expert_num
# task_specific + task_shared
cur_experts = specific_expert_outputs[
i * self._specific_expert_num:(i + 1) * self._specific_expert_num] + shared_expert_outputs
expert_concat = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=1))(cur_experts)
# build gate layers
gate_input = self._dnn(inputs[i], dnn_hidden_units=self._gate_dnn_hidden_units, layer_name=level_name + 'gate_specific_' + self._towers[i][0])
gate_out = tf.layers.dense(gate_input, units=cur_expert_num,
name=level_name + 'gate_softmax_specific_' + self._towers[i][0])
gate_out = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(gate_out)
gate_mul_expert = tf.keras.layers.Lambda(lambda x: tf.math.reduce_sum(x[0] * x[1], axis=1, keep_dims=False),
name=level_name + 'gate_mul_expert_specific_' + self._towers[i][0])(
[expert_concat, gate_out])
cgc_outs.append(gate_mul_expert)
# if not last, add a shared gate
if not is_last:
cur_expert_num = self._num_tasks * self._specific_expert_num + self._shared_expert_num
cur_experts = specific_expert_outputs + shared_expert_outputs # all the expert include task-specific expert and task-shared expert
expert_concat = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=1))(cur_experts)
# gate layers
gate_input = self._dnn(inputs[-1], dnn_hidden_units=self._gate_dnn_hidden_units, layer_name=level_name + 'gate_shared')
gate_out = tf.layers.dense(gate_input, units=cur_expert_num, use_bias=False, activation='softmax',
name=level_name + 'gate_softmax_shared')
gate_out = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(gate_out)
gate_mul_expert = tf.keras.layers.Lambda(lambda x: tf.math.reduce_sum(x[0] * x[1], axis=1, keep_dims=False),
name=level_name + 'gate_mul_expert_shared')(
[expert_concat, gate_out])
cgc_outs.append(gate_mul_expert)
return cgc_outs
# create model
def _create_model(self):
TAG_COLUMN = ['tag_category_list', 'tag_brand_list']
for key in TAG_COLUMN:
self._feature[key] = tf.strings.split(self._feature[key], '|')
key_dict = {}
with self._make_scope('input_layer', self._bf16, self._input_layer_partitioner):
print('Adaptive emb = ', self._adaptive_emb, 'TF = ', self.tf)
if self._adaptive_emb and not self.tf:
'''Adaptive Embedding Feature part 1 of 2'''
print('Adaptive Embedding Feature part 1 of 2')
adaptive_mask_tensors = {}
for col in HASH_INPUTS:
adaptive_mask_tensors[col] = tf.ones([args.batch_size],
tf.int32)
input_emb = tf.feature_column.input_layer(
self._feature,
self._feature_column,
adaptive_mask_tensors=adaptive_mask_tensors,
cols_to_output_tensors=key_dict)
else:
input_emb = tf.feature_column.input_layer(
self._feature,
self._feature_column,
cols_to_output_tensors=key_dict)
with self._make_scope('PLE', self._bf16, self._dense_layer_partitioner):
if self._bf16:
input_emb = tf.cast(input_emb, dtype=tf.bfloat16)
ple_inputs = [input_emb] * (self._num_tasks + 1)
ple_outputs = []
for i in range(self._num_layers):
with tf.variable_scope(f'extraction_network_{i}'):
if i == self._num_layers - 1: # the last extraction net
ple_outputs = self.cgc_model(inputs=ple_inputs, level_name='level_'+str(i)+'_', is_last=True)
else:
ple_outputs = self.cgc_model(inputs=ple_inputs, level_name='level_'+str(i)+'_', is_last=False)
ple_inputs = ple_outputs
towers=[]
for i, tower in enumerate(self._towers):
tower_name = tower[0]
hidden_units = tower[2]
with tf.variable_scope(tower_name, reuse=tf.AUTO_REUSE):
tower_output = self._dnn(ple_outputs[i], dnn_hidden_units=hidden_units, layer_name='tower_'+tower_name)
final_tower_predict = tf.layers.dense(inputs=tower_output,
units=1,
activation=None,
name=f'{tower_name}_output')
self._add_layer_summary(final_tower_predict, f'{tower_name}_output')
if self._bf16:
final_tower_predict = tf.cast(final_tower_predict, dtype=tf.float32)
towers.append(final_tower_predict)
tower_stack = tf.stack(towers, axis=1)
self._logits = tf.squeeze(tower_stack, [2])
self.probability = tf.math.sigmoid(self._logits)
self.output = tf.round(self.probability)
# compute loss
def _create_loss(self):
self._logits = tf.squeeze(self._logits)
self.loss = tf.losses.sigmoid_cross_entropy(
self._label,
self._logits,
scope='loss',
reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)
print(self.loss)
tf.summary.scalar('loss', self.loss)
# define optimizer and generate train_op
def _create_optimizer(self):
self.global_step = tf.train.get_or_create_global_step()
print('self.tf = ', self.tf, ' self._optimizer_type = ', self._optimizer_type)
if self.tf or self._optimizer_type == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate=self._learning_rate,
beta1=0.9,
beta2=0.999,
epsilon=1e-8)
elif self._optimizer_type == 'adagrad':
optimizer = tf.train.AdagradOptimizer(
learning_rate=self._learning_rate,
initial_accumulator_value=0.1,
use_locking=False)
elif self._optimizer_type == 'adamasync':
optimizer = tf.train.AdamAsyncOptimizer(
learning_rate=self._learning_rate,
beta1=0.9,
beta2=0.999,
epsilon=1e-8)
elif self._optimizer_type == 'adagraddecay':
optimizer = tf.train.AdagradDecayOptimizer(
learning_rate=self._learning_rate,
global_step=self.global_step)
else:
raise ValueError("Optimizer type error.")
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train_op = optimizer.minimize(self.loss,
global_step=self.global_step)
# compute acc & auc
def _create_metrics(self):
self.auc1, self.auc_op1 = tf.metrics.auc(labels=self._label[:,0],
predictions=self.probability[:,0],
num_thresholds=1000)
self.auc2, self.auc_op2 = tf.metrics.auc(labels=self._label[:,1],
predictions=self.probability[:, 1],
num_thresholds=1000)
self.acc1, self.acc_op1 = tf.metrics.accuracy(labels=self._label[:,0],
predictions=self.output[:,0])
self.acc2, self.acc_op2 = tf.metrics.accuracy(labels=self._label[:,1],
predictions=self.output[:,1])
tf.summary.scalar('eval_auc1', self.auc1)
tf.summary.scalar('eval_auc2', self.auc2)
tf.summary.scalar('eval_acc1', self.acc1)
tf.summary.scalar('eval_acc2', self.acc2)
# generate dataset pipline
def build_model_input(filename, batch_size, num_epochs):
def parse_csv(value):
tf.logging.info('Parsing {}'.format(filename))
HASH_defaults = [[" "] for i in range(0, len(HASH_INPUTS))]
label_defaults = [[0] for i in range (0, len(LABEL_COLUMNS))]
column_headers = LABEL_COLUMNS + HASH_INPUTS
record_defaults = label_defaults + HASH_defaults
columns = tf.io.decode_csv(value, record_defaults=record_defaults)
all_columns = collections.OrderedDict(zip(column_headers, columns))
labels = []
for i in range(0, len(LABEL_COLUMNS)):
labels.append(all_columns.pop(LABEL_COLUMNS[i]))
label = tf.stack(labels, axis=1)
features = all_columns
return features, label
def parse_parquet(value):
tf.logging.info('Parsing {}'.format(filename))
labels = []
for i in range(0, len(LABEL_COLUMNS)):
labels.append(value.pop(LABEL_COLUMNS[i]))
label = tf.stack(labels, axis=1)
features = value
return features, label
'''Work Queue Feature'''
if args.workqueue and not args.tf:
from tensorflow.python.ops.work_queue import WorkQueue
work_queue = WorkQueue([filename], num_epochs=num_epochs)
files = work_queue.input_dataset()
else:
files = filename
# Extract lines from input files using the Dataset API.
if args.parquet_dataset and not args.tf:
from tensorflow.python.data.experimental.ops import parquet_dataset_ops
dataset = parquet_dataset_ops.ParquetDataset(files, batch_size=batch_size)
if args.parquet_dataset_shuffle:
dataset = dataset.shuffle(buffer_size=40000,
seed=args.seed) # fix seed for reproducing
if not args.workqueue:
dataset = dataset.repeat(num_epochs)
dataset = dataset.map(parse_parquet, num_parallel_calls=28)
else:
dataset = tf.data.TextLineDataset(files)
dataset = dataset.shuffle(buffer_size=400000,
seed=args.seed) # set seed for reproducing
if not args.workqueue:
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
dataset = dataset.map(parse_csv, num_parallel_calls=28)
dataset = dataset.prefetch(2)
return dataset
# generate feature columns
def build_feature_cols():
feature_cols = []
if args.group_embedding and not args.tf:
with tf.feature_column.group_embedding_column_scope(name="categorical"):
for column_name in ALL_FEATURE_COLUMNS:
if column_name in NOT_USED_CATEGORY:
continue
if column_name in HASH_INPUTS:
print('Column name = ', column_name, ' hash bucket size = ', HASH_BUCKET_SIZES[column_name])
categorical_column = tf.feature_column.categorical_column_with_hash_bucket(
column_name,
hash_bucket_size=HASH_BUCKET_SIZES[column_name],
dtype=tf.string)
if not args.tf:
'''Feature Elimination of EmbeddingVariable Feature'''
if args.ev_elimination == 'gstep':
# Feature elimination based on global steps
evict_opt = tf.GlobalStepEvict(steps_to_live=4000)
elif args.ev_elimination == 'l2':
# Feature elimination based on l2 weight
evict_opt = tf.L2WeightEvict(l2_weigt_threshold=1.0)
else:
evict_opt = None
'''Feature Filter of EmbeddingVariable Feature'''
if args.ev_filter == 'cbf':
# CBF-based feature filter
filter_option = tf.CBFFilter(
filter_freq=3,
max_element_size=2**30,
false_positive_probability=0.01,
counter_type=tf.int64)
elif args.ev_filter == 'counter':
# Counter-based feature filter
filter_option = tf.CounterFilter(filter_freq=3)
else:
filter_option = None
ev_opt = tf.EmbeddingVariableOption(
evict_option=evict_opt, filter_option=filter_option)
if args.ev:
'''Embedding Variable Feature'''
categorical_column = tf.feature_column.categorical_column_with_embedding(
column_name, dtype=tf.string, ev_option=ev_opt)
elif args.adaptive_emb:
''' Adaptive Embedding Feature Part 2 of 2
Expcet the follow code, a dict, 'adaptive_mask_tensors', is need as the input of
'tf.feature_column.input_layer(adaptive_mask_tensors=adaptive_mask_tensors)'.
For column 'COL_NAME',the value of adaptive_mask_tensors['$COL_NAME'] is a int32
tensor with shape [batch_size].
'''
categorical_column = tf.feature_column.categorical_column_with_adaptive_embedding(
column_name,
hash_bucket_size=HASH_BUCKET_SIZES[column_name],
dtype=tf.string,
ev_option=ev_opt)
elif args.dynamic_ev:
'''Dynamic-dimension Embedding Variable'''
print("Dynamin-dimension Embedding Variable isn't really enabled in model.")
sys.exit()
if args.tf or not args.emb_fusion:
embedding_column = tf.feature_column.embedding_column(
categorical_column,
dimension=EMBEDDING_DIM,
combiner='mean')
else:
'''Embedding Fusion Feature'''
embedding_column = tf.feature_column.embedding_column(
categorical_column,
dimension=EMBEDDING_DIM,
combiner='mean',
do_fusion=args.emb_fusion)
feature_cols.append(embedding_column)
else:
raise ValueError('Unexpected column name occured')
else:
for column_name in ALL_FEATURE_COLUMNS:
if column_name in NOT_USED_CATEGORY:
continue
if column_name in HASH_INPUTS:
print('Column name = ', column_name, ' hash bucket size = ', HASH_BUCKET_SIZES[column_name])
categorical_column = tf.feature_column.categorical_column_with_hash_bucket(
column_name,
hash_bucket_size=HASH_BUCKET_SIZES[column_name],
dtype=tf.string)
if not args.tf:
'''Feature Elimination of EmbeddingVariable Feature'''
if args.ev_elimination == 'gstep':
# Feature elimination based on global steps
evict_opt = tf.GlobalStepEvict(steps_to_live=4000)
elif args.ev_elimination == 'l2':
# Feature elimination based on l2 weight
evict_opt = tf.L2WeightEvict(l2_weigt_threshold=1.0)
else:
evict_opt = None
'''Feature Filter of EmbeddingVariable Feature'''
if args.ev_filter == 'cbf':
# CBF-based feature filter
filter_option = tf.CBFFilter(
filter_freq=3,
max_element_size=2**30,
false_positive_probability=0.01,
counter_type=tf.int64)
elif args.ev_filter == 'counter':
# Counter-based feature filter
filter_option = tf.CounterFilter(filter_freq=3)
else:
filter_option = None
ev_opt = tf.EmbeddingVariableOption(
evict_option=evict_opt, filter_option=filter_option)
if args.ev:
'''Embedding Variable Feature'''
categorical_column = tf.feature_column.categorical_column_with_embedding(
column_name, dtype=tf.string, ev_option=ev_opt)
elif args.adaptive_emb:
''' Adaptive Embedding Feature Part 2 of 2
Expcet the follow code, a dict, 'adaptive_mask_tensors', is need as the input of
'tf.feature_column.input_layer(adaptive_mask_tensors=adaptive_mask_tensors)'.
For column 'COL_NAME',the value of adaptive_mask_tensors['$COL_NAME'] is a int32
tensor with shape [batch_size].
'''
categorical_column = tf.feature_column.categorical_column_with_adaptive_embedding(
column_name,
hash_bucket_size=HASH_BUCKET_SIZES[column_name],
dtype=tf.string,
ev_option=ev_opt)
elif args.dynamic_ev:
'''Dynamic-dimension Embedding Variable'''
print("Dynamin-dimension Embedding Variable isn't really enabled in model.")
sys.exit()
if args.tf or not args.emb_fusion:
embedding_column = tf.feature_column.embedding_column(
categorical_column,
dimension=EMBEDDING_DIM,
combiner='mean')
else:
'''Embedding Fusion Feature'''
embedding_column = tf.feature_column.embedding_column(
categorical_column,
dimension=EMBEDDING_DIM,
combiner='mean',
do_fusion=args.emb_fusion)
feature_cols.append(embedding_column)
else:
raise ValueError('Unexpected column name occured')
return feature_cols
def train(sess_config,
input_hooks,
model,
data_init_op,
steps,
checkpoint_dir,
tf_config=None,
server=None):
model.is_training = True
hooks = []
hooks.extend(input_hooks)
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
stop_hook = tf.train.StopAtStepHook(last_step=steps)
log_hook = tf.train.LoggingTensorHook(
{
'steps': model.global_step,
'loss': model.loss
}, every_n_iter=100)
hooks.append(stop_hook)
hooks.append(log_hook)
if args.timeline > 0:
hooks.append(
tf.train.ProfilerHook(save_steps=args.timeline,
output_dir=checkpoint_dir))
save_steps = args.save_steps if args.save_steps or args.no_eval else steps
'''
Incremental_Checkpoint
Please add `save_incremental_checkpoint_secs` in 'tf.train.MonitoredTrainingSession'
it's default to None, Incremental_save checkpoint time in seconds can be set
to use incremental checkpoint function, like `tf.train.MonitoredTrainingSession(
save_incremental_checkpoint_secs=args.incremental_ckpt)`
'''
if args.incremental_ckpt and not args.tf:
print("Incremental_Checkpoint is not really enabled.")
print("Please see the comments in the code.")
sys.exit()
with tf.train.MonitoredTrainingSession(
master=server.target if server else '',
is_chief=tf_config['is_chief'] if tf_config else True,
hooks=hooks,
scaffold=scaffold,
checkpoint_dir=checkpoint_dir,
save_checkpoint_steps=save_steps,
summary_dir=checkpoint_dir,
save_summaries_steps=args.save_steps,
config=sess_config) as sess:
while not sess.should_stop():
sess.run([model.loss, model.train_op])
print("Training completed.")
def eval(sess_config, input_hooks, model, data_init_op, steps, checkpoint_dir):
model.is_training = False
hooks = []
hooks.extend(input_hooks)
scaffold = tf.train.Scaffold(
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op))
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold, checkpoint_dir=checkpoint_dir, config=sess_config)
writer = tf.summary.FileWriter(os.path.join(checkpoint_dir, 'eval'))
merged = tf.summary.merge_all()
with tf.train.MonitoredSession(session_creator=session_creator,
hooks=hooks) as sess:
for _in in range(1, steps + 1):
if (_in != steps):
sess.run([model.auc_op1, model.auc_op2])
if (_in % 1000 == 0):
print("Evaluation complete:[{}/{}]".format(_in, steps))
else:
eval_auc1, eval_auc2, eval_acc1, eval_acc2, events = sess.run(
[model.auc_op1, model.auc_op2, model.acc_op1, model.acc_op2, merged])
writer.add_summary(events, _in)
print("Evaluation complete:[{}/{}]".format(_in, steps))
print("ACC1 = {}\nAUC1 = {}".format(eval_acc1, eval_auc1))
print("ACC2 = {}\nAUC2 = {}".format(eval_acc2, eval_auc2))
def main(tf_config=None, server=None):
# check dataset and count data set size
print("Checking dataset...")
train_file = args.data_location + '/taobao_train_data'
test_file = args.data_location + '/taobao_test_data'
if args.parquet_dataset and not args.tf:
train_file += '.parquet'
test_file += '.parquet'
if (not os.path.exists(train_file)) or (not os.path.exists(test_file)):
print("Dataset does not exist in the given data_location.")
sys.exit()
no_of_training_examples = 0
no_of_test_examples = 0
if args.parquet_dataset and not args.tf:
import pyarrow.parquet as pq
no_of_training_examples = pq.read_table(train_file).num_rows
no_of_test_examples = pq.read_table(test_file).num_rows
else:
no_of_training_examples = sum(1 for line in open(train_file))
no_of_test_examples = sum(1 for line in open(test_file))
print("Number of training dataset is {}".format(no_of_training_examples))
print("Number of test dataset is {}".format(no_of_test_examples))
# set batch size, epoch & steps
batch_size = math.ceil(
args.batch_size / args.micro_batch
) if args.micro_batch and not args.tf else args.batch_size
if args.steps == 0:
no_of_epochs = 100
train_steps = math.ceil(
(float(no_of_epochs) * no_of_training_examples) / batch_size)
else:
no_of_epochs = math.ceil(
(float(batch_size) * args.steps) / no_of_training_examples)
train_steps = args.steps
test_steps = math.ceil(float(no_of_test_examples) / batch_size)
print("The training steps is {}".format(train_steps))
print("The testing steps is {}".format(test_steps))
# set fixed random seed
tf.set_random_seed(args.seed)
# ste directory path for checkpoint_dir
model_dir = os.path.join(args.output_dir,
'model_PLE_' + str(int(time.time())))
checkpoint_dir = args.checkpoint if args.checkpoint else model_dir
print("Saving model checkpoints to = " + checkpoint_dir)
# create data pipeline of train & test dataset
train_dataset = build_model_input(train_file, batch_size, no_of_epochs)
test_dataset = build_model_input(test_file, batch_size, 1)
dataset_output_types = tf.data.get_output_types(train_dataset)
dataset_output_shapes = tf.data.get_output_shapes(test_dataset)
iterator = tf.data.Iterator.from_structure(dataset_output_types,
dataset_output_shapes)
next_element = iterator.get_next()
train_init_op = iterator.make_initializer(train_dataset)
test_init_op = iterator.make_initializer(test_dataset)
# create future column
feature_cols = build_feature_cols()
# create variable partitioner for distributed training
num_ps_replicas = len(tf_config['ps_hosts']) if tf_config else 0
input_layer_partitioner = partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=args.input_layer_partitioner <<
20) if args.input_layer_partitioner else None
dense_layer_partitioner = partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=args.dense_layer_partitioner <<
10) if args.dense_layer_partitioner else None
# Session config
sess_config = tf.ConfigProto()
sess_config.inter_op_parallelism_threads = args.inter
sess_config.intra_op_parallelism_threads = args.intra
# Session hooks
hooks = []
if args.smartstaged and not args.tf:
'''Smart staged Feature'''
next_element = tf.staged(next_element, num_threads=4, capacity=40)
sess_config.graph_options.optimizer_options.do_smart_stage = True
hooks.append(tf.make_prefetch_hook())
if args.op_fusion and not args.tf:
'''Auto Graph Fusion'''
sess_config.graph_options.optimizer_options.do_op_fusion = True
if args.micro_batch and not args.tf:
'''Auto Mirco Batch'''
sess_config.graph_options.optimizer_options.micro_batch_num = args.micro_batch
# create model
model = PLE(input=next_element,
feature_column=feature_cols,
num_experts=EXPERTS_COUNT,
expert_hidden_units=EXPERT_HIDDEN_UNITS,
towers=TOWERS,
optimizer_type=args.optimizer,
bf16=args.bf16,
stock_tf=args.tf,
adaptive_emb=args.adaptive_emb,
input_layer_partitioner=input_layer_partitioner,
dense_layer_partitioner=dense_layer_partitioner)
# run model training and evalutaion
train(sess_config, hooks, model, train_init_op, train_steps,
checkpoint_dir, tf_config, server)
if not (args.no_eval or tf_config):
eval(sess_config, hooks, model, test_init_op, test_steps,
checkpoint_dir)
def boolean_string(string):
low_string = string.lower()
if low_string not in {'false', 'true'}:
raise ValueError('Not a valid boolean string')
return low_string == 'true'
def get_arg_parser():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_location',
help='Full path of train data',
required=False,
default='./data')
parser.add_argument('--steps',
help='set the number of steps on train dataset',
type=int,
default=0)
parser.add_argument('--batch_size',
help='Batch size to train',
type=int,
default=2048)
parser.add_argument('--output_dir',
help='Full path to logs & model output directory',
required=False,
default='./result')
parser.add_argument('--checkpoint',
help='Full path to checkpoints input/output directory',
required=False)
parser.add_argument('--model_dir',
help='Full path to test model directory',
required=False)
parser.add_argument('--learning_rate',
help='Learning rate for model',
type=float,
default=0.1)
parser.add_argument('--l2_regularization',
help='L2 regularization for the model',
type=float,
default=L2_REGULARIZATION)
parser.add_argument('--timeline',
help='number of steps on saving timeline',
type=int,
default=0)
parser.add_argument('--save_steps',
help='set the number of steps on saving checkpoints',
type=int,
default=0)
parser.add_argument('--seed',
help='set random seed',
type=int,
default=2021)
parser.add_argument('--keep_checkpoint_max',
help='Maximum number of recent checkpoint to keep',
type=int,
default=1)
parser.add_argument('--bf16',
help='enable DeepRec BF16 in deep model',
action='store_true')
parser.add_argument('--no_eval',
help='not evaluate trained model by eval dataset.',
action='store_true')
parser.add_argument('--protocol',
type=str,
choices=['grpc', 'grpc++', 'star_server'],
default='grpc')
parser.add_argument('--inter',
help='set inter op parallelism threads',
type=int,
default=0)
parser.add_argument('--intra',
help='set intra op parallelism threads',
type=int,
default=0)
parser.add_argument('--input_layer_partitioner',
help='slice size of input layer partitioner. units MB',
type=int,
default=0)
parser.add_argument('--dense_layer_partitioner',
help='slice size of dense layer partitioner. units KB',
type=int,
default=0)
parser.add_argument('--optimizer',
type=str, \
choices=['adam', 'adamasync', 'adagraddecay', 'adagrad'],
default='adam')
parser.add_argument('--tf', \
help='Use TF 1.15.5 API and disable DeepRec feature to run a baseline.',
action='store_true')
parser.add_argument('--smartstaged', \
help='Whether to enable smart staged feature of DeepRec, Default to True.',
type=boolean_string,
default=True)
parser.add_argument('--emb_fusion', \
help='Whether to enable embedding fusion, Default to True.',
type=boolean_string,
default=True)
parser.add_argument('--ev', \
help='Whether to enable DeepRec EmbeddingVariable. Default False.',
type=boolean_string,
default=False)
parser.add_argument('--ev_elimination', \
help='Feature Elimination of EmbeddingVariable Feature. Default closed.',
type=str,
choices=[None, 'l2', 'gstep'],
default=None)
parser.add_argument('--ev_filter', \
help='Feature Filter of EmbeddingVariable Feature. Default closed.',
type=str,
choices=[None, 'counter', 'cbf'],
default=None)
parser.add_argument('--op_fusion', \
help='Whether to enable Auto graph fusion feature. Default to True',
type=boolean_string,
default=True)
parser.add_argument('--micro_batch',
help='Set num for Auto Mirco Batch. Default close.',
type=int,
default=0) #TODO: Defautl to True
parser.add_argument('--adaptive_emb', \
help='Whether to enable Adaptive Embedding. Default to False.',
type=boolean_string,
default=False)
parser.add_argument('--dynamic_ev', \
help='Whether to enable Dynamic-dimension Embedding Variable. Default to False.',
type=boolean_string,
default=False)#TODO:enable
parser.add_argument('--incremental_ckpt', \
help='Set time of save Incremental Checkpoint. Default 0 to close.',
type=int,
default=0)
parser.add_argument('--workqueue', \
help='Whether to enable Work Queue. Default to False.',
type=boolean_string,
default=False)
parser.add_argument("--parquet_dataset", \
help='Whether to enable Parquet DataSet. Defualt to True.',
type=boolean_string,
default=True)
parser.add_argument("--parquet_dataset_shuffle", \
help='Whether to enable shuffle operation for Parquet Dataset. Default to False.',
type=boolean_string,
default=False)
parser.add_argument("--group_embedding", \
help='Whether to enable Group Embedding. Defualt to None.',
type=str,
choices=[None, 'localized', 'collective'],
default=None)
return parser
# parse distributed training configuration and generate cluster information
def generate_cluster_info(TF_CONFIG):
print(TF_CONFIG)
tf_config = json.loads(TF_CONFIG)
cluster_config = tf_config.get('cluster')
ps_hosts = []
worker_hosts = []
chief_hosts = []
for key, value in cluster_config.items():
if 'ps' == key:
ps_hosts = value
elif 'worker' == key:
worker_hosts = value
elif 'chief' == key:
chief_hosts = value
if chief_hosts:
worker_hosts = chief_hosts + worker_hosts
if not ps_hosts or not worker_hosts:
print('TF_CONFIG ERROR')
sys.exit()
task_config = tf_config.get('task')
task_type = task_config.get('type')
task_index = task_config.get('index') + (1 if task_type == 'worker'
and chief_hosts else 0)
if task_type == 'chief':
task_type = 'worker'
is_chief = True if task_index == 0 else False
cluster = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
server = tf.distribute.Server(cluster,
job_name=task_type,
task_index=task_index,
protocol=args.protocol)
if task_type == 'ps':
server.join()
elif task_type == 'worker':
tf_config = {
'ps_hosts': ps_hosts,
'worker_hosts': worker_hosts,
'type': task_type,
'index': task_index,
'is_chief': is_chief
}
tf_device = tf.device(
tf.train.replica_device_setter(
worker_device='/job:worker/task:%d' % task_index,
cluster=cluster))
return tf_config, server, tf_device
else:
print("Task type or index error.")
sys.exit()
def set_env_for_DeepRec():
'''
Set some ENV for these DeepRec's features enabled by ENV.
More Detail information is shown in https://deeprec.readthedocs.io/zh/latest/index.html.
START_STATISTIC_STEP & STOP_STATISTIC_STEP: On CPU platform, DeepRec supports memory optimization
in both stand-alone and distributed trainging. It's default to open, and the
default start and stop steps of collection is 1000 and 1100. Reduce the initial
cold start time by the following settings.
MALLOC_CONF: On CPU platform, DeepRec can use memory optimization with the jemalloc library.
Please preload libjemalloc.so by `LD_PRELOAD=./libjemalloc.so.2 python ...`
'''
os.environ['START_STATISTIC_STEP'] = '100'
os.environ['STOP_STATISTIC_STEP'] = '110'
os.environ['MALLOC_CONF']= \
'background_thread:true,metadata_thp:auto,dirty_decay_ms:20000,muzzy_decay_ms:20000'
if args.group_embedding == "collective":