Skip to content

Commit

Permalink
✨ download large files from aws s3
Browse files Browse the repository at this point in the history
  • Loading branch information
bage79 committed Dec 23, 2021
1 parent ecd659f commit 8e267d1
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 87 deletions.
61 changes: 23 additions & 38 deletions examples/nsmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import argparse
import logging
import os

import wget
import pandas as pd
import numpy as np
import torch
Expand All @@ -35,8 +33,9 @@
from transformers.optimization import AdamW, get_cosine_schedule_with_warmup
from transformers import BartForSequenceClassification


from kobart import get_kobart_tokenizer, get_pytorch_kobart_model
from kobart.utils.utils import download


logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand All @@ -46,20 +45,6 @@ class ArgsBase:
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument(
"--train_file",
type=str,
default="nsmc/ratings_train.txt",
help="train file",
)

parser.add_argument(
"--test_file",
type=str,
default="nsmc/ratings_test.txt",
help="test file",
)

parser.add_argument("--batch_size", type=int, default=128, help="")
parser.add_argument("--max_seq_len", type=int, default=128, help="")
return parser
Expand Down Expand Up @@ -102,27 +87,27 @@ def __getitem__(self, index):


class NSMCDataModule(pl.LightningDataModule):
def __init__(self, train_file, test_file, max_seq_len=128, batch_size=32):
def __init__(self, max_seq_len=128, batch_size=32):
super().__init__()
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.train_file_path = os.path.join(args.cachedir, train_file)
self.test_file_path = os.path.join(args.cachedir, test_file)
print("train_file_path:", self.train_file_path)
print("test_file_path:", self.test_file_path)

os.makedirs(os.path.dirname(self.train_file_path), exist_ok=True)
os.makedirs(os.path.dirname(self.test_file_path), exist_ok=True)
if not os.path.exists(self.train_file_path):
wget.download(
"https://www.dropbox.com/s/374ftkec978br3d/ratings_train.txt?dl=1",
self.train_file_path,
)
if not os.path.exists(self.test_file_path):
wget.download(
"https://www.dropbox.com/s/977gbwh542gdy94/ratings_test.txt?dl=1",
self.test_file_path,
)

s3_train_file = {
"url": "s3://skt-lsl-nlp-model/KoBART/datasets/nsmc/ratings_train.txt",
"chksum": None,
}
s3_test_file = {
"url": "s3://skt-lsl-nlp-model/KoBART/datasets/nsmc/ratings_test.txt",
"chksum": None,
}

os.makedirs(os.path.dirname(args.cachedir), exist_ok=True)
self.train_file_path, is_cached = download(
s3_train_file["url"], s3_train_file["chksum"], cachedir=args.cachedir
)
self.test_file_path, is_cached = download(
s3_test_file["url"], s3_test_file["chksum"], cachedir=args.cachedir
)

@staticmethod
def add_model_specific_args(parent_parser):
Expand Down Expand Up @@ -276,7 +261,9 @@ def validation_epoch_end(self, outputs):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="subtask for KoBART")
parser.add_argument("--cachedir", type=str, default=".cache")
parser.add_argument(
"--cachedir", type=str, default=os.path.join(os.getcwd(), ".cache")
)
parser.add_argument("--subtask", type=str, default="NSMC", help="NSMC")
parser = Classification.add_model_specific_args(parser)
parser = ArgsBase.add_model_specific_args(parser)
Expand All @@ -294,8 +281,6 @@ def validation_epoch_end(self, outputs):
if args.subtask == "NSMC":
# init data
dm = NSMCDataModule(
args.train_file,
args.test_file,
batch_size=args.batch_size,
max_seq_len=args.max_seq_len,
)
Expand Down
2 changes: 1 addition & 1 deletion kobart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE.

from kobart.utils import download
from kobart.utils.utils import download
from kobart.pytorch_kobart import get_pytorch_kobart_model, get_kobart_tokenizer

__all__ = ("download", "get_kobart_tokenizer", "get_pytorch_kobart_model")
25 changes: 10 additions & 15 deletions kobart/pytorch_kobart.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,18 @@
from zipfile import ZipFile
from transformers import PreTrainedTokenizerFast

from kobart.utils import download as _download

pytorch_kobart = {
"url": "https://kobert.blob.core.windows.net/models/kobart/kobart_base_cased_ff4bda5738.zip",
"fname": "kobart_base_cased_ff4bda5738.zip",
"chksum": "ff4bda5738",
}
from kobart.utils.utils import download


