Skip to content
Open
17 changes: 16 additions & 1 deletion include/bout/fft.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@
#ifndef BOUT_FFT_H
#define BOUT_FFT_H

#include "bout/dcomplex.hxx"
#include "bout/build_defines.hxx"

#include <bout/array.hxx>
#include <bout/bout_enum_class.hxx>
#include <bout/dcomplex.hxx>

#include <string_view>

class Mesh;
class Options;

BOUT_ENUM_CLASS(FFT_MEASUREMENT_FLAG, estimate, measure, exhaustive);
Expand Down Expand Up @@ -111,6 +116,16 @@ Array<dcomplex> rfft(const Array<BoutReal>& in);
/// Expects that `in.size() == (length / 2) + 1`
Array<BoutReal> irfft(const Array<dcomplex>& in, int length);

/// Check simulation is using 1 processor in Z, throw exception if not
///
/// Generally, FFTs must be done over the full Z domain. Currently, most
/// methods using FFTs don't handle parallelising in Z
#if BOUT_CHECK_LEVEL > 0
void assertZSerial(const Mesh& mesh, std::string_view name);
#else
inline void assertZSerial([[maybe_unused]] const Mesh& mesh,
[[maybe_unused]] std::string_view name) {}
#endif
} // namespace fft
} // namespace bout

Expand Down
14 changes: 8 additions & 6 deletions include/bout/mesh.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,12 @@ public:

// non-local communications

virtual int getNXPE() = 0; ///< The number of processors in the X direction
virtual int getNYPE() = 0; ///< The number of processors in the Y direction
virtual int getXProcIndex() = 0; ///< This processor's index in X direction
virtual int getYProcIndex() = 0; ///< This processor's index in Y direction
virtual int getNXPE() const = 0; ///< The number of processors in the X direction
virtual int getNYPE() const = 0; ///< The number of processors in the Y direction
virtual int getNZPE() const = 0; ///< The number of processors in the Z direction
virtual int getXProcIndex() const = 0; ///< This processor's index in X direction
virtual int getYProcIndex() const = 0; ///< This processor's index in Y direction
virtual int getZProcIndex() const = 0; ///< This processor's index in Z direction

// X communications
virtual bool firstX()
Expand All @@ -368,8 +370,6 @@ public:
/// Domain is periodic in X?
bool periodicX{false};

int NXPE, PE_XIND; ///< Number of processors in X, and X processor index

/// Send a buffer of data to processor at X index +1
///
/// @param[in] buffer The data to send. Must be at least length \p size
Expand Down Expand Up @@ -507,8 +507,10 @@ public:

virtual BoutReal GlobalX(int jx) const = 0; ///< Continuous X index between 0 and 1
virtual BoutReal GlobalY(int jy) const = 0; ///< Continuous Y index (0 -> 1)
virtual BoutReal GlobalZ(int jz) const = 0; ///< Continuous Z index (0 -> 1)
virtual BoutReal GlobalX(BoutReal jx) const = 0; ///< Continuous X index between 0 and 1
virtual BoutReal GlobalY(BoutReal jy) const = 0; ///< Continuous Y index (0 -> 1)
virtual BoutReal GlobalZ(BoutReal jz) const = 0; ///< Continuous Z index (0 -> 1)

//////////////////////////////////////////////////////////

