Source code for cassiopeia.solver.MaxCutSolver

"""
This file stores a subclass of GreedySolver, the MaxCutSolver. This subclass 
implements an inference procedure inspired by Snir and Rao (2006) that 
approximates the max-cut problem on a connectivity graph generated from the 
observed mutations on a group of samples. The connectivity graph represents a
supertree generated over phylogenetic trees for each individual character, and
encodes similarities and differences in mutations between the samples. The
goal is to find a partition on the graph that resolves triplets on the 
supertree, grouping together samples that share mutations and seperating
samples that differ in mutations.
"""
from typing import Callable, Dict, List, Optional, Tuple, Union

import itertools
import networkx as nx
import numpy as np
import pandas as pd

from cassiopeia.solver import graph_utilities, GreedySolver


[docs]class MaxCutSolver(GreedySolver.GreedySolver): """ A MaxCut graph-based CassiopeiaSolver. The MaxCutSolver implements a top-down algorithm that recursively partitions the sample set based on connectivity. At each recursive step, a connectivity graph is generated for the sample set, where edge weights between samples represent the number of triplets that separate those samples minus the number of triplets that have those samples as an ingroup. Shared mutations are coded as strong negative connections and differing mutations are coded as positive connections. Then a partition is generated by finding a maximum weight cut over the graph. The final partition is also improved upon by a greedy hill-climbing procedure that further optimizes the cut. Args: prior_transformation: A function defining a transformation on the priors in forming weights to scale frequencies and the contribution of each mutation in the connectivity graph. One of the following: "negative_log": Transforms each probability by the negative log (default) "inverse": Transforms each probability p by taking 1/p "square_root_inverse": Transforms each probability by the the square root of 1/p sdimension: The number of dimensions to use for the embedding space. Acts as a hyperparameter iterations: The number of iterations in updating the embeddings. Acts as a hyperparameter Attributes: prior_transformation: Function to use when transforming priors into weights. sdimension: The number of dimensions to use for the embedding space iterations: The number of iterations in updating the embeddings """ def __init__( self, sdimension: Optional[int] = 3, iterations: Optional[int] = 50, prior_transformation: str = "negative_log", ): super().__init__(prior_transformation) self.sdimension = sdimension self.iterations = iterations
[docs] def perform_split( self, character_matrix: pd.DataFrame, samples: List[int], weights: Optional[Dict[int, Dict[int, float]]] = None, missing_state_indicator: int = -1, ) -> Tuple[List[str], List[str]]: """Generate a partition of the samples by finding the max-cut. First, a connectivity graph is generated with samples as nodes such that samples with shared mutations have strong negative edge weight and samples with distant mutations have positive edge weight. Then, the algorithm finds a partition by using a heuristic method to find the max-cut on the connectivity graph. The samples are randomly embedded in a d-dimensional sphere and the embeddings for each node are iteratively updated based on neighboring edge weights in the connectivity graph such that nodes with stronger connectivity cluster together. The final partition is generated by choosing random hyperplanes to bisect the d-sphere and taking the one that maximizes the cut. Args: character_matrix: Character matrix samples: A list of samples to partition weights: Weighting of each (character, state) pair. Typically a transformation of the priors. missing_state_indicator: Character representing missing data. Returns: A tuple of lists, representing the left and right partition groups """ mutation_frequencies = self.compute_mutation_frequencies( samples, character_matrix, missing_state_indicator ) G = graph_utilities.construct_connectivity_graph( character_matrix, mutation_frequencies, missing_state_indicator, samples, weights=weights, ) if len(G.edges) == 0: return samples, [] embedding_dimension = self.sdimension + 1 emb = {} for i in G.nodes(): x = np.random.normal(size=embedding_dimension) x = x / np.linalg.norm(x) emb[i] = x for _ in range(self.iterations): new_emb = {} for i in G.nodes: cm = np.zeros(embedding_dimension, dtype=float) for j in G.neighbors(i): cm -= ( G[i][j]["weight"] * np.linalg.norm(emb[i] - emb[j]) * emb[j] ) if cm.any(): cm = cm / np.linalg.norm(cm) new_emb[i] = cm emb = new_emb return_cut = [] best_score = 0 for _ in range(3 * embedding_dimension): b = np.random.normal(size=embedding_dimension) b = b / np.linalg.norm(b) cut = [] for i in G.nodes(): if np.dot(emb[i], b) > 0: cut.append(i) this_score = self.evaluate_cut(cut, G) if this_score > best_score: return_cut = cut best_score = this_score improved_left_set = graph_utilities.max_cut_improve_cut(G, return_cut) improved_right_set = [] for i in samples: if i not in improved_left_set: improved_right_set.append(i) return improved_left_set, improved_right_set
[docs] def evaluate_cut(self, cut: List[str], G: nx.DiGraph) -> float: """A simple function to evaluate the weight of a cut. For each edge in the graph, checks if it is in the cut, and then adds its edge weight to the cut if it is. Args: cut: A list of nodes that represents one of the sides of a cut on the graph G: The graph the cut is over Returns: The weight of the cut """ cut_score = 0 for e in G.edges(): u = e[0] v = e[1] w_uv = G[u][v]["weight"] if graph_utilities.check_if_cut(u, v, cut): cut_score += float(w_uv) return cut_score