"""
This file stores generally important functionality for the Cassiopeia-Preprocess
pipeline.
"""
import functools
import itertools
import os
import time
from typing import Callable, Dict, List, Optional, Tuple
import warnings
from collections import defaultdict, OrderedDict
import ngs_tools as ngs
import numpy as np
import pandas as pd
import pylab
import pysam
import re
from tqdm.auto import tqdm
from cassiopeia.mixins import is_ambiguous_state, logger, PreprocessWarning
def log_molecule_table(wrapped: Callable):
"""Function decorator that logs molecule_table stats.
Simple decorator that logs the number of total reads, the number of unique
UMIs, and the number of unique cellBCs in a DataFrame that is returned
from a function.
Args:
wrapped: The wrapped original function. Since this is a function
decorator, this argument is passed implicitly by Python internals.
"""
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
df = wrapped(*args, **kwargs)
umi_count = df["UMI"].dtype != object
logger.debug(
f"Resulting {'alleletable' if umi_count else 'molecule_table'} statistics:"
)
logger.debug(f"# Reads: {df['readCount'].sum()}")
logger.debug(f"# UMIs: {df['UMI'].sum() if umi_count else df.shape[0]}")
logger.debug(f"# Cell BCs: {df['cellBC'].nunique()}")
return df
return wrapper
def log_runtime(wrapped: Callable):
"""Function decorator that logs the start, end and runtime of a function.
Args:
wrapped: The wrapped original function. Since this is a function
decorator, this argument is passed implicitly by Python internals.
"""
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
t0 = time.time()
logger.info("Starting...")
try:
return wrapped(*args, **kwargs)
finally:
logger.info(f"Finished in {time.time() - t0} s.")
return wrapper
def log_kwargs(wrapped: Callable):
"""Function decorator that logs the keyword arguments of a function.
This function only logs keyword arguments because usually the unnamed
arguments contain Pandas DataFrames, which are difficult to log cleanly as
text.
Args:
wrapped: The wrapped original function. Since this is a function
decorator, this argument is passed implicitly by Python internals.
"""
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
logger.debug(f"Keyword arguments: {kwargs}")
return wrapped(*args, **kwargs)
return wrapper
[docs]
@log_molecule_table
def filter_cells(
molecule_table: pd.DataFrame,
min_umi_per_cell: int = 10,
min_avg_reads_per_umi: float = 2.0,
) -> pd.DataFrame:
"""Filter out cell barcodes that have too few UMIs or too few reads/UMI.
Args:
molecule_table: A molecule table of cellBC-UMI pairs to be filtered
min_umi_per_cell: Minimum number of UMIs per cell for cell to not be
filtered. Defaults to 10.
min_avg_reads_per_umi: Minimum coverage (i.e. average) reads per
UMI in a cell needed in order for that cell not to be filtered.
Defaults to 2.0.
Returns:
A filtered molecule table
"""
# Detect if the UMI column contains UMI counts or the actual UMI sequence
umi_count = molecule_table["UMI"].dtype != object
cell_groups = molecule_table.groupby("cellBC")
umis_per_cell = (
cell_groups["UMI"].sum() if umi_count else cell_groups.size()
)
umis_per_cell_mask = umis_per_cell >= min_umi_per_cell
avg_reads_per_umi = cell_groups["readCount"].sum() / umis_per_cell
avg_read_per_umi_mask = avg_reads_per_umi >= min_avg_reads_per_umi
umis_per_cell_passing = set(umis_per_cell_mask.index[umis_per_cell_mask])
avg_read_per_umi_passing = set(
avg_read_per_umi_mask.index[avg_read_per_umi_mask]
)
passing_cells = umis_per_cell_passing & avg_read_per_umi_passing
passing_mask = molecule_table["cellBC"].isin(passing_cells)
n_cells = molecule_table["cellBC"].nunique()
logger.info(
f"Filtered out {n_cells - len(passing_cells)} cells with too few UMIs "
"or too few average number of reads per UMI."
)
molecule_table_filt = molecule_table[~passing_mask]
n_umi_filt = (
molecule_table_filt["UMI"].sum()
if umi_count
else molecule_table_filt.shape[0]
)
logger.info(f"Filtered out {n_umi_filt} UMIs as a result.")
return molecule_table[passing_mask].copy()
[docs]
@log_molecule_table
def filter_umis(
molecule_table: pd.DataFrame, min_reads_per_umi: int = 100
) -> pd.DataFrame:
"""
Filters out UMIs with too few reads.
Filters out all UMIs with a read count <= min_reads_per_umi.
Args:
molecule_table: A molecule table of cellBC-UMI pairs to be filtered
min_reads_per_umi: The minimum read count needed for a UMI to not be
filtered. Defaults to 100.
Returns:
A filtered molecule table
"""
return molecule_table[molecule_table["readCount"] >= min_reads_per_umi]
@log_molecule_table
def error_correct_intbc(
molecule_table: pd.DataFrame,
prop: float = 0.5,
umi_count_thresh: int = 10,
dist_thresh: int = 1,
) -> pd.DataFrame:
"""
Error corrects close intBCs with small enough unique UMI counts.
Considers each pair of intBCs sharing a cellBC in the DataFrame for
correction. For a pair of intBCs, changes all instances of one to other if:
1. They have the same allele.
2. The Levenshtein distance between their sequences is <= dist_thresh.
3. The number of UMIs of the intBC to be changed is <= umi_count_thresh.
4. The proportion of the UMI count in the intBC to be changed out of the
total UMI count in both intBCs <= prop.
Note:
Prop should be <= 0.5, as this algorithm only corrects intBCs with
fewer/equal UMIs towards intBCs with more UMIs. Additionally, if
multiple intBCs are within the distance threshold of an intBC, it
corrects the intBC towards the intBC with the most UMIs.
Args:
molecule_table: A molecule table of cellBC-UMI pairs to be filtered
prop: proportion by which to filter integration barcodes
umi_count_thresh: maximum umi count for which to correct barcodes
dist_thresh: barcode distance threshold, to decide what's similar
enough to error correct
Returns:
Filtered molecule table with error corrected intBCs
"""
if prop > 0.5:
warnings.warn(
"No intBC correction was done because `prop` is greater than 0.5.",
PreprocessWarning,
)
return molecule_table
cellBC_intBC_allele_groups = molecule_table.groupby(
["cellBC", "intBC", "allele"], sort=False
)
cellBC_intBC_allele_indices = cellBC_intBC_allele_groups.groups
molecule_table_agg = (
cellBC_intBC_allele_groups.agg({"UMI": "count", "readCount": "sum"})
.sort_values("UMI", ascending=False)
.reset_index()
)
for (cellBC, allele), intBC_table in tqdm(
molecule_table_agg.groupby(["cellBC", "allele"], sort=False),
desc="Error Correcting intBCs",
):
# NOTE: row1 UMIs >= row2 UMIs because groupby operations preserve
# row orders
for i1 in range(intBC_table.shape[0]):
row1 = intBC_table.iloc[i1]
intBC1 = row1["intBC"]
UMI1 = row1["UMI"]
for i2 in range(i1 + 1, intBC_table.shape[0]):
row2 = intBC_table.iloc[i2]
intBC2 = row2["intBC"]
UMI2 = row2["UMI"]
total_count = UMI1 + UMI2
proportion = UMI2 / total_count
distance = ngs.sequence.levenshtein_distance(intBC1, intBC2)
# Correct
if (
distance <= dist_thresh
and proportion < prop
and UMI2 <= umi_count_thresh
):
key_to_correct = (cellBC, intBC2, allele)
molecule_table.loc[
cellBC_intBC_allele_indices[key_to_correct], "intBC"
] = intBC1
return molecule_table
def record_stats(
molecule_table: pd.DataFrame,
) -> Tuple[np.array, np.array, np.array]:
"""
Simple function to record the number of UMIs.
Args:
molecule_table: A DataFrame of alignments
Returns:
Read counts for each alignment, number of unique UMIs per intBC, number
of UMIs per cellBC
"""
umis_per_intBC = (
molecule_table.groupby(["cellBC", "intBC"], sort=False).size().values
)
umis_per_cellBC = molecule_table.groupby("cellBC", sort=False).size().values
return (
molecule_table["readCount"].values,
umis_per_intBC,
umis_per_cellBC,
)
def convert_bam_to_df(data_fp: str) -> pd.DataFrame:
"""Converts a BAM file to a Pandas dataframe.
Args:
data_fp: The input filepath for the BAM file to be converted.
Returns:
A Pandas dataframe containing the BAM information.
"""
als = []
with pysam.AlignmentFile(
data_fp, ignore_truncation=True, check_sq=False
) as bam_fh:
for al in bam_fh:
cellBC, UMI, readCount, grpFlag = al.query_name.split("_")
seq = al.query_sequence
qual = al.query_qualities
encode_qual = pysam.array_to_qualitystring(qual)
als.append(
[
cellBC,
UMI,
int(readCount),
grpFlag,
seq,
encode_qual,
al.query_name,
]
)
return pd.DataFrame(
als,
columns=[
"cellBC",
"UMI",
"readCount",
"grpFlag",
"seq",
"qual",
"readName",
],
)
[docs]
def convert_alleletable_to_character_matrix(
alleletable: pd.DataFrame,
ignore_intbcs: List[str] = [],
allele_rep_thresh: float = 1.0,
missing_data_allele: Optional[str] = None,
missing_data_state: int = -1,
mutation_priors: Optional[pd.DataFrame] = None,
cut_sites: Optional[List[str]] = None,
collapse_duplicates: bool = True,
) -> Tuple[
pd.DataFrame, Dict[int, Dict[int, float]], Dict[int, Dict[int, str]]
]:
"""Converts an AlleleTable into a character matrix.
Given an AlleleTable storing the observed mutations for each intBC / cellBC
combination, create a character matrix for input into a CassiopeiaSolver
object. By default, we codify uncut mutations as '0' and missing data items
as '-1'. The function also have the ability to ignore certain intBC sets as
well as cut sites with too little diversity.
Args:
alleletable: Allele Table to be converted into a character matrix
ignore_intbcs: A set of intBCs to ignore
allele_rep_thresh: A threshold for removing target sites that have an
allele represented by this proportion
missing_data_allele: Value in the allele table that indicates that the
cut-site is missing. This will be converted into
``missing_data_state``
missing_data_state: A state to use for missing data.
mutation_priors: A table storing the prior probability of a mutation
occurring. This table is used to create a character matrix-specific
probability dictionary for reconstruction.
cut_sites: Columns in the AlleleTable to treat as cut sites. If None,
we assume that the cut-sites are denoted by columns of the form
"r{int}" (e.g. "r1")
collapse_duplicates: Whether or not to collapse duplicate character
states present for a single cellBC-intBC pair. This option has no
effect if there are no allele conflicts. Defaults to True.
Returns:
A character matrix, a probability dictionary, and a dictionary mapping
states to the original mutation.
"""
if cut_sites is None:
cut_sites = get_default_cut_site_columns(alleletable)
filtered_samples = defaultdict(OrderedDict)
for sample in alleletable.index:
cell = alleletable.loc[sample, "cellBC"]
intBC = alleletable.loc[sample, "intBC"]
if intBC in ignore_intbcs:
continue
for i, c in enumerate(cut_sites):
if intBC not in ignore_intbcs:
filtered_samples[cell].setdefault(f"{intBC}{c}", []).append(
alleletable.loc[sample, c]
)
character_strings = defaultdict(list)
allele_counter = defaultdict(OrderedDict)
_intbc_uniq = set()
allele_dist = defaultdict(list)
for s in filtered_samples:
for key in filtered_samples[s]:
_intbc_uniq.add(key)
allele_dist[key].extend(list(set(filtered_samples[s][key])))
# remove intBCs that are not diverse enough
intbc_uniq = []
dropped = []
for key in allele_dist.keys():
props = np.unique(allele_dist[key], return_counts=True)[1]
props = props / len(allele_dist[key])
if np.any(props > allele_rep_thresh):
dropped.append(key)
else:
intbc_uniq.append(key)
print(
"Dropping the following intBCs due to lack of diversity with threshold "
+ str(allele_rep_thresh)
+ ": "
+ str(dropped)
)
prior_probs = defaultdict(dict)
indel_to_charstate = defaultdict(dict)
# for all characters
for i in tqdm(range(len(list(intbc_uniq))), desc="Processing characters"):
c = list(intbc_uniq)[i]
indel_to_charstate[i] = {}
# for all samples, construct a character string
for sample in filtered_samples.keys():
if c in filtered_samples[sample]:
# This is a list of states
states = filtered_samples[sample][c]
transformed_states = []
for state in states:
if type(state) != str and np.isnan(state):
transformed_states.append(missing_data_state)
continue
if state == "NONE" or "None" in state:
transformed_states.append(0)
elif (
missing_data_allele is not None
and state == missing_data_allele
):
transformed_states.append(missing_data_state)
else:
if state in allele_counter[c]:
transformed_states.append(allele_counter[c][state])
else:
# if this is the first time we're seeing the state for this character,
# add a new entry to the allele_counter
allele_counter[c][state] = (
len(allele_counter[c]) + 1
)
transformed_states.append(allele_counter[c][state])
indel_to_charstate[i][
len(allele_counter[c])
] = state
# add a new entry to the character's probability map
if mutation_priors is not None:
prob = np.mean(
mutation_priors.loc[state, "freq"]
)
prior_probs[i][len(allele_counter[c])] = float(
prob
)
if collapse_duplicates:
# Sort for testing
transformed_states = sorted(set(transformed_states))
transformed_states = tuple(transformed_states)
if len(transformed_states) == 1:
transformed_states = transformed_states[0]
character_strings[sample].append(transformed_states)
else:
character_strings[sample].append(missing_data_state)
character_matrix = pd.DataFrame.from_dict(
character_strings,
orient="index",
columns=[f"r{i}" for i in range(1, len(intbc_uniq) + 1)],
)
return character_matrix, prior_probs, indel_to_charstate
[docs]
def convert_alleletable_to_lineage_profile(
allele_table,
cut_sites: Optional[List[str]] = None,
collapse_duplicates: bool = True,
) -> pd.DataFrame:
"""Converts an AlleleTable to a lineage profile.
Takes in an allele table that summarizes the indels observed at individual
cellBC-intBC pairs and produces a lineage profile, which essentially is a
pivot table over the cellBC / intBCs. Conceptually, these lineage profiles
are identical to character matrices, only the values in the matrix are the
actual indel identities.
Args:
allele_table: AlleleTable.
cut_sites: Columns in the AlleleTable to treat as cut sites. If None,
we assume that the cut-sites are denoted by columns of the form
"r{int}" (e.g. "r1")
collapse_duplicates: Whether or not to collapse duplicate character
states present for a single cellBC-intBC pair. This option has no
effect if there are no allele conflicts. Defaults to True.
Returns:
An NxM lineage profile.
"""
if cut_sites is None:
cut_sites = get_default_cut_site_columns(allele_table)
agg_recipe = dict(
zip([cutsite for cutsite in cut_sites], [list] * len(cut_sites))
)
g = allele_table.groupby(["cellBC", "intBC"]).agg(agg_recipe)
intbcs = allele_table["intBC"].unique()
# create mutltindex df by hand
i1 = []
for i in intbcs:
i1 += [i] * len(cut_sites)
i2 = list(cut_sites) * len(intbcs)
indices = [i1, i2]
allele_piv = pd.DataFrame(index=g.index.levels[0], columns=indices)
for j in tqdm(g.index, desc="filling in multiindex table"):
for val, cutsite in zip(g.loc[j], cut_sites):
if collapse_duplicates:
# Sort for testing
val = sorted(set(val))
val = tuple(val)
if len(val) == 1:
val = val[0]
allele_piv.loc[j[0]][j[1], cutsite] = val
allele_piv2 = pd.pivot_table(
allele_table,
index=["cellBC"],
columns=["intBC"],
values="UMI",
aggfunc=pylab.size,
)
col_order = (
allele_piv2.dropna(axis=1, how="all")
.sum()
.sort_values(ascending=False, inplace=False)
.index
)
lineage_profile = allele_piv[col_order]
# collapse column names here
lineage_profile.columns = [
"_".join(tup).rstrip("_") for tup in lineage_profile.columns.values
]
return lineage_profile
[docs]
def convert_lineage_profile_to_character_matrix(
lineage_profile: pd.DataFrame,
indel_priors: Optional[pd.DataFrame] = None,
missing_allele_indicator: Optional[str] = None,
missing_state_indicator: int = -1,
) -> Tuple[
pd.DataFrame, Dict[int, Dict[int, float]], Dict[int, Dict[int, str]]
]:
"""Converts a lineage profile to a character matrix.
Takes in a lineage profile summarizing the explicit indel identities
observed at each cut site in a cell and converts this into a character
matrix where the indels are abstracted into integers.
Note:
The lineage profile is converted directly into a character matrix,
without performing any collapsing of duplicate states. Instead, this
should have been done in the previous step, when calling
:func:`convert_alleletable_to_lineage_profile`.
Args:
lineage_profile: Lineage profile
indel_priors: Dataframe mapping indels to prior probabilities
missing_allele_indicator: An allele that is being used to represent
missing data.
missing_state_indicator: State to indicate missing data
Returns:
A character matrix, prior probability dictionary, and mapping from
character/state pairs to indel identities.
"""
prior_probs = defaultdict(dict)
indel_to_charstate = defaultdict(dict)
lineage_profile = lineage_profile.copy()
lineage_profile = lineage_profile.fillna("Missing").copy()
if missing_allele_indicator:
lineage_profile.replace(
{missing_allele_indicator: "Missing"}, inplace=True
)
samples = []
lineage_profile.columns = [f"r{i}" for i in range(lineage_profile.shape[1])]
column_to_unique_values = dict(
zip(
lineage_profile.columns,
[
lineage_profile[x].factorize()[1].values
for x in lineage_profile.columns
],
)
)
column_to_number = dict(
zip(lineage_profile.columns, range(lineage_profile.shape[1]))
)
mutation_counter = dict(
zip(lineage_profile.columns, [0] * lineage_profile.shape[1])
)
mutation_to_state = defaultdict(dict)
for col in column_to_unique_values.keys():
c = column_to_number[col]
indel_to_charstate[c] = {}
for indels in column_to_unique_values[col]:
if not is_ambiguous_state(indels):
indels = (indels,)
for indel in indels:
if indel == "Missing" or indel == "NC":
mutation_to_state[col][indel] = -1
elif "none" in indel.lower():
mutation_to_state[col][indel] = 0
elif indel not in mutation_to_state[col]:
mutation_to_state[col][indel] = mutation_counter[col] + 1
mutation_counter[col] += 1
indel_to_charstate[c][mutation_to_state[col][indel]] = indel
if indel_priors is not None:
prob = np.mean(indel_priors.loc[indel]["freq"])
prior_probs[c][mutation_to_state[col][indel]] = float(
prob
)
# Helper function to apply to lineage profile
def apply_mutation_to_state(x):
column = []
for v in x.values:
if is_ambiguous_state(v):
column.append(tuple(mutation_to_state[x.name][_v] for _v in v))
else:
column.append(mutation_to_state[x.name][v])
return column
character_matrix = lineage_profile.apply(apply_mutation_to_state, axis=0)
character_matrix.index = lineage_profile.index
character_matrix.columns = [
f"r{i}" for i in range(lineage_profile.shape[1])
]
return character_matrix, prior_probs, indel_to_charstate
[docs]
def convert_character_matrix_to_allele_table(
character_matrix: pd.DataFrame,
state_to_indel: Optional[Dict[int, Dict[int, str]]] = None,
keep_ambiguous: bool = False,
):
"""Converts a character matrix back into an allele table.
Args:
character_matrix: A dataframe storing the character states for each
spot.
state_to_indel: A mapping of numerical states into indel strings.
keep_ambiguous: A boolean whether to keep ambiguous states for a given
intBC
Returns:
An allele table.
"""
allele_table = pd.DataFrame(
columns=["cellBC", "intBC", "allele", "r1", "UMI"]
)
def disambiguate_allele(char, allele):
if type(allele) == tuple:
if keep_ambiguous:
all_alleles = [
state_to_indel[int(char)][a] if a != 0 else "None"
for a in allele
]
return all_alleles
else:
allele = allele[0]
if allele == 0:
return ["None"]
if state_to_indel:
return [state_to_indel[int(char)][allele]]
return [allele]
for cell in tqdm(character_matrix.index):
alleles = character_matrix.loc[cell].values
non_missing_iid = np.where(alleles != -1)[0]
intbcs = []
all_alleles = []
for iid in non_missing_iid:
_alleles = disambiguate_allele(iid, alleles[iid])
intbcs += [f"intbc-{iid}" for _ in range(len(_alleles))]
all_alleles += _alleles
cellbcs = [cell] * len(intbcs)
umis = [1] * len(intbcs)
new_rows = pd.DataFrame(
[cellbcs, intbcs, all_alleles, all_alleles, umis],
index=allele_table.columns,
).T
allele_table = pd.concat([allele_table, new_rows])
return allele_table
[docs]
def compute_empirical_indel_priors(
allele_table: pd.DataFrame,
grouping_variables: List[str] = ["intBC"],
cut_sites: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Computes indel prior probabilities.
Generates indel prior probabilities from the input allele table. The general
idea behind this procedure is to count the number of times an indel
independently occur. By default, we treat each intBC as an independent,
which is true if the input allele table is a clonal population. Here, the
procedure will count the number of intBCs that contain a particular indel
and divide by the number of intBCs in the allele table. However, a user can
be more nuanced in their analysis and group intBC by other variables, such
as lineage group (this is especially useful if intBCs might occur several
clonal populations). Then, the procedure will count the number of times an
indel occurs in a unique lineage-intBC combination.
Args:
allele_table: AlleleTable
grouping_variables: Variables to stratify data by, to treat as
independent groups in counting indel occurrences. These must be
columns in the allele table
cut_sites: Columns in the AlleleTable to treat as cut sites. If None,
we assume that the cut-sites are denoted by columns of the form
"r{int}" (e.g. "r1")
Returns:
A DataFrame mapping indel identities to the probability.
"""
if cut_sites is None:
cut_sites = get_default_cut_site_columns(allele_table)
agg_recipe = dict(
zip([cut_site for cut_site in cut_sites], ["unique"] * len(cut_sites))
)
groups = allele_table.groupby(grouping_variables).agg(agg_recipe)
indel_count = defaultdict(int)
for g in groups.index:
alleles = np.unique(np.concatenate(groups.loc[g].values))
for a in alleles:
if "none" not in a.lower():
indel_count[a] += 1
tot = len(groups.index)
indel_freqs = dict(
zip(list(indel_count.keys()), [v / tot for v in indel_count.values()])
)
indel_priors = pd.DataFrame([indel_count, indel_freqs]).T
indel_priors.columns = ["count", "freq"]
indel_priors.index.name = "indel"
return indel_priors
def get_default_cut_site_columns(allele_table: pd.DataFrame) -> List[str]:
"""Retrieves the default cut-sites columns of an AlleleTable.
A basic helper function that will retrieve the cut-sites from an AlleleTable
if the AlleleTable was created using the Cassiopeia pipeline. In this case,
each cut-site is denoted by an integer preceded by the character "r", for
example "r1" or "r2".
Args:
allele_table: AlleleTable
Return:
Columns in the AlleleTable corresponding to the cut sites.
"""
cut_sites = [
column
for column in allele_table.columns
if bool(re.search(r"r\d", column))
]
return cut_sites