diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 4f4c105781..8afd90a569 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -219,6 +219,7 @@ OBJS_DEEPKS=LCAO_deepks.o\ deepks_vdelta.o\ deepks_vdpre.o\ deepks_vdrpre.o\ + deepks_grad.o\ deepks_pdm.o\ deepks_phialpha.o\ LCAO_deepks_io.o\ diff --git a/source/source_io/module_parameter/input_parameter.h b/source/source_io/module_parameter/input_parameter.h index 34a548b576..efe9ade040 100644 --- a/source/source_io/module_parameter/input_parameter.h +++ b/source/source_io/module_parameter/input_parameter.h @@ -283,6 +283,7 @@ struct Input_para int deepks_bandgap = 0; ///< for bandgap label. QO added 2021-12-15 std::vector deepks_band_range = {-1, 0}; ///< the range of bands to calculate bandgap int deepks_v_delta = 0; ///< for v_delta label. xuan added + bool deepks_grad = false; ///< output descriptor-gradient label intermediates for DeePKS training bool deepks_equiv = false; ///< whether to use equivariant version of DeePKS bool deepks_out_unittest = false; ///< if set to true, prints intermediate quantities that shall ///< be used for making unit test diff --git a/source/source_io/module_parameter/read_input_item_deepks.cpp b/source/source_io/module_parameter/read_input_item_deepks.cpp index 7f69b58dd5..a9e33c9b50 100644 --- a/source/source_io/module_parameter/read_input_item_deepks.cpp +++ b/source/source_io/module_parameter/read_input_item_deepks.cpp @@ -307,6 +307,24 @@ void ReadInput::item_deepks() }; this->add_item(item); } + { + Input_Item item("deepks_grad"); + item.annotation = "output descriptor-gradient label intermediates for DeePKS training"; + item.category = "DeePKS"; + item.type = "Boolean"; + item.default_value = "False"; + item.unit = ""; + item.availability = "NAO basis; requires deepks_out_labels >= 1 and deepks_v_delta < 0"; + read_sync_bool(input.deepks_grad); + item.check_value = [](const Input_Item& item, const Parameter& para) { + if (para.input.deepks_grad && para.input.deepks_out_labels == 0) + { + ModuleBase::WARNING_QUIT("ReadInput", + "deepks_grad requires deepks_out_labels 1"); + } + }; + this->add_item(item); + } { Input_Item item("deepks_out_unittest"); item.annotation = "if set 1, prints intermediate quantities that shall " diff --git a/source/source_lcao/module_deepks/CMakeLists.txt b/source/source_lcao/module_deepks/CMakeLists.txt index 332b0a77db..6722cfd50e 100644 --- a/source/source_lcao/module_deepks/CMakeLists.txt +++ b/source/source_lcao/module_deepks/CMakeLists.txt @@ -13,6 +13,7 @@ if(ENABLE_MLALGO) deepks_vdelta.cpp deepks_vdpre.cpp deepks_vdrpre.cpp + deepks_grad.cpp deepks_pdm.cpp deepks_phialpha.cpp LCAO_deepks_io.cpp diff --git a/source/source_lcao/module_deepks/LCAO_deepks_interface.cpp b/source/source_lcao/module_deepks/LCAO_deepks_interface.cpp index b18b448459..bc34abc5ca 100644 --- a/source/source_lcao/module_deepks/LCAO_deepks_interface.cpp +++ b/source/source_lcao/module_deepks/LCAO_deepks_interface.cpp @@ -10,6 +10,7 @@ #include "source_lcao/module_deepks/deepks_check.h" #include "source_lcao/module_deepks/deepks_descriptor.h" #include "source_lcao/module_deepks/deepks_fpre.h" +#include "source_lcao/module_deepks/deepks_grad.h" #include "source_lcao/module_deepks/deepks_orbital.h" #include "source_lcao/module_deepks/deepks_orbpre.h" #include "source_lcao/module_deepks/deepks_pdm.h" @@ -159,7 +160,13 @@ void LCAO_Deepks_Interface::out_deepks_labels(const double& etot, // new gedm is also useful in cal_f_delta, so it should be ld->gedm if (PARAM.inp.deepks_equiv) { - DeePKS_domain::cal_edelta_gedm_equiv(nat, deepks_param, descriptor, ld->model_deepks, ld->gedm, E_delta, rank); + DeePKS_domain::cal_edelta_gedm_equiv(nat, + deepks_param, + descriptor, + ld->model_deepks, + ld->gedm, + E_delta, + rank); } else // traditional version { @@ -629,7 +636,16 @@ void LCAO_Deepks_Interface::out_deepks_labels(const double& etot, int R_size = DeePKS_domain::get_R_size(*h_deltaR); torch::Tensor overlap_out; torch::Tensor iRmat; - DeePKS_domain::prepare_phialpha_iRmat(nlocal, R_size, deepks_param, phialpha, ucell, orb, *ParaV, GridD, overlap_out, iRmat); + DeePKS_domain::prepare_phialpha_iRmat(nlocal, + R_size, + deepks_param, + phialpha, + ucell, + orb, + *ParaV, + GridD, + overlap_out, + iRmat); const std::string file_overlap = PARAM.globalv.global_out_dir + "deepks_phialpha_r.npy"; LCAO_deepks_io::save_tensor2npy(file_overlap, overlap_out, rank); const std::string file_iRmat = PARAM.globalv.global_out_dir + "deepks_iRmat.npy"; @@ -641,7 +657,66 @@ void LCAO_Deepks_Interface::out_deepks_labels(const double& etot, } //================================================================================ - // 6. Hk + // 6. deepks_grad: descriptor-gradient label intermediates + //================================================================================ + if (PARAM.inp.deepks_grad && is_after_scf) + { + { + torch::Tensor dot_phialpha_hamilt; + DeePKS_domain::cal_phialpha_hamilt_proj(nlocal, + nat, + deepks_param, + phialpha, + *p_ham->getHR(), + ucell, + orb, + *ParaV, + GridD, + dot_phialpha_hamilt); + LCAO_deepks_io::save_tensor2npy(PARAM.globalv.global_out_dir + "deepks_dot_phialpha_hamilt.npy", + dot_phialpha_hamilt, + rank); + + if (PARAM.inp.deepks_scf) + { + // B^T B square matrix: (nat*des_per_atom, nat*des_per_atom) + // R_size: conservatively 2x the phialpha R extent to cover + // all dR2-dR1 differences produced by iterate_ad2. + const int btb_R_size = 2 * DeePKS_domain::get_R_size(*phialpha[0]); + torch::Tensor vdrpre_square; + DeePKS_domain::cal_vdrpre_square(nlocal, + nat, + btb_R_size, + deepks_param, + phialpha, + gevdm, + ucell, + orb, + *ParaV, + GridD, + vdrpre_square); + LCAO_deepks_io::save_tensor2npy(PARAM.globalv.global_out_dir + "deepks_vdrpre_square.npy", + vdrpre_square, + rank); + + // Avoid double-writing: the deepks_v_delta==-2 path already + // outputs deepks_gevdm.npy when deepks_out_labels==1. + const bool gevdm_already_output + = (PARAM.inp.deepks_v_delta == -2 && PARAM.inp.deepks_out_labels == 1); + if (!gevdm_already_output) + { + torch::Tensor gevdm_out; + DeePKS_domain::prepare_gevdm(nat, deepks_param, orb, gevdm, gevdm_out); + LCAO_deepks_io::save_tensor2npy(PARAM.globalv.global_out_dir + "deepks_gevdm.npy", + gevdm_out, + rank); + } + } + } + } + + //================================================================================ + // 7. Hk //================================================================================ if (PARAM.inp.deepks_v_delta > 0) diff --git a/source/source_lcao/module_deepks/deepks_grad.cpp b/source/source_lcao/module_deepks/deepks_grad.cpp new file mode 100644 index 0000000000..0f2260fb8a --- /dev/null +++ b/source/source_lcao/module_deepks/deepks_grad.cpp @@ -0,0 +1,234 @@ +// deepks_grad.cpp -- Descriptor-gradient label building blocks +// See deepks_grad.h for the full derivation and Python usage. + +#ifdef __MLALGO + +#include "deepks_grad.h" + +#include "deepks_iterate.h" +#include "deepks_vdrpre.h" +#include "source_base/parallel_reduce.h" + +// --------------------------------------------------------------------------- +// Step 1: dot_phialpha_hamilt[inl, m1, m2] +// = +// --------------------------------------------------------------------------- +template +void DeePKS_domain::cal_phialpha_hamilt_proj(const int nlocal, + const int nat, + const DeePKS_Param& deepks_param, + const std::vector*>& phialpha, + const hamilt::HContainer& hR, + const UnitCell& ucell, + const LCAO_Orbitals& orb, + const Parallel_Orbitals& pv, + const Grid_Driver& GridD, + torch::Tensor& dot_phialpha_hamilt) +{ + ModuleBase::TITLE("DeePKS_domain", "cal_phialpha_hamilt_proj"); + ModuleBase::timer::start("DeePKS_domain", "cal_phialpha_hamilt_proj"); + + using TorchScalar = typename std::conditional::value, double, c10::complex>::type; + + const int inlmax = deepks_param.inlmax; + const int lmaxd = deepks_param.lmaxd; + const int nm_max = 2 * lmaxd + 1; + + const torch::Dtype dtype = std::is_same::value ? torch::kFloat64 : torch::kComplexDouble; + torch::Tensor dot_ph = torch::zeros({inlmax, nm_max, nm_max}, torch::TensorOptions().dtype(dtype)); + auto dot_ph_acc = dot_ph.accessor(); + + DeePKS_domain::iterate_ad2( + ucell, + GridD, + orb, + false, + [&](const int iat, + const ModuleBase::Vector3& tau0, + const int ibt1, + const ModuleBase::Vector3& tau1, + const int start1, + const int nw1_tot, + ModuleBase::Vector3 dR1, + const int ibt2, + const ModuleBase::Vector3& tau2, + const int start2, + const int nw2_tot, + ModuleBase::Vector3 dR2) { + if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr + || phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr) + { + return; + } + + const ModuleBase::Vector3 R_H = dR2 - dR1; + + const auto row_indexes = pv.get_indexes_row(ibt1); + const auto col_indexes = pv.get_indexes_col(ibt2); + const int row_size = static_cast(row_indexes.size()); + const int col_size = static_cast(col_indexes.size()); + if (row_size == 0 || col_size == 0) + { + return; + } + + const hamilt::BaseMatrix* h_mat = hR.find_matrix(ibt1, ibt2, R_H.x, R_H.y, R_H.z); + if (h_mat == nullptr) + { + return; + } + const TR* h_data = h_mat->get_pointer(); + + hamilt::BaseMatrix* pa1 = phialpha[0]->find_matrix(iat, ibt1, dR1); + hamilt::BaseMatrix* pa2 = phialpha[0]->find_matrix(iat, ibt2, dR2); + + const int T0 = ucell.iat2it[iat]; + const int I0 = ucell.iat2ia[iat]; + + int ib = 0; + for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0) + { + for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0) + { + const int inl = deepks_param.inl_index[T0](I0, L0, N0); + const int nm = 2 * L0 + 1; + + for (int irow = 0; irow < row_size; ++irow) + { + for (int icol = 0; icol < col_size; ++icol) + { + const TR h_val = h_data[irow * col_size + icol]; + if (h_val == TR{}) + { + continue; + } + for (int m1 = 0; m1 < nm; ++m1) + { + const double pa1_val = pa1->get_value(row_indexes[irow], ib + m1); + for (int m2 = 0; m2 < nm; ++m2) + { + dot_ph_acc[inl][m1][m2] += static_cast( + pa1_val * h_val * pa2->get_value(col_indexes[icol], ib + m2)); + } + } + } + } + ib += nm; + } + } + }); + +#ifdef __MPI + if (std::is_same::value) + { + Parallel_Reduce::reduce_all(dot_ph.data_ptr(), dot_ph.numel()); + } + else + { + Parallel_Reduce::reduce_all(reinterpret_cast(dot_ph.data_ptr>()), + 2 * dot_ph.numel()); + } +#endif + + dot_phialpha_hamilt = dot_ph; + + ModuleBase::timer::end("DeePKS_domain", "cal_phialpha_hamilt_proj"); +} + +// --------------------------------------------------------------------------- +void DeePKS_domain::cal_vdrpre_square(const int nlocal, + const int nat, + const int R_size, + const DeePKS_Param& deepks_param, + const std::vector*>& phialpha, + const std::vector& gevdm, + const UnitCell& ucell, + const LCAO_Orbitals& orb, + const Parallel_Orbitals& pv, + const Grid_Driver& GridD, + torch::Tensor& vdrpre_square) +{ + ModuleBase::TITLE("DeePKS_domain", "cal_vdrpre_square"); + ModuleBase::timer::start("DeePKS_domain", "cal_vdrpre_square"); + + torch::Tensor vdr_precalc; + const std::vector> kvec_d_dummy; + DeePKS_domain::cal_vdr_precalc(nlocal, + nat, + 1, + R_size, + deepks_param, + kvec_d_dummy, + phialpha, + gevdm, + ucell, + orb, + pv, + GridD, + vdr_precalc); + + const int N_alpha = nat * deepks_param.des_per_atom; + const torch::Tensor flat = vdr_precalc.reshape({-1, N_alpha}); + const int K = static_cast(flat.size(0)); + const double* flat_data = flat.data_ptr(); + + // vdr_precalc is sparse: only (R, mu, nu) rows with actual phialpha overlaps + // are non-zero. Collect them into a compact matrix M and compute M^T @ M + // instead of flat^T @ flat, skipping the work for the many zero rows. + std::vector touched; + touched.reserve(K); + for (int k = 0; k < K; ++k) + { + const double* row = flat_data + k * N_alpha; + for (int j = 0; j < N_alpha; ++j) + { + if (row[j] != 0.0) + { + touched.push_back(k); + break; + } + } + } + + const int ntouched = static_cast(touched.size()); + torch::Tensor M = torch::empty({ntouched, N_alpha}, torch::TensorOptions().dtype(torch::kFloat64)); + double* M_data = M.data_ptr(); + for (int t = 0; t < ntouched; ++t) + { + const double* src = flat_data + touched[t] * N_alpha; + double* dst = M_data + t * N_alpha; + for (int j = 0; j < N_alpha; ++j) + { + dst[j] = src[j]; + } + } + + vdrpre_square = M.t().mm(M); + + ModuleBase::timer::end("DeePKS_domain", "cal_vdrpre_square"); +} + +template void DeePKS_domain::cal_phialpha_hamilt_proj(const int, + const int, + const DeePKS_Param&, + const std::vector*>&, + const hamilt::HContainer&, + const UnitCell&, + const LCAO_Orbitals&, + const Parallel_Orbitals&, + const Grid_Driver&, + torch::Tensor&); + +template void DeePKS_domain::cal_phialpha_hamilt_proj>( + const int, + const int, + const DeePKS_Param&, + const std::vector*>&, + const hamilt::HContainer>&, + const UnitCell&, + const LCAO_Orbitals&, + const Parallel_Orbitals&, + const Grid_Driver&, + torch::Tensor&); + +#endif diff --git a/source/source_lcao/module_deepks/deepks_grad.h b/source/source_lcao/module_deepks/deepks_grad.h new file mode 100644 index 0000000000..376511f632 --- /dev/null +++ b/source/source_lcao/module_deepks/deepks_grad.h @@ -0,0 +1,64 @@ +#ifndef DEEPKS_GRAD_H +#define DEEPKS_GRAD_H + +#ifdef __MLALGO + +#include "deepks_param.h" +#include "source_base/timer.h" +#include "source_basis/module_ao/parallel_orbitals.h" +#include "source_cell/module_neighbor/sltk_grid_driver.h" +#include "source_lcao/module_hcontainer/hcontainer.h" + +#include + +namespace DeePKS_domain +{ +//------------------------ +// deepks_grad.cpp +//------------------------ +// Building block for the descriptor-gradient label g* = (B^T B)^{-1} B^T H. +// B: vdr_precalc, g: gradient (\partial E_delta / \partial descriptor), H: Hamiltonian(HR) +// +// B_{(mu,nu,R),(I,nl,k)} = sum_{m,m'} v^I_{nl,k,m} * Theta^I_{nl,mu,nu}(R)_{mm'} * v^I_{nl,k,m'} +// where Theta^I_{nl,mu,nu}(R)_{mm'} = sum_{dR1-dR2=R} phi[mu,m,dR1] * phi[nu,m',dR2] +// +// cal_phialpha_hamilt_proj (Step 1) +// dot_phialpha_hamilt[inl, m, m'] +// = +// = sum_{mu,nu,R} phi[mu,m,dR1] * H[mu,nu](dR2-dR1) * phi[nu,m',dR2] +// Shape: (inlmax, nm_max, nm_max). +// +// cal_vdrpre_square: full B^T B matrix for the descriptor-gradient label. +// (B^T B)_{a1,a2} = sum_{R,mu,nu} B_{(mu,nu,R),a1} * B_{(mu,nu,R),a2} +// = flat(vdr_precalc)^T @ flat(vdr_precalc) +// Shape: (nat * des_per_atom, nat * des_per_atom). +// +// Requires gevdm (deepks_scf must be active). +void cal_vdrpre_square(const int nlocal, + const int nat, + const int R_size, + const DeePKS_Param& deepks_param, + const std::vector*>& phialpha, + const std::vector& gevdm, + const UnitCell& ucell, + const LCAO_Orbitals& orb, + const Parallel_Orbitals& pv, + const Grid_Driver& GridD, + torch::Tensor& vdrpre_square); + +template +void cal_phialpha_hamilt_proj(const int nlocal, + const int nat, + const DeePKS_Param& deepks_param, + const std::vector*>& phialpha, + const hamilt::HContainer& hR, + const UnitCell& ucell, + const LCAO_Orbitals& orb, + const Parallel_Orbitals& pv, + const Grid_Driver& GridD, + torch::Tensor& dot_phialpha_hamilt); + +} // namespace DeePKS_domain + +#endif +#endif diff --git a/source/source_lcao/module_deepks/deepks_vdrpre.cpp b/source/source_lcao/module_deepks/deepks_vdrpre.cpp index d640b48a95..71efe0b065 100644 --- a/source/source_lcao/module_deepks/deepks_vdrpre.cpp +++ b/source/source_lcao/module_deepks/deepks_vdrpre.cpp @@ -14,15 +14,15 @@ #include "source_lcao/module_hcontainer/atom_pair.h" void DeePKS_domain::prepare_phialpha_iRmat(const int nlocal, - const int R_size, - const DeePKS_Param& deepks_param, - const std::vector*> phialpha, - const UnitCell& ucell, - const LCAO_Orbitals& orb, - const Parallel_Orbitals& pv, - const Grid_Driver& GridD, - torch::Tensor& overlap, - torch::Tensor& iRmat) + const int R_size, + const DeePKS_Param& deepks_param, + const std::vector*> phialpha, + const UnitCell& ucell, + const LCAO_Orbitals& orb, + const Parallel_Orbitals& pv, + const Grid_Driver& GridD, + torch::Tensor& overlap, + torch::Tensor& iRmat) { ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_iRmat"); ModuleBase::timer::start("DeePKS_domain", "prepare_phialpha_iRmat"); @@ -30,27 +30,24 @@ void DeePKS_domain::prepare_phialpha_iRmat(const int nlocal, // get the maximum nnmax std::vector nnmax_vec(ucell.nat, 0); - DeePKS_domain::iterate_ad1( - ucell, - GridD, - orb, - false, // no trace_alpha - [&](const int iat, - const ModuleBase::Vector3& tau0, - const int ibt, - const ModuleBase::Vector3& tau1, - const int start, - const int nw_tot, - ModuleBase::Vector3 dR) - { - if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr) - { - return; // to next loop - } - nnmax_vec[iat]++; - } - ); - + DeePKS_domain::iterate_ad1(ucell, + GridD, + orb, + false, // no trace_alpha + [&](const int iat, + const ModuleBase::Vector3& tau0, + const int ibt, + const ModuleBase::Vector3& tau1, + const int start, + const int nw_tot, + ModuleBase::Vector3 dR) { + if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr) + { + return; // to next loop + } + nnmax_vec[iat]++; + }); + int nnmax = *std::max_element(nnmax_vec.begin(), nnmax_vec.end()); overlap = torch::zeros({ucell.nat, nnmax, nlocal, deepks_param.des_per_atom}, dtype); torch::Tensor dRmat_tmp = torch::zeros({ucell.nat, nnmax, 3}, torch::kInt32); @@ -58,42 +55,40 @@ void DeePKS_domain::prepare_phialpha_iRmat(const int nlocal, auto dRmat_accessor = dRmat_tmp.accessor(); std::fill(nnmax_vec.begin(), nnmax_vec.end(), 0); - DeePKS_domain::iterate_ad1( - ucell, - GridD, - orb, - false, // no trace_alpha - [&](const int iat, - const ModuleBase::Vector3& tau0, - const int ibt, - const ModuleBase::Vector3& tau1, - const int start, - const int nw_tot, - ModuleBase::Vector3 dR) - { - hamilt::BaseMatrix* overlap_mat = phialpha[0]->find_matrix(iat, ibt, dR); - if (overlap_mat == nullptr) - { - return; // to next loop - } - dRmat_accessor[iat][nnmax_vec[iat]][0] = dR.x; - dRmat_accessor[iat][nnmax_vec[iat]][1] = dR.y; - dRmat_accessor[iat][nnmax_vec[iat]][2] = dR.z; + DeePKS_domain::iterate_ad1(ucell, + GridD, + orb, + false, // no trace_alpha + [&](const int iat, + const ModuleBase::Vector3& tau0, + const int ibt, + const ModuleBase::Vector3& tau1, + const int start, + const int nw_tot, + ModuleBase::Vector3 dR) { + hamilt::BaseMatrix* overlap_mat = phialpha[0]->find_matrix(iat, ibt, dR); + if (overlap_mat == nullptr) + { + return; // to next loop + } + dRmat_accessor[iat][nnmax_vec[iat]][0] = dR.x; + dRmat_accessor[iat][nnmax_vec[iat]][1] = dR.y; + dRmat_accessor[iat][nnmax_vec[iat]][2] = dR.z; - for (int ix = 0; ix < nw_tot; ix++) - { - if (pv.global2local_row(start + ix) < 0 || pv.global2local_col(start + ix) < 0) - { - continue; - } - for (int iy = 0; iy < deepks_param.des_per_atom; iy++) - { - overlap_accessor[iat][nnmax_vec[iat]][start + ix][iy] = overlap_mat->get_value(ix, iy); - } - } - nnmax_vec[iat]++; - } - ); + for (int ix = 0; ix < nw_tot; ix++) + { + if (pv.global2local_row(start + ix) < 0 || pv.global2local_col(start + ix) < 0) + { + continue; + } + for (int iy = 0; iy < deepks_param.des_per_atom; iy++) + { + overlap_accessor[iat][nnmax_vec[iat]][start + ix][iy] + = overlap_mat->get_value(ix, iy); + } + } + nnmax_vec[iat]++; + }); #ifdef __MPI Parallel_Reduce::reduce_all(overlap.data_ptr(), overlap.numel()); #endif @@ -119,124 +114,153 @@ void DeePKS_domain::cal_vdr_precalc(const int nlocal, ModuleBase::TITLE("DeePKS_domain", "calc_vdr_precalc"); ModuleBase::timer::start("DeePKS_domain", "calc_vdr_precalc"); - torch::Tensor vdr_pdm = torch::zeros({R_size, - R_size, - R_size, - nlocal, - nlocal, - deepks_param.inlmax, - (2 * deepks_param.lmaxd + 1), - (2 * deepks_param.lmaxd + 1)}, - torch::TensorOptions().dtype(torch::kFloat64)); - auto accessor = vdr_pdm.accessor(); + const int des_per_atom = deepks_param.des_per_atom; - DeePKS_domain::iterate_ad2(ucell, - GridD, - orb, - false, // no trace_alpha - [&](const int iat, - const ModuleBase::Vector3& tau0, - const int ibt1, - const ModuleBase::Vector3& tau1, - const int start1, - const int nw1_tot, - ModuleBase::Vector3 dR1, - const int ibt2, - const ModuleBase::Vector3& tau2, - const int start2, - const int nw2_tot, - ModuleBase::Vector3 dR2) { - const int T0 = ucell.iat2it[iat]; - const int I0 = ucell.iat2ia[iat]; - if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr - || phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr) - { - return; // to next loop - } + // Shape: (R_size, R_size, R_size, nlocal, nlocal, nat, des_per_atom) + torch::Tensor vdr_prec = torch::zeros({R_size, R_size, R_size, nlocal, nlocal, nat, des_per_atom}, + torch::TensorOptions().dtype(torch::kFloat64)); + auto accessor = vdr_prec.accessor(); - hamilt::BaseMatrix* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1); - hamilt::BaseMatrix* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2); - assert(overlap_1->get_col_size() == overlap_2->get_col_size()); - ModuleBase::Vector3 dR = dR2 - dR1; - int iRx = DeePKS_domain::mapping_R(dR.x); - int iRy = DeePKS_domain::mapping_R(dR.y); - int iRz = DeePKS_domain::mapping_R(dR.z); - // Make sure the index is in range we need to save - if (iRx >= R_size || iRy >= R_size || iRz >= R_size) - { - return; // to next loop - } + DeePKS_domain::iterate_ad2( + ucell, + GridD, + orb, + false, + [&](const int iat, + const ModuleBase::Vector3& tau0, + const int ibt1, + const ModuleBase::Vector3& tau1, + const int start1, + const int nw1_tot, + ModuleBase::Vector3 dR1, + const int ibt2, + const ModuleBase::Vector3& tau2, + const int start2, + const int nw2_tot, + ModuleBase::Vector3 dR2) { + if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr + || phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr) + { + return; + } - for (int iw1 = 0; iw1 < nw1_tot; ++iw1) - { - const int iw1_all = start1 + iw1; // this is \mu - const int iw1_local = pv.global2local_row(iw1_all); - if (iw1_local < 0) - { - continue; - } - for (int iw2 = 0; iw2 < nw2_tot; ++iw2) - { - const int iw2_all = start2 + iw2; // this is \nu - const int iw2_local = pv.global2local_col(iw2_all); - if (iw2_local < 0) - { - continue; - } + hamilt::BaseMatrix* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1); + hamilt::BaseMatrix* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2); + assert(overlap_1->get_col_size() == overlap_2->get_col_size()); - int ib = 0; - for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0) - { - for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0) - { - const int inl = deepks_param.inl_index[T0](I0, L0, N0); - const int nm = 2 * L0 + 1; + const ModuleBase::Vector3 dR = dR2 - dR1; + const int iRx = DeePKS_domain::mapping_R(dR.x); + const int iRy = DeePKS_domain::mapping_R(dR.y); + const int iRz = DeePKS_domain::mapping_R(dR.z); + if (iRx >= R_size || iRy >= R_size || iRz >= R_size) + { + return; + } - for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d - { - for (int m2 = 0; m2 < nm; ++m2) // nm = 1 for s, 3 for p, 5 for d - { - double tmp = overlap_1->get_value(iw1, ib + m1) - * overlap_2->get_value(iw2, ib + m2); - accessor[iRx][iRy][iRz][iw1_all][iw2_all][inl][m1][m2] - += tmp; - } - } - ib += nm; - } - } - } // iw2 - } // iw1 - }); + // Collect MPI-local orbital indices for row (ibt1) and col (ibt2). + std::vector row_iw, row_iw_all, col_iw, col_iw_all; + row_iw.reserve(nw1_tot); + row_iw_all.reserve(nw1_tot); + col_iw.reserve(nw2_tot); + col_iw_all.reserve(nw2_tot); + for (int iw1 = 0; iw1 < nw1_tot; ++iw1) + { + if (pv.global2local_row(start1 + iw1) >= 0) + { + row_iw.push_back(iw1); + row_iw_all.push_back(start1 + iw1); + } + } + for (int iw2 = 0; iw2 < nw2_tot; ++iw2) + { + if (pv.global2local_col(start2 + iw2) >= 0) + { + col_iw.push_back(iw2); + col_iw_all.push_back(start2 + iw2); + } + } + const int nrow = static_cast(row_iw.size()); + const int ncol = static_cast(col_iw.size()); + if (nrow == 0 || ncol == 0) + { + return; + } + const int npairs = nrow * ncol; -#ifdef __MPI - const int size = R_size * R_size * R_size * nlocal * nlocal * deepks_param.inlmax * (2 * deepks_param.lmaxd + 1) - * (2 * deepks_param.lmaxd + 1); - double* data_ptr = vdr_pdm.data_ptr(); - Parallel_Reduce::reduce_all(data_ptr, size); -#endif + int ib = 0; + int nl = 0; + for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0) + { + for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0, ++nl) + { + const int nm = 2 * L0 + 1; + const int nm2 = nm * nm; - // transfer v_delta_pdm to v_delta_pdm_vector - int nlmax = deepks_param.inlmax / nat; - std::vector vdr_pdm_vector; - for (int nl = 0; nl < nlmax; ++nl) - { - int nm = 2 * deepks_param.inl2l[nl] + 1; - torch::Tensor vdr_pdm_sliced - = vdr_pdm.slice(5, nl, deepks_param.inlmax, nlmax).slice(6, 0, nm, 1).slice(7, 0, nm, 1); - vdr_pdm_vector.push_back(vdr_pdm_sliced); - } + // T_flat[k1*ncol+k2, m1*nm+m2] = pa1[iw1,ib+m1] * pa2[iw2,ib+m2] + std::vector T_flat(npairs * nm2); + for (int k1 = 0; k1 < nrow; ++k1) + { + const int iw1 = row_iw[k1]; + for (int k2 = 0; k2 < ncol; ++k2) + { + const int iw2 = col_iw[k2]; + const int k = k1 * ncol + k2; + for (int m1 = 0; m1 < nm; ++m1) + { + const double a = overlap_1->get_value(iw1, ib + m1); + for (int m2 = 0; m2 < nm; ++m2) + { + T_flat[k * nm2 + m1 * nm + m2] = a * overlap_2->get_value(iw2, ib + m2); + } + } + } + } - assert(vdr_pdm_vector.size() == nlmax); + // gdata[v, m1*nm+m2] = gevdm[nl][iat, v, m1, m2] + // gevdm[nl] shape: (nat, nm, nm, nm), contiguous. + const double* gdata = gevdm[nl].data_ptr() + static_cast(iat) * nm * nm2; - // einsum for each nl: - std::vector vdr_vector; - for (int nl = 0; nl < nlmax; ++nl) - { - vdr_vector.push_back(at::einsum("pqrxyamn, avmn->pqrxyav", {vdr_pdm_vector[nl], gevdm[nl]})); - } + // result[k, v] = sum_m T_flat[k, m] * gdata[v, m] + std::vector result(npairs * nm, 0.0); + for (int k = 0; k < npairs; ++k) + { + const double* T_row = T_flat.data() + k * nm2; + for (int v = 0; v < nm; ++v) + { + const double* g_row = gdata + v * nm2; + double sum = 0.0; + for (int m = 0; m < nm2; ++m) + { + sum += T_row[m] * g_row[m]; + } + result[k * nm + v] = sum; + } + } + + for (int k1 = 0; k1 < nrow; ++k1) + { + const int iw1_all = row_iw_all[k1]; + for (int k2 = 0; k2 < ncol; ++k2) + { + const int iw2_all = col_iw_all[k2]; + const double* res_row = result.data() + (k1 * ncol + k2) * nm; + for (int v = 0; v < nm; ++v) + { + accessor[iRx][iRy][iRz][iw1_all][iw2_all][iat][ib + v] += res_row[v]; + } + } + } + + ib += nm; + } + } + }); + +#ifdef __MPI + Parallel_Reduce::reduce_all(vdr_prec.data_ptr(), vdr_prec.numel()); +#endif - vdr_precalc = torch::cat(vdr_vector, -1); + vdr_precalc = vdr_prec; ModuleBase::timer::end("DeePKS_domain", "calc_vdr_precalc"); return;