Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ requires = [
"jinja2 ==3.*",
"packaging >=24.0",
"tomlkit >=0.13",
"setuptools>=80",
"setuptools-scm[simple]>=9.2.*"
]
build-backend = "poetry_dynamic_versioning.backend"

Expand Down
20 changes: 1 addition & 19 deletions simpeg_drivers/utils/nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,13 @@
import warnings
from collections.abc import Iterable
from copy import copy
from itertools import chain
from pathlib import Path

import numpy as np
from dask import compute, delayed
from dask.distributed import get_client
from discretize import TensorMesh, TreeMesh
from geoh5py.shared.utils import uuid_from_values
from scipy.optimize import linear_sum_assignment
from scipy.spatial import cKDTree
from scipy.spatial.distance import cdist
from simpeg import data, data_misfit, maps, meta, objective_function
from simpeg.dask.objective_function import DistributedComboMisfits
from simpeg.data_misfit import L2DataMisfit
from simpeg.electromagnetics.base_1d import BaseEM1DSimulation
from simpeg.electromagnetics.frequency_domain.simulation import BaseFDEMSimulation
Expand Down Expand Up @@ -539,21 +533,9 @@ def tile_locations(
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=n_tiles, random_state=0, n_init="auto")
cluster_size = int(np.ceil(grid_locs.shape[0] / n_tiles))
kmeans.fit(grid_locs)

if labels is not None:
cluster_id = kmeans.labels_
else:
# Redistribute cluster centers to even out the number of points
centers = kmeans.cluster_centers_
centers = (
centers.reshape(-1, 1, grid_locs.shape[1])
.repeat(cluster_size, 1)
.reshape(-1, grid_locs.shape[1])
)
distance_matrix = cdist(grid_locs, centers)
cluster_id = linear_sum_assignment(distance_matrix)[1] // cluster_size
cluster_id = kmeans.labels_

tiles = []
for tid in set(cluster_id):
Expand Down
56 changes: 28 additions & 28 deletions tests/locations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,34 +115,34 @@ def test_filter(tmp_path: Path):
assert np.all(filtered_data["key"] == [2, 3, 4])


def test_tile_locations(tmp_path: Path):
with Workspace.create(tmp_path / f"{__name__}.geoh5") as ws:
grid_x, grid_y = np.meshgrid(np.arange(100), np.arange(100))
choices = np.c_[grid_x.ravel(), grid_y.ravel(), np.zeros(grid_x.size)]
inds = np.random.randint(0, 10000, 1000)
pts = Points.create(
ws,
name="test-points",
vertices=choices[inds],
)
tiles = tile_locations(pts.vertices[:, :2], n_tiles=8)

values = np.zeros(pts.n_vertices)
pop = []
for ind, tile in enumerate(tiles):
values[tile] = ind
pop.append(len(tile))

pts.add_data(
{
"values": {
"values": values,
}
}
)
assert np.std(pop) / np.mean(pop) < 0.02, (
"Population of tiles are not almost equal."
)
# TODO Find a scalable algo better than linear_sum_assignment to do even split
# The tiling strategy should yield even "densities" (area x n_receivers)
# def test_tile_locations(tmp_path: Path):
# with Workspace.create(tmp_path / f"{__name__}.geoh5") as ws:
# grid_x, grid_y = np.meshgrid(np.arange(100), np.arange(100))
# choices = np.c_[grid_x.ravel(), grid_y.ravel(), np.zeros(grid_x.size)]
# inds = np.random.randint(0, 10000, 1000)
# pts = Points.create(
# ws,
# name="test-points",
# vertices=choices[inds],


def test_tile_locations():
n_points = 1000
rng = np.random.default_rng(0)
locations = rng.standard_normal((n_points, 2))

tiles = tile_locations(locations, n_tiles=8)

# All indices should be covered exactly once across tiles
all_indices = np.concatenate(tiles)
assert np.array_equal(np.sort(all_indices), np.arange(n_points))

# Tiles should be reasonably balanced in population
pop = np.array([len(tile) for tile in tiles])
assert pop.min() > 0
assert np.std(pop) / np.mean(pop) < 0.5


def test_tile_locations_labels(tmp_path: Path):
Expand Down
Loading