Source code for cassiopeia.solver.MaxCutGreedySolver
"""
This file stores a subclass of GreedySolver, the MaxCutGreedySolver. The
inference procedure here extends the "vanilla" Cassiopeia-Greedy, originally
proposed in Jones et al, Genome Biology (2020). After each putative split of
the samples generated by Cassiopeia-Greedy, the hill-climbing procedure from
the MaxCutSolver is applied to the partition to optimize it for the max cut
criterion on a connectivity graph built from the observed mutations in the
samples representing a supertree of phylogenetic trees on each individual
character.
"""
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from cassiopeia.solver import (
graph_utilities,
GreedySolver,
missing_data_methods,
solver_utilities,
)
[docs]class MaxCutGreedySolver(GreedySolver.GreedySolver):
"""
A CassioepiaGreedy solver with the max cut criterion.
The MaxCutGreedySolver implements a top-down algorithm that recursively
splits the sample set based on the presence/absence of the most frequent
mutation. Additionally, the hill-climbing procedure from the MaxCutSolver is
used to further optimize each split for the max cut on the similarity graph
on the samples. This effectively moves samples across the partition so that
samples with similar mutations are grouped together and samples with
different mutations are seperated. Multiple missing data imputation methods
are included for handling the case when a sample has a missing value on the
character being split, where presence or absence of the character is
ambiguous. The user can also specify a missing data method.
TODO: Implement fuzzy solver
Args:
prior_transformation: A function defining a transformation on the priors
in forming weights to scale frequencies and the contribution of
each mutation in the connectivity graph. One of the following:
"negative_log": Transforms each probability by the negative
log (default)
"inverse": Transforms each probability p by taking 1/p
"square_root_inverse": Transforms each probability by the
the square root of 1/p
Attributes:
prior_transformation: Function to transform priors, if these are
available.
missing_data_classifier: Function to classify missing data during
character splits.
"""
def __init__(
self,
missing_data_classifier: Callable = missing_data_methods.assign_missing_average,
prior_transformation: str = "negative_log",
):
super().__init__(prior_transformation)
self.missing_data_classifier = missing_data_classifier
[docs] def perform_split(
self,
character_matrix: pd.DataFrame,
samples: List[int],
weights: Optional[Dict[int, Dict[int, float]]] = None,
missing_state_indicator: int = -1,
) -> Tuple[List[str], List[str]]:
"""Performs a partition using both Greedy and MaxCut criteria.
First, uses the most frequent (character, state) pair to split the list
of samples. In doing so, the procedure makes use of the missing data
classifier. Then, it optimizes this partition for the max cut on a
connectivity graph constructed on the samples using a hill-climbing
method.
Args:
character_matrix: Character matrix
samples: A list of samples to partition
weights: Weighting of each (character, state) pair. Typically a
transformation of the priors.
missing_state_indicator: Character representing missing data.
Returns:
A tuple of lists, representing the left and right partition groups
"""
sample_indices = solver_utilities.convert_sample_names_to_indices(
character_matrix.index, samples
)
mutation_frequencies = self.compute_mutation_frequencies(
samples, character_matrix, missing_state_indicator
)
best_frequency = 0
chosen_character = 0
chosen_state = 0
for character in mutation_frequencies:
for state in mutation_frequencies[character]:
if state != missing_state_indicator and state != 0:
# Avoid splitting on mutations shared by all samples
if (
mutation_frequencies[character][state]
< len(samples)
- mutation_frequencies[character][
missing_state_indicator
]
):
if weights:
if (
mutation_frequencies[character][state]
* weights[character][state]
> best_frequency
):
chosen_character, chosen_state = (
character,
state,
)
best_frequency = (
mutation_frequencies[character][state]
* weights[character][state]
)
else:
if (
mutation_frequencies[character][state]
> best_frequency
):
chosen_character, chosen_state = (
character,
state,
)
best_frequency = mutation_frequencies[
character
][state]
if chosen_state == 0:
return samples, []
left_set = []
right_set = []
missing = []
unique_character_array = character_matrix.to_numpy()
sample_names = list(character_matrix.index)
for i in sample_indices:
if unique_character_array[i, chosen_character] == chosen_state:
left_set.append(sample_names[i])
elif (
unique_character_array[i, chosen_character]
== missing_state_indicator
):
missing.append(sample_names[i])
else:
right_set.append(sample_names[i])
left_set, right_set = self.missing_data_classifier(
character_matrix,
missing_state_indicator,
left_set,
right_set,
missing,
weights=weights,
)
G = graph_utilities.construct_connectivity_graph(
character_matrix,
mutation_frequencies,
missing_state_indicator,
samples,
weights=weights,
)
improved_left_set = graph_utilities.max_cut_improve_cut(G, left_set)
improved_right_set = []
for i in samples:
if i not in improved_left_set:
improved_right_set.append(i)
return improved_left_set, improved_right_set