Skip to content

Commit

Permalink
还原删除代码
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhonglian committed Dec 30, 2022
1 parent 2feccb3 commit 12978f2
Showing 1 changed file with 57 additions and 57 deletions.
114 changes: 57 additions & 57 deletions src/py3.x/tensorflow2.x/text_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,62 +93,62 @@ def seq_padding(X, padding=0):
## bert / Embedding/ + lstm + crt


#%%
# 加载数据
class TextBert():
def __init__(self):
self.path_config = Config.bert.path_config
self.path_checkpoint = Config.bert.path_checkpoint

self.token_dict = {}
with codecs.open(Config.bert.dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
self.token_dict[token] = len(self.token_dict)


def prepare_data(self):
neg = pd.read_excel(Config.bert.path_neg, header=None)
pos = pd.read_excel(Config.bert.path_pos, header=None)
data = []
for d in neg[0]:
data.append((d, 0))
for d in pos[0]:
data.append((d, 1))
# 按照9:1的比例划分训练集和验证集
random_order = list(range(len(data)))
np.random.shuffle(random_order)
train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0]
valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0]
return train_data, valid_data

def build_model(self, m_type="bert"):
if m_type == "bert":
bert_model = load_trained_model_from_checkpoint(self.path_config, self.path_checkpoint, seq_len=None)
for l in bert_model.layers:
l.trainable = True
x1_in = Input(shape=(None,))
x2_in = Input(shape=(None,))
x = bert_model([x1_in, x2_in])
x = Lambda(lambda x: x[:, 0])(x)
p = Dense(1, activation='sigmoid')(x)#根据分类种类自行调节,也可以多加一些层数
model = Model([x1_in, x2_in], p)
model.compile(
loss='binary_crossentropy',
optimizer=Adam(1e-5), # 用足够小的学习率
metrics=['accuracy']
)
else:
# 否则用 Embedding
model = Sequential()
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
crf = CRF(len(chunk_tags), sparse_target=True)
model.add(crf)
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])

model.summary()
return model
#%%
# 加载数据
class TextBert():
def __init__(self):
self.path_config = Config.bert.path_config
self.path_checkpoint = Config.bert.path_checkpoint

self.token_dict = {}
with codecs.open(Config.bert.dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
self.token_dict[token] = len(self.token_dict)


def prepare_data(self):
neg = pd.read_excel(Config.bert.path_neg, header=None)
pos = pd.read_excel(Config.bert.path_pos, header=None)
data = []
for d in neg[0]:
data.append((d, 0))
for d in pos[0]:
data.append((d, 1))
# 按照9:1的比例划分训练集和验证集
random_order = list(range(len(data)))
np.random.shuffle(random_order)
train_data = [data[j] for i, j in enumerate(random_order) if i % 10 != 0]
valid_data = [data[j] for i, j in enumerate(random_order) if i % 10 == 0]
return train_data, valid_data

def build_model(self, m_type="bert"):
if m_type == "bert":
bert_model = load_trained_model_from_checkpoint(self.path_config, self.path_checkpoint, seq_len=None)
for l in bert_model.layers:
l.trainable = True
x1_in = Input(shape=(None,))
x2_in = Input(shape=(None,))
x = bert_model([x1_in, x2_in])
x = Lambda(lambda x: x[:, 0])(x)
p = Dense(1, activation='sigmoid')(x)#根据分类种类自行调节,也可以多加一些层数
model = Model([x1_in, x2_in], p)
model.compile(
loss='binary_crossentropy',
optimizer=Adam(1e-5), # 用足够小的学习率
metrics=['accuracy']
)
else:
# 否则用 Embedding
model = Sequential()
model.add(Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
model.add(Bidirectional(LSTM(BiRNN_UNITS // 2, return_sequences=True)))
crf = CRF(len(chunk_tags), sparse_target=True)
model.add(crf)
model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
model.summary()
return model


#%%
Expand Down Expand Up @@ -348,4 +348,4 @@ def _generator():
print('Fill with: ', list(map(lambda x: token_dict_rev[x], predicts[0][1:3])))
# Fill with: ['数', '学']

# %%
# %%

0 comments on commit 12978f2

Please sign in to comment.