Skip to content

Commit 67e074a

Browse files
authored
Adds a UQ model for primitive screening (#155)
* backup * backukp * backup * refactor estimator * uses normalized coeffs, refactor testing infrastructure * cs screening seems to work * backup * module for screening pairs * finally r2g * run precommit/add missing header * ifdef protect sigma code * blank lines for precommit * address comments * call utils::set_defaults
1 parent 3274ae5 commit 67e074a

44 files changed

Lines changed: 1923 additions & 278 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ cmaize_add_library(
7272
DEPENDS "${project_depends}"
7373
)
7474

75+
cmaize_add_executable(
76+
primitive_error_models
77+
SOURCE_DIR "examples/primitive_error_models"
78+
DEPENDS "${PROJECT_NAME}"
79+
)
80+
7581
include(nwx_pybind11)
7682
nwx_add_pybind11_module(
7783
${PROJECT_NAME}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2026 NWChemEx-Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <integrals/integrals.hpp>
18+
#include <simde/simde.hpp>
19+
20+
/* This example showcases how to:
21+
*
22+
* 1. Compute the analytic error in an ERI4 integral tensor owing to primitive
23+
* pair screening.
24+
*/
25+
26+
namespace {
27+
28+
// This makes a basis set for H2 (bond distance 1.40 a.u.) using STO-3G.
29+
inline simde::type::ao_basis_set h2_sto3g_basis_set() {
30+
using ao_basis_t = simde::type::ao_basis_set;
31+
using atomic_basis_t = simde::type::atomic_basis_set;
32+
using cg_t = simde::type::contracted_gaussian;
33+
using point_t = simde::type::point;
34+
using doubles_t = std::vector<double>;
35+
36+
point_t r0{0.0, 0.0, 0.0};
37+
point_t r1{0.0, 0.0, 1.40};
38+
39+
doubles_t cs{0.1543289673, 0.5353281423, 0.4446345422};
40+
doubles_t es{3.425250914, 0.6239137298, 0.1688554040};
41+
cg_t cg0(cs.begin(), cs.end(), es.begin(), es.end(), r0);
42+
cg_t cg1(cs.begin(), cs.end(), es.begin(), es.end(), r1);
43+
atomic_basis_t h0("sto-3g", 1, r0);
44+
atomic_basis_t h1("sto-3g", 1, r1);
45+
h0.add_shell(chemist::ShellType::cartesian, 0, cg0);
46+
h1.add_shell(chemist::ShellType::cartesian, 0, cg1);
47+
48+
ao_basis_t bs;
49+
bs.add_center(h0);
50+
bs.add_center(h1);
51+
return bs;
52+
}
53+
54+
} // namespace
55+
56+
// Property types for the ERI4 and the error in the ERI4
57+
using eri4_pt = simde::ERI4;
58+
using eri4_error_pt = integrals::property_types::Uncertainty<eri4_pt>;
59+
60+
int main(int argc, char* argv[]) {
61+
// Makes sure the environment doesn't go out of scope before the end.
62+
auto rt = std::make_unique<parallelzone::runtime::RuntimeView>();
63+
64+
// Initializes a ModuleManager object with the integrals plugin
65+
pluginplay::ModuleManager mm(std::move(rt), nullptr);
66+
integrals::load_modules(mm);
67+
integrals::set_defaults(mm);
68+
69+
// Modules for computing analytic error and estimating error
70+
auto& analytic_error_mod = mm.at("Analytic Error");
71+
auto& error_model = mm.at("Primitive Error Model");
72+
73+
// Makes: basis set, direct product of the basis set, and 1/r12 operator
74+
simde::type::aos aos(h2_sto3g_basis_set());
75+
simde::type::aos_squared aos2(aos, aos);
76+
simde::type::v_ee_type op{};
77+
78+
// Make BraKet
79+
chemist::braket::BraKet mnls(aos2, op, aos2);
80+
81+
// Compute the error by screening with tolerance "tol"
82+
double tol = 1E-10;
83+
auto error = analytic_error_mod.run_as<eri4_error_pt>(mnls, tol);
84+
auto approx_error = error_model.run_as<eri4_error_pt>(mnls, tol);
85+
86+
std::cout << "Analytic error: " << error << std::endl;
87+
std::cout << "Estimated error: " << approx_error << std::endl;
88+
89+
return 0;
90+
}

include/integrals/integrals.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
* your unit test also needs most of the headers included by it).
2323
*/
2424
#pragma once
25-
#include "integrals/integrals_mm.hpp"
25+
#include <integrals/integrals_mm.hpp>
26+
#include <integrals/property_types.hpp>
2627

2728
/** @namespace integrals
2829
*

include/integrals/integrals_mm.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,6 @@ namespace integrals {
2727
*/
2828
DECLARE_PLUGIN(integrals);
2929

30+
void set_defaults(pluginplay::ModuleManager&);
31+
3032
} // end namespace integrals

include/integrals/property_types.hpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,64 @@
3030
*/
3131
namespace integrals::property_types {
3232

33+
// PT used to estimate the contribution of primitive pairs
34+
DECLARE_PROPERTY_TYPE(PrimitivePairEstimator);
35+
PROPERTY_TYPE_INPUTS(PrimitivePairEstimator) {
36+
using ao_basis = const simde::type::ao_basis_set&;
37+
auto rv = pluginplay::declare_input()
38+
.add_field<ao_basis>("Bra Basis Set")
39+
.add_field<ao_basis>("Ket Basis Set");
40+
rv["Bra Basis Set"].set_description(
41+
"The atomic orbital basis set for the bra");
42+
rv["Ket Basis Set"].set_description(
43+
"The atomic orbital basis set for the ket");
44+
return rv;
45+
}
46+
47+
PROPERTY_TYPE_RESULTS(PrimitivePairEstimator) {
48+
using tensor = simde::type::tensor;
49+
auto rv = pluginplay::declare_result().add_field<tensor>(
50+
"Primitive Pair Estimates");
51+
rv["Primitive Pair Estimates"].set_description(
52+
"A tensor containing the estimated values for each primitive pair "
53+
"integral");
54+
return rv;
55+
}
56+
57+
DECLARE_PROPERTY_TYPE(PairScreener);
58+
PROPERTY_TYPE_INPUTS(PairScreener) {
59+
using ao_basis = const simde::type::ao_basis_set&;
60+
auto rv = pluginplay::declare_input()
61+
.add_field<ao_basis>("Bra Basis Set")
62+
.add_field<ao_basis>("Ket Basis Set")
63+
.add_field<double>("Tolerance");
64+
return rv;
65+
}
66+
67+
PROPERTY_TYPE_RESULTS(PairScreener) {
68+
using index_vector = std::vector<std::size_t>;
69+
using pair_vector = std::vector<index_vector>;
70+
auto rv = pluginplay::declare_result().add_field<pair_vector>(
71+
"Primitive Pairs Passing Screening");
72+
return rv;
73+
}
74+
75+
template<typename BasePT>
76+
DECLARE_TEMPLATED_PROPERTY_TYPE(Uncertainty, BasePT);
77+
78+
template<typename BasePT>
79+
TEMPLATED_PROPERTY_TYPE_INPUTS(Uncertainty, BasePT) {
80+
auto rv = BasePT::inputs();
81+
auto rv0 = rv.template add_field<double>("Tolerance");
82+
rv0["Tolerance"].set_description("The screening threshold");
83+
return rv0;
84+
}
85+
86+
template<typename BasePT>
87+
TEMPLATED_PROPERTY_TYPE_RESULTS(Uncertainty, BasePT) {
88+
return BasePT::results();
89+
}
90+
3391
using DecontractBasisSet =
3492
simde::Convert<simde::type::ao_basis_set, simde::type::ao_basis_set>;
3593

src/integrals/ao_integrals/ao_integrals.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ inline void set_defaults(pluginplay::ModuleManager& mm) {
4040
mm.change_submod("Density Fitting Integral", "Coulomb Metric",
4141
"Coulomb Metric");
4242
mm.change_submod("UQ Driver", "ERIs", "ERI4");
43+
mm.change_submod("UQ Driver", "ERI Error", "Primitive Error Model");
4344
}
4445

4546
inline void load_modules(pluginplay::ModuleManager& mm) {

src/integrals/ao_integrals/uq_driver.cpp

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
*/
1616

1717
#include "ao_integrals.hpp"
18+
#include <integrals/integrals.hpp>
19+
#ifdef ENABLE_SIGMA
20+
#include <sigma/sigma.hpp>
21+
#endif
1822

1923
using namespace tensorwrapper;
2024

@@ -37,19 +41,25 @@ struct Kernel {
3741
const std::span<FloatType> error) {
3842
Tensor rv;
3943

40-
if constexpr(types::is_uncertain_v<FloatType>) {
41-
auto rv_buffer = buffer::make_contiguous<FloatType>(m_shape);
42-
auto rv_data = buffer::get_raw_data<FloatType>(rv_buffer);
44+
using float_type = std::decay_t<FloatType>;
45+
if constexpr(types::is_uncertain_v<float_type>) {
46+
throw std::runtime_error("Did not expect an uncertain type");
47+
} else {
48+
#ifdef ENABLE_SIGMA
49+
using uq_type = sigma::Uncertain<float_type>;
50+
auto rv_buffer = buffer::make_contiguous<uq_type>(m_shape);
51+
auto rv_data = buffer::get_raw_data<uq_type>(rv_buffer);
4352
for(std::size_t i = 0; i < t.size(); ++i) {
44-
const auto elem = t[i].mean();
45-
const auto elem_error = error[i].mean();
46-
rv_data[i] = FloatType(elem, elem_error);
53+
const auto elem = t[i];
54+
const auto elem_error = error[i];
55+
rv_data[i] = uq_type(elem, elem_error);
4756
}
48-
4957
rv = tensorwrapper::Tensor(m_shape, std::move(rv_buffer));
50-
} else {
51-
throw std::runtime_error("Expected an uncertain type");
58+
#else
59+
throw std::runtime_error("Sigma support not enabled!");
60+
#endif
5261
}
62+
5363
return rv;
5464
}
5565
shape_type m_shape;
@@ -63,35 +73,24 @@ UQ Integrals Driver
6373

6474
} // namespace
6575

66-
using eri_pt = simde::ERI4;
76+
using eri_pt = simde::ERI4;
77+
using error_pt = integrals::property_types::Uncertainty<eri_pt>;
6778

6879
MODULE_CTOR(UQDriver) {
6980
satisfies_property_type<eri_pt>();
7081
description(desc);
7182
add_submodule<eri_pt>("ERIs");
72-
add_input<double>("benchmark precision").set_default(1.0e-16);
73-
add_input<double>("precision").set_default(1.0e-16);
83+
add_submodule<error_pt>("ERI Error");
7484
}
7585

7686
MODULE_RUN(UQDriver) {
77-
auto tau_0 = inputs.at("benchmark precision").value<double>();
78-
auto tau = inputs.at("precision").value<double>();
87+
const auto& [braket] = eri_pt::unwrap_inputs(inputs);
7988

8089
auto& eri_mod = submods.at("ERIs").value();
90+
auto tol = eri_mod.inputs().at("Threshold").value<double>();
8191

82-
auto benchmark_mod = eri_mod.unlocked_copy();
83-
benchmark_mod.change_input("Threshold", tau_0);
84-
benchmark_mod.change_input("With UQ?", true);
85-
86-
auto normal_mod = eri_mod.unlocked_copy();
87-
normal_mod.change_input("Threshold", tau);
88-
normal_mod.change_input("With UQ?", true);
89-
90-
const auto& [t_0] = eri_pt::unwrap_results(benchmark_mod.run(inputs));
91-
const auto& [t] = eri_pt::unwrap_results(normal_mod.run(inputs));
92-
93-
simde::type::tensor error;
94-
error("m,n,l,s") = t("m,n,l,s") - t_0("m,n,l,s");
92+
const auto& t = eri_mod.run_as<eri_pt>(braket);
93+
const auto& error = submods.at("ERI Error").run_as<error_pt>(braket, tol);
9594

9695
using buffer::visit_contiguous_buffer;
9796
shape::Smooth shape = t.buffer().layout().shape().as_smooth().make_smooth();

src/integrals/integrals_mm.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ namespace integrals {
2828
* @throw none No throw guarantee
2929
*/
3030
void set_defaults(pluginplay::ModuleManager& mm) {
31+
libint::set_defaults(mm);
32+
ao_integrals::set_defaults(mm);
33+
utils::set_defaults(mm);
3134
mm.change_submod("AO integral driver", "Kinetic", "Kinetic");
3235
mm.change_submod("AO integral driver", "Electron-Nuclear attraction",
3336
"Nuclear");
@@ -41,8 +44,6 @@ void load_modules(pluginplay::ModuleManager& mm) {
4144
ao_integrals::load_modules(mm);
4245
libint::load_modules(mm);
4346
utils::load_modules(mm);
44-
set_defaults(mm);
45-
ao_integrals::set_defaults(mm);
4647
}
4748

4849
} // namespace integrals
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2026 NWChemEx-Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "libint.hpp"
18+
#include <integrals/integrals.hpp>
19+
20+
namespace integrals::libint {
21+
namespace {
22+
23+
const auto desc = "Uses the error in the ERI4 as the uncertainty.";
24+
}
25+
26+
using eri4_pt = simde::ERI4;
27+
using pt = integrals::property_types::Uncertainty<eri4_pt>;
28+
29+
MODULE_CTOR(AnalyticError) {
30+
satisfies_property_type<pt>();
31+
description(desc);
32+
33+
add_submodule<eri4_pt>("ERI4s");
34+
}
35+
36+
MODULE_RUN(AnalyticError) {
37+
const auto& [braket, tol] = pt::unwrap_inputs(inputs);
38+
39+
auto& eri_mod = submods.at("ERI4s");
40+
41+
auto normal_mod = eri_mod.value().unlocked_copy();
42+
normal_mod.change_input("Threshold", tol);
43+
44+
// N.b., t_0 is the benchmark value
45+
const auto& t_0 = eri_mod.run_as<eri4_pt>(braket);
46+
const auto& t = normal_mod.run_as<eri4_pt>(braket);
47+
48+
simde::type::tensor error;
49+
error("m,n,l,s") = t("m,n,l,s") - t_0("m,n,l,s");
50+
51+
// Wrap and return the results
52+
auto rv = results();
53+
return pt::wrap_results(rv, error);
54+
}
55+
56+
} // namespace integrals::libint

0 commit comments

Comments
 (0)