forked from InternLM/xtuner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharxiv.py
38 lines (32 loc) · 1.13 KB
/
arxiv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmchat.datasets import process_hf_dataset
from mmchat.datasets.collate_fns import default_collate_fn
from mmchat.datasets.map_fns import arxiv_map_fn
data_root = './data/'
# 1. Download data from https://kaggle.com/datasets/Cornell-University/arxiv
# 2. Process data with `./tools/data_preprocess/arxiv.py`
json_file = 'arxiv_postprocess_csAIcsCLcsCV_20200101.json'
arxiv = dict(
type=load_dataset,
path='json',
data_files=dict(train=data_root + json_file))
arxiv_dataset = dict(
type=process_hf_dataset,
dataset=arxiv,
split='train',
tokenizer=None,
max_length=2048,
map_fn=arxiv_map_fn,
remove_columns=[
'id', 'submitter', 'authors', 'title', 'comments', 'journal-ref',
'doi', 'report-no', 'categories', 'license', 'abstract', 'versions',
'update_date', 'authors_parsed'
],
concat_to_max_length=True)
train_dataloader = dict(
batch_size=1,
num_workers=0,
dataset=arxiv_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))