Skip to content

Design Ideas

Avirup Sircar edited this page May 1, 2026 · 17 revisions

CPU–GPU #ifdef DFTEFE_WITH_DEVICE Design Patterns in DFTEFE

There are three distinct ways #ifdef DFTEFE_WITH_DEVICE appears in DFTEFE.


1. Class stores device data members and launches device kernels internally

Preferred approach: template the class on MemorySpace.

This is the cleanest design when the class owns memory-space-specific data.

Example

template <utils::MemorySpace memorySpace>
class DensityCalculator
{
public:
  void computeRho(
    const std::vector<RealType> &occupation,
    const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
      &waveFunc,
    quadrature::QuadratureValuesContainer<RealType,
                                          utils::MemorySpace::HOST> &rho);
};

Implementation:

template <utils::MemorySpace memorySpace>
void DensityCalculator<memorySpace>::computeRho(
  const std::vector<RealType> &occupation,
  const linearAlgebra::MultiVector<ValueTypeBasisCoeff, memorySpace>
    &waveFunc,
  quadrature::QuadratureValuesContainer<RealType,
                                        utils::MemorySpace::HOST> &rho)
{
  // internally calls computeRhoInBatch(...)
}

Kernel dispatch helper

template <
  typename ValueType,
  typename RealType,
  utils::MemorySpace memorySpace>
class DensityCalculatorKernels
{
public:
  static void computeRhoInBatch(
    const utils::MemoryStorage<RealType, memorySpace> &occupationInBatch,
    quadrature::QuadratureValuesContainer<ValueType, memorySpace>
      &psiBatchQuad,
    quadrature::QuadratureValuesContainer<RealType, memorySpace>
      &modPsiSqBatchQuad,
    std::shared_ptr<const quadrature::QuadratureRuleContainer>
      quadRuleContainer,
    quadrature::QuadratureValuesContainer<RealType, memorySpace>
      &rhoBatch,
    linearAlgebra::LinAlgOpContext<memorySpace>
      &linAlgOpContext);
};

Device specialization

#ifdef DFTEFE_WITH_DEVICE

template <typename ValueType, typename RealType>
class DensityCalculatorKernels<
  ValueType,
  RealType,
  utils::MemorySpace::DEVICE>
{
public:
  static void computeRhoInBatch(
    const utils::MemoryStorage<
      RealType,
      utils::MemorySpace::DEVICE> &occupationInBatch,

    quadrature::QuadratureValuesContainer<
      ValueType,
      utils::MemorySpace::DEVICE> &psiBatchQuad,

    quadrature::QuadratureValuesContainer<
      RealType,
      utils::MemorySpace::DEVICE> &modPsiSqBatchQuad,

    std::shared_ptr<const quadrature::QuadratureRuleContainer>
      quadRuleContainer,

    quadrature::QuadratureValuesContainer<
      RealType,
      utils::MemorySpace::DEVICE> &rhoBatch,

    linearAlgebra::LinAlgOpContext<
      utils::MemorySpace::DEVICE> &linAlgOpContext);
};

#endif

Why this works

The memory space is resolved at compile time, so no runtime branching is needed.


2. Base class exposes host API and internally dispatches to host/device implementations

Alternative when templating the class would require large refactoring.

This pattern keeps the external interface unchanged.

Base class

class ScalarSpatialFunction
{
public:
  template <utils::MemorySpace memorySpace>
  void eval(
    const size_type numPoints,
    const double   *t,
    Q              *q) const
  {
    if constexpr (memorySpace == utils::MemorySpace::DEVICE)
    {
#ifdef DFTEFE_WITH_DEVICE
      evalDevice(numPoints, t, q);
#else
      utils::throwException(
        false,
        "eval<DEVICE>() called but DEVICE support not compiled.");
#endif
    }
    else
    {
      evalHost(numPoints, t, q);
    }
  }

protected:
  virtual void evalHost(
    size_type      numPoints,
    const double  *t,
    Q             *q) const;

#ifdef DFTEFE_WITH_DEVICE
  virtual void evalDevice(
    size_type      numPoints,
    const double  *t,
    Q             *q) const;
#endif
};

Derived class

class SmearChargePotentialFunction
  : public ScalarSpatialFunctionReal
{
protected:
  void evalHost(
    size_type      numPoints,
    const double  *t,
    double        *q) const override;

#ifdef DFTEFE_WITH_DEVICE
  void evalDevice(
    size_type      numPoints,
    const double  *t,
    double        *q) const override;
#endif
};

Advantages

  • Minimal API disruption
  • No class-wide templating
  • Preserves polymorphism

3. Class member functions callable inside device kernels

This is the hardest case in the current DFTEFE design.

This happens when a class member function itself must execute inside a CUDA/HIP kernel.

Example:

obj->getValueDevice(...)

This is problematic because:

  • Host objects cannot generally be dereferenced on device
  • Virtual dispatch on device is difficult
  • Host object pointers must be transferred to device
  • Polymorphism does not map cleanly

Recommended solution: DeviceView by chatgpt

Expose a device representation.


Header: SphericalDataNumerical.h

class SphericalDataNumerical : public SphericalData
{
public:

#ifdef DFTEFE_WITH_DEVICE

  struct DeviceView
  {
    utils::SplineDeviceView radialSpline;

    int l;
    int m;
    int mEff;

    double constant;
    double cutoff;
    double smoothness;
    double polarAngleTolerance;

    DFTEFE_DEVICE_FUNC
    double getValueDevice(
      const double *point,
      const double *origin) const;
  };

  DeviceView getDeviceView() const;

