From 0b79126487895a85a0b594d66a9c40506f93668c Mon Sep 17 00:00:00 2001 From: Edoardo Cignoni Date: Mon, 5 Jun 2023 16:41:42 +0200 Subject: [PATCH] added possibility to exclude the 1s orbitals of heavy atoms from the basis --- halex/train_utils.py | 60 ++++++++++++++++++++++++++++++++++++++------ halex/utils.py | 19 ++++++++++++++ 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/halex/train_utils.py b/halex/train_utils.py index 0440766..96a8414 100644 --- a/halex/train_utils.py +++ b/halex/train_utils.py @@ -136,6 +136,7 @@ def batched_dataset_for_a_single_molecule( core_feats: List[TensorMap] = None, mo_indices=None, lowdin_mo_indices=None, + ignore_heavy_1s: bool = False, ) -> BatchedMemoryDataset: """ Create a BatchedMemoryDataset (which is what our models expect) @@ -149,25 +150,38 @@ def batched_dataset_for_a_single_molecule( frames = small_basis.frames + # get the number of core elements as the number of atoms that are + # not hydrogens. Choose on the basis of the first frame + if ignore_heavy_1s: + ncore = _number_of_heavy_elements(big_basis.frames[0]) + else: + ncore = 0 + # truncate the big basis MO energies. # If indices are present, use them + mo_energy = big_basis.mo_energy[:, ncore:] + if mo_indices is None: - mo_energy = big_basis.mo_energy[:, : small_basis.mo_energy.shape[1]] + mo_energy = mo_energy[:, : small_basis.mo_energy.shape[1] - ncore] else: - mo_energy = torch.take(big_basis.mo_energy, mo_indices) + mo_energy = torch.take(mo_energy, mo_indices) # no need to truncate here as they refer to _occupied_ MOs lowdin_charges = ( - big_basis.lowdin_charges_byMO + big_basis.lowdin_charges_byMO[:, ncore:] if lowdin_charges_by_MO else big_basis.lowdin_charges ) # orbitals in the small basis (because we predict a small basis Fock) orbs = small_basis.orbs + if ignore_heavy_1s: + orbs = _drop_heavy_1s_from_orbs(orbs) # ao labels in the small basis ao_labels = small_basis.ao_labels + if ignore_heavy_1s: + ao_labels = _drop_heavy_1s_from_ao_labels(ao_labels) if core_feats is None: return BatchedMemoryDataset( @@ -198,6 +212,36 @@ def batched_dataset_for_a_single_molecule( ) +def _number_of_heavy_elements(frame): + return sum(frame.numbers != 1) + + +def _drop_heavy_1s_from_orbs(orbs): + new_orbs = dict() + for key in orbs.keys(): + if key == 1: + new_orbs[key] = orbs[key] + else: + new_orbs[key] = list() + for nlm in orbs[key]: + if tuple(nlm) == (1, 0, 0): + pass + else: + new_orbs[key].append(nlm) + return new_orbs + + +def _drop_heavy_1s_from_ao_labels(ao_labels): + new_ao_labels = [] + for lbl in ao_labels: + if lbl[1] == "H": + new_ao_labels.append(lbl) + else: + if tuple(lbl[2]) != (1, 0, 0): + new_ao_labels.append(lbl) + return new_ao_labels + + def baselined_batched_dataset_for_a_single_molecule( scf_datasets: Tuple[SCFData, SCFData], feats: List[TensorMap], @@ -207,6 +251,7 @@ def baselined_batched_dataset_for_a_single_molecule( core_feats: List[TensorMap] = None, mo_indices=None, lowdin_mo_indices=None, + ignore_heavy_1s: bool = False, ) -> BatchedMemoryDataset: """ Create a BatchedMemoryDataset (which is what our models expect) @@ -220,12 +265,13 @@ def baselined_batched_dataset_for_a_single_molecule( frames = small_basis.frames - # truncate the big basis MO energies. - # If indices are present, use them + # get the number of core elements as the number of atoms that are + # not hydrogens. Choose on the basis of the first frame + mo_energy = big_basis.mo_energy if mo_indices is None: - mo_energy = big_basis.mo_energy[:, : small_basis.mo_energy.shape[1]] + mo_energy = mo_energy[:, : small_basis.mo_energy.shape[1]] else: - mo_energy = torch.take(big_basis.mo_energy, mo_indices) + mo_energy = torch.take(mo_energy, mo_indices) # no need to truncate here as they refer to _occupied_ MOs lowdin_charges = ( diff --git a/halex/utils.py b/halex/utils.py index df86489..aa2d79c 100644 --- a/halex/utils.py +++ b/halex/utils.py @@ -7,6 +7,7 @@ import ase.io import torch +import equistore from equistore import Labels, TensorBlock, TensorMap @@ -275,6 +276,24 @@ def drop_target_samples( return TensorMap(targ_coupled.keys, blocks) +def drop_target_heavy_1s(targ_coupled, verbose=False): + def is_core(key, at="i"): + cond0 = key[f"a_{at}"] != 1 + cond1 = key[f"n_{at}"] == 1 + cond2 = key[f"l_{at}"] == 0 + return cond0 and cond1 and cond2 + + keys_to_drop = [] + for key in targ_coupled.keys: + if is_core(key, "i") or is_core(key, "j"): + keys_to_drop.append(tuple(key)) + if verbose: + print(f"Dropping key: {key}") + keys_to_drop = Labels(targ_coupled.keys.names, values=np.array(keys_to_drop)) + + return equistore.drop_blocks(targ_coupled, keys_to_drop) + + def fix_pyscf_l1_crossoverlap(cross_ovlp, frame, orbs_sb, orbs_bb): indeces_sb = fix_pyscf_l1(None, frame, orbs_sb, return_index=True) indeces_bb = fix_pyscf_l1(None, frame, orbs_bb, return_index=True)