Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyabacus/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["scikit-build-core>=0.3.3", "pybind11>=2.10.0"]
requires = ["scikit-build-core<0.10", "pybind11>=2.10.0"]
build-backend = "scikit_build_core.build"


Expand Down
2 changes: 2 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ list(APPEND device_srcs
source_base/module_device/device_helpers.cpp
source_base/module_device/output_device.cpp
source_base/module_device/memory_op.cpp
source_base/module_device/memory_op_dsp.cpp
source_base/module_device/dsp_selector.cpp
source_base/kernels/math_kernel_op.cpp
source_base/kernels/math_kernel_op_vec.cpp

Expand Down
5 changes: 5 additions & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,14 @@ OBJS_BASE=abfs-vector3_order.o\
pulay_mixing.o\
broyden_mixing.o\
memory_op.o\
memory_op_dsp.o\
dsp_selector.o\
device.o\
device_helpers.o\
output_device.o\
parallel_2d.o\


OBJS_CELL=atom_pseudo.o\
atom_spec.o\
pseudo.o\
Expand Down Expand Up @@ -515,6 +518,7 @@ OBJS_XC=xc_functional.o\
exx_info.o\

OBJS_IO=module_parameter/input_conv.o\
module_parameter/dsp_config.o\
module_unk/berryphase.o\
module_bessel/bessel_basis.o\
cal_test.o\
Expand Down Expand Up @@ -592,6 +596,7 @@ OBJS_IO=module_parameter/input_conv.o\
filename.o\
ucell_io.o\


OBJS_IO_LCAO=cal_r_overlap_R.o\
write_orb_info.o\
write_dos_lcao.o\
Expand Down
44 changes: 0 additions & 44 deletions source/source_base/module_device/CMakeLists.txt

This file was deleted.

64 changes: 64 additions & 0 deletions source/source_base/module_device/dsp_selector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "dsp_selector.h"
#include <string>
#include <stdexcept>

#ifdef __DSP

namespace base_device
{
namespace memory
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dsp_selector belongs to base_device::memory, but will be used by every function that needs to be run on DSP.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fine.

{

// Global selector instance
std::unique_ptr<DspSelector> dsp_selector = nullptr;

// Get current DSP selector
DspSelector* get_dsp_selector()
{
if (!dsp_selector)
{
throw std::runtime_error(
"ModuleBase::memory::get_dsp_selector: "
"DSP selector not initialized. Call init_dsp_selector first."
);
}
return dsp_selector.get();
}

// Default DSP selector implementation
class DefaultDspSelector : public DspSelector
{
private:
int rank_ = 0;

public:
int get_rank() const override
{
return rank_;
}

void set_rank(const int rank) override
{
if (rank < 0)
{
throw std::runtime_error(
"ModuleBase::memory::DspSelector: "
"DSP rank must be non-negative"
);
}
rank_ = rank;
}
};


// Create default DSP selector and set it as global
void create_default_selector(const int rank)
{
dsp_selector = std::unique_ptr<DefaultDspSelector>(new DefaultDspSelector());
dsp_selector->set_rank(rank);
}

} // namespace memory
} // namespace base_device

#endif
35 changes: 35 additions & 0 deletions source/source_base/module_device/dsp_selector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef MODULE_DEVICE_DSP_SELECTOR_H_
#define MODULE_DEVICE_DSP_SELECTOR_H_

#ifdef __DSP

#include <memory>

namespace base_device {
namespace memory {

// DSP selector interface
class DspSelector {
public:
virtual ~DspSelector() = default;
// Get DSP rank
virtual int get_rank() const = 0;
// Set DSP rank
virtual void set_rank(const int rank) = 0;
};

// Global selector instance
extern std::unique_ptr<DspSelector> dsp_selector;

// Get current DSP selector
DspSelector* get_dsp_selector();

// Create default DSP selector and set it as global
void create_default_selector(const int rank);

} // namespace memory
} // namespace base_device

#endif // end __DSP

#endif // MODULE_DEVICE_DSP_SELECTOR_H_
74 changes: 0 additions & 74 deletions source/source_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@

#include "source_base/memory.h"
#include "source_base/tool_threading.h"
#ifdef __DSP
#include "source_base/kernels/dsp/dsp_connector.h"
#include "source_base/global_variable.h"
#include "source_io/module_parameter/parameter.h"
#endif

#include <complex>
#include <cstring>
Expand Down Expand Up @@ -442,76 +437,7 @@ template struct delete_memory_op<std::complex<float>, base_device::DEVICE_GPU>;
template struct delete_memory_op<std::complex<double>, base_device::DEVICE_GPU>;
#endif

#ifdef __DSP

template <typename FPTYPE>
struct resize_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(FPTYPE*& arr, const size_t size, const char* record_in)
{
if (arr != nullptr)
{
mtfunc::free_ht(arr);
}
arr = (FPTYPE*)mtfunc::malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK % PARAM.inp.dsp_count);
std::string record_string;
if (record_in != nullptr)
{
record_string = record_in;
}
else
{
record_string = "no_record";
}

if (record_string != "no_record")
{
ModuleBase::Memory::record(record_string, sizeof(FPTYPE) * size);
}
}
};

template <typename FPTYPE>
struct set_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(FPTYPE* arr, const int var, const size_t size)
{
ModuleBase::OMP_PARALLEL([&](int num_thread, int thread_id) {
int beg = 0, len = 0;
ModuleBase::BLOCK_TASK_DIST_1D(num_thread, thread_id, size, (size_t)4096 / sizeof(FPTYPE), beg, len);
memset(arr + beg, var, sizeof(FPTYPE) * len);
});
}
};

template <typename FPTYPE>
struct delete_memory_op_mt<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(FPTYPE* arr)
{
mtfunc::free_ht(arr);
}
};


template struct resize_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;

template struct set_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct set_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;

template struct delete_memory_op_mt<int, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<float, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<double, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<std::complex<float>, base_device::DEVICE_CPU>;
template struct delete_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>;
#endif

template <typename FPTYPE>
void resize_memory(FPTYPE* arr, const size_t size, base_device::AbacusDevice_t device_type)
Expand Down
41 changes: 1 addition & 40 deletions source/source_base/module_device/memory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define MODULE_DEVICE_MEMORY_H_

#include "types.h"
#include "memory_op_dsp.h"

#include <complex>
#include <cstddef>
Expand Down Expand Up @@ -218,47 +219,7 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_GPU>
};
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM

#ifdef __DSP

template <typename FPTYPE, typename Device>
struct resize_memory_op_mt
{
/// @brief Allocate memory for a given pointer. Note this op will free the pointer first.
///
/// Input Parameters
/// \param size : array size
/// \param record_string : label for memory record
///
/// Output Parameters
/// \param arr : allocated array
void operator()(FPTYPE*& arr, const size_t size, const char* record_in = nullptr);
};

template <typename FPTYPE, typename Device>
struct set_memory_op_mt
{
/// @brief memset for DSP memory allocated by mt allocator.
///
/// Input Parameters
/// \param var : the specified constant byte value
/// \param size : array size
///
/// Output Parameters
/// \param arr : output array initialized by the input value
void operator()(FPTYPE* arr, const int var, const size_t size);
};

template <typename FPTYPE, typename Device>
struct delete_memory_op_mt
{
/// @brief free memory for multi-device
///
/// Input Parameters
/// \param arr : the input array
void operator()(FPTYPE* arr);
};

#endif // __DSP

} // end of namespace memory
} // end of namespace base_device
Expand Down
Loading
Loading