Skip to content

Commit

Permalink
Merge pull request Cornell-RelaxML#34 from Cornell-RelaxML/patchc4
Browse files Browse the repository at this point in the history
patch c4
  • Loading branch information
tsengalb99 authored Jan 10, 2024
2 parents 6a506a3 + 1284949 commit 1ac36c1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion eval_ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def main(args):
datasets = ['wikitext2', 'c4']
datasets = ['wikitext2', 'c4', 'c4_new']
model, model_str = model_from_hf_path(args.hf_path,
use_cuda_graph=not args.no_use_cuda_graph,
use_flash_attn=not args.no_use_flash_attn)
Expand Down
13 changes: 6 additions & 7 deletions lib/utils/gptq_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,11 @@ def get_ptb(nsamples, seed, seqlen, model):
def get_c4(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
)
valdata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
)

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)

Expand Down Expand Up @@ -126,10 +125,10 @@ def get_ptb_new(nsamples, seed, seqlen, model):
def get_c4_new(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
)
valdata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
)

from transformers import AutoTokenizer
Expand Down Expand Up @@ -184,9 +183,9 @@ def get_test_tokens(
train_samples = 0
if name == 'wikitext2':
return get_wikitext2(train_samples, seed, seqlen, model)[1]['input_ids']
elif name == 'ptb':
return get_ptb_new(train_samples, seed, seqlen, model)[1].input_ids
elif name == 'c4':
return get_c4(train_samples, seed, seqlen, model)[1].input_ids
elif name == 'c4_new':
return get_c4_new(train_samples, seed, seqlen, model)[1].input_ids
else:
raise Exception
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ click==8.1.7
colorama==0.4.6
cryptography
DataProperty==1.0.1
datasets==2.14.6
datasets==2.16.1
dill==0.3.7
distro==1.8.0
einops==0.7.0
Expand Down

0 comments on commit 1ac36c1

Please sign in to comment.