Source code for cassiopeia.solver.PercolationSolver

from collections import defaultdict
import itertools
import networkx as nx
import numpy as np
import pandas as pd
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union

from cassiopeia.data import CassiopeiaTree
from cassiopeia.data import utilities as data_utilities
from cassiopeia.solver import (
    CassiopeiaSolver,
    dissimilarity_functions,
    solver_utilities,
)


[docs] class PercolationSolver(CassiopeiaSolver.CassiopeiaSolver): """ A top-down percolatin-based CassiopeiaSolver. The PercolationSolver implements a top-down algorithm that recursively partitions the sample set based on similarity in the observed mutations. It is an implicit version of Aho's algorithm for tree discovery (1981). At each recursive step, the similarities of each sample pair are embedded as edges in a graph with weight equal to the similarity between the nodes. The graph is then percolated by removing the minimum edges until multiple connected components are produced. Ultimately, this groups clusters of samples that share strong similarity amongst themselves. If more than two connected components are produced by the percolation procedure, then the components are clustered by applying the specified solver to the LCAs of the clusters, obeying parsimony TODO(richardyz98): Experiment to find the best default similarity function Args: joining_solver: The CassiopeiaSolver that is used to cluster groups of samples in the case that the percolation procedure generates more than two groups of samples in the partition prior_transformation: A function defining a transformation on the priors in forming weights to scale the contribution of each mutation in the similarity graph similarity_function: A function that calculates a similarity score between two given samples and their observed mutations. The default is "hamming_distance_without_missing" threshold: The minimum similarity threshold. A similarity threshold of 1 for example means that only samples with similarities above 1 will have an edge between them in the graph. Acts as a hyperparameter that controls the sparsity of the graph by filtering low similarities. Attributes: joining_solver: The CassiopeiaSolver that is used to cluster groups of samples after percolation steps that produce more than two groups prior_transformation: Function to use when transforming priors into weights. similarity_function: A function that calculates a similarity score between two given samples and their observed mutations threshold: A minimum similarity threshold """ def __init__( self, joining_solver: CassiopeiaSolver.CassiopeiaSolver, prior_transformation: str = "negative_log", similarity_function: Optional[ Callable[ [ np.array, np.array, int, Optional[Dict[int, Dict[int, float]]], ], float, ] ] = dissimilarity_functions.hamming_similarity_without_missing, threshold: Optional[int] = 0, ): super().__init__(prior_transformation) self.joining_solver = joining_solver self.threshold = threshold self.similarity_function = similarity_function
[docs] def solve( self, cassiopeia_tree: CassiopeiaTree, layer: Optional[str] = None, collapse_mutationless_edges: bool = False, logfile: str = "stdout.log", ): """Implements a solving procedure for the Percolation Algorithm. The procedure recursively splits a set of samples to build a tree. At each partition of the samples produced by the percolation procedure, an ancestral node is created and each side of the partition is placed as a daughter clade of that node. This continues until each side of the partition is comprised only of single samples. If an algorithm cannot produce a split on a set of samples, then those samples are placed as sister nodes and the procedure terminates, generating a polytomy in the tree. This function will populate a tree inside the input CassiopeiaTree. Args: cassiopeia_tree: CassiopeiaTree storing a character matrix and priors. layer: Layer storing the character matrix for solving. If None, the default character matrix is used in the CassiopeiaTree. collapse_mutationless_edges: Indicates if the final reconstructed tree should collapse mutationless edges based on internal states inferred by Camin-Sokal parsimony. In scoring accuracy, this removes artifacts caused by arbitrarily resolving polytomies. logfile: Location to write standard out. Not currently used. """ node_name_generator = solver_utilities.node_name_generator() # A helper function that builds the subtree given a set of samples def _solve( samples: List[Union[str, int]], tree: nx.DiGraph, unique_character_matrix: pd.DataFrame, priors: Dict[int, Dict[int, float]], weights: Dict[int, Dict[int, float]], missing_state_indicator: int, ): if len(samples) == 1: return samples[0] # Partitions the set of samples by percolating a similarity graph clades = list( self.percolate( unique_character_matrix, samples, priors, weights, missing_state_indicator, ) ) # Generates a root for this subtree with a unique int identifier root = next(node_name_generator) tree.add_node(root) for clade in clades: if len(clade) == 0: clades.remove(clade) # If unable to return a split, generate a polytomy and return if len(clades) == 1: for clade in clades[0]: tree.add_edge(root, clade) return root # Recursively generate the subtrees for each daughter clade for clade in clades: child = _solve( clade, tree, unique_character_matrix, priors, weights, missing_state_indicator, ) tree.add_edge(root, child) return root weights = None priors = None if cassiopeia_tree.priors: weights = solver_utilities.transform_priors( cassiopeia_tree.priors, self.prior_transformation ) priors = cassiopeia_tree.priors # extract character matrix if layer: character_matrix = cassiopeia_tree.layers[layer].copy() else: character_matrix = cassiopeia_tree.character_matrix.copy() unique_character_matrix = character_matrix.drop_duplicates() tree = nx.DiGraph() tree.add_nodes_from(list(unique_character_matrix.index)) _solve( list(unique_character_matrix.index), tree, unique_character_matrix, priors, weights, cassiopeia_tree.missing_state_indicator, ) # Append duplicate samples duplicates_tree = self.__add_duplicates_to_tree( tree, character_matrix, node_name_generator ) cassiopeia_tree.populate_tree(duplicates_tree, layer=layer) # Collapse mutationless edges if collapse_mutationless_edges: cassiopeia_tree.collapse_mutationless_edges( infer_ancestral_characters=True )
[docs] def percolate( self, character_matrix: pd.DataFrame, samples: List[str], priors: Optional[Dict[int, Dict[int, float]]] = None, weights: Optional[Dict[int, Dict[int, float]]] = None, missing_state_indicator: int = -1, ) -> Tuple[List[str], List[str]]: """The function used by the percolation algorithm to partition the set of samples in two. First, a pairwise similarity graph is generated with samples as nodes such that edges between a pair of nodes is some provided function on the number of character/state mutations shared. Then, the algorithm removes the minimum edge (in the case of ties all are removed) until the graph is split into multiple connected components. If there are more than two connected components, the procedure joins them until two remain. This is done by inferring the mutations of the LCA of each sample set obeying Camin-Sokal Parsimony, and then clustering the groups of samples based on their LCAs. The provided solver is used to cluster the groups into two clusters. Args: character_matrix: Character matrix samples: A list of samples to partition priors: A dictionary storing the probability of each character mutating to a particular state. weights: Weighting of each (character, state) pair. Typically a transformation of the priors. missing_state_indicator: Character representing missing data. Returns: A tuple of lists, representing the left and right partition groups """ sample_indices = solver_utilities.convert_sample_names_to_indices( character_matrix.index, samples ) unique_character_array = character_matrix.to_numpy() G = nx.Graph() G.add_nodes_from(sample_indices) # Add edge weights into the similarity graph edge_weight_buckets = defaultdict(list) for i, j in itertools.combinations(sample_indices, 2): similarity = self.similarity_function( unique_character_array[i, :], unique_character_array[j, :], missing_state_indicator, weights, ) if similarity > self.threshold: edge_weight_buckets[similarity].append((i, j)) G.add_edge(i, j) if len(G.edges) == 0: return samples, [] connected_components = list(nx.connected_components(G)) sorted_edge_weights = sorted(edge_weight_buckets, reverse=True) # Percolate the similarity graph by continuously removing the minimum # edge until at least two components exists while len(connected_components) <= 1: min_weight = sorted_edge_weights.pop() for edge in edge_weight_buckets[min_weight]: G.remove_edge(edge[0], edge[1]) connected_components = list(nx.connected_components(G)) # If the number of connected components > 2, merge components by # joining the most similar LCAs of each component until # only 2 remain partition_sides = [] if len(connected_components) > 2: for c in range(len(connected_components)): connected_components[c] = list(connected_components[c]) lcas = {} component_to_nodes = {} # Find the LCA of the nodes in each connected component for ind in range(len(connected_components)): component_identifier = "component" + str(ind) component_to_nodes[component_identifier] = connected_components[ ind ] character_vectors = [ list(i) for i in list( unique_character_array[connected_components[ind], :] ) ] lcas[component_identifier] = data_utilities.get_lca_characters( character_vectors, missing_state_indicator ) # Build a tree on the LCA characters to cluster the components lca_tree = CassiopeiaTree( pd.DataFrame.from_dict(lcas, orient="index"), missing_state_indicator=missing_state_indicator, priors=priors, ) self.joining_solver.solve( lca_tree, collapse_mutationless_edges=False ) grouped_components = [] # Take the split at the root as the clusters of components # in the split, ignoring unifurcations current_node = lca_tree.root while len(grouped_components) == 0: successors = lca_tree.children(current_node) if len(successors) == 1: current_node = successors[0] else: for i in successors: grouped_components.append(lca_tree.leaves_in_subtree(i)) # For each component in each cluster, take the nodes in that # component to form the final split for cluster in grouped_components: sample_index_group = [] for component in cluster: sample_index_group.extend(component_to_nodes[component]) partition_sides.append(sample_index_group) else: for c in range(len(connected_components)): partition_sides.append(list(connected_components[c])) # Convert from component indices back to the sample names in the # original character matrix sample_names = list(character_matrix.index) partition_named = [] for sample_index_group in partition_sides: sample_name_group = [] for sample_index in sample_index_group: sample_name_group.append(sample_names[sample_index]) partition_named.append(sample_name_group) return partition_named
def __add_duplicates_to_tree( self, tree: nx.DiGraph, character_matrix: pd.DataFrame, node_name_generator: Generator[str, None, None], ) -> nx.DiGraph: """Takes duplicate samples and places them in the tree. Places samples removed in removing duplicates in the tree as sisters to the corresponding cells that share the same mutations. Args: tree: The tree to have duplicates added to character_matrix: Character matrix Returns: The tree with duplicates added """ character_matrix.index.name = "index" duplicate_groups = ( character_matrix[character_matrix.duplicated(keep=False) == True] .reset_index() .groupby(character_matrix.columns.tolist())["index"] .agg(["first", tuple]) .set_index("first")["tuple"] .to_dict() ) for i in duplicate_groups: new_internal_node = next(node_name_generator) nx.relabel_nodes(tree, {i: new_internal_node}, copy=False) for duplicate in duplicate_groups[i]: tree.add_edge(new_internal_node, duplicate) return tree