-
Notifications
You must be signed in to change notification settings - Fork 3
Design Ideas
There are three distinct ways #ifdef DFTEFE_WITH_DEVICE appears in DFTEFE.
Preferred approach: template the class on MemorySpace.
This is the cleanest design when the class owns memory-space-specific data.
template <utils::MemorySpace memorySpace>
class DensityCalculator
{
public:
void computeRho(
const std::vector<RealType> &occupation,
const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
&waveFunc,
quadrature::QuadratureValuesContainer<RealType,
utils::MemorySpace::HOST> &rho);
};Implementation:
template <utils::MemorySpace memorySpace>
void DensityCalculator<memorySpace>::computeRho(
const std::vector<RealType> &occupation,
const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
&waveFunc,
quadrature::QuadratureValuesContainer<RealType,
utils::MemorySpace::HOST> &rho)
{
// internally calls computeRhoInBatch(...)
}template <
typename ValueType,
typename RealType,
utils::MemorySpace memorySpace>
class DensityCalculatorKernels
{
public:
static void computeRhoInBatch(
const utils::MemoryStorage<RealType, memorySpace> &occupationInBatch,
quadrature::QuadratureValuesContainer<ValueType, memorySpace>
&psiBatchQuad,
quadrature::QuadratureValuesContainer<RealType, memorySpace>
&modPsiSqBatchQuad,
std::shared_ptr<const quadrature::QuadratureRuleContainer>
quadRuleContainer,
quadrature::QuadratureValuesContainer<RealType, memorySpace>
&rhoBatch,
linearAlgebra::LinAlgOpContext<memorySpace>
&linAlgOpContext);
};#ifdef DFTEFE_WITH_DEVICE
template <typename ValueType, typename RealType>
class DensityCalculatorKernels<
ValueType,
RealType,
utils::MemorySpace::DEVICE>
{
public:
static void computeRhoInBatch(
const utils::MemoryStorage<
RealType,
utils::MemorySpace::DEVICE> &occupationInBatch,
quadrature::QuadratureValuesContainer<
ValueType,
utils::MemorySpace::DEVICE> &psiBatchQuad,
quadrature::QuadratureValuesContainer<
RealType,
utils::MemorySpace::DEVICE> &modPsiSqBatchQuad,
std::shared_ptr<const quadrature::QuadratureRuleContainer>
quadRuleContainer,
quadrature::QuadratureValuesContainer<
RealType,
utils::MemorySpace::DEVICE> &rhoBatch,
linearAlgebra::LinAlgOpContext<
utils::MemorySpace::DEVICE> &linAlgOpContext);
};
#endifThe memory space is resolved at compile time, so no runtime branching is needed.
Alternative when templating the class would require large refactoring.
This pattern keeps the external interface unchanged.
class ScalarSpatialFunction
{
public:
template <utils::MemorySpace memorySpace>
void eval(
const size_type numPoints,
const double *t,
Q *q) const
{
if constexpr (memorySpace == utils::MemorySpace::DEVICE)
{
#ifdef DFTEFE_WITH_DEVICE
evalDevice(numPoints, t, q);
#else
utils::throwException(
false,
"eval<DEVICE>() called but DEVICE support not compiled.");
#endif
}
else
{
evalHost(numPoints, t, q);
}
}
protected:
virtual void evalHost(
size_type numPoints,
const double *t,
Q *q) const;
#ifdef DFTEFE_WITH_DEVICE
virtual void evalDevice(
size_type numPoints,
const double *t,
Q *q) const;
#endif
};class SmearChargePotentialFunction
: public ScalarSpatialFunctionReal
{
protected:
void evalHost(
size_type numPoints,
const double *t,
double *q) const override;
#ifdef DFTEFE_WITH_DEVICE
void evalDevice(
size_type numPoints,
const double *t,
double *q) const override;
#endif
};- Minimal API disruption
- No class-wide templating
- Preserves polymorphism
This is the hardest case in the current DFTEFE design.
This happens when a class member function itself must execute inside a CUDA/HIP kernel.
Example:
obj->getValueDevice(...)This is problematic because:
- Host objects cannot generally be dereferenced on device
- Virtual dispatch on device is difficult
- Host object pointers must be transferred to device
- Polymorphism does not map cleanly
Expose a device representation.
class SphericalDataNumerical : public SphericalData
{
public:
#ifdef DFTEFE_WITH_DEVICE
struct DeviceView
{
utils::SplineDeviceView radialSpline;
int l;
int m;
int mEff;
double constant;
double cutoff;
double smoothness;
double polarAngleTolerance;
DFTEFE_DEVICE_FUNC
double getValueDevice(
const double *point,
const double *origin) const;
};
DeviceView getDeviceView() const;
DFTEFE_DEVICE_FUNC
double getValueDevice(
const double *point,
const double *origin) const;
#endif
private:
#ifdef DFTEFE_WITH_DEVICE
DeviceView d_deviceView;
#endif
};#ifdef DFTEFE_WITH_DEVICE
SphericalDataNumerical::DeviceView
SphericalDataNumerical::getDeviceView() const
{
DeviceView v;
v.radialSpline = d_spline->getDeviceView();
v.l = d_qNumbers[1];
v.m = d_qNumbers[2];
v.mEff = std::abs(v.m);
v.constant = Clm(v.l, v.m) * Dm(v.m);
v.cutoff = d_cutoff;
v.smoothness = d_smoothness;
v.polarAngleTolerance = d_polarAngleTolerance;
return v;
}
#endifDFTEFE_DEVICE_FUNC
double
SphericalDataNumerical::DeviceView::getValueDevice(
const double *point,
const double *origin) const
{
...
}The key idea is:
- Construct
DeviceViewon host - Copy an array of
DeviceViews to device - Pass the device array directly into kernels
- Invoke device member functions from inside the kernel
namespace
{
DFTEFE_CREATE_KERNEL(
void,
evalEnrichmentInCell,
{
for (size_type iThread = globalThreadId;
iThread < numEnrichInCell * numQuadInCell;
iThread += nThreadsPerBlock * nThreadBlock)
{
const size_type enrichId = iThread % numQuadInCell;
const size_type quadId = iThread / numQuadInCell;
output[iThread] =
views[enrichId].getValueDevice(
quadPtsInCell + quadId * 3,
origin + enrichId * 3);
}
},
const double *quadPtsInCell,
const double *origin,
const size_type numEnrichInCell,
const size_type numQuadInCell,
const atoms::SphericalDataNumerical::DeviceView *views,
double *output);
}The important line is:
views[enrichId].getValueDevice(...)This is possible because views is a device-accessible array of DeviceView objects which have to be created,
void
EnrichmentDataEvalKernels<utils::MemorySpace::DEVICE>::
getEnrichmentValuesInCellRange(
const double *quadPtsInAllCells,
const double *originPtsInAllCells,
std::pair<size_type, size_type> cellRange,
const std::vector<size_type> numEnrichIdsInAllCells,
const std::vector<size_type> numQuadPtsInAllCells,
const atoms::SphericalDataNumerical::DeviceView *data,
double *output,
linearAlgebra::LinAlgOpContext<
utils::MemorySpace::DEVICE> &linAlgOpContext)
{
const size_type numStreams =
linAlgOpContext.numBlasStreams();
auto *streams =
linAlgOpContext.getBlasStreamsVec();
constexpr size_type dim = 3;
size_type cumulativeEnrichInCellRange = 0;
size_type cumulativeQuadPtsInCellRange = 0;
size_type cumulativeQuadxEnrichInCellRange = 0;
for (int iCell = 0; iCell < cellRange.first; iCell++)
{
const size_type numEnrichInCell =
numEnrichIdsInAllCells[iCell];
const size_type numQuadInCell =
numQuadPtsInAllCells[iCell];
cumulativeEnrichInCellRange += numEnrichInCell;
cumulativeQuadPtsInCellRange += numQuadInCell;
cumulativeQuadxEnrichInCellRange +=
numEnrichInCell * numQuadInCell;
}
size_type cumulativeCellWithNonZeroNumEnrich = 0;
for (int iCell = cellRange.first;
iCell < cellRange.second;
iCell++)
{
const size_type numEnrichInCell =
numEnrichIdsInAllCells[iCell];
const size_type numQuadInCell =
numQuadPtsInAllCells[iCell];
if (numEnrichInCell > 0)
{
const size_type sid =
cumulativeCellWithNonZeroNumEnrich % numStreams;
const size_type total =
numEnrichInCell * numQuadInCell;
const size_type blockSize =
utils::DEVICE_BLOCK_SIZE;
const size_type grid =
(total + blockSize - 1) / blockSize;
DFTEFE_LAUNCH_KERNEL(
evalEnrichmentInCell,
grid,
blockSize,
streams[sid],
quadPtsInAllCells +
cumulativeQuadPtsInCellRange * dim,
originPtsInAllCells +
cumulativeEnrichInCellRange * dim,
numEnrichInCell,
numQuadInCell,
data + cumulativeEnrichInCellRange,
output +
cumulativeQuadxEnrichInCellRange);
cumulativeCellWithNonZeroNumEnrich++;
}
cumulativeQuadPtsInCellRange += numQuadInCell;
cumulativeEnrichInCellRange += numEnrichInCell;
cumulativeQuadxEnrichInCellRange +=
numEnrichInCell * numQuadInCell;
}
for (int s = 0; s < numStreams; ++s)
utils::deviceStreamSynchronize(streams[s]);
}The kernel receives:
const atoms::SphericalDataNumerical::DeviceView *viewsEach thread can safely call:
views[i].getValueDevice(...)because:
-
DeviceViewcontains only device-safe data - no host pointers are dereferenced
- all state is explicitly copied to device memory