Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/gauxc/xc_integrator_settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ struct IntegratorSettingsSNLinK : public IntegratorSettingsEXX {
struct IntegratorSettingsXC { virtual ~IntegratorSettingsXC() noexcept = default; };
struct IntegratorSettingsKS : public IntegratorSettingsXC {
double gks_dtol = 1e-12;
// RKS density matrices are interpreted as one-spin densities by default.
// Set this when the caller provides the spin-summed closed-shell density.
bool rks_density_matrix_is_spin_summed = false;
};

struct IntegratorSettingsEXC_GRAD : public IntegratorSettingsKS {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class IncoreReplicatedXCDeviceIntegrator :
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data, bool do_vxc );
XCDeviceData& device_data, bool do_vxc,
const IntegratorSettingsXC& settings );

void exc_vxc_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
Expand All @@ -127,14 +128,14 @@ class IncoreReplicatedXCDeviceIntegrator :
value_type* VXCy, int64_t ldvxcy,
value_type* VXCx, int64_t ldvxcx, value_type* EXC, value_type *N_EL,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data );
XCDeviceData& device_data, const IntegratorSettingsXC& settings );

void fxc_contraction_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
const value_type* tPs, int64_t ldtps,
const value_type* tPz, int64_t ldtpz,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data);
XCDeviceData& device_data, const IntegratorSettingsXC& settings);

void fxc_contraction_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
Expand All @@ -144,7 +145,7 @@ class IncoreReplicatedXCDeviceIntegrator :
value_type* FXCs, int64_t ldfxcs,
value_type* FXCz, int64_t ldfxcz,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data );
XCDeviceData& device_data, const IntegratorSettingsXC& settings );

void eval_exc_grad_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx,
// Passing nullptr for VXCs disables VXC entirely
nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, EXC, &N_EL,
tasks.begin(), tasks.end(), *device_data_ptr);
tasks.begin(), tasks.end(), *device_data_ptr, ks_settings);
});

GAUXC_MPI_CODE(
Expand Down Expand Up @@ -100,4 +100,3 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::

}
}

Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
IntegratorSettingsEXC_GRAD exc_grad_settings;
if( auto* tmp = dynamic_cast<const IntegratorSettingsEXC_GRAD*>(&settings) ) {
exc_grad_settings = *tmp;
} else if( auto* ks_tmp = dynamic_cast<const IntegratorSettingsKS*>(&settings) ) {
exc_grad_settings.gks_dtol = ks_tmp->gks_dtol;
exc_grad_settings.rks_density_matrix_is_spin_summed =
ks_tmp->rks_density_matrix_is_spin_summed;
}

// Check that Partition Weights have been calculated
Expand Down Expand Up @@ -221,7 +225,8 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
else lwd->eval_collocation_gradient( &device_data );

// Evaluate X matrix and V vars
const auto xmat_fac = is_rks ? 2.0 : 1.0;
const auto xmat_fac =
(is_rks and not exc_grad_settings.rks_density_matrix_is_spin_summed) ? 2.0 : 1.0;
const auto need_lapl = func.needs_laplacian();
const auto need_xmat_grad = not func.is_lda();
auto do_xmat_vvar = [&](density_id den_id) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
// If we can do reductions on the device (e.g. NCCL)
// Don't communicate data back to the host before reduction
this->timer_.time_op("XCIntegrator.LocalWork_EXC_VXC", [&](){
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx, tasks.begin(), tasks.end(),
*device_data_ptr, true);
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx, tasks.begin(), tasks.end(),
*device_data_ptr, true, settings);
});

GAUXC_MPI_CODE(
Expand Down Expand Up @@ -177,7 +177,7 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
this->timer_.time_op("XCIntegrator.LocalWork_EXC_VXC", [&](){
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx,
VXCs, ldvxcs, VXCz, ldvxcz, VXCy, ldvxcy, VXCx, ldvxcx, EXC,
&N_EL, tasks.begin(), tasks.end(), *device_data_ptr);
&N_EL, tasks.begin(), tasks.end(), *device_data_ptr, settings);
});

GAUXC_MPI_CODE(
Expand Down Expand Up @@ -225,7 +225,8 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data, bool do_vxc ) {
XCDeviceData& device_data, bool do_vxc,
const IntegratorSettingsXC& settings ) {
const bool is_gks = (Pz != nullptr) and (Py != nullptr) and (Px != nullptr);
const bool is_uks = (Pz != nullptr) and (Py == nullptr) and (Px == nullptr);
const bool is_rks = (Ps != nullptr) and (not is_uks and not is_gks);
Expand All @@ -243,6 +244,11 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::

if( func.is_mgga() and is_gks ) GAUXC_GENERIC_EXCEPTION("GKS mGGAs NYI!");

IntegratorSettingsKS ks_settings;
if( auto* tmp = dynamic_cast<const IntegratorSettingsKS*>(&settings) ) {
ks_settings = *tmp;
}

// Get basis map
BasisSetMap basis_map(basis,mol);

Expand Down Expand Up @@ -312,7 +318,8 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
else if( func.is_gga() ) lwd->eval_collocation_gradient( &device_data );
else lwd->eval_collocation( &device_data );

const double xmat_fac = is_rks ? 2.0 : 1.0;
const double xmat_fac =
(is_rks and not ks_settings.rks_density_matrix_is_spin_summed) ? 2.0 : 1.0;
const bool need_xmat_grad = func.is_mgga();

// Evaluate X matrix and V vars
Expand Down Expand Up @@ -396,11 +403,12 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::
value_type* VXCy, int64_t ldvxcy,
value_type* VXCx, int64_t ldvxcx, value_type* EXC, value_type *N_EL,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data ) {
XCDeviceData& device_data, const IntegratorSettingsXC& settings ) {

// Get integrate and keep data on device
const bool do_vxc = VXCs;
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx, task_begin, task_end, device_data, do_vxc );
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx,
task_begin, task_end, device_data, do_vxc, settings );
auto rt = detail::as_device_runtime(this->load_balancer_->runtime());
rt.device_backend()->master_queue_synchronize();

Expand All @@ -414,4 +422,3 @@ void IncoreReplicatedXCDeviceIntegrator<ValueType>::

}
}

Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ namespace GauXC::detail {
// Don't communicate data back to the host before reduction
this->timer_.time_op("XCIntegrator.LocalWork_FXC", [&](){
fxc_contraction_local_work_( basis, Ps, ldps, Pz, ldpz, tPs, ldtps, tPz, ldtpz,
tasks.begin(), tasks.end(), *device_data_ptr);
tasks.begin(), tasks.end(), *device_data_ptr, ks_settings);
});