  DFTEFE_DEVICE_FUNC
  double getValueDevice(
    const double *point,
    const double *origin) const;

#endif

private:

#ifdef DFTEFE_WITH_DEVICE
  DeviceView d_deviceView;
#endif
};

DeviceView construction

#ifdef DFTEFE_WITH_DEVICE

SphericalDataNumerical::DeviceView
SphericalDataNumerical::getDeviceView() const
{
  DeviceView v;

  v.radialSpline = d_spline->getDeviceView();

  v.l    = d_qNumbers[1];
  v.m    = d_qNumbers[2];
  v.mEff = std::abs(v.m);

  v.constant            = Clm(v.l, v.m) * Dm(v.m);
  v.cutoff              = d_cutoff;
  v.smoothness          = d_smoothness;
  v.polarAngleTolerance = d_polarAngleTolerance;

  return v;
}

#endif

Device-side implementation

DFTEFE_DEVICE_FUNC
double
SphericalDataNumerical::DeviceView::getValueDevice(
  const double *point,
  const double *origin) const
{
  ...
}

How it is used inside kernels

The key idea is:

  • Construct DeviceView on host
  • Copy an array of DeviceViews to device
  • Pass the device array directly into kernels
  • Invoke device member functions from inside the kernel

Kernel definition

namespace
{
  DFTEFE_CREATE_KERNEL(
    void,
    evalEnrichmentInCell,
    {
      for (size_type iThread = globalThreadId;
           iThread < numEnrichInCell * numQuadInCell;
           iThread += nThreadsPerBlock * nThreadBlock)
      {
        const size_type enrichId = iThread % numQuadInCell;
        const size_type quadId   = iThread / numQuadInCell;

        output[iThread] =
          views[enrichId].getValueDevice(
            quadPtsInCell + quadId * 3,
            origin + enrichId * 3);
      }
    },
    const double *quadPtsInCell,
    const double *origin,
    const size_type numEnrichInCell,
    const size_type numQuadInCell,
    const atoms::SphericalDataNumerical::DeviceView *views,
    double *output);
}

The important line is:

views[enrichId].getValueDevice(...)

This is possible because views is a device-accessible array of DeviceView objects which have to be created,


Host-side kernel launcher

void
EnrichmentDataEvalKernels<utils::MemorySpace::DEVICE>::
getEnrichmentValuesInCellRange(
  const double *quadPtsInAllCells,
  const double *originPtsInAllCells,
  std::pair<size_type, size_type> cellRange,
  const std::vector<size_type> numEnrichIdsInAllCells,
  const std::vector<size_type> numQuadPtsInAllCells,
  const atoms::SphericalDataNumerical::DeviceView *data,
  double *output,
  linearAlgebra::LinAlgOpContext<
    utils::MemorySpace::DEVICE> &linAlgOpContext)
{
  const size_type numStreams =
    linAlgOpContext.numBlasStreams();

  auto *streams =
    linAlgOpContext.getBlasStreamsVec();

  constexpr size_type dim = 3;

  size_type cumulativeEnrichInCellRange = 0;
  size_type cumulativeQuadPtsInCellRange = 0;
  size_type cumulativeQuadxEnrichInCellRange = 0;

  for (int iCell = 0; iCell < cellRange.first; iCell++)
  {
    const size_type numEnrichInCell =
      numEnrichIdsInAllCells[iCell];

    const size_type numQuadInCell =
      numQuadPtsInAllCells[iCell];

    cumulativeEnrichInCellRange += numEnrichInCell;
    cumulativeQuadPtsInCellRange += numQuadInCell;
    cumulativeQuadxEnrichInCellRange +=
      numEnrichInCell * numQuadInCell;
  }

  size_type cumulativeCellWithNonZeroNumEnrich = 0;

  for (int iCell = cellRange.first;
       iCell < cellRange.second;
       iCell++)
  {
    const size_type numEnrichInCell =
      numEnrichIdsInAllCells[iCell];

    const size_type numQuadInCell =
      numQuadPtsInAllCells[iCell];

    if (numEnrichInCell > 0)
    {
      const size_type sid =
        cumulativeCellWithNonZeroNumEnrich % numStreams;

      const size_type total =
        numEnrichInCell * numQuadInCell;

      const size_type blockSize =
        utils::DEVICE_BLOCK_SIZE;

      const size_type grid =
        (total + blockSize - 1) / blockSize;

      DFTEFE_LAUNCH_KERNEL(
        evalEnrichmentInCell,
        grid,
        blockSize,
        streams[sid],
        quadPtsInAllCells +
          cumulativeQuadPtsInCellRange * dim,
        originPtsInAllCells +
          cumulativeEnrichInCellRange * dim,
        numEnrichInCell,
        numQuadInCell,
        data + cumulativeEnrichInCellRange,
        output +
          cumulativeQuadxEnrichInCellRange);

      cumulativeCellWithNonZeroNumEnrich++;
    }

    cumulativeQuadPtsInCellRange += numQuadInCell;
    cumulativeEnrichInCellRange += numEnrichInCell;
    cumulativeQuadxEnrichInCellRange +=
      numEnrichInCell * numQuadInCell;
  }

  for (int s = 0; s < numStreams; ++s)
    utils::deviceStreamSynchronize(streams[s]);
}

Why this works

The kernel receives:

const atoms::SphericalDataNumerical::DeviceView *views

Each thread can safely call:

views[i].getValueDevice(...)

because:

  • DeviceView contains only device-safe data
  • no host pointers are dereferenced
  • all state is explicitly copied to device memory