Skip to content

Commit 7aa420b

Browse files
authored
Merge pull request #135 from MiraGeoscience/GEOPY-1962
GEOPY-1962: Convert all of potential fields to BaseData params class
2 parents 7a0c6aa + 39bdc08 commit 7aa420b

26 files changed

Lines changed: 502 additions & 1257 deletions

simpeg_drivers-assets/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,3 @@
77
# (see LICENSE file at the root of this source code package). '
88
# '
99
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
10-
11-
#
12-
# This file is part of simpeg-drivers.
13-
#
14-
#
15-
# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
16-
#
17-
# This file is part of simpeg_drivers package.
18-
#
19-
# All rights reserved.

simpeg_drivers/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,16 @@ def assets_path() -> Path:
9797
),
9898
"magnetic scalar": (
9999
"simpeg_drivers.potential_fields.magnetic_scalar.driver",
100-
{"inversion": "MagneticScalarDriver"},
100+
{
101+
"forward": "MagneticScalarForwardDriver",
102+
"inversion": "MagneticScalarInversionDriver",
103+
},
101104
),
102105
"magnetic vector": (
103106
"simpeg_drivers.potential_fields.magnetic_vector.driver",
104-
{"inversion": "MagneticVectorDriver"},
107+
{
108+
"forward": "MagneticScalarForwardDriver",
109+
"inversion": "MagneticVectorInversionDriver",
110+
},
105111
),
106112
}

simpeg_drivers/components/data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import numpy as np
2626
from discretize import TreeMesh
27+
from geoh5py.shared.utils import fetch_active_workspace
2728
from scipy.spatial import cKDTree
2829
from simpeg import maps
2930
from simpeg.electromagnetics.static.utils.static_utils import geometric_factor

simpeg_drivers/components/factories/source_factory.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,11 @@ def assemble_keyword_arguments( # pylint: disable=arguments-differ
127127
_ = (receivers, frequency)
128128
kwargs = {}
129129
if self.factory_type in ["magnetic scalar", "magnetic vector"]:
130-
kwargs = dict(
131-
zip(
132-
["amplitude", "inclination", "declination"],
133-
self.params.inducing_field_aid(),
134-
strict=False,
135-
)
136-
)
130+
kwargs = {
131+
"amplitude": self.params.inducing_field_strength,
132+
"inclination": self.params.inducing_field_inclination,
133+
"declination": self.params.inducing_field_declination,
134+
}
137135
if self.factory_type in ["magnetotellurics", "tipper"]:
138136
background = deepcopy(self.params.background_conductivity)
139137

simpeg_drivers/components/locations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,12 @@ def set_z_from_topo(self, locs: np.ndarray):
167167
if locs is None:
168168
return None
169169

170-
topo = self.get_locations(self.params.topography_object)
171-
if self.params.topography is not None:
172-
if isinstance(self.params.topography, Entity):
173-
z = self.params.topography.values
170+
topo = self.get_locations(self.params.active_cells.topography_object)
171+
if self.params.active_cells.topography is not None:
172+
if isinstance(self.params.active_cells.topography, Entity):
173+
z = self.params.active_cells.topography.values
174174
else:
175-
z = np.ones_like(topo[:, 2]) * self.params.topography
175+
z = np.ones_like(topo[:, 2]) * self.params.active_cells.topography
176176

177177
topo[:, 2] = z
178178

simpeg_drivers/params.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def mesh_cannot_be_rotated(cls, value: Octree):
128128
def out_group_if_none(cls, data) -> SimPEGGroup:
129129
group = data.get("out_group", None)
130130

131+
if isinstance(group, SimPEGGroup):
132+
return data
133+
131134
if isinstance(group, UIJsonGroup | type(None)):
132135
name = cls.title if group is None else group.name
133136
with fetch_active_workspace(data["geoh5"], mode="r+") as geoh5:
@@ -140,8 +143,9 @@ def out_group_if_none(cls, data) -> SimPEGGroup:
140143
@model_validator(mode="after")
141144
def update_out_group_options(self):
142145
assert self.out_group is not None
143-
self.out_group.options = self.serialize()
144-
self.out_group.metadata = None
146+
with fetch_active_workspace(self.geoh5):
147+
self.out_group.options = self.serialize()
148+
self.out_group.metadata = None
145149
return self
146150

147151
@property

simpeg_drivers/potential_fields/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@
1010

1111

1212
from .gravity.params import GravityForwardParams, GravityInversionParams
13-
from .magnetic_scalar.params import MagneticScalarParams
14-
from .magnetic_vector.params import MagneticVectorParams
13+
from .magnetic_scalar.params import (
14+
MagneticScalarForwardParams,
15+
MagneticScalarInversionParams,
16+
)
17+
from .magnetic_vector.params import (
18+
MagneticVectorForwardParams,
19+
MagneticVectorInversionParams,
20+
)
1521

1622
# pylint: disable=unused-import
1723
# flake8: noqa

simpeg_drivers/potential_fields/magnetic_scalar/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
1010

1111

12-
from .params import MagneticScalarParams
12+
from .params import MagneticScalarForwardParams, MagneticScalarInversionParams
1313

1414
# pylint: disable=unused-import
1515
# flake8: noqa

simpeg_drivers/potential_fields/magnetic_scalar/driver.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
from __future__ import annotations
1313

1414
from simpeg_drivers.driver import InversionDriver
15+
from simpeg_drivers.potential_fields.magnetic_scalar.constants import validations
16+
from simpeg_drivers.potential_fields.magnetic_scalar.params import (
17+
MagneticScalarForwardParams,
18+
MagneticScalarInversionParams,
19+
)
1520

16-
from .constants import validations
17-
from .params import MagneticScalarParams
1821

19-
20-
class MagneticScalarDriver(InversionDriver):
21-
_params_class = MagneticScalarParams
22+
class MagneticScalarForwardDriver(InversionDriver):
23+
_params_class = MagneticScalarForwardParams
2224
_validations = validations
2325

24-
def __init__(self, params: MagneticScalarParams):
25-
super().__init__(params)
26+
27+
class MagneticScalarInversionDriver(InversionDriver):
28+
_params_class = MagneticScalarInversionParams
29+
_validations = validations

0 commit comments

Comments
 (0)