Skip to content

Commit

Permalink
Adding OneFormer model to DeepChem (deepchem#4146)
Browse files Browse the repository at this point in the history
* Added OneFormer model + tests + init + docs

* Modified docs

* Added model to rst

* Removed dummy tester file

* Added save/reload + overfit tests

* Try 1 DQC Fix

* Try 2 DQC Fix

* Made changes to tests

* Modified mini config

* Updated tests
  • Loading branch information
aaronrockmenezes authored Nov 6, 2024
1 parent 11d3932 commit b30a523
Show file tree
Hide file tree
Showing 6 changed files with 637 additions and 2 deletions.
1 change: 1 addition & 0 deletions deepchem/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from deepchem.models.torch_models import HuggingFaceModel
from deepchem.models.torch_models import Chemberta
from deepchem.models.torch_models import MoLFormer
from deepchem.models.torch_models import OneFormer
except ImportError as e:
logger.warning(e)

Expand Down
1 change: 1 addition & 0 deletions deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from deepchem.models.torch_models.chemberta import Chemberta
from deepchem.models.torch_models.molformer import MoLFormer
from deepchem.models.torch_models.prot_bert import ProtBERT
from deepchem.models.torch_models.oneformer import OneFormer

except ModuleNotFoundError as e:
logger.warning(f'Skipped loading modules with transformers dependency. {e}')
5 changes: 4 additions & 1 deletion deepchem/models/torch_models/hf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from deepchem.trans import Transformer, undo_transforms
from deepchem.utils.typing import LossFn, OneOrMany
from transformers.data.data_collator import DataCollatorForLanguageModeling
from transformers.models.auto import AutoModel, AutoModelForSequenceClassification, AutoModelForMaskedLM
from transformers.models.auto import AutoModel, AutoModelForSequenceClassification, AutoModelForMaskedLM, AutoModelForUniversalSegmentation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,6 +216,9 @@ def load_from_pretrained( # type: ignore
elif self.task in ['mtr', 'regression', 'classification']:
self.model = AutoModelForSequenceClassification.from_pretrained(
model_dir)
elif self.task == "universal_segmentation":
self.model = AutoModelForUniversalSegmentation.from_pretrained(
model_dir)
else:
self.model = AutoModel.from_pretrained(model_dir)
elif not from_hf_checkpoint:
Expand Down
Loading

0 comments on commit b30a523

Please sign in to comment.