Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/xc_integrator/integrator_util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@
#
target_sources( gauxc PRIVATE integrator_common.cxx integral_bounds.cxx
exx_screening.cxx
spherical_harmonics.cxx
onedft_util.cxx )
spherical_harmonics.cxx )
if( GAUXC_ENABLE_ONEDFT )
target_sources( gauxc PRIVATE onedft_util.cxx )
endif()
3 changes: 0 additions & 3 deletions src/xc_integrator/integrator_util/onedft_util.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
#include <gauxc/gauxc_config.hpp>
#include <torch/script.h>
#include <torch/torch.h>
#ifdef GAUXC_HAS_CUDA
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
#endif
#include <nlohmann/json.hpp>
#include <gauxc/xc_integrator/local_work_driver.hpp>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "incore_replicated_xc_device_integrator_exx.hpp"
#include "incore_replicated_xc_device_integrator_fxc_contraction.hpp"
#include "incore_replicated_xc_device_integrator_dd.hpp"
#ifdef GAUXC_HAS_ONEDFT
#include "incore_replicated_xc_device_integrator_onedft.hpp"
#endif
namespace GauXC {
namespace detail {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "device/common/device_blas.hpp"
#include "integrator_util/onedft_util.hpp"
#include <cuda_runtime.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include "device/cuda/cuda_backend.hpp"
#include <cstddef> // for size_t

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "reference_replicated_xc_host_integrator_fxc_contraction.hpp"
#include "reference_replicated_xc_host_integrator_dd_psi.hpp"
#include "reference_replicated_xc_host_integrator_dd_psi_potential.hpp"
#ifdef GAUXC_HAS_ONEDFT
#include "reference_replicated_xc_host_integrator_onedft.hpp"
#endif

namespace GauXC::detail {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,17 @@ class ReferenceReplicatedXCHostIntegrator :
value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;

/// Onedft
#ifdef GAUXC_HAS_ONEDFT
void eval_exc_vxc_onedft_( int64_t m, int64_t n, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz, value_type* VXCs, int64_t ldvxcs,
value_type* VXCz, int64_t ldvxcz, value_type* EXC, const IntegratorSettingsXC& ks_settings ) override;
#else
void eval_exc_vxc_onedft_( int64_t, int64_t, const value_type*, int64_t,
const value_type*, int64_t, value_type*, int64_t,
value_type*, int64_t, value_type*, const IntegratorSettingsXC& ) override {
throw std::runtime_error("OneDFT support not compiled");
}
#endif

/// RKS EXC Gradient
void eval_exc_grad_( int64_t m, int64_t n, const value_type* P, int64_t ldp,
Expand Down Expand Up @@ -154,6 +162,7 @@ class ReferenceReplicatedXCHostIntegrator :

void dd_psi_potential_local_work_( const value_type* X, value_type* Vddx, unsigned max_Ylm );

#ifdef GAUXC_HAS_ONEDFT
void pre_onedft_local_work_( const basis_type& basis, const value_type* Ps, int64_t ldps,
const value_type* Pz, int64_t ldpz, value_type *N_EL,
const bool is_gga, const bool is_mgga, const bool needs_laplacian);
Expand All @@ -163,6 +172,7 @@ class ReferenceReplicatedXCHostIntegrator :
value_type* VXCs, int64_t ldvxcs,
value_type* VXCz, int64_t ldvxcz,
const bool is_gga, const bool is_mgga, const bool needs_laplacian);
#endif


public:
Expand Down