"""
A subclass of LeafSubsampler, the SupercellularSampler.
Iteratively, this subsampler randomly merges two leaves to generate a tree with
ambiguous character states. The probability that two leaves will be merged is
proportional to their branch distance.
"""
import copy
from typing import Optional
import numpy as np
from cassiopeia.data.CassiopeiaTree import CassiopeiaTree, CassiopeiaTreeError
from cassiopeia.simulator.LeafSubsampler import (
LeafSubsampler,
LeafSubsamplerError,
)
[docs]class SupercellularSampler(LeafSubsampler):
def __init__(
self,
ratio: Optional[float] = None,
number_of_merges: Optional[float] = None,
):
"""
Merge leaves in a tree to generate a new tree with ambiguous states.
Note that according to this procedure, an already merged (and therefore
ambiguous leaf) may be merged again.
Only one of ``ratio`` or ``number_of_merges`` must be provided.
Args:
ratio: The number of times to merge as a ratio of the total number
of leaves. A ratio of 0.5 indicates the number of merges will be
approximately half the number of leaves.
number_of_merges: Explicit number of merges to perform.
"""
if (ratio is None) == (number_of_merges is None):
raise LeafSubsamplerError(
"Exactly one of 'ratio' and 'number_of_merges' must be specified."
)
self.__ratio = ratio
self.__number_of_merges = number_of_merges
[docs] def subsample_leaves(
self,
tree: CassiopeiaTree,
collapse_source: Optional[str] = None,
collapse_duplicates: bool = True,
) -> CassiopeiaTree:
"""Construct a new CassiopeiaTree by merging leaves.
Pairs of leaves in the given tree is iteratively merged until a stopping
condition is met (specified by ``ratio`` or ``number_of_merges`` when
initializing the sampler). Pairs of leaves are selected by the following
procedure:
1) A random leaf `A` is selected.
2) Its pair `B` is randomly selected with probability inversely
proportional to the branch distance from the leaf selected in the
previous step.
3) The pair is merged into a new leaf with name `A-B` and character
states merged in to ambiguous states. The new leaf is connected to
the LCA of the two leaves and at time max(time of LCA, mean time of
the two leaves).
Args:
tree: The CassiopeiaTree for which to subsample leaves
collapse_source: The source node from which to collapse unifurcations
collapse_duplicates: Whether or not to collapse duplicated character
states, so that only unique character states are present in each
ambiguous state. Defaults to True.
Raises:
LeafSubsamplerError if the number of merges exceeds the number of
leaves in the tree or no merges will be performed.
"""
n_merges = (
self.__number_of_merges
if self.__number_of_merges is not None
else int(tree.n_cell * self.__ratio)
)
if n_merges >= len(tree.leaves):
raise LeafSubsamplerError(
"Number of required merges exceeds number of leaves in the tree."
)
if n_merges == 0:
raise LeafSubsamplerError("No merges to be performed.")
# Tree needs to have character matrix defined
if tree.character_matrix is None:
raise CassiopeiaTreeError("Character matrix not defined.")
merged_tree = copy.deepcopy(tree)
for _ in range(n_merges):
# Choose first leaf
leaf1 = np.random.choice(merged_tree.leaves)
leaf1_state = merged_tree.get_character_states(leaf1)
# Choose second leaf with weight proportional to inverse distance
distances = merged_tree.get_distances(leaf1, leaves_only=True)
leaves = []
weights = []
for leaf in sorted(distances.keys()):
if leaf == leaf1:
continue
leaves.append(leaf)
weights.append(1 / distances[leaf])
leaf2 = np.random.choice(
leaves, p=np.array(weights) / np.sum(weights)
)
leaf2_state = merged_tree.get_character_states(leaf2)
# Merge these two leaves at the mean time of the two leaves.
# Note that the mean time of the two leaves may never be earlier than
# the LCA time, because each of the leaf times must be greater than or
# equal to the LCA time.
# If the tree is ultrametric, this preserves ultrametricity.
new_leaf = f"{leaf1}-{leaf2}"
lca = merged_tree.find_lca(leaf1, leaf2)
new_time = (
merged_tree.get_time(leaf1) + merged_tree.get_time(leaf2)
) / 2
new_state = []
for char1, char2 in zip(leaf1_state, leaf2_state):
new_char = []
if not isinstance(char1, tuple):
char1 = (char1,)
if not isinstance(char2, tuple):
char2 = (char2,)
new_state.append(char1 + char2)
merged_tree.add_leaf(lca, new_leaf, states=new_state, time=new_time)
merged_tree.remove_leaf_and_prune_lineage(leaf1)
merged_tree.remove_leaf_and_prune_lineage(leaf2)
if collapse_source is None:
collapse_source = merged_tree.root
merged_tree.collapse_unifurcations(source=collapse_source)
if collapse_duplicates:
merged_tree.collapse_ambiguous_characters()
return merged_tree