Source code for cassiopeia.tools.branch_length_estimator.IIDExponentialMLE
"""
This file stores a subclass of BranchLengthEstimator, the IIDExponentialMLE.
Briefly, this model assumes that CRISPR/Cas9 mutates each site independently
and identically, with an exponential waiting time.
"""
from collections import defaultdict
from typing import List, Optional
import cvxpy as cp
import numpy as np
from cassiopeia.data import CassiopeiaTree
from cassiopeia.mixins import IIDExponentialMLEError
from .BranchLengthEstimator import BranchLengthEstimator
[docs]
class IIDExponentialMLE(BranchLengthEstimator):
"""
MLE under a model of IID memoryless CRISPR/Cas9 mutations.
In more detail, this model assumes that CRISPR/Cas9 mutates each site
independently and identically, with an exponential waiting time. The
tree is assumed to have depth exactly 1, and the user can provide a
minimum branch length. The MLE under this set of assumptions can be
solved with a special kind of convex optimization problem known as an
exponential cone program, which can be readily solved with off-the-shelf
(open source) solvers.
This estimator requires that the ancestral characters be provided (these
can be imputed with CassiopeiaTree's reconstruct_ancestral_characters
method if they are not known, which is usually the case for real data).
The estimated mutation rate(s) will be stored as an attribute called
`mutation_rate`. The log-likelihood will be stored in an attribute
called `log_likelihood`.
Missing states are treated as missing at random by the model.
Args:
minimum_branch_length: Estimated branch lengths will be constrained to
have length at least this value. By default it is set to 0.01,
since the MLE tends to collapse mutationless edges to length 0.
relative_mutation_rates: List of positive floats of length equal to the
number of character sites. Number at each character site indicates
the relative mutation rate at that site. Must be fully specified or
None in which case all sites are assumed to evolve at the same rate.
None is the default value for this argument.
solver: Convex optimization solver to use. Can be "SCS", "ECOS", or
"MOSEK". Note that "MOSEK" solver should be installed separately.
verbose: Verbosity level.
Attributes:
mutation_rate: The estimated CRISPR/Cas9 mutation rate, assuming that
the tree has depth exactly 1.
log_likelihood: The log-likelihood of the training data under the
estimated model.
"""
def __init__(
self,
minimum_branch_length: float = 0.01,
relative_mutation_rates: Optional[List[float]] = None,
verbose: bool = False,
solver: str = "SCS",
):
allowed_solvers = ["ECOS", "SCS", "MOSEK"]
if solver not in allowed_solvers:
raise ValueError(
f"Solver {solver} not allowed. "
f"Allowed solvers: {allowed_solvers}"
) # pragma: no cover
self._minimum_branch_length = minimum_branch_length
self._relative_mutation_rates = relative_mutation_rates
self._verbose = verbose
self._solver = solver
self._mutation_rate = None
self._log_likelihood = None
[docs]
def estimate_branch_lengths(self, tree: CassiopeiaTree) -> None:
r"""
MLE under a model of IID memoryless CRISPR/Cas9 mutations.
The only caveat is that this method raises an IIDExponentialMLEError
if the underlying convex optimization solver fails, or a
ValueError if the character matrix is degenerate (fully mutated,
or fully unmutated).
Raises:
IIDExponentialMLEError
ValueError
"""
# Extract parameters
minimum_branch_length = self._minimum_branch_length
relative_mutation_rates = self._relative_mutation_rates
solver = self._solver
verbose = self._verbose
# # # # # Check that the character has at least one mutation # # # # #
if (tree.character_matrix == 0).all().all():
raise ValueError(
"The character matrix has no mutations. Please check your data."
)
# # # # # Check that the character is not saturated # # # # #
if (tree.character_matrix != 0).all().all():
raise ValueError(
"The character matrix is fully mutated. The MLE does not "
"exist. Please check your data."
)
# # # # # Check that the relative_mutation_rates list is valid # # # # #
is_rates_specified = False
if relative_mutation_rates is not None:
is_rates_specified = True
if tree.character_matrix.shape[1] != len(relative_mutation_rates):
raise ValueError(
"The number of character sites does not match the length \
of the provided relative_mutation_rates list. Please check \
your data."
)
for x in relative_mutation_rates:
if x <= 0:
raise ValueError(
f"Relative mutation rates must be strictly positive, \
but you provided: {relative_mutation_rates}"
)
else:
relative_mutation_rates = [1.0] * tree.character_matrix.shape[1]
# Group together sites having the same rate
sites_by_rate = defaultdict(list)
for i in range(len(relative_mutation_rates)):
rate = relative_mutation_rates[i]
sites_by_rate[rate].append(i)
# # # # # Create variables of the optimization problem # # # # #
t_variables = dict(
[
(node_id, cp.Variable(name=f"t_{node_id}"))
for node_id in tree.nodes
]
)
# # # # # Create constraints of the optimization problem # # # # #
a_leaf = tree.leaves[0]
root = tree.root
root_has_time_0_constraint = [t_variables[root] == 0]
minimum_branch_length_constraints = [
t_variables[child]
>= t_variables[parent] + minimum_branch_length * t_variables[a_leaf]
for (parent, child) in tree.edges
]
ultrametric_constraints = [
t_variables[leaf] == t_variables[a_leaf]
for leaf in tree.leaves
if leaf != a_leaf
]
all_constraints = (
root_has_time_0_constraint
+ minimum_branch_length_constraints
+ ultrametric_constraints
)
# # # # # Compute the log-likelihood # # # # #
log_likelihood = 0
for (parent, child) in tree.edges:
edge_length = t_variables[child] - t_variables[parent]
parent_states = tree.get_character_states(parent)
child_states = tree.get_character_states(child)
for rate in sites_by_rate.keys():
num_mutated = 0
num_unmutated = 0
for site in sites_by_rate[rate]:
if parent_states[site] == 0 and child_states[site] == 0:
num_unmutated += 1
elif parent_states[site] != child_states[site]:
if (
parent_states[site] != tree.missing_state_indicator
and child_states[site]
!= tree.missing_state_indicator
):
num_mutated += 1
if num_unmutated > 0:
log_likelihood += num_unmutated * (-edge_length * rate)
if num_mutated > 0:
log_likelihood += num_mutated * cp.log(
1 - cp.exp(-edge_length * rate - 1e-5)
)
# # # # # Solve the problem # # # # #
obj = cp.Maximize(log_likelihood)
prob = cp.Problem(obj, all_constraints)
try:
prob.solve(solver=solver, verbose=verbose)
except cp.SolverError: # pragma: no cover
raise IIDExponentialMLEError("Third-party solver failed")
# # # # # Extract the mutation rate # # # # #
scaling_factor = float(t_variables[a_leaf].value)
if scaling_factor < 1e-8 or scaling_factor > 15.0:
raise IIDExponentialMLEError(
"The solver failed when it shouldn't have."
)
if is_rates_specified:
self._mutation_rate = tuple(
[rate * scaling_factor for rate in relative_mutation_rates]
)
else:
self._mutation_rate = scaling_factor
# # # # # Extract the log-likelihood # # # # #
log_likelihood = float(log_likelihood.value)
if np.isnan(log_likelihood):
log_likelihood = -np.inf
self._log_likelihood = log_likelihood
# # # # # Populate the tree with the estimated branch lengths # # # # #
times = {
node: float(t_variables[node].value) / scaling_factor
for node in tree.nodes
}
# Make sure that the root has time 0 (avoid epsilons)
times[tree.root] = 0.0
# We smooth out epsilons that might make a parent's time greater
# than its child (which can happen if minimum_branch_length=0)
for (parent, child) in tree.depth_first_traverse_edges():
times[child] = max(times[parent], times[child])
tree.set_times(times)
@property
def log_likelihood(self):
"""
The log-likelihood of the training data under the estimated model.
"""
return self._log_likelihood
@property
def mutation_rate(self):
"""
The estimated CRISPR/Cas9 mutation rate(s) under the given model. If
relative_mutation_rates is specified, we return a list of rates (one per
site). Otherwise all sites have the same rate and that rate is returned.
"""
return self._mutation_rate