Skip to content

Commit

Permalink
before pull
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceanusity committed Jan 29, 2024
1 parent e7ba030 commit b8d909b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@

.DS_Store
OpenDFT/QHBench/QH9/test_codes/*
**/*.pt
**/.idea
7 changes: 3 additions & 4 deletions OpenDFT/QHNet/train_wH.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision.transforms import Compose
from torch_geometric.loader import DataLoader

from ori_dataset import MD17_DFT, random_split, recorder_pos, get_mask
from ori_dataset import MD17_DFT, random_split, get_mask
from torch_ema import ExponentialMovingAverage
from transformers import get_polynomial_decay_schedule_with_warmup, get_cosine_schedule_with_warmup
logger = logging.getLogger()
Expand Down Expand Up @@ -95,9 +95,8 @@ def main(conf):
dataset = MD17_DFT(
os.path.join('/data/haiyang/QC_matrix/equiwave', 'dataset'),
name=conf.dataset.dataset_name,
transform=Compose([
recorder_pos,
get_mask]))
transform=get_mask
)

train_dataset, valid_dataset, test_dataset = \
random_split(dataset,
Expand Down

0 comments on commit b8d909b

Please sign in to comment.