-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathmulti_choice_evaluator.py
62 lines (53 loc) · 2 KB
/
multi_choice_evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
# @Time : 2023/3/5 11:10 下午
# @Author : JianingWang
# @File : EvaluatorBase.py
import json
import os.path
import math
import torch
import numpy as np
from typing import Dict, Union, Any, Optional, Callable, List, Tuple, Iterator
import datasets
from datasets import Dataset
from config import DataTrainingArguments, TrainingArguments, ModelArguments
from hugnlp_trainer import HugTrainer
from processors.ProcessorBase import DataProcessor
from evaluators.EvaluatorBase import NO_GENERATE, DO_GENERATE, Evaluator, ClassificationEvaluator
from metrics.classification_metric import ClassificationMetric
from tools.runner_utils.log_util import logging
logger = logging.getLogger(__name__)
"""
Evaluator for the task of multi-choice classification with Masked PLMs.
"""
class MultiChoiceEvaluator(ClassificationEvaluator):
def __init__(
self,
model_args: ModelArguments,
data_args: DataTrainingArguments,
training_args: TrainingArguments,
processor: DataProcessor,
model: torch.nn.Module,
trainer: Optional[HugTrainer] = None,
eval_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
) -> None:
super().__init__(model_args, data_args, training_args, processor, model, trainer, eval_dataset, test_dataset)
self.paradigm = NO_GENERATE
"""
Evaluator for the task of multi-choice classification with Causal PLMs.
"""
class CausalMultiChoiceEvaluator(Evaluator):
def __init__(
self,
model_args: ModelArguments,
data_args: DataTrainingArguments,
training_args: TrainingArguments,
processor: DataProcessor,
model: torch.nn.Module,
trainer: Optional[HugTrainer] = None,
eval_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
) -> None:
super().__init__(model_args, data_args, training_args, processor, model, trainer, eval_dataset, test_dataset)
self.paradigm = DO_GENERATE