Source code for cassiopeia.simulator.BirthDeathFitnessSimulator

"""
This file stores a general phylogenetic tree simulator using forward birth-death
process, including differing fitness on lineages on the tree. Allows for a
variety of division and fitness regimes to be specified by the user.
"""
from typing import Callable, Dict, Generator, List, Optional, Union

import networkx as nx
import numpy as np
from queue import PriorityQueue

from cassiopeia.data.CassiopeiaTree import CassiopeiaTree
from cassiopeia.mixins import TreeSimulatorError
from cassiopeia.simulator.TreeSimulator import TreeSimulator


[docs] class BirthDeathFitnessSimulator(TreeSimulator): """Simulator class for a general forward birth-death process with fitness. Implements a flexible phylogenetic tree simulator using a forward birth- death process. In this process starting from an initial root lineage, births represent the branching of a new lineage and death represents the cessation of an existing lineage. The process is represented as a tree, with internal nodes representing division events, branch lengths representing the lifetimes of individuals, and leaves representing samples observed at the end of the experiment. Allows any distribution on birth and death waiting times to be specified, including constant, exponential, weibull, etc. If no death waiting time distribution is provided, the process reduces to a Yule birth process. Also robustly simulates differing fitness on lineages within a simulated tree. Fitness in this context represents potential mutations that may be acquired on a lineage that change the rate at which new members are born. Each lineage maintains its own birth scale parameter, altered from an initial specified experiment-wide birth scale parameter by accrued mutations. Different fitness regimes can be specified based on user provided distributions on how often fitness mutations occur and their respective strengths. There are two stopping conditions for the simulation. The first is "number of extant nodes", which specifies the simulation to run until the first moment a number of extant nodes exist. The second is "experiment time", which specifies the time at which lineages are sampled. At least one of these two stopping criteria must be provided. Both can be provided in which case the simulation is run until one of the stopping conditions is reached. Example use snippet: # note that numpy uses a different parameterization of the # exponential distribution with the scale parameter, which is 1/rate birth_waiting_distribution = lambda scale: np.random.exponential(scale) death_waiting_distribution = np.random.exponential(1.5) initial_birth_scale = 0.5 mutation_distribution = lambda: 1 if np.random.uniform() > 0.5 else 0 fitness_distribution = lambda: np.random.uniform(-1,1) fitness_base = 2 bd_sim = BirthDeathFitnessSimulator( birth_waiting_distribution, initial_birth_scale, death_waiting_distribution=death_waiting_distribution, mutation_distribution=mutation_distribution, fitness_distribution=fitness_distribution, fitness_base=fitness_base, num_extant=8 ) tree = bd_sim.simulate_tree() Args: birth_waiting_distribution: A function that samples waiting times from the birth distribution. Determines how often births occur. Must take a scale parameter as the input initial_birth_scale: The initial scale parameter that is used at the start of the experiment death_waiting_distribution: A function that samples waiting times from the death distribution. Determines how often deaths occur mutation_distribution: A function that samples the number of mutations that occur at a division event. If None, then no mutations are sampled fitness_distribution: One of the two elements in determining the multiplicative coefficient of a fitness mutation. A function that samples the exponential that the fitness base is raised by. Determines the distribution of fitness mutation strengths. Must not be None if mutation_distribution provided fitness_base: One of the two elements in determining the multiplicative strength of a fitness mutation. The base that is raised by the value given by the fitness distribution. Determines the base strength of fitness mutations. By default is e, Euler's Constant num_extant: Specifies the number of extant lineages existing at the same time as a stopping condition for the experiment experiment_time: Specifies the total time that the experiment runs as a stopping condition for the experiment collapse_unifurcations: Specifies whether to collapse unifurcations in the tree resulting from pruning dead lineages random_seed: A seed for reproducibility Raises: TreeSimulatorError if invalid stopping conditions are provided or if a fitness distribution is not provided when a mutation distribution isn't """ def __init__( self, birth_waiting_distribution: Callable[[float], float], initial_birth_scale: float, death_waiting_distribution: Optional[ Callable[[], float] ] = lambda: np.inf, mutation_distribution: Optional[Callable[[], int]] = None, fitness_distribution: Optional[Callable[[], float]] = None, fitness_base: float = np.e, num_extant: Optional[int] = None, experiment_time: Optional[float] = None, collapse_unifurcations: bool = True, random_seed: int = None, ): if num_extant is None and experiment_time is None: raise TreeSimulatorError( "Please specify at least one stopping condition" ) if mutation_distribution is not None and fitness_distribution is None: raise TreeSimulatorError( "Please specify a fitness strength distribution" ) if num_extant is not None and num_extant <= 0: raise TreeSimulatorError( "Please specify number of extant lineages greater than 0" ) if num_extant is not None and type(num_extant) is not int: raise TreeSimulatorError( "Please specify an integer number of extant tips" ) if experiment_time is not None and experiment_time <= 0: raise TreeSimulatorError( "Please specify an experiment time greater than 0" ) self.birth_waiting_distribution = birth_waiting_distribution self.initial_birth_scale = initial_birth_scale self.death_waiting_distribution = death_waiting_distribution self.mutation_distribution = mutation_distribution self.fitness_distribution = fitness_distribution self.fitness_base = fitness_base self.num_extant = num_extant self.experiment_time = experiment_time self.collapse_unifurcations = collapse_unifurcations self.random_seed = random_seed
[docs] def simulate_tree( self, ) -> CassiopeiaTree: """Simulates trees from a general birth/death process with fitness. A forward-time birth/death process is simulated by tracking a series of lineages and sampling event waiting times for each lineage. Each lineage draws death waiting times from the same distribution, but maintains its own birth scale parameter that determines the shape of its birth waiting time distribution. At each division event, fitness mutation events are sampled, and the birth scale parameter is scaled by their multiplicative coefficients. This updated birth scale passed onto successors. Returns: A CassiopeiaTree with the tree topology initialized with the simulated tree Raises: TreeSimulatorError if all lineages die before a stopping condition """ def node_name_generator() -> Generator[str, None, None]: """Generates unique node names for the tree.""" i = 0 while True: yield str(i) i += 1 names = node_name_generator() # Set the seed if self.random_seed: np.random.seed(self.random_seed) # Instantiate the implicit root tree = nx.DiGraph() root = next(names) tree.add_node(root) tree.nodes[root]["birth_scale"] = self.initial_birth_scale tree.nodes[root]["time"] = 0 current_lineages = PriorityQueue() # Records the nodes that are observed at the end of the experiment observed_nodes = [] starting_lineage = { "id": root, "birth_scale": self.initial_birth_scale, "total_time": 0, "active": True, } # Sample the waiting time until the first division self.sample_lineage_event( starting_lineage, current_lineages, tree, names, observed_nodes ) # Perform the process until there are no active extant lineages left while not current_lineages.empty(): # If number of extant lineages is the stopping criterion, at the # first instance of having n extant tips, stop the experiment # and set the total lineage time for each lineage to be equal to # the minimum, to produce ultrametric trees. Also, the birth_scale # parameter of each leaf is rolled back to equal its parent's. if self.num_extant: if current_lineages.qsize() == self.num_extant: remaining_lineages = [] while not current_lineages.empty(): _, _, lineage = current_lineages.get() remaining_lineages.append(lineage) min_total_time = remaining_lineages[0]["total_time"] for lineage in remaining_lineages: parent = list(tree.predecessors(lineage["id"]))[0] tree.nodes[lineage["id"]]["time"] += ( min_total_time - lineage["total_time"] ) tree.nodes[lineage["id"]]["birth_scale"] = tree.nodes[ parent ]["birth_scale"] observed_nodes.append(lineage["id"]) break # Pop the minimum age lineage to simulate forward time _, _, lineage = current_lineages.get() # If the lineage is no longer active, just remove it from the queue. # This represents the time at which the lineage dies. if lineage["active"]: for _ in range(2): self.sample_lineage_event( lineage, current_lineages, tree, names, observed_nodes ) cassiopeia_tree = CassiopeiaTree(tree=tree) time_dictionary = {} for i in tree.nodes: time_dictionary[i] = tree.nodes[i]["time"] cassiopeia_tree.set_times(time_dictionary) # Prune dead lineages and collapse resulting unifurcations to_remove = list(set(cassiopeia_tree.leaves) - set(observed_nodes)) cassiopeia_tree.remove_leaves_and_prune_lineages(to_remove) if self.collapse_unifurcations and len(cassiopeia_tree.nodes) > 1: cassiopeia_tree.collapse_unifurcations(source="1") # If only implicit root remains after pruning dead lineages, error if len(cassiopeia_tree.nodes) == 1: raise TreeSimulatorError( "All lineages died before stopping condition" ) return cassiopeia_tree
[docs] def sample_lineage_event( self, lineage: Dict[str, Union[int, float]], current_lineages: PriorityQueue, tree: nx.DiGraph, names: Generator, observed_nodes: List[str], ) -> None: """A helper function that samples an event for a lineage. Takes a lineage and determines the next event in that lineage's future. Simulates the lifespan of a new descendant. Birth and death waiting times are sampled, representing how long the descendant lived. If a death event occurs first, then the lineage with the new descendant is added to the queue of currently alive, but its status is marked as inactive and will be removed at the time the lineage dies. If a birth event occurs first, then the lineage with the new descendant is added to the queue, but with its status marked as active, and further events will be sampled at the time the lineage divides. Additionally, its fitness will be updated by altering its birth rate. The descendant node is added to the tree object, with the edge weight between the current node and the descendant representing the lifespan of the descendant. In the case the descendant would live past the end of the experiment (both birth and death times exceed past the end of the experiment), then the lifespan is cut off at the experiment time and a final observed sample is added to the tree. In this case the lineage is marked as inactive as well. Args: unique_id: The unique ID number to be used to name a new node added to the tree lineage: The current extant lineage to extend. Contains the ID of the internal node to attach the descendant to, the current birth scale parameter of the lineage, the current total lived time of the lineage, and the status of whether the lineage is still dividing current_lineages: The queue containing currently alive lineages tree: The tree object being constructed by the simulator representing the birth death process names: A generator providing unique names for tree nodes observed_nodes: A list of nodes that are observed at the end of the experiment Raises: TreeSimulatorError if a negative waiting time is sampled or a non-active lineage is passed in """ if not lineage["active"]: raise TreeSimulatorError( "Cannot sample event for non-active lineage" ) unique_id = next(names) birth_waiting_time = self.birth_waiting_distribution( lineage["birth_scale"] ) death_waiting_time = self.death_waiting_distribution() if birth_waiting_time <= 0 or death_waiting_time <= 0: raise TreeSimulatorError("0 or negative waiting time detected") # If birth and death would happen after the total experiment time, # just cut off the living branch length at the experiment time if ( self.experiment_time and lineage["total_time"] + birth_waiting_time >= self.experiment_time and lineage["total_time"] + death_waiting_time >= self.experiment_time ): tree.add_node(unique_id) tree.nodes[unique_id]["birth_scale"] = lineage["birth_scale"] tree.add_edge(lineage["id"], unique_id) tree.nodes[unique_id]["time"] = self.experiment_time current_lineages.put( ( self.experiment_time, unique_id, { "id": unique_id, "birth_scale": lineage["birth_scale"], "total_time": self.experiment_time, "active": False, }, ) ) # Indicate this node is observed at the end of experiment observed_nodes.append(unique_id) else: if birth_waiting_time < death_waiting_time: # Update birth rate updated_birth_scale = self.update_fitness( lineage["birth_scale"] ) # Annotate parameters for a given node in the tree tree.add_node(unique_id) tree.nodes[unique_id]["birth_scale"] = updated_birth_scale tree.add_edge(lineage["id"], unique_id) tree.nodes[unique_id]["time"] = ( birth_waiting_time + lineage["total_time"] ) # Add the newly generated cell to the list of living lineages current_lineages.put( ( birth_waiting_time + lineage["total_time"], unique_id, { "id": unique_id, "birth_scale": updated_birth_scale, "total_time": birth_waiting_time + lineage["total_time"], "active": True, }, ) ) else: tree.add_node(unique_id) tree.nodes[unique_id]["birth_scale"] = lineage["birth_scale"] tree.add_edge(lineage["id"], unique_id) tree.nodes[unique_id]["time"] = ( death_waiting_time + lineage["total_time"] ) current_lineages.put( ( death_waiting_time + lineage["total_time"], unique_id, { "id": unique_id, "birth_scale": lineage["birth_scale"], "total_time": death_waiting_time + lineage["total_time"], "active": False, }, ) )
[docs] def update_fitness(self, birth_scale: float) -> float: """Updates a lineage birth scale, which represents its fitness. At each division event, the fitness is updated by sampling from a distribution determining the number of mutations. The birth scale parameter of the lineage is then scaled by the total multiplicative coefficient across all mutations and passed on to the descendant nodes. The multiplicative factor of each mutation is determined by exponentiating a base parameter by a value drawn from another 'fitness' distribution. Therefore, negative values from the fitness distribution are valid and down-scale the birth scale parameter. The base determines the base strength of the mutations in either direction and the fitness distribution determines how the mutations are distributed. Args: birth_scale: The birth_scale to be updated Returns: The updated birth_scale Raises: TreeSimulatorError if a negative number of mutations is sampled """ base_selection_coefficient = 1 if self.mutation_distribution: num_mutations = int(self.mutation_distribution()) if num_mutations < 0: raise TreeSimulatorError( "Negative number of mutations detected" ) for _ in range(num_mutations): base_selection_coefficient *= ( self.fitness_base ** self.fitness_distribution() ) return birth_scale * base_selection_coefficient