import warnings
from collections import deque
from functools import partial
from hashlib import sha256
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from matplotlib.colors import (
hsv_to_rgb,
rgb_to_hsv,
to_rgb,
to_rgba,
to_rgba_array,
)
from . import palettes
from ..data import CassiopeiaTree
from ..mixins import PlottingError, PlottingWarning, try_import
# Optional dependencies that are required for 3D plotting
cv2 = try_import("cv2")
pv = try_import("pyvista")
measure = try_import("skimage.measure")
vtk = try_import("vtk")
def interpolate_branch(
parent: Tuple[float, float, float], child: Tuple[float, float, float]
) -> np.ndarray:
"""Helper function to interpolate the branch between a parent and child.
The branch is interpolated in such a way that there is a 90 degree angle.
Args:
parent: Parent coordinates as a triplet
child: Child coordinates as a triplet
Returns:
Numpy array containing x, y, z coordinates of the branch.
"""
x1, y1, z1 = parent
x2, y2, z2 = child
return np.array(
[
(x1, y1, z1),
(x2, y2, z1),
(x2, y2, z2),
]
)
def polyline_from_points(points: np.ndarray) -> "pv.PolyData":
"""Helper function to create a Pyvista object connected a set of points.
Args:
points: Points to create the Pyvista object from
Returns:
A Pyvista.PolyData object
"""
poly = pv.PolyData()
poly.points = points
the_cell = np.arange(0, len(points), dtype=np.int_)
the_cell = np.insert(the_cell, 0, len(points))
poly.lines = the_cell
return poly
def average_mixing(*c):
"""Helper function to mix a set of colors by averaging each channel."""
return tuple(to_rgba_array(c)[:, :3].mean(axis=0))
def highlight(c) -> Tuple[float, float, float]:
"""Helper function to highlight a certain color."""
hsv = rgb_to_hsv(c)
hsv[2] = min(hsv[2] + 0.5, 1)
return hsv_to_rgb(hsv)
def lowlight(c) -> Tuple[float, float, float]:
"""Helper function to dim out a certain color."""
hsv = rgb_to_hsv(c)
hsv[2] = max(hsv[2] - 0.5, 0)
return hsv_to_rgb(hsv)
[docs]
def labels_from_coordinates(
tree: CassiopeiaTree,
attribute_key: str = "spatial",
shape: tuple = (1000,1000)
) -> np.ndarray:
"""Create a (synthetic) labels Numpy array for use with 3D plotting.
Cells are represented as circles with a radius scaled by the square
root of the number of cells. The center of each circle is the cell's
spatial coordinates.
Args:
tree: CassiopieaTree to generate labels for
attribute_key: Attribute name in the `cell_meta` of the tree containing
coordinates. All columns of the form `{attribute_key}_i` where `i`
is an integer `0...` will be used.
shape: Shape of the array to generate. This should be a tuple of two
integers, representing the height and width of the array.
Returns:
A synthetic labels array that can be used for 3D plotting.
Raises:
PlottingError if there are not exactly two spatial coordinates.
"""
# check that tree contains cell meta
if tree.cell_meta is None:
raise ValueError("CassiopeiaTree must contain cell meta.")
if f"{attribute_key}_0" not in tree.cell_meta.columns:
raise ValueError(
f"Attribute key {attribute_key} not found in cell meta."
)
# parse meta columns
meta = tree.cell_meta.copy()
columns = []
i = 0
while True:
column = f"{attribute_key}_{i}"
if column in meta.columns:
columns.append(column)
i += 1
else:
break
if len(columns) != 2:
raise ValueError(
f"Only 2-dimensional data is supported, but found {len(columns)} "
"dimensions."
)
# check that shape is length 2
if len(shape) != 2:
raise ValueError(
f"Shape must be a tuple of length 2, "
f"but found {len(shape)}."
)
# check that shape is int
if not isinstance(shape[0], int) or not isinstance(shape[1], int):
raise ValueError(
f"Shape must be an integer tuple, "
f"but found {shape[0]} x {shape[1]}."
)
# check that shape is at least 10 x 10
if shape[0] < 10 or shape[1] < 10:
raise ValueError(
f"Shape must be at least 10 x 10,"
f"but found {shape[0]} x {shape[1]}."
f"When shape is too small, cells will overlap."
)
# Normalize coordinates to [.05,.95]
coordinates = meta[columns].values
min_val = coordinates.min(axis=0)
max_val = coordinates.max(axis=0)
normalized_coordinates = .05 + ((coordinates - min_val)/
(max_val - min_val)) * .9
# Compute scale based on number of points
scale = int(np.min(shape)/(np.sqrt(coordinates.shape[0])*3))
scale = min(scale, 100)
scale = max(scale, 1)
# Generate labels array by drawing circles
labels = np.zeros(shape, dtype=int)
leaf_to_label = {}
for leaf, coord in zip(meta.index, normalized_coordinates):
center = tuple((coord * np.array(shape)).astype(int))
ellipse = cv2.ellipse(
np.zeros(shape, dtype=np.uint8),
center,
(scale,scale),
0,
0,
360,
1,
-1,
).astype(bool)
ellipse[center] = True
label = len(leaf_to_label) + 1
labels[ellipse] = label
leaf_to_label[leaf] = label
# Make sure centers of each leaf is always that leaf
for leaf, coord in zip(meta.index, coordinates):
labels[tuple(coord.astype(int))] = leaf_to_label[leaf]
# Add label column
meta[f"{attribute_key}_label"] = meta.index.map(leaf_to_label)
tree.cell_meta = meta
return labels
[docs]
class Tree3D:
"""Create a 3D projection of a tree onto a 2D surface.
This class provides various wrappers around Pyvista, which is used for 3D
rendering.
Example:
# When labels aren't available, they can be synthetically created by
# using `labels_from_coordinates`, as so. The tree must contain spatial
# coordinates in the cell meta.
labels = cas.pl.labels_from_coordinates(tree)
tree3d = cas.pl.Tree3D(tree, labels)
tree3d.add_image(img) # img is a np.array with the same shape as labels
tree3d.plot()
Args:
tree: The Cassiopeia tree to plot. The leaf names must be string-casted
integers.
labels: Optional numpy array containing cell labels on a 2D surface.
This array must contain all the cells in the `tree`, but as
integers. In None, the labels will be generated synthetically using
`labels_from_coordinates`.
offset: Offset to give to tree and subclone shading. This option exists
because in some cases if the tree and subclone shading is placed at
the same height as the image, weird clipping happens.
downscale: Downscale all images by this amount. This option is
recommended for more responsive visualization.
cmap: Colormap to use. Defaults to the Godsnot color palette, as defined
at https://github.com/scverse/scanpy/blob/master/scanpy/plotting/palettes.py
attribute_key: Attribute key to use as the integer labels for each leaf
node. A column with the name `{attribute_key}_label` will be looked
up in the cell meta.
"""
def __init__(
self,
tree: CassiopeiaTree,
labels: np.ndarray = None,
offset: float = 1.0,
downscale: float = 1.0,
cmap: Optional[List[str]] = None,
attribute_key: str = "spatial",
):
# Check optional dependencies.
if None in (cv2, pv):
raise PlottingError(
"Some required modules were not found. Make sure you installed "
"Cassiopeia with the `spatial` extras, or run `pip install "
"cassiopeia-lineage[spatial]`."
)
# Caches. These come first because initialization may cache stuff.
self.cut_tree_cache = {}
self.place_nodes_cache = {}
self.tree = tree.copy()
self.offset = offset
self.downscale = downscale
if labels is None:
self.labels = labels_from_coordinates(self.tree, attribute_key)
else:
self.labels = labels
self.scale = max(*self.labels.shape) * downscale
self.init_label_mapping(f"{attribute_key}_label")
self.plotter = pv.Plotter()
self.node_actors = {}
self.branch_actors = {}
self.subclone_actor = None
self.text_actors = {}
self.node_colors = {}
resized_labels = cv2.resize(
self.labels,
None,
fx=downscale,
fy=downscale,
interpolation=cv2.INTER_NEAREST,
)
self.images = {}
self.image_dims = resized_labels.shape + (1,)
self.image_actors = {}
self.init_nodes()
self.init_subclones()
# Initial visualization values
self.root = tree.root
self.time = self.tree.get_time(self.root)
self.shown_images = []
self.subclone_sigma = self.scale / 40
self.leaves = sorted(self.cut_tree(self.root, self.time))
self.show_nodes = False
self.show_text = False
self.selected_node = None
# Widget values
self.checkbox_size = 30
self.checkbox_border_size = 2
# Colormap to use. Colors are cycled through when more are needed.
self.cmap = cmap or palettes.godsnot_102
[docs]
def init_label_mapping(self, key):
"""Initialize label mappings."""
# Construct leaf-to-label mapping
self.leaf_to_label = {}
if self.tree.cell_meta is None or key not in self.tree.cell_meta:
warnings.warn(
f"Failed to locate {key} column in cell meta. "
"Leaf names casted as integers will be used as the labels.",
PlottingWarning,
)
self.leaf_to_label = {leaf: int(leaf) for leaf in self.tree.leaves}
else:
self.leaf_to_label = dict(self.tree.cell_meta[key])
self.label_to_leaf = {
label: leaf for leaf, label in self.leaf_to_label.items()
}
# Check labels
if not np.isin(
[self.leaf_to_label[leaf] for leaf in self.tree.leaves], self.labels
).all():
raise PlottingError(
"Label array must contain all leaves in the tree."
)
self.regionprops = {
prop.label: prop for prop in measure.regionprops(self.labels)
}
[docs]
def init_nodes(self):
"""Initialize node information."""
self.nodes = self.tree.nodes
self.times = self.tree.get_times()
self.node_index = {node: i for i, node in enumerate(self.nodes)}
self.node_coordinates = np.full((len(self.tree.nodes), 3), np.nan)
areas = np.zeros(self.node_coordinates.shape[0], dtype=int)
for label, props in self.regionprops.items():
leaf = self.label_to_leaf[label]
if leaf in self.tree.leaves:
i = self.node_index[leaf]
areas[i] = props.area
self.node_coordinates[i] = props.centroid + (0,)
queue = deque(set(self.tree.parent(leaf) for leaf in self.tree.leaves))
processed = set(self.tree.leaves)
while queue:
node = queue.popleft()
children = self.tree.children(node)
if not all(child in processed for child in children):
queue.append(node)
continue
i = self.node_index[node]
children_indices = [self.node_index[child] for child in children]
child_coordinates = self.node_coordinates[children_indices]
child_areas = areas[children_indices]
self.node_coordinates[i] = (
child_coordinates * child_areas.reshape(-1, 1)
).sum(axis=0) / child_areas.sum()
areas[i] = child_areas.sum()
processed.add(node)
if not self.tree.is_root(node):
queue.append(self.tree.parent(node))
assert not np.isnan(self.node_coordinates).any()
[docs]
def init_subclones(self):
"""Initialize subclone grid."""
self.subclones = self.create_grid()
self.subclones.origin = (0, 0, 1)
[docs]
def get_mask(self, node: str) -> np.ndarray:
"""Helper function to get a boolean mask of where certain subclades are.
Args:
node: Node name to select cells
Returns:
Boolean mask where True indicates cells in the subclade marked by
`node`
"""
regionprops = []
if self.tree.is_leaf(node):
regionprops.append(self.regionprops[self.leaf_to_label[node]])
else:
for leaf in self.tree.leaves_in_subtree(node):
regionprops.append(self.regionprops[self.leaf_to_label[leaf]])
mask = np.zeros(self.labels.shape, dtype=bool)
for props in regionprops:
label = props.label
min_row, min_col, max_row, max_col = props.bbox
label_mask = self.labels[min_row:max_row, min_col:max_col] == label
mask[min_row:max_row, min_col:max_col][label_mask] = True
return mask
[docs]
def create_grid(self) -> "pv.ImageData":
"""Helper function to create a Pyvista ImageData object with the
appropriate shape.
"""
return pv.ImageData(dimensions=self.image_dims)
[docs]
def add_image(self, key: str, img: np.ndarray):
"""Add an image so that it may be displayed with the tree.
Args:
key: Identifiable name of the image.
img: Image as a Numpy array
"""
if img.ndim not in (2, 3):
raise PlottingError(
"Only 2- and 3-dimensional images are supported."
)
if img.shape[:2] != self.labels.shape:
raise PlottingError(
f"The first two dimensions of the image must have shape "
"{self.labels.shape}."
)
img = (
cv2.resize(img, None, fx=self.downscale, fy=self.downscale) * 255
).astype(np.uint8)
# Immediately convert to mesh
grid = self.create_grid()
grid.point_data["values"] = img.reshape(
np.prod(self.image_dims), -1, order="F"
)
self.images[key] = grid
# Always show the first one
if len(self.images) == 1:
self.shown_images = [key]
[docs]
def cut_tree(self, root: str, time: float) -> List[str]:
"""Cut the tree at a specific time after the time specified by the root.
Args:
root: Root to calculate the time delta from.
time: Time from root to cut the tree at.
Returns:
A list of leaf nodes for the cut tree. These nodes may be internal
nodes.
"""
key = (root, time)
if key in self.cut_tree_cache:
return self.cut_tree_cache[key]
root_time = self.tree.get_time(root)
leaves = []
for n1, n2 in self.tree.breadth_first_traverse_edges(root):
time1 = self.tree.get_time(n1) - root_time
time2 = self.tree.get_time(n2) - root_time
if (time1 <= time and time2 > time) or (
self.tree.is_leaf(n2) and time2 <= time
):
leaves.append(n2)
self.cut_tree_cache[key] = leaves
return leaves
[docs]
def place_nodes(self, root: str, leaves: List[str]) -> np.ndarray:
"""Place a subset of nodes in xyz coordinates.
Only nodes in the path between `root` and each node in `leaves` are
placed.
Args:
root: Root node of the tree. Does not have to be the actual root of
the tree.
leaves: List of nodes at the bottom of the tree. May be internal
nodes.
Returns:
A Numpy array of N x 3, where N is the total number of nodes in the
tree. Nodes that were not placed will have `np.nan` as its three
coordinates. Each row is indexed to each node according to
`self.node_index`.
"""
if not all(
root in self.tree.get_all_ancestors(leaf) for leaf in leaves
):
raise PlottingError(
f"The desired root {root} is not an ancestor of all the leaves "
"{leaves}."
)
key = (root, frozenset(leaves))
if key in self.place_nodes_cache:
return self.place_nodes_cache[key]
coordinates = np.full(self.node_coordinates.shape, np.nan)
leaf_indices = [self.node_index[leaf] for leaf in leaves]
coordinates[leaf_indices] = self.node_coordinates[leaf_indices]
queue = deque(set(self.tree.parent(leaf) for leaf in leaves))
processed = set(leaves)
while queue:
node = queue.popleft()
children = self.tree.children(node)
if node != root:
queue.append(self.tree.parent(node))
if not all(child in processed for child in children):
queue.append(node)
continue
i = self.node_index[node]
children_indices = [self.node_index[child] for child in children]
child_coordinates = coordinates[children_indices]
coordinates[i] = child_coordinates.mean(axis=0)
coordinates[i, 2] = child_coordinates[:, 2].min() - 1
processed.add(node)
self.place_nodes_cache[key] = coordinates
return coordinates
[docs]
def place_branches(
self, root: str, coordinates: np.ndarray
) -> Dict[Tuple[str, str], np.ndarray]:
"""Place a subset of branches in xyz coordinates.
Only nodes in the path between `root` and each node that have valid
coordinates in `coordinates` are placed.
Args:
root: Root node of the tree. Does not have to be the actual root of
the tree.
coordinates: Coordinates of nodes as produced by `self.place_nodes`.
Returns:
Dictionary of branch tuples (node1, node2) as keys and branch
coordinates as a Numpy arrays as values.
"""
branches = {}
for n1, n2 in self.tree.breadth_first_traverse_edges(root):
i1 = self.node_index[n1]
i2 = self.node_index[n2]
if (
np.isnan(coordinates[i1]).any()
or np.isnan(coordinates[i2]).any()
):
continue
branch_coords = interpolate_branch(coordinates[i1], coordinates[i2])
branches[(n1, n2)] = branch_coords
return branches
[docs]
def render_node(self, coords: np.ndarray, radius: float) -> "pv.Sphere":
"""Helper function to create a Pyvista object representing a node.
Args:
coords: XYZ coordinates as a 1-dimensional Numpy array
radius: Radius of sphere
Returns:
Pyvista object representing a node.
"""
coords = coords.copy()
coords[0] *= self.downscale
coords[1] *= self.downscale
coords[2] *= -self.scale * 0.1
coords[2] += self.offset
return pv.Sphere(center=coords, radius=radius)
[docs]
def render_branch(
self, branch_coords: np.ndarray, radius: float
) -> "pv.Tube":
"""Helper function to create a Pyvista object representing a branch.
Args:
coords: XYZ coordinates as a 2-dimensional Numpy array
radius: Radius of tube
Returns:
Pyvista object representing a branch.
"""
coords = branch_coords.copy()
coords[:, 0] *= self.downscale
coords[:, 1] *= self.downscale
coords[:, 2] *= -self.scale * 0.1
coords[:, 2] += self.offset
branch = polyline_from_points(coords)
return branch.tube(radius=radius)
[docs]
def clear_node_actors(self):
"""Clear nodes from visualization."""
for actor in self.node_actors.values():
self.plotter.remove_actor(actor)
self.node_actors = {}
[docs]
def clear_branch_actors(self):
"""Clear branches from visualization."""
for actor in self.branch_actors.values():
self.plotter.remove_actor(actor)
self.branch_actors = {}
[docs]
def clear_subclone_actor(self):
"""Clear subclone shades from visualization."""
if self.subclone_actor is not None:
self.plotter.remove_actor(self.subclone_actor)
self.subclone_actor = None
[docs]
def clear_image_actors(self):
"""Clear images from visualization."""
for actor in self.image_actors.values():
self.plotter.remove_actor(actor)
self.image_actors = {}
[docs]
def set_subclone_sigma(self, sigma: float):
"""Set subclone shade blur strength.
Args:
sigma: Blur strength
"""
if sigma == self.subclone_sigma:
return
self.subclone_sigma = sigma
self.update_subclones()
[docs]
def set_height(self, height):
"""Set the height of the tree.
The height is defined as the number of branches from the root.
Args:
height: Cutoff height as an integer
"""
height = int(height)
times = sorted(
set(
self.times[node]
for node in self.tree.depth_first_traverse_nodes(
source=self.root
)
)
)
if height <= len(times):
self.set_time(times[height])
else:
self.set_time(times[-1])
[docs]
def set_time(self, time: float):
"""Set the time of the tree.
Args:
time: Cutoff time
"""
if time == self.time:
return
self.time = time
leaves = sorted(self.cut_tree(self.root, self.time))
if self.leaves != leaves:
self.leaves = leaves
self.update_branches()
self.update_subclones()
if f"node:{self.selected_node}" in self.node_actors:
self.select_node(self.selected_node)
self.clear_picked_mesh()
else:
self.reset_selected_node()
#self.update_texts()
[docs]
def set_node_picking(self, flag: bool):
"""Helper function to setup node selection.
Args:
flag: True to enable node selection, False otherwise.
"""
self.show_nodes = flag
for actor in self.node_actors.values():
actor.SetVisibility(flag)
if not flag:
self.plotter.remove_actor("_mesh_picking_selection")
[docs]
def set_shown_image(self, key: str, show: bool):
"""Helper function to show an image.
Args:
key: Image key
show: True to show, False otherwise.
"""
update = False
if show:
if key not in self.shown_images:
self.shown_images.append(key)
update = True
else:
if key in self.shown_images:
self.shown_images.remove(key)
update = True
if update:
self.update_images()
[docs]
def set_root(self, root: str):
"""Helper function to set the root of the tree.
Args:
root: Desired root node
"""
self.reset_selected_node()
#self.update_texts()
if root == self.root:
return
self.root = root
self.leaves = sorted(self.cut_tree(self.root, self.time))
self.update_branches()
self.update_subclones()
[docs]
def set_selected_node_as_root(self):
"""Helper function to set the selected node as the root."""
mesh = self.plotter.picked_mesh
if mesh is not None:
node = self.nodes[self.leaf_to_label[mesh.field_data["node"][0]]]
self.set_root(node)
[docs]
def select_node_mesh(self, mesh):
"""Helper function remember the selected node."""
node = self.nodes[self.leaf_to_label[mesh.field_data["node"][0]]]
self.select_node(node)
[docs]
def select_node(self, node: str):
"""Helper function to select a node.
When a node is selected, its children nodes and branches are
highlighted, including subclone shading. All other elements are
dimmed out.
Args:
node: Selected node
"""
self.selected_node = node
#self.update_texts()
reset = node is None
selected = set(
self.tree.depth_first_traverse_nodes(
source=node if not reset else self.root
)
)
for name, actor in self.node_actors.items():
node = name[len("node:") :]
func = highlight if node in selected else lowlight
color = (
func(self.node_colors[node])
if not reset
else self.node_colors[node]
)
actor.GetProperty().SetColor(color)
if node == self.root:
branch_name = "branch:synthetic_root"
branch_actor = self.branch_actors[branch_name]
branch_actor.GetProperty().SetColor(color)
elif node != self.tree.root:
branch_name = f"branch:{self.tree.parent(node)}-{node}"
if branch_name in self.branch_actors:
branch_actor = self.branch_actors[branch_name]
branch_actor.GetProperty().SetColor(color)
key = sha256(";".join(self.leaves).encode("utf-8")).hexdigest()
labels_key = f"labels:{key}"
leaf_labels = np.array(self.subclones.point_data[labels_key]).reshape(
*self.image_dims[:2], order="F"
)
colors = [(1, 1, 1)]
for i, leaf in enumerate(self.leaves):
func = highlight if leaf in selected else lowlight
colors.append(
func(to_rgb(self.cmap[i % len(self.cmap)]))
if not reset
else to_rgb(self.cmap[i % len(self.cmap)])
)
colors = np.pad(np.array(colors), ((0, 0), (0, 1)))
colors[1:, 3] = 1
leaf_colors = colors[leaf_labels]
mask = leaf_colors[:, :, 3] > 0
blur = cv2.GaussianBlur(leaf_colors, (0, 0), sigmaX=self.subclone_sigma)
alpha = cv2.GaussianBlur(
mask.astype(float), (0, 0), sigmaX=self.subclone_sigma
)
alpha -= alpha.min()
alpha /= alpha.max()
blur[:, :, 3] = alpha
self.subclones.point_data["values"] = blur.reshape(
np.prod(self.image_dims), -1, order="F"
)
self.subclones.set_active_scalars("values")
self.subclone_actor = self.plotter.add_mesh(
self.subclones, rgba=True, name="subclones", pickable=False
)
[docs]
def clear_picked_mesh(self):
"""Helper function to clear the selected mesh."""
self.plotter.remove_actor("_mesh_picking_selection")
self.plotter._picked_mesh = None
[docs]
def reset_selected_node(self):
"""Helper function to clear the selected node."""
self.clear_picked_mesh()
self.select_node(None)
#self.update_texts()
[docs]
def update_actors(
self,
actors: Dict[str, "vtk.vtkActor"],
new_actors: Dict[str, "vtk.vtkActor"],
):
"""Helper function to update a set of actors.
Args:
actors: Dictionary of actors that are currently displayed.
new_actors: Dictionary of actors to replace the existing actors
with.
"""
for name, actor in actors.items():
if name not in new_actors:
self.plotter.remove_actor(actor)
actors.clear()
actors.update(new_actors)
[docs]
def update_texts(self):
"""Update displayed text.
The following text is updated.
* The current root
* The current time
* The selected node (if one is selected)
"""
new_actors = {}
# Root
name = "root"
actor = self.plotter.add_text(
f"Root: {self.root}",
position=(0.2, 0.9),
viewport=True,
color="black",
font_size=self.checkbox_size * (1 / 4),
name=name,
)
actor.SetVisibility(self.show_text)
new_actors[name] = actor
# Times
name = "times"
actor = self.plotter.add_text(
f"Time range: {self.tree.get_time(self.root)} - "
f"{min(self.tree.get_time(leaf) for leaf in self.leaves)}",
position=(0.2, 0.87),
viewport=True,
color="black",
font_size=self.checkbox_size * (1 / 4),
name=name,
)
actor.SetVisibility(self.show_text)
new_actors[name] = actor
# Selected node
if self.selected_node is not None:
name = "node"
actor = self.plotter.add_text(
f"Selected: {self.selected_node}",
position=(0.2, 0.84),
viewport=True,
color="black",
font_size=self.checkbox_size * (1 / 4),
name=name,
)
actor.SetVisibility(self.show_text)
new_actors[name] = actor
self.update_actors(self.text_actors, new_actors)
[docs]
def update_images(self):
"""update displayed image(s)."""
new_actors = {}
for i, key in enumerate(self.shown_images):
self.images[key].origin = (0, 0, -i * self.scale * 0.25)
name = f"image:{key}"
actor = self.plotter.add_mesh(
self.images[key], rgba=True, name=name, pickable=False
)
new_actors[name] = actor
self.update_actors(self.image_actors, new_actors)
[docs]
def update_branches(self):
"""Update displayed branches."""
root = self.root
leaves = self.leaves
cmap = self.cmap
coordinates = self.place_nodes(root, leaves)
branches = self.place_branches(root, coordinates)
# NODES
self.node_colors = {
leaf: to_rgb(cmap[i % len(cmap)]) for i, leaf in enumerate(leaves)
}
queue = deque(leaves)
new_actors = {}
while queue:
node = queue.popleft()
if node not in self.node_colors:
children = self.tree.children(node)
if not all(child in self.node_colors for child in children):
queue.append(node)
continue
color = average_mixing(
*[self.node_colors[child] for child in children]
)
self.node_colors[node] = color
else:
color = self.node_colors[node]
name = f"node:{node}"
i = self.node_index[node]
sphere = self.render_node(
coordinates[i],
self.scale
* 0.00175
* np.log2(max(len(self.tree.leaves_in_subtree(node)), 2)),
)
sphere.add_field_data(np.array([i,]), "node")
actor = self.plotter.add_mesh(
sphere,
color=color,
smooth_shading=True,
name=name,
pickable=True,
)
actor.SetVisibility(self.show_nodes)
new_actors[name] = actor
if node != self.root:
queue.append(self.tree.parent(node))
self.update_actors(self.node_actors, new_actors)
# BRANCHES
new_actors = {}
for (n1, n2), branch_coords in branches.items():
branch = self.render_branch(
branch_coords,
self.scale
* 0.001
* np.log2(max(len(self.tree.leaves_in_subtree(n2)), 2)),
)
name = f"branch:{n1}-{n2}"
actor = self.plotter.add_mesh(
branch,
color=self.node_colors[n2],
smooth_shading=True,
name=name,
pickable=False,
)
new_actors[name] = actor
# Synthetic root
root_coords = coordinates[self.node_index[root]]
branch_coords = np.array(
[(root_coords[0], root_coords[1], root_coords[2] - 1), root_coords]
)
branch = self.render_branch(
branch_coords,
self.scale
* 0.001
* np.log2(max(len(self.tree.leaves_in_subtree(root)), 2)),
)
name = "branch:synthetic_root"
actor = self.plotter.add_mesh(
branch,
color=self.node_colors[root],
smooth_shading=True,
name=name,
pickable=False,
)
new_actors[name] = actor
self.update_actors(self.branch_actors, new_actors)
[docs]
def update_subclones(self):
"""Update displayed subclone shading."""
leaves = self.leaves
cmap = self.cmap
key = sha256(";".join(leaves).encode("utf-8")).hexdigest()
labels_key = f"labels:{key}"
colors_key = f"colors:{key}"
if colors_key not in self.subclones.point_data:
colors = np.array(
[to_rgba(cmap[i % len(cmap)]) for i in range(len(leaves))]
)
colors = np.insert(colors, 0, [1, 1, 1, 0], axis=0)
leaf_labels = np.zeros(self.labels.shape, dtype=int)
for i, leaf in enumerate(leaves):
mask = self.get_mask(leaf)
leaf_labels[mask] = i + 1
leaf_labels = cv2.resize(
leaf_labels,
None,
fx=self.downscale,
fy=self.downscale,
interpolation=cv2.INTER_NEAREST,
)
self.subclones.point_data[labels_key] = leaf_labels.flatten(
order="F"
)
leaf_colors = colors[leaf_labels]
self.subclones.point_data[colors_key] = leaf_colors.reshape(
np.prod(self.image_dims), -1, order="F"
)
leaf_colors = np.array(self.subclones.point_data[colors_key]).reshape(
self.image_dims[0], self.image_dims[1], -1, order="F"
)
mask = leaf_colors[:, :, 3] > 0
blur = cv2.GaussianBlur(leaf_colors, (0, 0), sigmaX=self.subclone_sigma)
alpha = cv2.GaussianBlur(
mask.astype(float), (0, 0), sigmaX=self.subclone_sigma
)
alpha -= alpha.min()
alpha /= alpha.max()
blur[:, :, 3] = alpha
self.subclones.point_data["values"] = blur.reshape(
np.prod(self.image_dims), -1, order="F"
)
self.subclones.set_active_scalars("values")
self.subclone_actor = self.plotter.add_mesh(
self.subclones, rgba=True, name="subclones", pickable=False
)
[docs]
def add_blur_slider(self):
"""Add slider to control subclone blur strength."""
self.plotter.add_slider_widget(
self.set_subclone_sigma,
(1, self.scale / 20),
1,
title="Blur",
color="black",
pointa=(0.7, 0.9),
pointb=(0.9, 0.9),
)
[docs]
def add_height_slider(self):
"""Add slider to control tree height."""
self.plotter.add_text_slider_widget(
self.set_height,
["1","2","3","4","5","6","7","8","9"],
2,
color="black",
pointa=(0.45, 0.9),
pointb=(0.65, 0.9),
)
[docs]
def add_time_slider(self):
"""Add slider to control current time."""
self.plotter.add_slider_widget(
self.set_time,
(
self.tree.get_time(self.root),
min(
self.tree.get_max_depth_of_tree()
- self.tree.get_time(self.root),
5,
),
),
self.time,
title="Time",
color="black",
pointa=(0.85, 0.6),
pointb=(0.85, 1.0),
)
[docs]
def add_height_key_events(self):
"""Add key events such that pressing numbers from 1 through 9 controls
the tree height."""
for i in range(1, 10):
self.plotter.add_key_event(str(i), partial(self.set_height, i - 1))
[docs]
def add_image_checkboxes(self):
for i, key in enumerate(self.images):
self.plotter.add_checkbox_button_widget(
partial(self.set_shown_image, key),
value=key in self.shown_images,
position=(10.0, 10.0 + (i + 1) * self.checkbox_size * 1.1),
size=self.checkbox_size,
border_size=self.checkbox_border_size,
color_on="black",
color_off="lightgrey",
background_color="grey",
)
self.plotter.add_text(
f"Show {key}",
position=(
10 + self.checkbox_size * 1.1,
10.0 + (i + 1) * self.checkbox_size * 1.1,
),
color="black",
font_size=self.checkbox_size * (2 / 5),
)
[docs]
def add_node_picking(self):
"""Enable node selection."""
self.plotter.enable_mesh_picking(
self.select_node_mesh,
show=True,
show_message=False,
style="surface",
)
self.plotter.add_key_event("h", self.reset_selected_node)
self.plotter.add_key_event("r", partial(self.set_root, self.tree.root))
self.plotter.add_checkbox_button_widget(
self.set_node_picking,
value=self.show_nodes,
position=(10.0, 10.0),
size=self.checkbox_size,
border_size=self.checkbox_border_size,
color_on="black",
color_off="lightgrey",
background_color="grey",
)
self.plotter.add_text(
"Enable node selection",
position=(10 + self.checkbox_size * 1.1, 10.0),
color="black",
font_size=self.checkbox_size * (2 / 5),
)
self.plotter.add_key_event("s", self.set_selected_node_as_root)
[docs]
def plot(
self,
plot_tree: bool = True,
add_widgets: bool = True,
show: bool = True,
):
"""Display 3D render.
Args:
plot_tree: Immediately render the tree.
If False, the initial plot will not have any tree rendered.
add_widgets: Add widgets to scene.
show: Whether to show the plot immmediately.
"""
self.update_images()
if add_widgets:
self.add_widgets()
self.show_text = True
if plot_tree:
self.update_subclones()
self.update_branches()
#self.update_texts()
self.plotter.set_background("white")
self.plotter.add_axes(viewport=(0, 0.75, 0.2, 0.95))
self.plotter.enable_lightkit()
self.plotter.enable_anti_aliasing()
if show:
self.plotter.show()
[docs]
def reset(self):
"""Helper function to reset everything."""
self.branch_actors = {}
self.subclone_actor = None
self.image_actors = {}
self.shown_images = []
self.plotter.clear()