Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
[1.0.x] - YYYY-MM-DD
--------------------

**Features**

- ``ts.samples(population=...)`` now accepts dictionaries to filter samples
by population metadata. (:user:`hyanwong`, :issue:`1697` :pr:`3345`)

**Bugfixes**

- ``ts.samples(population=...)`` now raises a ``ValueError`` if the population
Expand Down
82 changes: 81 additions & 1 deletion python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,91 @@ def test_bad_samples(self, pop):
with pytest.raises(ValueError, match="must be an integer ID"):
ts.samples(population=pop)

@pytest.mark.parametrize("pop", [0, np.int32(0), np.int64(0), np.uint32(0)])
@pytest.mark.parametrize(
"pop", [0, np.int32(0), np.int64(0), np.uint32(0), {"name": "pop_0"}, {}]
)
def test_good_samples(self, pop):
ts = msprime.sim_ancestry(2)
assert ts.num_populations == 1
assert np.array_equiv(ts.samples(population=pop), ts.samples())

@pytest.mark.parametrize(
"pop",
[
{"name": "nonexistent"},
{"name": "pop_0", "description": "nonexistent"},
{"name": "pop_0", "nonexistent": ""},
],
)
def test_samples_metadata_no_selected(self, pop):
ts = msprime.sim_ancestry(2)
with pytest.raises(
ValueError, match="No populations match the specified metadata"
):
ts.samples(population=pop)

@pytest.mark.parametrize("pop", [{"name": "pop_0"}, {}])
def test_samples_metadata_nopop(self, pop):
ts = tskit.Tree.generate_balanced(4).tree_sequence
assert ts.num_populations == 0
with pytest.raises(
ValueError, match="No populations match the specified metadata"
):
ts.samples(population=pop)

def test_samples_metadata_multipop(self):
demography = msprime.Demography()
demography.add_population(name="A", initial_size=10_000)
demography.add_population(name="B", initial_size=5_000)
demography.add_population(name="C", initial_size=1_000)
demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C")
samples = {"A": 1, "B": 1}
ts = msprime.sim_ancestry(samples, demography=demography, random_seed=12)
with pytest.raises(ValueError, match=r"populations \(\[0, 1, 2\]\) match"):
ts.samples(population={"description": ""})

@pytest.mark.parametrize(
"pop_param",
[
{"name": "B"},
{"name": "B", "description": "A&B"},
{"description": "A&B", "+": "B⊕C"},
],
)
def test_samples_metadata_onepop(self, pop_param):
demography = msprime.Demography()
N = 100
demography.add_population(name="A", description="A&B", initial_size=N)
demography.add_population(
name="B", description="A&B", extra_metadata={"+": "B⊕C"}, initial_size=N
)
demography.add_population(name="C", extra_metadata={"+": "B⊕C"}, initial_size=N)
demography.add_population_split(time=1000, derived=["A", "B"], ancestral="C")
ts = msprime.sim_ancestry(
samples={"A": 1, "B": 1}, demography=demography, random_seed=12
)
samp = ts.samples(population=pop_param)
id_B = {pop.metadata["name"]: pop.id for pop in ts.populations()}["B"]
assert np.array_equiv(samp, ts.samples(population=id_B))

@pytest.mark.parametrize("md", [b"{}", b"", None])
def test_bad_pop_metadata(self, md):
tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables()
tables.populations.add_row(metadata=md)
ts = tables.tree_sequence()
with pytest.raises(ValueError, match="metadata is not a dictionary"):
ts.samples(population={})

def test_empty_pop_metadata(self):
# The docs state "Tskit deviates from standard JSON in that
# empty metadata is interpreted as an empty object." - test this
tables = tskit.Tree.generate_balanced(4).tree_sequence.dump_tables()
tables.populations.add_row()
tables.populations.metadata_schema = tskit.MetadataSchema.permissive_json()
tables.nodes.population = np.zeros_like(tables.nodes.population) # all in pop 0
ts = tables.tree_sequence()
assert np.array_equiv(ts.samples(population={}), ts.samples())

