Skip to content

Commit

Permalink
added option to select only some MOs when doing population analysis M…
Browse files Browse the repository at this point in the history
…O by MO
  • Loading branch information
ecignoni committed May 30, 2023
1 parent a361376 commit e48c081
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
13 changes: 13 additions & 0 deletions halex/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _loss_eigenvalues_lowdinq_vectorized(
weight_eigvals: float = 1.0,
weight_lowdinq: float = 1.0,
weight_regloss: float = 1.0,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""combined loss on MO energies and Lowdin charges
Expand Down Expand Up @@ -258,6 +259,7 @@ def loss_fn(
weight_eigvals=1.5e6,
weight_lowdinq=1e6,
weight_regloss=1.0,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return _loss_eigenvalues_lowdinq_vectorized(
pred_blocks=pred_blocks,
Expand All @@ -271,6 +273,7 @@ def loss_fn(
weight_eigvals=weight_eigvals,
weight_lowdinq=weight_lowdinq,
weight_regloss=weight_regloss,
**kwargs,
)

def _train_step(
Expand All @@ -294,6 +297,7 @@ def _train_step(
orbs=train_dataset.orbs,
ao_labels=train_dataset.ao_labels,
nelec_dict=train_dataset.nelec_dict,
mo_indices=train_dataset.lowdin_mo_indices,
**loss_kwargs,
)

Expand Down Expand Up @@ -332,6 +336,7 @@ def _validation_step(
orbs=valid_dataset.orbs,
ao_labels=valid_dataset.ao_labels,
nelec_dict=valid_dataset.nelec_dict,
mo_indices=valid_dataset.lowdin_mo_indices,
**loss_kwargs,
)

Expand Down Expand Up @@ -403,6 +408,7 @@ def loss_fn(
orbs: Dict[int, List],
ao_labels: List[int],
nelec_dict: Dict[str, float],
mo_indices=None,
weight_eigvals=1.5e6,
weight_lowdinq=1e6,
weight_lowdinq_tot=1e6,
Expand All @@ -416,6 +422,7 @@ def loss_fn(
orbs=orbs,
ao_labels=ao_labels,
nelec_dict=nelec_dict,
mo_indices=mo_indices,
regloss=self.regloss_,
weight_eigvals=weight_eigvals,
weight_lowdinq=weight_lowdinq,
Expand Down Expand Up @@ -473,6 +480,7 @@ def loss_fn(
weight_eigvals=1.5e6,
weight_lowdinq=1e6,
weight_regloss=1.0,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return _loss_eigenvalues_lowdinq_vectorized(
pred_blocks=pred_blocks,
Expand All @@ -486,6 +494,7 @@ def loss_fn(
weight_eigvals=weight_eigvals,
weight_lowdinq=weight_lowdinq,
weight_regloss=weight_regloss,
**kwargs,
)

def _train_step(
Expand Down Expand Up @@ -513,6 +522,7 @@ def _train_step(
orbs=train_dataset.orbs,
ao_labels=train_dataset.ao_labels,
nelec_dict=train_dataset.nelec_dict,
mo_indices=train_dataset.lowdin_mo_indices,
**loss_kwargs,
)

Expand Down Expand Up @@ -550,6 +560,7 @@ def _validation_step(
orbs=valid_dataset.orbs,
ao_labels=valid_dataset.ao_labels,
nelec_dict=valid_dataset.nelec_dict,
mo_indices=valid_dataset.lowdin_mo_indices,
**loss_kwargs,
)

Expand Down Expand Up @@ -622,6 +633,7 @@ def loss_fn(
orbs: Dict[int, List[Tuple[int, int, int]]],
ao_labels: List[List[Any]],
nelec_dict: Dict[str, float],
mo_indices=None,
weight_eigvals=1.5e6,
weight_lowdinq=1e6,
weight_lowdinq_tot=1e6,
Expand All @@ -635,6 +647,7 @@ def loss_fn(
orbs=orbs,
ao_labels=ao_labels,
nelec_dict=nelec_dict,
mo_indices=mo_indices,
regloss=self.regloss_,
weight_eigvals=weight_eigvals,
weight_lowdinq=weight_lowdinq,
Expand Down
11 changes: 9 additions & 2 deletions halex/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def batched_dataset_for_a_single_molecule(
lowdin_charges_by_MO: bool = False,
core_feats: List[TensorMap] = None,
mo_indices=None,
lowdin_mo_indices=None,
) -> BatchedMemoryDataset:
"""
Create a BatchedMemoryDataset (which is what our models expect)
Expand Down Expand Up @@ -179,6 +180,7 @@ def batched_dataset_for_a_single_molecule(
orbs=orbs,
nelec_dict=nelec_dict,
batch_size=batch_size,
lowdin_mo_indices=lowdin_mo_indices,
)
else:
return BatchedMemoryDataset(
Expand All @@ -192,6 +194,7 @@ def batched_dataset_for_a_single_molecule(
orbs=orbs,
nelec_dict=nelec_dict,
batch_size=batch_size,
lowdin_mo_indices=lowdin_mo_indices,
)


Expand All @@ -216,6 +219,7 @@ def coupled_fock_matrix_from_multiple_molecules(

def indices_from_MOM(cross_ovlp_paths, scf_datasets):
indices = {}
projections = []
for path, (mol, (sb, bb)) in zip(cross_ovlp_paths, scf_datasets.items()):
cross_ovlps = load_cross_ovlps(
path,
Expand All @@ -227,6 +231,7 @@ def indices_from_MOM(cross_ovlp_paths, scf_datasets):
c_sb = unorthogonalize_coeff(sb.ovlps, sb.mo_coeff_orth)
c_bb = unorthogonalize_coeff(bb.ovlps, bb.mo_coeff_orth)
proj = mom_orbital_projection(cross_ovlps, c_sb, c_bb, which="2over1")
projections.append(proj)
nocc = sum(sb.mo_occ == 2).item()
nvir = sum(sb.mo_occ == 0).item()
mo_vir_idx = indices_highest_orbital_projection(proj, n=nvir, skip_n=nocc)
Expand All @@ -235,11 +240,12 @@ def indices_from_MOM(cross_ovlp_paths, scf_datasets):
)
selected = torch.column_stack([mo_occ, mo_vir_idx])
indices[mol] = selected
return indices
return indices, projections


def indices_from_PMOM(cross_ovlp_paths, scf_datasets):
indices = {}
projections = []
for path, (mol, (sb, bb)) in zip(cross_ovlp_paths, scf_datasets.items()):
cross_ovlps = load_cross_ovlps(
path,
Expand All @@ -251,6 +257,7 @@ def indices_from_PMOM(cross_ovlp_paths, scf_datasets):
c_sb = unorthogonalize_coeff(sb.ovlps, sb.mo_coeff_orth)
c_bb = unorthogonalize_coeff(bb.ovlps, bb.mo_coeff_orth)
proj = pmom_orbital_projection(cross_ovlps, c_sb, c_bb, which="2over1")
projections.append(proj)
nocc = sum(sb.mo_occ == 2).item()
nvir = sum(sb.mo_occ == 0).item()
mo_vir_idx = indices_highest_orbital_projection(proj, n=nvir, skip_n=nocc)
Expand All @@ -259,4 +266,4 @@ def indices_from_PMOM(cross_ovlp_paths, scf_datasets):
)
selected = torch.column_stack([mo_occ, mo_vir_idx])
indices[mol] = selected
return indices
return indices, projections

0 comments on commit e48c081

Please sign in to comment.