def get_pytorch_kobart_model(ctx="cpu", cachedir=".cache"):
# download model
global pytorch_kobart
model_info = pytorch_kobart
model_zip, is_cached = _download(
model_info["url"], model_info["fname"], model_info["chksum"], cachedir=cachedir
pytorch_kobart = {
"url": "s3://skt-lsl-nlp-model/KoBART/models/kobart_base_cased_ff4bda5738.zip",
"chksum": "ff4bda5738",
}
model_zip, is_cached = download(
pytorch_kobart["url"], pytorch_kobart["chksum"], cachedir=cachedir
)
cachedir_full = os.path.expanduser(cachedir)
cachedir_full = os.path.join(os.getcwd(), cachedir)
model_path = os.path.join(cachedir_full, "kobart_from_pretrained")
if not os.path.exists(model_path) or not is_cached:
if not is_cached:
Expand All @@ -55,10 +50,10 @@ def get_pytorch_kobart_model(ctx="cpu", cachedir=".cache"):
def get_kobart_tokenizer(cachedir=".cache"):
"""Get KoGPT2 Tokenizer file path after downloading"""
tokenizer = {
"url": "s3://skt-lsl-apne2/model/public_storage/KoBART/tokenizers/kobart_base_tokenizer_cased_cf74400bce.zip",
"url": "s3://skt-lsl-nlp-model/KoBART/tokenizers/kobart_base_tokenizer_cased_cf74400bce.zip",
"chksum": "cf74400bce",
}
file_path, is_cached = _download(
file_path, is_cached = download(
tokenizer["url"], tokenizer["chksum"], cachedir=cachedir
)
cachedir_full = os.path.expanduser(cachedir)
Expand Down
67 changes: 67 additions & 0 deletions kobart/utils/aws_s3_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import boto3
import os
import sys
from botocore import UNSIGNED
from botocore.client import Config


class AwsS3Downloader(object):
def __init__(
self,
aws_access_key_id=None,
aws_secret_access_key=None,
):
self.resource = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
).resource("s3")
self.client = boto3.client(
"s3",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
config=Config(signature_version=UNSIGNED),
)

def __split_url(self, url: str):
if url.startswith("s3://"):
url = url.replace("s3://", "")
bucket, key = url.split("/", maxsplit=1)
return bucket, key

def download(self, url: str, local_dir: str):
bucket, key = self.__split_url(url)
filename = os.path.basename(key)
file_path = os.path.join(local_dir, filename)

os.makedirs(os.path.dirname(file_path), exist_ok=True)
meta_data = self.client.head_object(Bucket=bucket, Key=key)
total_length = int(meta_data.get("ContentLength", 0))

downloaded = 0

def progress(chunk):
nonlocal downloaded
downloaded += chunk
done = int(50 * downloaded / total_length)
sys.stdout.write(
"\r{}[{}{}]".format(file_path, "█" * done, "." * (50 - done))
)
sys.stdout.flush()

try:
with open(file_path, "wb") as f:
self.client.download_fileobj(bucket, key, f, Callback=progress)
sys.stdout.write("\n")
sys.stdout.flush()
except:
raise Exception(f"downloading file is failed. {url}")
return file_path


if __name__ == "__main__":
s3 = AwsS3Downloader()

s3.download(
url="s3://skt-lsl-nlp-model/KoBART/models/kobart_base_cased_ff4bda5738.zip",
local_dir=".cache",
)
45 changes: 13 additions & 32 deletions kobart/utils.py → kobart/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,23 @@
import hashlib
import os

from kobart.utils.aws_s3_downloader import AwsS3Downloader


tokenizer = {
"url": "https://kobert.blob.core.windows.net/models/kobart/kobart_base_tokenizer_cased_cf74400bce.zip",
"fname": "kobart_base_tokenizer_cased_cf74400bce.zip",
"chksum": "cf74400bce",
}


def download(url, filename, chksum, cachedir=".cached"):
f_cachedir = os.path.expanduser(cachedir)
os.makedirs(f_cachedir, exist_ok=True)
file_path = os.path.join(f_cachedir, filename)
def download(url, chksum=None, cachedir=".cache"):
cachedir_full = os.path.join(os.getcwd(), cachedir)
os.makedirs(cachedir_full, exist_ok=True)
filename = os.path.basename(url)
file_path = os.path.join(cachedir_full, filename)
if os.path.isfile(file_path):
if hashlib.md5(open(file_path, "rb").read()).hexdigest()[:10] == chksum:
print("using cached model")
print(f"using cached model. {file_path}")
return file_path, True
with open(file_path, "wb") as f:
response = requests.get(url, stream=True)
total = response.headers.get("content-length")

if total is None:
f.write(response.content)
else:
downloaded = 0
total = int(total)
for data in response.iter_content(
chunk_size=max(int(total / 1000), 1024 * 1024)
):
downloaded += len(data)
f.write(data)
done = int(50 * downloaded / total)
sys.stdout.write("\r[{}{}{}]".format(file_path, "█" * done, "." * (50 - done)))
sys.stdout.flush()
sys.stdout.write("\n")
assert (
chksum == hashlib.md5(open(file_path, "rb").read()).hexdigest()[:10]
), "corrupted file!"
s3 = AwsS3Downloader()
file_path = s3.download(url, cachedir_full)
if chksum:
assert (
chksum == hashlib.md5(open(file_path, "rb").read()).hexdigest()[:10]
), "corrupted file!"
return file_path, False
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
boto3
pandas
pytorch-lightning == 1.2.1
torch == 1.7.1
transformers == 4.3.3
wget

0 comments on commit 8e267d1

Please sign in to comment.