Skip to content

Port stress/response calculations to the GPU#1187

Merged
mfherbst merged 9 commits intoJuliaMolSim:masterfrom
abussy:stress_gpu
Jan 14, 2026
Merged

Port stress/response calculations to the GPU#1187
mfherbst merged 9 commits intoJuliaMolSim:masterfrom
abussy:stress_gpu

Conversation

@abussy
Copy link
Copy Markdown
Collaborator

@abussy abussy commented Nov 7, 2025

This PR enables ForwardDiff calculations (stress and response) on the GPU. Main changes are:

  1. Data transfer from/to the device where necessary
  2. Various small changes to avoid GPU compiler confusion (e.g. see changes in src/workarounds/forwarddiff_rules.jl)
  3. CPU fall-backs for all XC operations taking place in the DftFunctionals.jl package
  4. Refactoring of the ForwardDiff tests, such that all tests can be run on various architectures (CPU, CUDA, AMDGPU)

With this PR, all ForwardDiff workflows currently tested on the CPU successfully run on both NVIDIA and AMD GPUs.

Future improvements will come with:

@abussy
Copy link
Copy Markdown
Collaborator Author

abussy commented Nov 12, 2025

Merged master. Adapted tests to the refactoring brought by #1182.

Additionally, removed this problematic bit of code in ext/DFTKAMDGPUExt.jl:

# Enable comparisons of Duals on AMD GPUs
_val(x) = x
_val(x::Dual) = _val(ForwardDiff.value(x))
function Base.:<(x::Dual{T,V,N},
                 y::Dual{T,V,N}) where {T,V,N}
    _val(x) < _val(y)
end
function Base.:>(x::Dual{T,V,N},
                 y::Dual{T,V,N}) where {T,V,N}
    _val(x) > _val(y)
end

It turns out that comparison of Duals does not take place on the GPU, as long as all XC operations are done on the CPU. This might become a concern again in the future, once DftFunctionals.jl is refactored.

Comment thread ext/DFTKAMDGPUExt.jl Outdated
Comment thread src/gpu/gpu_arrays.jl Outdated
Comment thread src/workarounds/forwarddiff_rules.jl Outdated
Comment thread src/terms/xc.jl
Comment thread src/workarounds/forwarddiff_rules.jl Outdated
copyto!(y, _mul(p, x))
end
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
function _mul(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this feels strange and is surprising to me. Why did you need this ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this workaround, the GPU compiler throws an invalid LLVM IR error during stress calculations. I think there is confusion around which method of Base.:* to use, but I don't understand why.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this we need to understand.

@niklasschmitz I recall we anyway only needed this because on the AbstractFFT side this was not properly supported. Could it be that now it is and we can drop our type piracy workaround alltogether ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made progress there by properly reading the error message, and it turns out that a more specific definition of Base.:* in cufft takes priority: https://github.com/JuliaGPU/CUDA.jl/blob/44cde93bf03812012da5c883b6532d80a5226268/lib/cufft/fft.jl#L359-L377. While I understand the problem now, I don't see a better way to deal with it than the current solution. Any help/suggestion is welcome.

CUDA.jl overloads Base.:* for more specific types than AbstractFFTs.Plan and AbstractArray, but it breaks with (complex) Duals.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's a bug in CUDA.jl, effectively ? Their typing is too broad as it covers Duals, which they don't support ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's my understanding, yes.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once there is an issue opened and referenced here (please @mention me) this is fine.

Comment thread test/forwarddiff.jl
Comment thread src/terms/local_nonlinearity.jl Outdated
Comment thread test/forwarddiff_gpu.jl Outdated
Comment thread src/terms/xc.jl Outdated
@abussy
Copy link
Copy Markdown
Collaborator Author

abussy commented Jan 7, 2026

I reorganized the ForwardDiff tests:

  • To avoid having a single gigantic file, I split the tests into 3 separate files based on categories. forwarddiff_geometry.jl contains tests based on perturbation of geometry/symmetry. forwarddiff_parameters.jl contains tests on variation of model parameters. forwarddiff_generic.jl contains small generic FD tests. Thanks to @niklasschmitz for helping with the categories.
  • The tests now follow the logic of the various silicon_*.jl files: A @testmodule defines a test function, that is then called with various parameters in different @testitem. In this case, it allows running the same test on CPU and GPU with minimal code duplication. The test definition and calls are now in the same location too.

I additionally addressed the various concerns of @Technici4n on comments in src/terms/xc.jl and src/terms/local_nonlinearity.jl. Finally, I also added a couple of GPU tests for stress calculations.

I believe the last remaining issue is the overload of the Base.:* operators, i.e.:

function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}

which fails to compile with CUDA.

I made progress there by properly reading the error message, and it turns out that a more specific definition of Base.:* in cufft takes priority: https://github.com/JuliaGPU/CUDA.jl/blob/44cde93bf03812012da5c883b6532d80a5226268/lib/cufft/fft.jl#L359-L377. While I understand the problem now, I don't see a better way to deal with it than the current solution. Any help/suggestion is welcome.

Copy link
Copy Markdown
Member

@mfherbst mfherbst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice refactoring !

Comment thread src/terms/xc.jl
Comment thread src/workarounds/forwarddiff_rules.jl Outdated
Comment thread src/workarounds/forwarddiff_rules.jl Outdated
copyto!(y, _mul(p, x))
end
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
function _mul(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this we need to understand.

@niklasschmitz I recall we anyway only needed this because on the AbstractFFT side this was not properly supported. Could it be that now it is and we can drop our type piracy workaround alltogether ?

Comment thread src/workarounds/forwarddiff_rules.jl Outdated
Comment thread test/forwarddiff/generic.jl
Comment thread test/forwarddiff/geometry.jl
Comment thread test/forwarddiff/parameters.jl
Comment thread test/forwarddiff/geometry.jl
Comment thread test/forwarddiff/parameters.jl
Comment thread test/forwarddiff/parameters.jl
@abussy
Copy link
Copy Markdown
Collaborator Author

abussy commented Jan 9, 2026

  • Addressed review comments on the FD tests, and moved them to their own subfolder.
  • Addressed src/workarounds/forwarddiff_rules.jl comments.
  • Rebased on top of current master for merge compatibility

Comment thread test/forwarddiff/parameters.jl
Comment thread src/workarounds/forwarddiff_rules.jl
Comment thread src/workarounds/forwarddiff_rules.jl Outdated
copyto!(y, _mul(p, x))
end
function Base.:*(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
function _mul(p::AbstractFFTs.Plan, x::AbstractArray{<:Complex{<:Dual{Tg}}}) where {Tg}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once there is an issue opened and referenced here (please @mention me) this is fine.

Copy link
Copy Markdown
Member

@mfherbst mfherbst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final nits before we merge.

Comment thread src/workarounds/forwarddiff_rules.jl Outdated
@mfherbst mfherbst enabled auto-merge (squash) January 14, 2026 11:37
@mfherbst mfherbst disabled auto-merge January 14, 2026 14:42
@mfherbst mfherbst merged commit d8d425a into JuliaMolSim:master Jan 14, 2026
7 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants