Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jkwang93 committed May 20, 2021
1 parent 4b59b09 commit 18c7b8a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions 4_train_agent_save_smiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def train_agent(restore_prior_from='./data/DM_middle_drd.ckpt',
restore_agent_from='./data/DM_middle_drd.ckpt',
restore_agent_from='./data/DM_middle_drd.ckpt',agent_save='./',
batch_size=128, n_steps=5000, sigma=60, save_dir='./MCMG_results/',
experience_replay=0):
voc = Vocabulary(init_from_file="data/Voc_RE1")
Expand Down Expand Up @@ -90,6 +90,8 @@ def train_agent(restore_prior_from='./data/DM_middle_drd.ckpt',
save_smiles_df = pd.DataFrame(smiles_save)
save_smiles_df.to_csv(save_dir + '_MCMG_drd.csv', index=False, header=False)
break
if step % 100 == 0 and step != 0:
torch.save(Agent.rnn.state_dict(), agent_save)

# Calculate augmented likelihood
augmented_likelihood = prior_likelihood + sigma * Variable(score)
Expand Down Expand Up @@ -151,10 +153,10 @@ def train_agent(restore_prior_from='./data/DM_middle_drd.ckpt',
parser.add_argument('--middle', action='store', dest='restore_prior_from',
default='./data/DM_middle_drd.ckpt',
help='Path to an RNN checkpoint file to use as a Prior')
parser.add_argument('--agent', action='store', dest='restore_agent_from',
parser.add_argument('--agent', action='store', dest='agent_save',
default='./data/DM_middle_drd.ckpt',
help='Path to an RNN checkpoint file to use as a Agent.')
parser.add_argument('--save-dir', action='store', dest='save_dir',
parser.add_argument('--save-file-path', action='store', dest='save_dir',
help='Path where results and model are saved. Default is data/results/run_<datetime>.')

arg_dict = vars(parser.parse_args())
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ The default task of our code is the DRD2 target. Users can customize their own t
python 1_train_prior_Transformer.py --train-data {your_training_data_path} --valid-data {your_valid_data_path} --save-prior-path {path_to_save_prior_model}
python 2_generator_Transformer.py --prior {piror_model_path} --save_molecules_path {save_molecules_path}
python 3_train_middle_model_dm.py --train-data {your_training_data_path} --save-middle-path {path_to_save_middle_model}
python 4_train_agent_save_smiles.py --num-steps 5000 --batch-size 128 --middle {path_of_middle_model} --agent {path_to_save_agent_model} --save-dir {save_smiles}
python 4_train_agent_save_smiles.py --num-steps 5000 --batch-size 128 --middle {path_of_middle_model} --agent {path_to_save_agent_model} ---save-file-path{save_smiles}
```

0 comments on commit 18c7b8a

Please sign in to comment.