Skip to content

Commit

Permalink
added possibility to exclude the 1s orbitals of heavy atoms from the …
Browse files Browse the repository at this point in the history
…basis
  • Loading branch information
ecignoni committed Jun 5, 2023
1 parent 3efb10c commit 0b79126
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 7 deletions.
60 changes: 53 additions & 7 deletions halex/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def batched_dataset_for_a_single_molecule(
core_feats: List[TensorMap] = None,
mo_indices=None,
lowdin_mo_indices=None,
ignore_heavy_1s: bool = False,
) -> BatchedMemoryDataset:
"""
Create a BatchedMemoryDataset (which is what our models expect)
Expand All @@ -149,25 +150,38 @@ def batched_dataset_for_a_single_molecule(

frames = small_basis.frames

# get the number of core elements as the number of atoms that are
# not hydrogens. Choose on the basis of the first frame
if ignore_heavy_1s:
ncore = _number_of_heavy_elements(big_basis.frames[0])
else:
ncore = 0

# truncate the big basis MO energies.
# If indices are present, use them
mo_energy = big_basis.mo_energy[:, ncore:]

if mo_indices is None:
mo_energy = big_basis.mo_energy[:, : small_basis.mo_energy.shape[1]]
mo_energy = mo_energy[:, : small_basis.mo_energy.shape[1] - ncore]
else:
mo_energy = torch.take(big_basis.mo_energy, mo_indices)
mo_energy = torch.take(mo_energy, mo_indices)

# no need to truncate here as they refer to _occupied_ MOs
lowdin_charges = (
big_basis.lowdin_charges_byMO
big_basis.lowdin_charges_byMO[:, ncore:]
if lowdin_charges_by_MO
else big_basis.lowdin_charges
)

# orbitals in the small basis (because we predict a small basis Fock)
orbs = small_basis.orbs
if ignore_heavy_1s:
orbs = _drop_heavy_1s_from_orbs(orbs)

# ao labels in the small basis
ao_labels = small_basis.ao_labels
if ignore_heavy_1s:
ao_labels = _drop_heavy_1s_from_ao_labels(ao_labels)

if core_feats is None:
return BatchedMemoryDataset(
Expand Down Expand Up @@ -198,6 +212,36 @@ def batched_dataset_for_a_single_molecule(
)


def _number_of_heavy_elements(frame):
return sum(frame.numbers != 1)


def _drop_heavy_1s_from_orbs(orbs):
new_orbs = dict()
for key in orbs.keys():
if key == 1:
new_orbs[key] = orbs[key]
else:
new_orbs[key] = list()
for nlm in orbs[key]:
if tuple(nlm) == (1, 0, 0):
pass
else:
new_orbs[key].append(nlm)
return new_orbs


def _drop_heavy_1s_from_ao_labels(ao_labels):
new_ao_labels = []
for lbl in ao_labels:
if lbl[1] == "H":
new_ao_labels.append(lbl)
else:
if tuple(lbl[2]) != (1, 0, 0):
new_ao_labels.append(lbl)
return new_ao_labels


def baselined_batched_dataset_for_a_single_molecule(
scf_datasets: Tuple[SCFData, SCFData],
feats: List[TensorMap],
Expand All @@ -207,6 +251,7 @@ def baselined_batched_dataset_for_a_single_molecule(
core_feats: List[TensorMap] = None,
mo_indices=None,
lowdin_mo_indices=None,
ignore_heavy_1s: bool = False,
) -> BatchedMemoryDataset:
"""
Create a BatchedMemoryDataset (which is what our models expect)
Expand All @@ -220,12 +265,13 @@ def baselined_batched_dataset_for_a_single_molecule(

frames = small_basis.frames

# truncate the big basis MO energies.
# If indices are present, use them
# get the number of core elements as the number of atoms that are
# not hydrogens. Choose on the basis of the first frame
mo_energy = big_basis.mo_energy
if mo_indices is None:
mo_energy = big_basis.mo_energy[:, : small_basis.mo_energy.shape[1]]
mo_energy = mo_energy[:, : small_basis.mo_energy.shape[1]]
else:
mo_energy = torch.take(big_basis.mo_energy, mo_indices)
mo_energy = torch.take(mo_energy, mo_indices)

# no need to truncate here as they refer to _occupied_ MOs
lowdin_charges = (
Expand Down
19 changes: 19 additions & 0 deletions halex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ase.io
import torch

import equistore
from equistore import Labels, TensorBlock, TensorMap


Expand Down Expand Up @@ -275,6 +276,24 @@ def drop_target_samples(
return TensorMap(targ_coupled.keys, blocks)


def drop_target_heavy_1s(targ_coupled, verbose=False):
def is_core(key, at="i"):
cond0 = key[f"a_{at}"] != 1
cond1 = key[f"n_{at}"] == 1
cond2 = key[f"l_{at}"] == 0
return cond0 and cond1 and cond2

keys_to_drop = []
for key in targ_coupled.keys:
if is_core(key, "i") or is_core(key, "j"):
keys_to_drop.append(tuple(key))
if verbose:
print(f"Dropping key: {key}")
keys_to_drop = Labels(targ_coupled.keys.names, values=np.array(keys_to_drop))

return equistore.drop_blocks(targ_coupled, keys_to_drop)


def fix_pyscf_l1_crossoverlap(cross_ovlp, frame, orbs_sb, orbs_bb):
indeces_sb = fix_pyscf_l1(None, frame, orbs_sb, return_index=True)
indeces_bb = fix_pyscf_l1(None, frame, orbs_bb, return_index=True)
Expand Down

0 comments on commit 0b79126

Please sign in to comment.