Source code for cassiopeia.spatial.spatial_imputation
"""
Functionality for spatial imputation.
"""
import anndata
import networkx as nx
import numpy as np
import pandas as pd
import tqdm
from cassiopeia.spatial import spatial_utilities
[docs]
def impute_alleles_from_spatial_data(
character_matrix: pd.DataFrame,
adata: anndata.AnnData | None = None,
spatial_graph: nx.Graph | None = None,
neighborhood_size: int | None = None,
neighborhood_radius: float = 30.0,
imputation_hops: int = 2,
imputation_concordance: float = 0.8,
num_imputation_iterations: int = 1,
max_neighbor_distance: float = np.inf,
coordinates: pd.DataFrame | None = None,
connect_key: str = "spatial_connectivities"
) -> pd.DataFrame:
"""Imputes data based on spatial location.
This procedure iterates over spots in a given anndata and imputes missing
data based on spatial neigborhoods.
The procedure is an interative algorithm where for each iteration a missing
character is imputed based on the neighborhood of a node. Node neighborhoods
are built off of a spatial graph that can passed in directly or inferred
from a specified spatial anndata. In the simplest example, node neighobrhoods
are just the immediate neighbors of a node, but users can also include
neighbors-of-neighbors (and neighbors-of-neighbors-of-neighbors, so on)
using the `imputation_hops` argument. An imputation will be accepted if the
fraction of neighbors that agree with an observed non-zero allele exceeds
the specified `imputation_concordance` threshold (by default 0.8). This
procedure can be repeated for several rounds, controlled by the
`num_imputation_iterations` argument and in this way approximates a
message-passing process.
Args:
character_matrix: A character matrix of spots, constructed using a
function like `convert_allele_table_to_character_matrix`.
adata: Anndata of spatially-resolved data. Only the spatial coordinates
need to be stored, and this is used to construct a graph.
spatial_graph: Optionally, the user can provide a spatial connectivity
graph instead of passing in an adata.
neighborhood_size: If a connectivitity graph is being constructed,
this is the number of nearest neighbors to connect to a node. If
both neighborhood_size and neighborhood_radius are passed in,
neighborhood_size is preferred.
neighborhood_radius: Intead of passing in `neighborhood_size`, this
is the radius of the connectivity graph.
imputation_hops: Number of adjacent node's adjacencies to query. For
example, if this is 2, this means that imputation is done not just
on nearest neighbors of a given node, but also each nearest
neighbor's nearest neighbors.
imputation_concordance: Fraction of votes that must agree in order
to accept an imputation.
num_imputation_iterations: Number of iterations for imputation
procedure.
max_neighbor_distance: Maximum distance to neighbor to be used for
imputation.
coordinates: If an AnnData is not specified, and you wish to set an
upper limit on the distance for spatial imputation, these
coordinates can be passed to the imputation procedure.
connect_key: Key used to store spatial connectivities in
`adata.obsp`. This will be passed into the `key_added` argument
of sq.gr.spatial_neighbors and an etnry in `adata.obsp` will be added
of the form `{connect_key}_connectivities`.
Returns:
An imputed character matrix.
"""
if (not spatial_graph) and (not adata):
raise Exception(
"One of the following must be specified: "
"`spatial_graph` or `adata`."
)
if not spatial_graph:
# create spatial graph if needed
spatial_graph = spatial_utilities.get_spatial_graph_from_anndata(
adata,
neighborhood_radius=neighborhood_radius,
neighborhood_size=neighborhood_size,
connect_key=connect_key,
)
prev_character_matrix_imputed = character_matrix.copy()
missing_indices = np.where(character_matrix == -1)
for _round in range(num_imputation_iterations):
print(f">> Imputation round {_round+1}...")
character_matrix_imputed = prev_character_matrix_imputed.copy()
missing_indices = np.where(prev_character_matrix_imputed == -1)
for i, j in tqdm.tqdm(
zip(missing_indices[0], missing_indices[1]),
total=len(missing_indices[0]),
):
(imputed_value, proportion_of_votes, number_of_votes) = (
spatial_utilities.impute_single_state(
prev_character_matrix_imputed.index.values[i],
j,
prev_character_matrix_imputed,
neighborhood_graph=spatial_graph,
number_of_hops=imputation_hops,
max_neighbor_distance=max_neighbor_distance,
coordinates=coordinates,
)
)
if (
proportion_of_votes >= imputation_concordance
and imputed_value != -1
and imputed_value != 0
):
character_matrix_imputed.iloc[i, j] = int(imputed_value)
prev_character_matrix_imputed = character_matrix_imputed.copy()
# apply final missingness filter
final_character_matrix = character_matrix_imputed
return final_character_matrix