Source code for cassiopeia.spatial.spatial_utilities

"""
Utilities for spatial lineage-tracing module.
"""

from typing import Tuple

import anndata
import networkx as nx
import numpy as np
import pandas as pd

from cassiopeia.mixins import try_import


[docs] def get_spatial_graph_from_anndata( adata: anndata.AnnData, neighborhood_radius: int = 30.0, neighborhood_size: float | None = None, connect_key: str = "spatial" ) -> nx.DiGraph: """Get a spatial graph structure from an spatial anndata Construct a spatial graph connecting each node to its nearest neighbors in space. Assumes that the specified adata has spatial coordinates specified in the `.obsm` key. Args: adata: Anndata of spatially-resolved data. Only the spatial coordinates need to be stored, and this is used to construct a graph. neighborhood_size: If a connectivitity graph is being constructed, this is the number of nearest neighbors to connect to a node. neighborhood_radius: Intead of passing in `neighborhood_size`, this is the radius of the connectivity graph. connect_key: Key used to store spatial connectivities in `adata.obsp`. This will be passed into the `key_added` argument of sq.gr.spatial_neighbors and an etnry in `adata.obsp` will be added of the form `{connect_key}_connectivities`. Returns: A networkx object storing the spatial graph. """ # Optional dependencies that are required for 3D plotting sq = try_import("squidpy") if sq is None: raise Exception("If you would like to infer a spatial graph, please " "install squidpy.") # create spatial graph if needed if neighborhood_size: sq.gr.spatial_neighbors( adata, coord_type="generic", spatial_key="spatial", n_neighs=neighborhood_size, key_added=connect_key, ) else: sq.gr.spatial_neighbors( adata, coord_type="generic", spatial_key="spatial", radius=neighborhood_radius, key_added=connect_key, ) spatial_graph = nx.from_numpy_array(adata.obsp[f'{connect_key}_connectivities']) node_map = dict( zip( range(adata.obsp[f'{connect_key}_connectivities'].shape[0]), adata.obs_names, ) ) spatial_graph = nx.relabel_nodes(spatial_graph, node_map) return spatial_graph
def impute_single_state( cell: str, character: int, character_matrix: pd.DataFrame, neighborhood_graph: nx.DiGraph = None, number_of_hops: int = 1, max_neighbor_distance: float = np.inf, coordinates: pd.DataFrame | None = None, ) -> Tuple[int, float, int]: """Imputes missing character state for a cell at a defined position. Args: cell: Cell barcode character: Which character to impute. character_matrix: Character matrix of all character states adata: Anndata object with spatial nearest neighbors number_of_hops: Number of hops to make during imputation. max_neighbor_distance: Maximum distance to neighbor to be used for imputation. Returns: The state, the frequency of votes, and the absolute number of votes """ votes = [] for _, node in nx.bfs_edges( neighborhood_graph, cell, depth_limit=number_of_hops ): if node not in character_matrix.index: continue distance = 0 if not (coordinates is None): distance = np.sqrt( np.sum( ( coordinates.loc[cell].values - coordinates.loc[node].values ) ** 2 ) ) state = character_matrix.loc[node].iloc[character] if distance <= max_neighbor_distance and state != -1: if type(state) == tuple: for _state in state: votes.append(_state) else: votes.append(state) if len(votes) > 0: values, counts = np.unique(votes, return_counts=True) return ( values[np.argmax(counts)], np.max(counts) / np.sum(counts), np.max(counts), ) return -1, 0, 0