GAUXC_MPI_CODE(
Expand Down Expand Up @@ -127,7 +127,8 @@ namespace GauXC::detail {
// data from device
this->timer_.time_op("XCIntegrator.LocalWork_FXC", [&](){
fxc_contraction_local_work_( basis, Ps, ldps, Pz, ldpz, tPs, ldtps, tPz, ldtpz, &N_EL,
FXCs, ldfxcs, FXCz, ldfxcz, tasks.begin(), tasks.end(), *device_data_ptr);
FXCs, ldfxcs, FXCz, ldfxcz, tasks.begin(), tasks.end(), *device_data_ptr,
ks_settings);
});

GAUXC_MPI_CODE(
Expand Down Expand Up @@ -160,7 +161,7 @@ namespace GauXC::detail {
const value_type* tPs, int64_t ldtps,
const value_type* tPz, int64_t ldtpz,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data) {
XCDeviceData& device_data, const IntegratorSettingsXC& settings) {
const bool is_uks = (Pz != nullptr);
const bool is_rks = !is_uks;
if (not is_rks and not is_uks) {
Expand All @@ -175,6 +176,11 @@ namespace GauXC::detail {
const auto& func = *this->func_;
const auto& mol = this->load_balancer_->molecule();

IntegratorSettingsKS ks_settings;
if( auto* tmp = dynamic_cast<const IntegratorSettingsKS*>(&settings) ) {
ks_settings = *tmp;
}

// Get basis map
BasisSetMap basis_map(basis,mol);

Expand Down Expand Up @@ -243,7 +249,8 @@ namespace GauXC::detail {
else if( func.is_gga() ) lwd->eval_collocation_gradient( &device_data );
else lwd->eval_collocation( &device_data );

const double xmat_fac = is_rks ? 2.0 : 1.0;
const double xmat_fac =
(is_rks and not ks_settings.rks_density_matrix_is_spin_summed) ? 2.0 : 1.0;
const bool need_xmat_grad = func.is_mgga();

// Evaluate X matrix and V vars
Expand Down Expand Up @@ -327,11 +334,11 @@ namespace GauXC::detail {
value_type* FXCs, int64_t ldfxcs,
value_type* FXCz, int64_t ldfxcz,
host_task_iterator task_begin, host_task_iterator task_end,
XCDeviceData& device_data ) {
XCDeviceData& device_data, const IntegratorSettingsXC& settings ) {

// Get integrate and keep data on device
fxc_contraction_local_work_( basis, Ps, ldps, Pz, ldpz, tPs, ldtps, tPz, ldtpz,
task_begin, task_end, device_data);
task_begin, task_end, device_data, settings);
auto rt = detail::as_device_runtime(this->load_balancer_->runtime());
rt.device_backend()->master_queue_synchronize();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ void ReferenceReplicatedXCHostIntegrator<ValueType>::
IntegratorSettingsEXC_GRAD exc_grad_settings;
if( auto* tmp = dynamic_cast<const IntegratorSettingsEXC_GRAD*>(&settings) ) {
exc_grad_settings = *tmp;
} else if( auto* ks_tmp = dynamic_cast<const IntegratorSettingsKS*>(&settings) ) {
exc_grad_settings.gks_dtol = ks_tmp->gks_dtol;
exc_grad_settings.rks_density_matrix_is_spin_summed =
ks_tmp->rks_density_matrix_is_spin_summed;
}

// Get basis map
Expand Down Expand Up @@ -330,7 +334,8 @@ void ReferenceReplicatedXCHostIntegrator<ValueType>::

// Evaluate X matrix (2 * P * B/Bx/By/Bz) -> store in Z
// XXX: This assumes that bfn + gradients are contiguous in memory
const auto xmat_fac = is_rks ? 2.0 : 1.0;
const auto xmat_fac =
(is_rks and not exc_grad_settings.rks_density_matrix_is_spin_summed) ? 2.0 : 1.0;
const int xmat_len = func.is_lda() ? 1 : 4;
lwd->eval_xmat( xmat_len*npts, nbf, nbe, submat_map, xmat_fac, Ps, ldps, basis_eval, nbe,
xNmat, nbe, nbe_scr );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ void ReferenceReplicatedXCHostIntegrator<ValueType>::


// Evaluate X matrix (fac * P * B) -> store in Z
const auto xmat_fac = is_rks ? 2.0 : 1.0; // TODO Fix for spinor RKS input
const auto xmat_fac =
(is_rks and not ks_settings.rks_density_matrix_is_spin_summed) ? 2.0 : 1.0;
lwd->eval_xmat( mgga_dim_scal * npts, nbf, nbe, submat_map, xmat_fac, Ps, ldps, basis_eval, nbe,
zmat, nbe, nbe_scr );

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ void ReferenceReplicatedXCHostIntegrator<ValueType>::


// Evaluate X matrix (fac * P * B) -> store in Z
const auto xmat_fac = is_rks ? 2.0 : 1.0; // TODO Fix for spinor RKS input
const auto xmat_fac =
(is_rks and not ks_settings.rks_density_matrix_is_spin_summed) ? 2.0 : 1.0;
lwd->eval_xmat( mgga_dim_scal * npts, nbf, nbe, submat_map, xmat_fac, Ps, ldps, basis_eval, nbe,
zmat, nbe, nbe_scr );
// X matrix for Pz
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ class ShellBatchedReplicatedXCIntegrator :
value_type* VXCy, int64_t ldvxcy,
value_type* VXCx, int64_t ldvxcx,
value_type* EXC, value_type *N_EL,
host_task_iterator task_begin, host_task_iterator task_end, incore_integrator_type& incore_integrator
host_task_iterator task_begin, host_task_iterator task_end,
incore_integrator_type& incore_integrator,
const IntegratorSettingsXC& ks_settings
);


Expand All @@ -146,7 +148,9 @@ class ShellBatchedReplicatedXCIntegrator :
value_type* VXCz, int64_t ldvxcz,
value_type* VXCy, int64_t ldvxcy,
value_type* VXCx, int64_t ldvxcx,
value_type* EXC, value_type* N_EL, incore_integrator_type& incore_integrator);
value_type* EXC, value_type* N_EL,
incore_integrator_type& incore_integrator,
const IntegratorSettingsXC& ks_settings);
public:

template <typename... Args>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
const value_type* Pz, int64_t ldpz,
const value_type* Py, int64_t ldpy,
const value_type* Px, int64_t ldpx,
value_type* EXC, const IntegratorSettingsXC& /*ks_settings*/) {
value_type* EXC, const IntegratorSettingsXC& ks_settings) {


const auto& basis = this->load_balancer_->basis();
Expand Down Expand Up @@ -84,7 +84,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
this->timer_.time_op("XCIntegrator.LocalWork", [&](){
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx,
nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0, EXC,
&N_EL, tasks.begin(), tasks.end(), incore_integrator );
&N_EL, tasks.begin(), tasks.end(), incore_integrator, ks_settings );
});

// Release ownership of LWD back to this integrator instance
Expand Down Expand Up @@ -134,4 +134,3 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType

}
}

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
eval_exc_grad_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* EXC_GRAD, const IntegratorSettingsXC& settings ) {

GAUXC_GENERIC_EXCEPTION("ShellBatched exc_grad NYI" );
util::unused(m,n,P,ldp,EXC_GRAD);
util::unused(m,n,P,ldp,EXC_GRAD,settings);
}

