Skip to content

Commit

Permalink
torch.no_grad vs. requires_grad
Browse files Browse the repository at this point in the history
  • Loading branch information
chunhuizhang committed Jul 3, 2022
1 parent 8c1d502 commit 8131406
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 0 deletions.
20 changes: 20 additions & 0 deletions fine_tune/bert/tutorials/03_bert_input_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

from transformers import BertTokenizer, BertModel
from transformers.models.bert import BertModel
import torch
from torch import nn


model_name = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, output_hidden_states=True)

test_sent = 'this is a test sentence'

model_input = tokenizer(test_sent, return_tensors='pt')


model.eval()
with torch.no_grad():
output = model(**model_input)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions fine_tune/bert_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
model = BertModel.from_pretrained(model_name)
cls_model = BertForSequenceClassification.from_pretrained(model_name)



total_params = 0
total_learnable_params = 0
total_embedding_params = 0
Expand Down
13 changes: 13 additions & 0 deletions fine_tune/input_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

from transformers import BertModel, BertTokenizer

model_name = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

raw_sentences = ['Tom likes cats', 'Liz likes dogs']

inputs = tokenizer.encode_plus(raw_sentences[0], raw_sentences[1], return_tensors='pt')
# inputs = tokenizer('Hello, my dog is cute', return_tensors='pt')
model(**inputs)
11 changes: 11 additions & 0 deletions myweb/demo/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from flask import Flask, render_template

app = Flask(__name__)

@app.route('/')
def index():
return render_template('index.html')

if __name__ == '__main__':
app.run()

4 changes: 4 additions & 0 deletions myweb/demo/templates/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

<html>
hello world!
</html>

0 comments on commit 8131406

Please sign in to comment.