Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 committed Jan 13, 2021
1 parent 9fc9f14 commit 5b30635
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
Binary file modified .DS_Store
Binary file not shown.
10 changes: 4 additions & 6 deletions bert_seq2seq/model/roberta_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,10 @@ def __init__(self, config):
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids=None, token_type_ids=None, position_ids=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]

def forward(self, input_ids, token_type_ids=None, position_ids=None):

input_shape = input_ids.size()

seq_length = input_shape[1]
device = input_ids.device
if position_ids is None:
Expand Down
8 changes: 4 additions & 4 deletions test/auto_title_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
## 加载训练的模型参数~
load_recent_model(bert_model, recent_model_path=auto_title_model, device=device)

test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
"楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
"新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]

# test_data = ["针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示将严查央视3·15晚会曝光通信违规违法行为,工信部称已约谈三大运营商有关负责人并连夜责成三大运营商和所在省通信管理局进行调查依法依规严肃处理",
# "楚天都市报记者采访了解到,对于进口冷链食品,武汉已经采取史上最严措施,进行“红区”管理,严格执行证明查验制度,确保冷冻冷藏肉等冻品的安全。",
# "新华社受权于18日全文播发修改后的《中华人民共和国立法法》修改后的立法法分为“总则”“法律”“行政法规”“地方性法规自治条例和单行条例规章”“适用与备案审查”“附则”等6章共计105条"]
test_data = ["重庆潼南县的8位村民一年前在河道里挖出一根30米长乌木,卖得19.6万元,大家分了这笔数额不小的意外之财。如今,当地财政局将他们起诉到法院,称乌木在河道中发现,其所有权应归国家。法院一审二审都判决村民们还钱"]
for text in test_data:
with torch.no_grad():
print(bert_model.generate(text, beam_size=3, device=device))
Expand Down
5 changes: 5 additions & 0 deletions test/test1..py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

a = 10
c = 5
b = a + c
print(b)

0 comments on commit 5b30635

Please sign in to comment.