Skip to content

Commit

Permalink
add device_map
Browse files Browse the repository at this point in the history
  • Loading branch information
mit1280 committed Jun 2, 2024
1 parent 363ae46 commit f3c70c9
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions unlimited_classifier/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def initialize_model(self, model: str) -> PreTrainedModel:
if config.is_encoder_decoder: # type: ignore
return ( # type: ignore
AutoModelForSeq2SeqLM # type: ignore
.from_pretrained(model, device_map=self.device, quantization_config=self.quantization_config)
.from_pretrained(model, device_map=self.device_map, quantization_config=self.quantization_config)
)
else:
try:
return ( # type: ignore
AutoModelForCausalLM # type: ignore
.from_pretrained(model, device_map=self.device, quantization_config=self.quantization_config)
.from_pretrained(model, device_map=self.device_map, quantization_config=self.quantization_config)
)
except:
raise ValueError("Expected generative model.")
Expand All @@ -91,6 +91,7 @@ def __init__(
],
prompt: str = "Classifity the following text:\n {}\nLabel:",
device: str="cpu",
device_map:str="cpu",
quantization_config=None,
num_beams: int=5,
max_new_tokens: int=512,
Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
raise ValueError("No labels provided.")

self.device = device
self.device_map = device_map
self.num_beams = min(num_beams, len(labels))
self.max_new_tokens = max_new_tokens
self.scorer = scorer
Expand Down

0 comments on commit f3c70c9

Please sign in to comment.