@pytest.mark.parametrize("time", [0, 0.1, 1 / 3, 1 / 4, 5 / 7])
def test_samples_time(self, time):
ts = self.get_tree_sequence(num_demes=2, n=20, times=[time, 0.2, 1, 15])
Expand Down
39 changes: 36 additions & 3 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -6514,15 +6514,32 @@ def samples(self, population=None, *, population_id=None, time=None):
time is approximately equal to the specified time. If `time` is a pair
of values of the form `(min_time, max_time)`, only return sample IDs
whose node time `t` is in this interval such that `min_time <= t < max_time`.
If both `population` and `time` are specified, the returned samples
will satisfy both criteria.

:param int population: The population of interest. If None, do not
filter samples by population.
.. note::
The population can be specified either by an integer (in which case
this is the population ID) or a dictionary matching information in the
population metadata. If a dictionary, it should contain key-value pair(s)
that match the metadata of the desired population; for instance,
``population={'name': 'abc'}`` will select the population that has a
'name' of 'abc' in metadata: there should be exactly one population
that has matching key-value pair(s), if not, an error is raised.

:param Union[int, dict] population: The population of interest. If an
integer, this is the population ID. If a dictionary, the keys
in the dictionary specify metadata key-value pairs to match (see note
above). If None, do not filter samples by population.
:param int population_id: Deprecated alias for ``population``.
:param float,tuple time: The time or time interval of interest. If
None, do not filter samples by time.
:return: A numpy array of the node IDs for the samples of interest,
listed in numerical order.
:rtype: numpy.ndarray (dtype=np.int32)
:raises ValueError: If population or time is specified incorrectly.
:raises ValueError: If multiple or no populations match the specified metadata.
:raises ValueError: If a dictionary is specified to select a population
but existing population metadata entries cannot be treated as dictionaries.
"""
if population is not None and population_id is not None:
raise ValueError(
Expand All @@ -6533,8 +6550,24 @@ def samples(self, population=None, *, population_id=None, time=None):
samples = self._ll_tree_sequence.get_samples()
keep = np.full(shape=samples.shape, fill_value=True)
if population is not None:
if isinstance(population, dict):
# look for the key names in the population metadata: we don't expect
# there to be many populations, so a simple loop is fine.
pops = []
for pop in self.populations():
if not isinstance(pop.metadata, dict):
raise ValueError("Population metadata is not a dictionary")
if set(population.items()).issubset(pop.metadata.items()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to be too concise and "clever" here. Clever code that tries to do too much at once is bad. Here's a way to do it that's clearer, more obvious and less clever.

# If the query keys are a subset of the population metadata keys, we try to compare. 
# Note that *all* keys must match.
pops = []
if set(population.keys()).issubset(pop.metadata.keys()):
       for key, query_value in population.items():
              # This requires the values are comparable, which should work for nested dictionaries
              # and so on. 
              if query_value != pop[key]:
                      break
        else:
                pops.append(pop.id)

Note that we've ended up with some quite tricky logic which needs to be tested now. So, I think we should separate this out into it's own function that can be unit tested. Something like

def match_metadata(table, query):
     """
     Return the row IDs of the specified table that match the specified query dictionary. All   
     rows matching *all* key-value pairs will be returned.
      """
      # implementation above

class TestMatchMetadata:
       # test all the quirky combos of stuff here.

pops.append(pop.id)
if len(pops) == 0:
raise ValueError("No populations match the specified metadata")
if len(pops) > 1:
raise ValueError(
f"Multiple populations ({pops}) match the specified metadata"
)
population = pops[0]
if not isinstance(population, numbers.Integral):
raise ValueError("`population` must be an integer ID")
raise ValueError("`population` must be an integer ID or a dictionary")
population = int(population)
sample_population = self.nodes_population[samples]
keep = np.logical_and(keep, sample_population == population)
Expand Down