Source code for cassiopeia.solver.VanillaGreedySolver

"""
This file stores a subclass of GreedySolver, the VanillaGreedySolver. The
inference procedure here is the "vanilla" Cassiopeia-Greedy, originally proposed
in Jones et al, Genome Biology (2020). In essence, the algorithm proceeds by
recursively splitting samples into mutually exclusive groups based on the
presence, or absence, of the most frequently occurring mutation.
"""
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd

from cassiopeia.solver import (
    GreedySolver,
    missing_data_methods,
    solver_utilities,
)


[docs]class VanillaGreedySolver(GreedySolver.GreedySolver): """ A class for the basic Cassiopeia-Greedy solver. The VanillaGreedySolver implements a top-down algorithm that optimizes for parsimony by recursively splitting the sample set based on the most presence, or absence, of the most frequent mutation. Multiple missing data imputation methods are included for handling the case when a sample has a missing value on the character being split, where presence or absence of the character is ambiguous. The user can also specify a missing data method. TODO(richardyz98): Implement fuzzysolver Args: missing_data_classifier: Takes either a string specifying one of the included missing data imputation methods, or a function implementing the user-specified missing data method. The default is the "average" method prior_transformation: A function defining a transformation on the priors in forming weights to scale frequencies. One of the following: "negative_log": Transforms each probability by the negative log (default) "inverse": Transforms each probability p by taking 1/p "square_root_inverse": Transforms each probability by the the square root of 1/p Attributes: prior_transformation: Function to transform priors, if these are available. missing_data_classifier: Function to classify missing data during character splits. """ def __init__( self, missing_data_classifier: Callable = missing_data_methods.assign_missing_average, prior_transformation: str = "negative_log", ): super().__init__(prior_transformation) self.missing_data_classifier = missing_data_classifier
[docs] def perform_split( self, character_matrix: pd.DataFrame, samples: List[int], weights: Optional[Dict[int, Dict[int, float]]] = None, missing_state_indicator: int = -1, ) -> Tuple[List[str], List[str]]: """Partitions based on the most frequent (character, state) pair. Uses the (character, state) pair to split the list of samples into two partitions. In doing so, the procedure makes use of the missing data classifier to classify samples that have missing data at that character where presence or absence of the character is ambiguous. Args: character_matrix: Character matrix samples: A list of samples to partition weights: Weighting of each (character, state) pair. Typically a transformation of the priors. missing_state_indicator: Character representing missing data. Returns: A tuple of lists, representing the left and right partition groups """ sample_indices = solver_utilities.convert_sample_names_to_indices( character_matrix.index, samples ) mutation_frequencies = self.compute_mutation_frequencies( samples, character_matrix, missing_state_indicator ) best_frequency = 0 chosen_character = 0 chosen_state = 0 for character in mutation_frequencies: for state in mutation_frequencies[character]: if state != missing_state_indicator and state != 0: # Avoid splitting on mutations shared by all samples if ( mutation_frequencies[character][state] < len(samples) - mutation_frequencies[character][ missing_state_indicator ] ): if weights: if ( mutation_frequencies[character][state] * weights[character][state] > best_frequency ): chosen_character, chosen_state = ( character, state, ) best_frequency = ( mutation_frequencies[character][state] * weights[character][state] ) else: if ( mutation_frequencies[character][state] > best_frequency ): chosen_character, chosen_state = ( character, state, ) best_frequency = mutation_frequencies[ character ][state] if chosen_state == 0: return samples, [] left_set = [] right_set = [] missing = [] unique_character_array = character_matrix.to_numpy() sample_names = list(character_matrix.index) for i in sample_indices: if unique_character_array[i, chosen_character] == chosen_state: left_set.append(sample_names[i]) elif ( unique_character_array[i, chosen_character] == missing_state_indicator ): missing.append(sample_names[i]) else: right_set.append(sample_names[i]) left_set, right_set = self.missing_data_classifier( character_matrix, missing_state_indicator, left_set, right_set, missing, weights=weights, ) return left_set, right_set