Skip to content

Commit

Permalink
add environment
Browse files Browse the repository at this point in the history
  • Loading branch information
rubycheen committed Mar 21, 2022
1 parent 5c8f56e commit 6acb61c
Show file tree
Hide file tree
Showing 4 changed files with 434 additions and 0 deletions.
6 changes: 6 additions & 0 deletions hw1/r10946029/preprocess.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
if [ ! -f glove.840B.300d.txt ]; then
wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O glove.840B.300d.zip
unzip glove.840B.300d.zip
fi
python preprocess_intent.py
python preprocess_slot.py
129 changes: 129 additions & 0 deletions hw1/r10946029/preprocess_intent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import json
import logging
import pickle
import re
from argparse import ArgumentParser, Namespace
from collections import Counter
from pathlib import Path
from random import random, seed
from typing import List, Dict

import torch
from tqdm.auto import tqdm

from utils import Vocab

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def build_vocab(
words: Counter, vocab_size: int, output_dir: Path, glove_path: Path
) -> None:
common_words = {w for w, _ in words.most_common(vocab_size)}
vocab = Vocab(common_words)
vocab_path = output_dir / "vocab.pkl"
with open(vocab_path, "wb") as f:
pickle.dump(vocab, f)
logging.info(f"Vocab saved at {str(vocab_path.resolve())}")

glove: Dict[str, List[float]] = {}
logging.info(f"Loading glove: {str(glove_path.resolve())}")
with open(glove_path) as fp:
row1 = fp.readline()
# if the first row is not header
if not re.match("^[0-9]+ [0-9]+$", row1):
# seek to 0
fp.seek(0)
# otherwise ignore the header

for i, line in tqdm(enumerate(fp)):
cols = line.rstrip().split(" ")
word = cols[0]
vector = [float(v) for v in cols[1:]]

# skip word not in words if words are provided
if word not in common_words:
continue
glove[word] = vector
glove_dim = len(vector)

assert all(len(v) == glove_dim for v in glove.values())
assert len(glove) <= vocab_size

num_matched = sum([token in glove for token in vocab.tokens])
logging.info(
f"Token covered: {num_matched} / {len(vocab.tokens)} = {num_matched / len(vocab.tokens)}"
)
embeddings: List[List[float]] = [
glove.get(token, [random() * 2 - 1 for _ in range(glove_dim)])
for token in vocab.tokens
]
embeddings = torch.tensor(embeddings)
embedding_path = output_dir / "embeddings.pt"
torch.save(embeddings, str(embedding_path))
logging.info(f"Embedding shape: {embeddings.shape}")
logging.info(f"Embedding saved at {str(embedding_path.resolve())}")


def main(args):
seed(args.rand_seed)

intents = set()
words = Counter()
for split in ["train", "eval"]:
dataset_path = args.data_dir / f"{split}.json"
dataset = json.loads(dataset_path.read_text())
logging.info(f"Dataset loaded at {str(dataset_path.resolve())}")

intents.update({instance["intent"] for instance in dataset})
words.update(
[token for instance in dataset for token in instance["text"].split()]
)

intent2idx = {tag: i for i, tag in enumerate(intents)}
intent_tag_path = args.output_dir / "intent2idx.json"
intent_tag_path.write_text(json.dumps(intent2idx, indent=2))
logging.info(f"Intent 2 index saved at {str(intent_tag_path.resolve())}")

build_vocab(words, args.vocab_size, args.output_dir, args.glove_path)


def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument(
"--data_dir",
type=Path,
help="Directory to the dataset.",
default="./data/intent/",
)
parser.add_argument(
"--glove_path",
type=Path,
help="Path to Glove Embedding.",
default="./glove.840B.300d.txt",
)
parser.add_argument("--rand_seed", type=int, help="Random seed.", default=13)
parser.add_argument(
"--output_dir",
type=Path,
help="Directory to save the processed file.",
default="./cache/intent/",
)
parser.add_argument(
"--vocab_size",
type=int,
help="Number of token in the vocabulary",
default=10_000,
)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
main(args)
72 changes: 72 additions & 0 deletions hw1/r10946029/preprocess_slot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
import logging
from argparse import ArgumentParser, Namespace
from collections import Counter
from pathlib import Path
from random import seed

from preprocess_intent import build_vocab

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def main(args):
seed(args.rand_seed)

tags = set()
words = Counter()
for split in ["train", "eval"]:
dataset_path = args.data_dir / f"{split}.json"
dataset = json.loads(dataset_path.read_text())
logging.info(f"Dataset loaded at {str(dataset_path.resolve())}")

tags.update({tag for instance in dataset for tag in instance["tags"]})
words.update([token for instance in dataset for token in instance["tokens"]])

tag2idx = {tag: i for i, tag in enumerate(tags)}
tag_idx_path = args.output_dir / "tag2idx.json"
tag_idx_path.write_text(json.dumps(tag2idx, indent=2))
logging.info(f"Tag 2 index saved at {str(tag_idx_path.resolve())}")

build_vocab(words, args.vocab_size, args.output_dir, args.glove_path)


def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument(
"--data_dir",
type=Path,
help="Directory to the dataset.",
default="./data/slot/",
)
parser.add_argument(
"--glove_path",
type=Path,
help="Path to Glove Embedding.",
default="./glove.840B.300d.txt",
)
parser.add_argument("--rand_seed", type=int, help="Random seed.", default=13)
parser.add_argument(
"--output_dir",
type=Path,
help="Directory to save the processed file.",
default="./cache/slot/",
)
parser.add_argument(
"--vocab_size",
type=int,
help="Number of token in the vocabulary",
default=10_000,
)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
main(args)
Loading

0 comments on commit 6acb61c

Please sign in to comment.