Skip to content

Commit

Permalink
More PBC modules (#264)
Browse files Browse the repository at this point in the history
* Add AFTDF

* Add PBC GDF code

* Add k-point enabled HF and DFT modules

* Add tests for KSCF

* Refactor pbc.dft.numint

* Lint error

* Fix dtype issue in cupy_helper.contract function

* lint

* typo

* typo

---------

Co-authored-by: Qiming Sun <[email protected]>
Co-authored-by: Xiaojie Wu <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2024
1 parent 5811bb4 commit 4386b44
Show file tree
Hide file tree
Showing 34 changed files with 3,832 additions and 323 deletions.
2 changes: 1 addition & 1 deletion gpu4pyscf/dft/libxc.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def compute(self, inp, output=None, do_exc=True, do_vxc=True, do_fxc=False, do_k
output = _check_arrays(output, output_labels[3:4], xc_func_sizes, npoints, do_kxc)
output = _check_arrays(output, output_labels[4:5], xc_func_sizes, npoints, do_lxc)

args.extend([ inp[x] for x in input_labels])
args.extend([ inp[x].ravel() for x in input_labels])
args.extend([output[x] for x in output_labels])

out_params = xc_lda_out_params()
Expand Down
15 changes: 5 additions & 10 deletions gpu4pyscf/dft/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,7 @@ def _nr_rks_task(ni, mol, grids, xc_code, dms, mo_coeff, mo_occ,
excsum = cupy.zeros(nset)
wv = []
for i in range(nset):
if xctype == 'LDA':
exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i][0], deriv=1, xctype=xctype)[:2]
else:
exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i], deriv=1, xctype=xctype)[:2]
exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i], deriv=1, xctype=xctype)[:2]
vxc = cupy.asarray(vxc, order='C')
exc = cupy.asarray(exc, order='C')
den = rho_tot[i][0] * weights
Expand Down Expand Up @@ -1399,18 +1396,15 @@ def eval_xc_eff(ni, xc_code, rho, deriv=1, omega=None, xctype=None, verbose=None
'''
Different from PySCF, this function employ cuda version libxc
'''
if xctype == 'LDA':
spin_polarized = rho.ndim >= 2
else:
spin_polarized = rho.ndim == 3

if omega is None: omega = ni.omega
if xctype is None: xctype = ni._xc_type(xc_code)

spin_polarized = rho.ndim >= 2 and rho.shape[0] == 2
xcfuns = ni._init_xcfuns(xc_code, spin_polarized)

inp = {}
if not spin_polarized:
assert rho.dtype == np.float64
if xctype == 'LDA':
inp['rho'] = rho
if xctype == 'GGA':
Expand All @@ -1421,8 +1415,9 @@ def eval_xc_eff(ni, xc_code, rho, deriv=1, omega=None, xctype=None, verbose=None
inp['sigma'] = batch_square(rho[1:4])
inp['tau'] = rho[-1] # can be 4 (without laplacian) or 5 (with laplacian)
else:
assert rho[0].dtype == np.float64
if xctype == 'LDA':
inp['rho'] = cupy.stack([rho[0], rho[1]], axis=1)
inp['rho'] = cupy.stack([rho[0].ravel(), rho[1].ravel()], axis=1)
if xctype == 'GGA':
inp['rho'] = cupy.stack([rho[0,0], rho[1,0]], axis=1)
sigma0 = batch_square(rho[0,1:4])
Expand Down
30 changes: 27 additions & 3 deletions gpu4pyscf/lib/cupy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,33 @@ def filter_ret(*args, **kwargs):
return to_cupy(ret)
return filter_ret

def unpack_tril(cderi_tril, cderi, stream=None):
nao = cderi.shape[1]
def pack_tril(a):
if a.ndim == 2:
a = a[None]
n = a.shape[-1]
idx = cupy.arange(n)
mask = idx[:,None] >= idx
return a[:,mask]

def unpack_tril(cderi_tril, cderi=None, stream=None):
assert cderi_tril.flags.c_contiguous
if cderi_tril.ndim == 1:
cderi_tril = cderi_tril[None]
count = cderi_tril.shape[0]
if cderi is None:
nao = int((2*cderi_tril.shape[1])**.5)
cderi = cupy.empty((count,nao,nao), dtype=cderi_tril.dtype)
else:
nao = cderi.shape[1]

if cderi_tril.dtype != np.float64:
idx = cupy.arange(nao)
mask = idx[:,None] >= idx
cderiT = cderi.transpose(0,2,1)
cderiT[:,mask] = cderi_tril.conj()
cderi [:,mask] = cderi_tril
return cderi

if stream is None:
stream = cupy.cuda.get_current_stream()
err = libcupy_helper.unpack_tril(
Expand All @@ -214,7 +238,7 @@ def unpack_tril(cderi_tril, cderi, stream=None):
ctypes.c_int(count))
if err != 0:
raise RuntimeError('failed in unpack_tril kernel')
return
return cderi

def unpack_sparse(cderi_sparse, row, col, p0, p1, nao, out=None, stream=None):
if stream is None:
Expand Down
4 changes: 3 additions & 1 deletion gpu4pyscf/lib/cutensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def contraction(
mode_b = list(str_b)
mode_c = list(str_c)

dtype = np.result_type(a.dtype, b.dtype)
a = cupy.asarray(a, dtype=dtype)
b = cupy.asarray(b, dtype=dtype)
if out is None:
dtype = np.result_type(a, b, alpha)
out = cupy.empty([shape[k] for k in str_c], order='C', dtype=dtype)
c = out

Expand Down
2 changes: 1 addition & 1 deletion gpu4pyscf/lib/diis.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _store(self, key, value):
else:
self._diisfile[key] = value
# to avoid "Unable to find a valid file signature" error when reload the hdf5
# file from a crashed claculation
# file from a crashed calculation
self._diisfile.flush()

def push_err_vec(self, xerr):
Expand Down
14 changes: 14 additions & 0 deletions gpu4pyscf/lib/tests/test_cupy_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest
import numpy
import cupy
from gpu4pyscf.lib import cupy_helper
from gpu4pyscf.lib.cupy_helper import (
take_last2d, transpose_sum, krylov, unpack_sparse,
add_sparse, takebak, empty_mapped, dist_matrix,
Expand Down Expand Up @@ -201,6 +202,19 @@ def test_cart2sph(self):
a_sph1 = cart2sph(a_cart, axis=1, ang=7)
assert cupy.linalg.norm(a_sph0 - a_sph1) < 1e-8
'''

def test_unpack_tril(self):
d = 10
n = 515
npair = n * (n+1) // 2
atril = cupy.random.rand(d, npair) + cupy.random.rand(d, npair)*1j
a = cupy_helper.unpack_tril(atril)
idx, idy = cupy.tril_indices(n)
ref = cupy.empty((d, n, n), dtype=atril.dtype)
ref[:,idy,idx] = atril.conj()
ref[:,idx,idy] = atril
assert abs(a - ref).max() < 1e-12

if __name__ == "__main__":
print("Full tests for cupy helper module")
unittest.main()
10 changes: 4 additions & 6 deletions gpu4pyscf/pbc/df/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from . import fft
#from . import aft
#from . import df
from . import aft
from . import df
from .fft import FFTDF
#from .df import DF, GDF
#from .aft import AFTDF

class DF: pass # Just a placeholder
from .df import GDF
from .aft import AFTDF
213 changes: 213 additions & 0 deletions gpu4pyscf/pbc/df/aft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
#!/usr/bin/env python
#
# Copyright 2024 The GPU4PySCF Developers. All Rights Reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

'''Density expansion on plane waves'''

__all__ = [
'get_pp', 'get_nuc', 'AFTDF'
]

import contextlib
import numpy as np
import cupy as cp
from pyscf import lib
from pyscf import gto
from pyscf.pbc.df import aft as aft_cpu
from pyscf.pbc.gto.pseudo import pp_int
from pyscf.pbc.lib.kpts_helper import is_zero
from pyscf.pbc.df import ft_ao
from pyscf.pbc.df.aft import _check_kpts
from pyscf.pbc.tools import k2gamma
from gpu4pyscf.pbc.tools.pbc import get_coulG
from gpu4pyscf.pbc.df import aft_jk
from gpu4pyscf.lib import logger, utils
from gpu4pyscf.lib.cupy_helper import return_cupy_array, contract, unpack_tril

KE_SCALING = aft_cpu.KE_SCALING

def _get_pp_loc_part1(mydf, kpts=None, with_pseudo=True):
kpts, is_single_kpt = _check_kpts(mydf, kpts)
log = logger.new_logger(mydf)
cell = mydf.cell
mesh = np.asarray(mydf.mesh)
nkpts = len(kpts)
nao = cell.nao_nr()
nao_pair = nao * (nao+1) // 2

kpt_allow = np.zeros(3)
if cell.dimension > 0:
ke_guess = aft_cpu.estimate_ke_cutoff(cell, cell.precision)
mesh_guess = cell.cutoff_to_mesh(ke_guess)
if np.any(mesh < mesh_guess*KE_SCALING):
logger.warn(mydf, 'mesh %s is not enough for AFTDF.get_nuc function '
'to get integral accuracy %g.\nRecommended mesh is %s.',
mesh, cell.precision, mesh_guess)
log.debug1('aft.get_pp_loc_part1 mesh = %s', mesh)
Gv, Gvbase, kws = cell.get_Gv_weights(mesh)

if with_pseudo:
vpplocG = pp_int.get_gth_vlocG_part1(cell, Gv)
vpplocG = -np.einsum('ij,ij->j', cell.get_SI(Gv), vpplocG)
vpplocG = cp.asarray(vpplocG)
else:
fakenuc = aft_cpu._fake_nuc(cell, with_pseudo=with_pseudo)
aoaux = cp.asarray(ft_ao.ft_ao(fakenuc, Gv))
charges = cp.asarray(cell.atom_charges(), dtype=np.float64)
coulG = get_coulG(cell, kpt_allow, mesh=mesh, Gv=Gv)
vpplocG = contract('i,xi->x', -charges, aoaux)
vpplocG *= coulG

vpplocG *= cp.asarray(kws)
vj = cp.zeros((nkpts, nao_pair), dtype=np.complex128)
for Gpq, p0, p1 in mydf.ft_loop(mesh, kpt_allow, kpts, aosym='s2'):
vj += contract('kGx,G->kx', Gpq, vpplocG[p0:p1].conj())

vj_kpts = unpack_tril(vj)
if is_zero(kpts):
vj_kpts = vj_kpts.real
if is_single_kpt:
vj_kpts = vj_kpts[0]
return vj_kpts

def get_pp(mydf, kpts=None):
'''Get the periodic pseudopotential nuc-el AO matrix, with G=0 removed.
Kwargs:
mesh: custom mesh grids. By default mesh is determined by the
function _guess_eta from module pbc.df.gdf_builder.
'''
cell = mydf.cell
kpts, is_single_kpt = aft_cpu._check_kpts(mydf, kpts)
vpp = _get_pp_loc_part1(mydf, kpts, with_pseudo=True)
pp2builder = aft_cpu._IntPPBuilder(cell, kpts)
vpp += cp.asarray(pp2builder.get_pp_loc_part2())
vpp += cp.asarray(pp_int.get_pp_nl(cell, kpts))
if is_single_kpt:
vpp = vpp[0]
return vpp


def get_nuc(mydf, kpts=None):
'''Get the periodic nuc-el AO matrix, with G=0 removed.
Kwargs:
function _guess_eta from module pbc.df.gdf_builder.
'''
return _get_pp_loc_part1(mydf, kpts, with_pseudo=False)


class AFTDFMixin:

weighted_coulG = return_cupy_array(aft_cpu.weighted_coulG)
pw_loop = NotImplemented

def ft_loop(self, mesh=None, q=np.zeros(3), kpts=None, shls_slice=None,
max_memory=4000, aosym='s1', intor='GTO_ft_ovlp', comp=1,
bvk_kmesh=None, return_complex=True):
'''
Fourier transform iterator for all kpti which satisfy
2pi*N = (kpts - kpti - q)*a, N = -1, 0, 1
The tensors returned by this function is different to the one in PySCF CPU version
'''
assert return_complex
cell = self.cell
if mesh is None:
mesh = self.mesh
if kpts is None:
assert (is_zero(q))
kpts = self.kpts
kpts = np.asarray(kpts)
nkpts = len(kpts)

nao = cell.nao
Gv, Gvbase, kws = cell.get_Gv_weights(mesh)
gxyz = lib.cartesian_prod([np.arange(len(x)) for x in Gvbase])
ngrids = gxyz.shape[0]

assert shls_slice is None
if aosym == 's2':
nij = nao * (nao+1) // 2
else:
nij = nao * nao

if bvk_kmesh is None:
bvk_kmesh = k2gamma.kpts_to_kmesh(cell, kpts)

rcut = ft_ao.estimate_rcut(cell)
supmol = ft_ao.ExtendedMole.from_cell(cell, bvk_kmesh, rcut.max())
supmol = supmol.strip_basis(rcut)
ft_kern = supmol.gen_ft_kernel(aosym, intor=intor, comp=comp,
return_complex=True)

blksize = max(16, int(max_memory*.9e6/(nij*nkpts*16*comp)))
blksize = min(blksize, ngrids, 16384)

for p0, p1 in lib.prange(0, ngrids, blksize):
dat = ft_kern(Gv[p0:p1], gxyz[p0:p1], Gvbase, q, kpts, shls_slice)
yield cp.asarray(dat), p0, p1

range_coulomb = aft_cpu.AFTDFMixin.range_coulomb


class AFTDF(lib.StreamObject, AFTDFMixin):
'''Density expansion on plane waves
'''

_keys = aft_cpu.AFTDF._keys

__init__ = aft_cpu.AFTDF.__init__
dump_flags = aft_cpu.AFTDF.dump_flags
reset = aft_cpu.AFTDF.reset
check_sanity = aft_cpu.AFTDF.check_sanity
build = aft_cpu.AFTDF.build

get_nuc = get_nuc
get_pp = get_pp

# Note: Special exxdiv by default should not be used for an arbitrary
# input density matrix. When the df object was used with the molecular
# post-HF code, get_jk was often called with an incomplete DM (e.g. the
# core DM in CASCI). An SCF level exxdiv treatment is inadequate for
# post-HF methods.
def get_jk(self, dm, hermi=1, kpts=None, kpts_band=None,
with_j=True, with_k=True, omega=None, exxdiv=None):
if omega is not None: # J/K for RSH functionals
with self.range_coulomb(omega) as rsh_df:
return rsh_df.get_jk(dm, hermi, kpts, kpts_band, with_j, with_k,
omega=None, exxdiv=exxdiv)

kpts, is_single_kpt = _check_kpts(self, kpts)
if is_single_kpt:
return aft_jk.get_jk(self, dm, hermi, kpts[0], kpts_band, with_j,
with_k, exxdiv)

vj = vk = None
if with_k:
vk = aft_jk.get_k_kpts(self, dm, hermi, kpts, kpts_band, exxdiv)
if with_j:
vj = aft_jk.get_j_kpts(self, dm, hermi, kpts, kpts_band)
return vj, vk

get_eri = get_ao_eri = NotImplemented
ao2mo = get_mo_eri = NotImplemented
ao2mo_7d = NotImplemented
get_ao_pairs_G = get_ao_pairs = NotImplemented
get_mo_pairs_G = get_mo_pairs = NotImplemented

to_gpu = utils.to_gpu
device = utils.device
to_cpu = utils.to_cpu
Loading

0 comments on commit 4386b44

Please sign in to comment.