Skip to content

Commit

Permalink
update variable name to avoid confusion
Browse files Browse the repository at this point in the history
  • Loading branch information
RuiWang1998 committed Jul 25, 2022
1 parent f6e6ab5 commit fb7628e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
8 changes: 5 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def main():
)
):
logging.info(f"Predicting {i + 1}th chain in {args.input_file}")
logging.info(f"{len(input_data[0]['msa'][0])} residues in this chain.")
logging.info(
f"{len(input_data[0]['p_msa'][0])} residues in this chain."
)
ts = time.time()
try:
output = model(
Expand All @@ -88,8 +90,8 @@ def main():
pipeline.save_pdb(
pos14=output['final_atom_positions'],
b_factors=output['confidence'] * 100,
sequence=input_data[0]['msa'][0],
mask=input_data[0]['msa_mask'][0],
sequence=input_data[0]['p_msa'][0],
mask=input_data[0]['p_msa_mask'][0],
save_path=save_path,
model=0
)
Expand Down
13 changes: 7 additions & 6 deletions omegafold/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ def forward(
"""
# Preparation before entering the cycles
primary_sequence = inputs[0]['msa'][..., 0, :]
primary_sequence = inputs[0]['p_msa'][..., 0, :]
max_confidence = 0
prev_dict = self.create_initial_prev_dict(len(primary_sequence))
final_result = None

# Start cycling
for cycle_data in inputs:
msa, msa_mask = cycle_data['msa'], cycle_data['msa_mask']
fasta, mask = msa[..., 0, :], msa_mask[..., 0, :]
p_msa, p_msa_mask = cycle_data['p_msa'], cycle_data['p_msa_mask']
fasta, mask = p_msa[..., 0, :], p_msa_mask[..., 0, :]
node_repr, edge_repr = self.deep_sequence_embed(
msa,
msa_mask,
p_msa,
p_msa_mask,
fwd_cfg
)
prev_dict['fasta'] = fasta
Expand All @@ -175,7 +175,7 @@ def forward(

result, prev_dict = self.omega_fold_cycle(
fasta=fasta,
mask=msa_mask,
mask=p_msa_mask,
node_repr=node_repr,
edge_repr=edge_repr,
fwd_cfg=fwd_cfg
Expand Down Expand Up @@ -208,6 +208,7 @@ def deep_sequence_embed(
Args:
fasta: the fasta sequence
mask: the mask indicating the validity of the token
fwd_cfg:
Returns:
Expand Down
12 changes: 6 additions & 6 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def fasta2inputs(
g = torch.Generator()
g.manual_seed(num_res)
for _ in range(num_cycle):
msa = aatype[None, :].repeat(num_pseudo_msa, 1)
msa_mask = torch.rand(
p_msa = aatype[None, :].repeat(num_pseudo_msa, 1)
p_msa_mask = torch.rand(
[num_pseudo_msa, num_res], generator=g
).gt(mask_rate)
msa_mask = torch.cat((mask[None, :], msa_mask), dim=0)
msa = torch.cat((aatype[None, :], msa), dim=0)
msa[~msa_mask.bool()] = 21
data.append({"msa": msa, "msa_mask": msa_mask})
p_msa_mask = torch.cat((mask[None, :], p_msa_mask), dim=0)
p_msa = torch.cat((aatype[None, :], p_msa), dim=0)
p_msa[~p_msa_mask.bool()] = 21
data.append({"p_msa": p_msa, "p_msa_mask": p_msa_mask})

yield (
utils.recursive_to(data, device=device),
Expand Down

0 comments on commit fb7628e

Please sign in to comment.