template <typename BaseIntegratorType, typename IncoreIntegratorType>
Expand All @@ -31,7 +31,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
const value_type* Pz, int64_t lpdz, value_type* EXC_GRAD, const IntegratorSettingsXC& settings ) {

GAUXC_GENERIC_EXCEPTION("ShellBatched exc_grad NYI" );
util::unused(m,n,Ps,ldps,Pz,lpdz,EXC_GRAD);
util::unused(m,n,Ps,ldps,Pz,lpdz,EXC_GRAD,settings);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
value_type* VXCz, int64_t ldvxcz,
value_type* VXCy, int64_t ldvxcy,
value_type* VXCx, int64_t ldvxcx,
value_type* EXC, const IntegratorSettingsXC& /*ks_settings*/) {
value_type* EXC, const IntegratorSettingsXC& ks_settings) {


const auto& basis = this->load_balancer_->basis();
Expand Down Expand Up @@ -98,7 +98,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
this->timer_.time_op("XCIntegrator.LocalWork", [&](){
exc_vxc_local_work_( basis, Ps, ldps, Pz, ldpz, Py, ldpy, Px, ldpx,
VXCs, ldvxcs, VXCz, ldvxcz, VXCy, ldvxcy, VXCx, ldvxcx, EXC,
&N_EL, tasks.begin(), tasks.end(), incore_integrator );
&N_EL, tasks.begin(), tasks.end(), incore_integrator, ks_settings );
});

// Release ownership of LWD back to this integrator instance
Expand Down Expand Up @@ -166,7 +166,8 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
value_type* VXCx, int64_t ldvxcx,
value_type* EXC, value_type *N_EL,
host_task_iterator task_begin, host_task_iterator task_end,
incore_integrator_type& incore_integrator ) {
incore_integrator_type& incore_integrator,
const IntegratorSettingsXC& ks_settings ) {

//incore_integrator.exc_vxc_local_work( basis, P, ldp, VXC, ldvxc, EXC, N_EL, task_begin, task_end, device_data );
//return;
Expand Down Expand Up @@ -227,7 +228,7 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
// Execute task
execute_task_batch( next_task, basis, mol, Ps, ldps, Pz, ldpz,
Py, ldpy, Px, ldpx, VXCs, ldvxcs, VXCz, ldvxcz, VXCy, ldvxcy,
VXCx, ldvxcx, EXC, N_EL, incore_integrator );
VXCx, ldvxcx, EXC, N_EL, incore_integrator, ks_settings );
};


Expand Down Expand Up @@ -287,7 +288,8 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
value_type* VXCy, int64_t ldvxcy,
value_type* VXCx, int64_t ldvxcx,
value_type* EXC, value_type *N_EL,
incore_integrator_type& incore_integrator ) {
incore_integrator_type& incore_integrator,
const IntegratorSettingsXC& ks_settings ) {


// Alias information
Expand Down Expand Up @@ -398,13 +400,13 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType
incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe,
Pz_submat, nbe, Py_submat, nbe, Px_submat, nbe, VXCs_submat, nbe,
VXCz_submat, nbe, VXCy_submat, nbe, VXCx_submat, nbe,
&EXC_tmp, &NEL_tmp, task_begin, task_end, *device_data_ptr_ );
&EXC_tmp, &NEL_tmp, task_begin, task_end, *device_data_ptr_, ks_settings );
} else if constexpr (not IncoreIntegratorType::is_device) {
#endif
incore_integrator.exc_vxc_local_work( basis_subset, Ps_submat, nbe,
Pz_submat, nbe, Py_submat, nbe, Px_submat, nbe, VXCs_submat, nbe,
VXCz_submat, nbe, VXCy_submat, nbe, VXCx_submat, nbe,
&EXC_tmp, &NEL_tmp, IntegratorSettingsKS{}, task_begin, task_end );
&EXC_tmp, &NEL_tmp, ks_settings, task_begin, task_end );
#ifdef GAUXC_HAS_DEVICE
}
#endif
Expand Down Expand Up @@ -444,4 +446,3 @@ void ShellBatchedReplicatedXCIntegrator<BaseIntegratorType, IncoreIntegratorType

}
}

Loading
Loading