Skip to content

Commit

Permalink
pbc.df.ft_ao on GPU (#291)
Browse files Browse the repository at this point in the history
* Add ft_ao cuda kernel

* Add ft_ao.py

* Add helper functions in gpu4pyscf.gto.mole

* Update pbc.ft_ao

* ft_ao runs, output incorrect

* PBC ft_ao general kernel correct

* ft_ao unrolled

* Modified kpts_to_kmesh

* Add tests

* Handle non-symmetric case; add more tests.

* Lint

* Missing files

* Apply the ft_ao GPU implementation in aft and aft_jk

* Update VHFOpt in scf.jk module

* Undefined variables

* vhfopt.mol -> vhfopt.sorted_mol

* Fix J-engine due to the change of _VHFOpt class

* Remove print statements

* Apache header

* More Apache headers

---------

Co-authored-by: Qiming Sun <[email protected]>
  • Loading branch information
sunqm and Qiming Sun authored Dec 24, 2024
1 parent 632c563 commit 9d28f26
Show file tree
Hide file tree
Showing 26 changed files with 3,452 additions and 417 deletions.
4 changes: 2 additions & 2 deletions gpu4pyscf/df/int3c2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def build(self, cutoff=1e-14, group_size=None,
_mol = self.mol
_auxmol = self.auxmol

mol = basis_seg_contraction(_mol,allow_replica=True)
auxmol = basis_seg_contraction(_auxmol, allow_replica=True)
mol = basis_seg_contraction(_mol, allow_replica=True)[0]
auxmol = basis_seg_contraction(_auxmol, allow_replica=True)[0]

log = logger.new_logger(_mol, _mol.verbose)
cput0 = log.init_timer()
Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/dft/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def build(self, mol=None):
if hasattr(mol, '_decontracted') and mol._decontracted:
raise RuntimeError('mol object is already decontracted')

pmol = basis_seg_contraction(mol, allow_replica=True)
pmol = basis_seg_contraction(mol, allow_replica=True)[0]
pmol.cart = mol.cart
coeff = cupy.eye(mol.nao) # without cart2sph transformation
# Sort basis according to angular momentum and contraction patterns so
Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/grad/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _jk_energy_per_atom(mol, dm, vhfopt=None,
if vhfopt is None:
vhfopt = _VHFOpt(mol).build()

mol = vhfopt.mol
mol = vhfopt.sorted_mol
nao, nao_orig = vhfopt.coeff.shape

dm = cp.asarray(dm, order='C')
Expand Down
14 changes: 5 additions & 9 deletions gpu4pyscf/grad/tests/test_rhf_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,11 @@

def setUpModule():
global mol_sph, mol_cart
mol_sph = pyscf.M(atom=atom, basis=bas0, max_memory=32000)
mol_sph.output = '/dev/null'
mol_sph.build()
mol_sph.verbose = 1

mol_cart = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=1)
mol_cart.output = '/dev/null'
mol_cart.build()
mol_cart.verbose = 1
mol_sph = pyscf.M(atom=atom, basis=bas0, max_memory=32000,
output='/dev/null', verbose=1)

mol_cart = pyscf.M(atom=atom, basis=bas0, max_memory=32000, cart=1,
output='/dev/null', verbose=1)

def tearDownModule():
global mol_sph, mol_cart
Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/gto/int3c1e.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __del__(self):

def build(self, cutoff=1e-13, group_size=BLKSIZE, diag_block_with_triu=False, aosym=True):
original_mol = self.mol
mol = basis_seg_contraction(original_mol, allow_replica=True)
mol = basis_seg_contraction(original_mol, allow_replica=True)[0]

log = logger.new_logger(original_mol, original_mol.verbose)
cput0 = log.init_timer()
Expand Down
188 changes: 156 additions & 32 deletions gpu4pyscf/gto/mole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.


import os
import numpy as np
import cupy
import functools
import copy
import numpy as np
import scipy.linalg
from pyscf import gto
from pyscf.gto import (ANG_OF, ATOM_OF, NPRIM_OF, NCTR_OF, PTR_COORD, PTR_COEFF,
PTR_EXP)
from gpu4pyscf.lib import logger

PTR_BAS_COORD = 7

@functools.lru_cache(20)
def get_cart2sph(lmax=12):
Expand All @@ -28,67 +31,90 @@ def get_cart2sph(lmax=12):
cart2sph.append(np.asarray(c2s, order='C'))
return cart2sph

def basis_seg_contraction(mol, allow_replica=False):
def basis_seg_contraction(mol, allow_replica=1):
'''transform generally contracted basis to segment contracted basis
Kwargs:
allow_replica:
transform the generally contracted basis to replicated
segment-contracted basis
when angular momentum lower than (or equal to) this value, transform
the generally contracted basis to replicated segment-contracted basis.
By default, high angular momentum functions (d, f shells) are fully
uncontracted.
'''
# Ensure backward compatibility. When allow_replica is True, decontraction
# to primitive functions is disabled. When allow_replica is False, all
# general contraction are decontracted.
if allow_replica is True:
allow_replica = 8
elif allow_replica is False:
allow_replica = -1

bas_templates = {}
_bas = []
_env = mol._env.copy()

contr_coeff = []
aoslices = mol.aoslice_by_atom()
for ia, (ib0, ib1) in enumerate(aoslices[:,:2]):
key = tuple(mol._bas[ib0:ib1,gto.PTR_EXP])
key = tuple(mol._bas[ib0:ib1,PTR_COEFF])
if key in bas_templates:
bas_of_ia = bas_templates[key]
bas_of_ia, coeff = bas_templates[key]
bas_of_ia = bas_of_ia.copy()
bas_of_ia[:,gto.ATOM_OF] = ia
bas_of_ia[:,ATOM_OF] = ia
else:
# Generate the template for decontracted basis
coeff = []
bas_of_ia = []
for shell in mol._bas[ib0:ib1]:
l = shell[gto.ANG_OF]
nctr = shell[gto.NCTR_OF]
l = shell[ANG_OF]
nf = (l + 1) * (l + 2) // 2
nctr = shell[NCTR_OF]
if nctr == 1:
bas_of_ia.append(shell)
coeff.append(np.eye(nf))
continue

# Only basis with nctr > 1 needs to be decontracted
nprim = shell[gto.NPRIM_OF]
pcoeff = shell[gto.PTR_COEFF]
if allow_replica:
nprim = shell[NPRIM_OF]
pcoeff = shell[PTR_COEFF]
if l <= allow_replica:
coeff.extend([np.eye(nf)] * nctr)
bs = np.repeat(shell[np.newaxis], nctr, axis=0)
bs[:,gto.NCTR_OF] = 1
bs[:,gto.PTR_COEFF] = np.arange(pcoeff, pcoeff+nprim*nctr, nprim)
bs[:,NCTR_OF] = 1
bs[:,PTR_COEFF] = np.arange(pcoeff, pcoeff+nprim*nctr, nprim)
bas_of_ia.append(bs)
else:
pexp = shell[gto.PTR_EXP]
else: # To avoid recomputation, decontract to primitive functions
pexp = shell[PTR_EXP]
exps = _env[pexp:pexp+nprim]
norm = gto.gto_norm(l, exps)
# remove normalization from contraction coefficients
c = _env[pcoeff:pcoeff+nprim*nctr].reshape(nctr,nprim)
c = np.einsum('ip,p,ef->iepf', c, 1/norm, np.eye(nf))
coeff.append(c.reshape(nf*nctr, nf*nprim).T)

_env[pcoeff:pcoeff+nprim] = norm
bs = np.repeat(shell[np.newaxis], nprim, axis=0)
bs[:,gto.NPRIM_OF] = 1
bs[:,gto.NCTR_OF] = 1
bs[:,gto.PTR_EXP] = np.arange(pexp, pexp+nprim)
bs[:,gto.PTR_COEFF] = np.arange(pcoeff, pcoeff+nprim)
bs[:,NPRIM_OF] = 1
bs[:,NCTR_OF] = 1
bs[:,PTR_EXP] = np.arange(pexp, pexp+nprim)
bs[:,PTR_COEFF] = np.arange(pcoeff, pcoeff+nprim)
bas_of_ia.append(bs)

bas_of_ia = np.vstack(bas_of_ia)
bas_templates[key] = bas_of_ia
if len(bas_of_ia) > 0:
bas_of_ia = np.vstack(bas_of_ia)
bas_templates[key] = (bas_of_ia, coeff)
else:
continue

_bas.append(bas_of_ia)
contr_coeff.extend(coeff)

pmol = mol.copy()
pmol.output = mol.output
pmol.verbose = mol.verbose
pmol.stdout = mol.stdout
pmol.cart = True #mol.cart
pmol.cart = True
pmol._bas = np.asarray(np.vstack(_bas), dtype=np.int32)
pmol._env = _env
return pmol
contr_coeff = scipy.linalg.block_diag(*contr_coeff)

if not mol.cart:
contr_coeff = contr_coeff.dot(mol.cart2sph_coeff())
return pmol, contr_coeff

def sort_atoms(mol):
"""
Expand Down Expand Up @@ -133,3 +159,101 @@ def sort_atoms(mol):
full_path[heavy_idx].append(hydrogen_atoms[i])

return [x for heavy_list in full_path for x in heavy_list]

def group_basis(mol, tile=1, group_size=None):
'''Group basis functions according to their [l, nprim] patterns'''
mol, coeff = basis_seg_contraction(mol)
# Sort basis according to angular momentum and contraction patterns so
# as to group the basis functions to blocks in GPU kernel.
l_ctrs = mol._bas[:,[ANG_OF, NPRIM_OF]]
# Ensure the more contracted Gaussians being accessed first
l_ctrs_descend = l_ctrs.copy()
l_ctrs_descend[:,1] = -l_ctrs[:,1]
uniq_l_ctr, where, inv_idx, l_ctr_counts = np.unique(
l_ctrs_descend, return_index=True, return_inverse=True, return_counts=True, axis=0)
uniq_l_ctr[:,1] = -uniq_l_ctr[:,1]

nao_orig = coeff.shape[1]
ao_loc = mol.ao_loc
coeff = np.split(coeff, ao_loc[1:-1], axis=0)

pad_bas = []
if tile > 1:
l_ctr_counts_orig = l_ctr_counts.copy()
pad_inv_idx = []
env_ptr = mol._env.size
# for each pattern, padding basis to the end of mol._bas, ensure alignment to tile
for n, (l_ctr, m, counts) in enumerate(zip(uniq_l_ctr, where, l_ctr_counts)):
if counts % tile == 0: continue
n_alined = (counts+tile-1) & (0x100000-tile)
padding = n_alined - counts
l_ctr_counts[n] = n_alined

bas = mol._bas[m].copy()
bas[PTR_COEFF] = env_ptr
pad_bas.extend([bas] * padding)
pad_inv_idx.extend([n] * padding)

l = l_ctr[0]
nf = (l + 1) * (l + 2) // 2
coeff.extend([np.zeros((nf, nao_orig))] * padding)

inv_idx = np.hstack([inv_idx.ravel(), pad_inv_idx])

sorted_idx = np.argsort(inv_idx.ravel(), kind='stable').astype(np.int32)
coeff = np.vstack([coeff[i] for i in sorted_idx])
assert coeff.shape[0] < 32768

max_nprims = uniq_l_ctr[:,1].max()
mol._env = np.append(mol._env, np.zeros(max_nprims))
if pad_bas:
mol._bas = np.vstack([mol._bas, pad_bas])[sorted_idx]
else:
mol._bas = mol._bas[sorted_idx]
assert mol._bas.dtype == np.int32

## Limit the number of AOs in each group
if group_size is not None:
uniq_l_ctr, l_ctr_counts = _split_l_ctr_groups(
uniq_l_ctr, l_ctr_counts, group_size, tile)

if mol.verbose >= logger.DEBUG1:
logger.debug1(mol, 'Number of shells for each [l, nprim] group')
if tile > 1:
for l_ctr, n, n8 in zip(uniq_l_ctr, l_ctr_counts_orig, l_ctr_counts):
logger.debug1(mol, ' %s : %s -> %s', l_ctr, n, n8)
else:
for l_ctr, n in zip(uniq_l_ctr, l_ctr_counts):
logger.debug1(mol, ' %s : %s', l_ctr, n)

# PTR_BAS_COORD is required by various CUDA kernels
mol._bas[:,PTR_BAS_COORD] = mol._atm[mol._bas[:,ATOM_OF],PTR_COORD]
return mol, coeff, uniq_l_ctr, l_ctr_counts

def _split_l_ctr_groups(uniq_l_ctr, l_ctr_counts, group_size, align=1):
'''Splits l_ctr patterns into small groups with group_size the maximum
number of AOs in each group
'''
l = uniq_l_ctr[:,0]
nf = l * (l + 1) // 2
_l_ctrs = []
_l_ctr_counts = []
for l_ctr, counts in zip(uniq_l_ctr, l_ctr_counts):
l = l_ctr[0]
nf = (l + 1) * (l + 2) // 2
max_shells = max(group_size//nf-align+1, align, 2)
max_shells = (max_shells + align - 1) & (0x100000-align)
if counts <= max_shells:
_l_ctrs.append(l_ctr)
_l_ctr_counts.append(counts)
continue

nsubs, remaining = counts.__divmod__(max_shells)
_l_ctrs.extend([l_ctr] * nsubs)
_l_ctr_counts.extend([max_shells] * nsubs)
if remaining > 0:
_l_ctrs.append(l_ctr)
_l_ctr_counts.append(remaining)
uniq_l_ctr = np.vstack(_l_ctrs)
l_ctr_counts = np.hstack(_l_ctr_counts)
return uniq_l_ctr, l_ctr_counts
4 changes: 2 additions & 2 deletions gpu4pyscf/hessian/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _partial_ejk_ip2(mol, dm, vhfopt=None, j_factor=1., k_factor=1., verbose=Non
if vhfopt is None:
vhfopt = _VHFOpt(mol).build()

mol = vhfopt.mol
mol = vhfopt.sorted_mol
nao, nao_orig = vhfopt.coeff.shape

dm = cp.asarray(dm, order='C')
Expand Down Expand Up @@ -487,7 +487,7 @@ def _get_jk(mol, dm, with_j=True, with_k=True, atoms_slice=None, verbose=None):
vhfopt.tile = 1
vhfopt.build()

mol = vhfopt.mol
mol = vhfopt.sorted_mol
nao, nao_orig = vhfopt.coeff.shape

dm = cp.asarray(dm, order='C')
Expand Down
1 change: 1 addition & 0 deletions gpu4pyscf/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ endif()

add_subdirectory(gvhf-rys)
add_subdirectory(gvhf-md)
add_subdirectory(pbc)

option(BUILD_LIBXC "Using libxc for DFT" ON)
if(BUILD_LIBXC)
Expand Down
9 changes: 8 additions & 1 deletion gpu4pyscf/lib/gvhf-md/md_j_driver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ int MD_build_j(double *vj, double *dm, int n_dm, int nao,
return 0;
}

void init_mdj_constant(int shm_size)
int init_mdj_constant(int shm_size)
{
Fold2Index i_in_fold2idx[165];
Fold3Index i_in_fold3idx[495];
Expand All @@ -446,5 +446,12 @@ void init_mdj_constant(int shm_size)
cudaMemcpyToSymbol(c_i_in_fold3idx, i_in_fold3idx, 495*sizeof(Fold3Index));
cudaFuncSetAttribute(md_j_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
set_md_j_unrolled_shm_size();
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "Failed to set CUDA shm size %d: %s\n", shm_size,
cudaGetErrorString(err));
return 1;
}
return 0;
}
}
20 changes: 17 additions & 3 deletions gpu4pyscf/lib/gvhf-rys/rys_jk_driver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ int RYS_per_atom_jk_ip2_type3(double *ejk, double j_factor, double k_factor,
return 0;
}

void RYS_init_constant(int *g_pair_idx, int *offsets,
double *env, int env_size, int shm_size)
int RYS_init_constant(int *g_pair_idx, int *offsets,
double *env, int env_size, int shm_size)
{
// TODO: test whether the constant memory c_env can improve performance
//cudaMemcpyToSymbol(c_env, env, sizeof(double)*env_size);
Expand All @@ -486,9 +486,16 @@ void RYS_init_constant(int *g_pair_idx, int *offsets,
cudaFuncSetAttribute(rys_ejk_ip1_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
cudaFuncSetAttribute(rys_ejk_ip2_type12_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
cudaFuncSetAttribute(rys_ejk_ip2_type3_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "Failed to set CUDA shm size %d: %s\n", shm_size,
cudaGetErrorString(err));
return 1;
}
return 0;
}

void RYS_init_rysj_constant(int shm_size)
int RYS_init_rysj_constant(int shm_size)
{
Fold2Index i_in_fold2idx[165];
Fold3Index i_in_fold3idx[495];
Expand All @@ -512,6 +519,13 @@ void RYS_init_rysj_constant(int shm_size)
cudaMemcpyToSymbol(c_i_in_fold3idx, i_in_fold3idx, 495*sizeof(Fold3Index));
cudaFuncSetAttribute(rys_j_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
cudaFuncSetAttribute(rys_j_with_gout_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "Failed to set CUDA shm size %d: %s\n", shm_size,
cudaGetErrorString(err));
return 1;
}
return 0;
}

int cuda_version()
Expand Down
Loading

0 comments on commit 9d28f26

Please sign in to comment.