Expand Down
4 changes: 4 additions & 0 deletions src/field/field3d.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ FieldPerp pow(const Field3D& lhs, const FieldPerp& rhs, const std::string& rgn)
Field3D filter(const Field3D& var, int N0, const std::string& rgn) {
TRACE("filter(Field3D, int)");

bout::fft::assertZSerial(*var.getMesh(), "`filter`");
checkData(var);

int ncz = var.getNz();
Expand Down Expand Up @@ -683,6 +684,7 @@ Field3D filter(const Field3D& var, int N0, const std::string& rgn) {
Field3D lowPass(const Field3D& var, int zmax, bool keep_zonal, const std::string& rgn) {
TRACE("lowPass(Field3D, {}, {})", zmax, keep_zonal);

bout::fft::assertZSerial(*var.getMesh(), "`lowPass`");
checkData(var);
int ncz = var.getNz();

Expand Down Expand Up @@ -732,6 +734,8 @@ Field3D lowPass(const Field3D& var, int zmax, bool keep_zonal, const std::string
*/
void shiftZ(Field3D& var, int jx, int jy, double zangle) {
TRACE("shiftZ");

bout::fft::assertZSerial(*var.getMesh(), "`shiftZ`");
checkData(var);
var.allocate(); // Ensure that var is unique
Mesh* localmesh = var.getMesh();
Expand Down
18 changes: 17 additions & 1 deletion src/invert/fft_fftw.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@

#include "bout/build_defines.hxx"

#include <bout/coordinates.hxx>
#include <bout/fft.hxx>
#include <bout/globals.hxx>
#include <bout/mesh.hxx>
#include <bout/options.hxx>
#include <bout/unused.hxx>

#if BOUT_HAS_FFTW
#include <bout/constants.hxx>
#include <bout/openmpwrap.hxx>

#include <cmath>
#include <fftw3.h>

#if BOUT_USE_OPENMP
Expand All @@ -46,6 +47,12 @@
#include <bout/boutexception.hxx>
#endif // BOUT_HAS_FFTW

#if BOUT_CHECK_LEVEL > 0
#include <bout/boutexception.hxx>

#include <string_view>
#endif

namespace bout {
namespace fft {

Expand Down Expand Up @@ -527,5 +534,14 @@ Array<BoutReal> irfft(const Array<dcomplex>& in, int length) {
return out;
}

#if BOUT_CHECK_LEVEL > 0
void assertZSerial(const Mesh& mesh, std::string_view name) {
if (mesh.getNZPE() != 1) {
throw BoutException("{} uses FFTs which are currently incompatible with multiple "
"processors in Z (using {})",
name, mesh.getNZPE());
}
}
#endif
} // namespace fft
} // namespace bout
2 changes: 2 additions & 0 deletions src/invert/laplace/impls/cyclic/cyclic_laplace.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ LaplaceCyclic::LaplaceCyclic(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), Acoef(0.0), C1coef(1.0), C2coef(1.0), Dcoef(1.0) {

bout::fft::assertZSerial(*localmesh, "`cyclic` inversion");

Acoef.setLocation(location);
C1coef.setLocation(location);
C2coef.setLocation(location);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ LaplaceIPT::LaplaceIPT(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED
au(ny, nmode), bu(ny, nmode), rl(nmode), ru(nmode), r1(ny, nmode), r2(ny, nmode),
first_call(ny), x0saved(ny, 4, nmode), converged(nmode), fine_error(4, nmode) {

bout::fft::assertZSerial(*localmesh, "`ipt` inversion");

A.setLocation(location);
C.setLocation(location);
D.setLocation(location);

// Number of procs must be a factor of 2
const int n = localmesh->NXPE;
const int n = localmesh->getNXPE();
if (!is_pow2(n)) {
throw BoutException("LaplaceIPT error: NXPE must be a power of 2");
}
Expand Down
2 changes: 2 additions & 0 deletions src/invert/laplace/impls/pcr/pcr.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ LaplacePCR::LaplacePCR(Options* opt, CELL_LOC loc, Mesh* mesh_in, Solver* UNUSED
ncx(localmesh->LocalNx), ny(localmesh->LocalNy), avec(ny, nmode, ncx),
bvec(ny, nmode, ncx), cvec(ny, nmode, ncx) {

bout::fft::assertZSerial(*localmesh, "`pcr` inversion");

Acoef.setLocation(location);
C1coef.setLocation(location);
C2coef.setLocation(location);
Expand Down
2 changes: 2 additions & 0 deletions src/invert/laplace/impls/pcr_thomas/pcr_thomas.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ LaplacePCR_THOMAS::LaplacePCR_THOMAS(Options* opt, CELL_LOC loc, Mesh* mesh_in,
ncx(localmesh->LocalNx), ny(localmesh->LocalNy), avec(ny, nmode, ncx),
bvec(ny, nmode, ncx), cvec(ny, nmode, ncx) {

bout::fft::assertZSerial(*localmesh, "`pcr_thomas` inversion");

Acoef.setLocation(location);
C1coef.setLocation(location);
C2coef.setLocation(location);
Expand Down
3 changes: 3 additions & 0 deletions src/invert/laplace/impls/serial_band/serial_band.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
LaplaceSerialBand::LaplaceSerialBand(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), Acoef(0.0), Ccoef(1.0), Dcoef(1.0) {

bout::fft::assertZSerial(*localmesh, "`band` inversion");

Acoef.setLocation(location);
Ccoef.setLocation(location);
Dcoef.setLocation(location);
Expand Down
7 changes: 4 additions & 3 deletions src/invert/laplace/impls/serial_tri/serial_tri.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@
#include <bout/lapack_routines.hxx>
#include <bout/mesh.hxx>
#include <bout/openmpwrap.hxx>
#include <bout/utils.hxx>
#include <cmath>

#include <bout/output.hxx>
#include <bout/utils.hxx>

LaplaceSerialTri::LaplaceSerialTri(Options* opt, CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), A(0.0), C(1.0), D(1.0) {

bout::fft::assertZSerial(*localmesh, "`tri` inversion");

A.setLocation(location);
C.setLocation(location);
D.setLocation(location);
Expand Down
15 changes: 9 additions & 6 deletions src/invert/laplace/impls/spt/spt.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
LaplaceSPT::LaplaceSPT(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
: Laplacian(opt, loc, mesh_in), Acoef(0.0), Ccoef(1.0), Dcoef(1.0) {

bout::fft::assertZSerial(*localmesh, "`spt` inversion");

Acoef.setLocation(location);
Ccoef.setLocation(location);
Dcoef.setLocation(location);
Expand Down Expand Up @@ -341,14 +344,14 @@ int LaplaceSPT::start(const FieldPerp& b, SPT_data& data) {
// Send data
localmesh->sendXOut(std::begin(data.buffer), 4 * (maxmode + 1), data.comm_tag);

} else if (localmesh->PE_XIND == 1) {
} else if (localmesh->getXProcIndex() == 1) {
// Post a receive
data.recv_handle =
localmesh->irecvXIn(std::begin(data.buffer), 4 * (maxmode + 1), data.comm_tag);
}

data.proc++; // Now moved onto the next processor
if (localmesh->NXPE == 2) {
if (localmesh->getNXPE() == 2) {
data.dir = -1; // Special case. Otherwise reversal handled in spt_continue
}

Expand All @@ -366,7 +369,7 @@ int LaplaceSPT::next(SPT_data& data) {
return 1;
}

if (localmesh->PE_XIND == data.proc) {
if (localmesh->getXProcIndex() == data.proc) {
/// This processor's turn to do inversion

// Wait for data to arrive
Expand Down Expand Up @@ -450,7 +453,7 @@ int LaplaceSPT::next(SPT_data& data) {
}
}

if (localmesh->PE_XIND != 0) { // If not finished yet
if (localmesh->getXProcIndex() != 0) { // If not finished yet
/// Send data

if (data.dir > 0) {
Expand All @@ -460,7 +463,7 @@ int LaplaceSPT::next(SPT_data& data) {
}
}

} else if (localmesh->PE_XIND == data.proc + data.dir) {
} else if (localmesh->getXProcIndex() == data.proc + data.dir) {
// This processor is next, post receive

if (data.dir > 0) {
Expand All @@ -474,7 +477,7 @@ int LaplaceSPT::next(SPT_data& data) {

data.proc += data.dir;

if (data.proc == localmesh->NXPE - 1) {
if (data.proc == localmesh->getNXPE() - 1) {
data.dir = -1; // Reverses direction at the end
}

Expand Down
1 change: 1 addition & 0 deletions src/invert/laplacexz/impls/cyclic/laplacexz-cyclic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
LaplaceXZcyclic::LaplaceXZcyclic(Mesh* m, Options* options, const CELL_LOC loc)
: LaplaceXZ(m, options, loc) {
// Note: `m` may be nullptr, but localmesh is set in LaplaceXZ base constructor
bout::fft::assertZSerial(*localmesh, "`cyclic` X-Z inversion");

// Number of Z Fourier modes, including DC
nmode = (localmesh->LocalNz) / 2 + 1;
Expand Down
2 changes: 2 additions & 0 deletions src/invert/parderiv/impls/cyclic/cyclic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@

InvertParCR::InvertParCR(Options* opt, CELL_LOC location, Mesh* mesh_in)
: InvertPar(opt, location, mesh_in), A(1.0), B(0.0), C(0.0), D(0.0), E(0.0) {

bout::fft::assertZSerial(*localmesh, "InvertParCR");
// Number of k equations to solve for each x location
nsys = 1 + (localmesh->LocalNz) / 2;

Expand Down
1 change: 1 addition & 0 deletions src/invert/pardiv/impls/cyclic/pardiv_cyclic.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

InvertParDivCR::InvertParDivCR(Options* opt, CELL_LOC location, Mesh* mesh_in)
: InvertParDiv(opt, location, mesh_in) {
bout::fft::assertZSerial(*localmesh, "InvertParDivCR");
// Number of k equations to solve for each x location
nsys = 1 + (localmesh->LocalNz) / 2;
}
Expand Down
4 changes: 4 additions & 0 deletions src/mesh/boundary_standard.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,7 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) {
#if not(BOUT_USE_METRIC_3D)
Mesh* mesh = bndry->localmesh;
ASSERT1(mesh == f.getMesh());
bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D");
int ncz = mesh->LocalNz;

Coordinates* metric = f.getCoordinates();
Expand Down Expand Up @@ -2736,6 +2737,7 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) {
#if not(BOUT_USE_METRIC_3D)
Mesh* mesh = bndry->localmesh;
ASSERT1(mesh == f.getMesh());
bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D");
const int ncz = mesh->LocalNz;

ASSERT0(ncz % 2 == 0); // Allocation assumes even number
Expand Down Expand Up @@ -2845,6 +2847,8 @@ void BoundaryNeumann_NonOrthogonal::apply(Field3D& f) {

Mesh* mesh = bndry->localmesh;
ASSERT1(mesh == f.getMesh());
bout::fft::assertZSerial(*mesh, "Zero Laplace on Field3D");

Coordinates* metric = f.getCoordinates();

int ncz = mesh->LocalNz;
Expand Down
4 changes: 2 additions & 2 deletions src/mesh/coordinates.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1662,7 +1662,7 @@ Field3D Coordinates::Delp2(const Field3D& f, CELL_LOC outloc, bool useFFT) {

Field3D result{emptyFrom(f).setLocation(outloc)};

if (useFFT and not bout::build::use_metric_3d) {
if (useFFT and not bout::build::use_metric_3d and localmesh->getNZPE() == 1) {
int ncz = localmesh->LocalNz;

// Allocate memory
Expand Down Expand Up @@ -1731,7 +1731,7 @@ FieldPerp Coordinates::Delp2(const FieldPerp& f, CELL_LOC outloc, bool useFFT) {
int jy = f.getIndex();
result.setIndex(jy);

if (useFFT) {
if (useFFT and localmesh->getNZPE() == 1) {
int ncz = localmesh->LocalNz;

// Allocate memory
Expand Down
3 changes: 1 addition & 2 deletions src/mesh/data/gridfromoptions.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ bool GridFromOptions::get(Mesh* m, std::vector<BoutReal>& var, const std::string
}
case GridDataSource::Z: {
for (int z = 0; z < len; z++) {
pos.set("z",
(TWOPI * (z - m->OffsetZ + offset)) / static_cast<BoutReal>(m->LocalNz));
pos.set("z", TWOPI * m->GlobalZ(z - m->OffsetZ + offset));
var[z] = gen->generate(pos);
}
break;
Expand Down
Loading
Loading