Skip to content

Commit

Permalink
Additional preparation for mGENRE HF release
Browse files Browse the repository at this point in the history
  • Loading branch information
nicola-decao committed Jun 9, 2022
1 parent 1fa1623 commit 9a28197
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 7 deletions.
10 changes: 7 additions & 3 deletions genre/fairseq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = logging.getLogger(__name__)


class GENREHubInterface(BARTHubInterface):
class _GENREHubInterface:
def sample(
self,
sentences: List[str],
Expand Down Expand Up @@ -71,6 +71,11 @@ def encode(self, sentence) -> torch.LongTensor:
else:
return tokens

class GENREHubInterface(_GENREHubInterface, BARTHubInterface):
pass

class mGENREHubInterface(_GENREHubInterface, BARTHubInterface):
pass

class GENRE(BARTModel):
@classmethod
Expand All @@ -95,7 +100,6 @@ def from_pretrained(
)
return GENREHubInterface(x["args"], x["task"], x["models"][0])


class mGENRE(BARTModel):
@classmethod
def from_pretrained(
Expand All @@ -120,4 +124,4 @@ def from_pretrained(
sentencepiece_model=os.path.join(model_name_or_path, sentencepiece_model),
**kwargs,
)
return GENREHubInterface(x["args"], x["task"], x["models"][0])
return mGENREHubInterface(x["args"], x["task"], x["models"][0])
14 changes: 10 additions & 4 deletions genre/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@

import torch
from transformers import (
BartForConditionalGeneration,
BartTokenizer,
BartForConditionalGeneration,
XLMRobertaTokenizer,
MBartForConditionalGeneration,
)

from genre.utils import chunk_it, post_process_wikidata

logger = logging.getLogger(__name__)


class GENREHubInterface(BartForConditionalGeneration):
class _GENREHubInterface:
def sample(
self,
sentences: List[str],
Expand Down Expand Up @@ -70,6 +71,11 @@ def sample(
def encode(self, sentence):
return self.tokenizer.encode(sentence, return_tensors="pt")[0]

class GENREHubInterface(_GENREHubInterface, BartForConditionalGeneration):
pass

class mGENREHubInterface(_GENREHubInterface, MBartForConditionalGeneration):
pass

class GENRE(BartForConditionalGeneration):
@classmethod
Expand All @@ -79,9 +85,9 @@ def from_pretrained(cls, model_name_or_path):
return model


class mGENRE(BartForConditionalGeneration):
class mGENRE(MBartForConditionalGeneration):
@classmethod
def from_pretrained(cls, model_name_or_path):
model = GENREHubInterface.from_pretrained(model_name_or_path)
model = mGENREHubInterface.from_pretrained(model_name_or_path)
model.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name_or_path)
return model
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ kilt
fairseq
transformers
bs4
marisa_trie
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import torch
from genre.fairseq_model import GENRE, mGENRE
from transformers import (
BartConfig,
BartTokenizer,
BartForConditionalGeneration,
TFBartForConditionalGeneration,
MBartConfig,
XLMRobertaTokenizer,
MBartForConditionalGeneration,
TFMBartForConditionalGeneration,
load_pytorch_model_in_tf2_model,
)


def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
"decoder.output_projection.weight",
]
for k in ignore_keys:
state_dict.pop(k, None)


def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = torch.nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer


# Load GENRE

# fairseq_path = "../models/fairseq_entity_disambiguation_aidayago"
# hf_path = "../models/hf_entity_disambiguation_aidayago"

# fairseq_model = GENRE.from_pretrained(fairseq_path).eval()
# config = BartConfig(vocab_size=50264)
# hf_model = BartForConditionalGeneration(config).eval()
# hf_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

# Load mGENRE

fairseq_path = "../models/fairseq_multilingual_entity_disambiguation"
hf_path = "../models/hf_multilingual_entity_disambiguation"

fairseq_model = mGENRE.from_pretrained(fairseq_path).eval()
config = MBartConfig(vocab_size=256001, scale_embedding=True)
hf_model = MBartForConditionalGeneration(config).eval()
hf_tokenizer = XLMRobertaTokenizer(os.path.join(fairseq_path, "spm_256000.model"))

# Convert model

state_dict = fairseq_model.model.state_dict()
remove_ignore_keys_(state_dict)
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
hf_model.model.load_state_dict(state_dict)
hf_model.lm_head = make_linear_from_emb(hf_model.model.shared)

# Save

hf_tokenizer.save_pretrained(hf_path)
hf_model.save_pretrained(hf_path)

# Convert TF GENRE

# hf_model = load_pytorch_model_in_tf2_model(
# TFBartForConditionalGeneration(
# config
# ),
# hf_model,
# )

# Convert TF mGENRE

hf_model = load_pytorch_model_in_tf2_model(
TFMBartForConditionalGeneration(config),
hf_model,
)

# Save

hf_model.save_pretrained(hf_path)

0 comments on commit 9a28197

Please sign in to comment.