#!/usr/bin/env python
import re
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import defaultdict
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap
from pycircos import Garc, Gcircle
from typing import Optional, Tuple, Dict, Union
from ktplotspy.utils.settings import DEFAULT_SEP # DEFAULT_PAL
from ktplotspy.utils.support import celltype_fraction, celltype_means, find_complex, flatten, generate_df, present
from ktplotspy.plot import plot_cpdb
[docs]def plot_cpdb_chord(
adata: "AnnData",
means: pd.DataFrame,
pvals: pd.DataFrame,
deconvoluted: pd.DataFrame,
celltype_key: str,
face_col_dict: Optional[Dict[str, str]] = None,
edge_col_dict: Optional[Dict[str, str]] = None,
edge_cmap: LinearSegmentedColormap = plt.cm.nipy_spectral,
remove_self: bool = True,
gap: Union[int, float] = 2,
scale_lw: Union[int, float] = 10,
size: Union[int, float] = 50,
interspace: Union[int, float] = 2,
raxis_range: Tuple[int, int] = (950, 1000),
labelposition: Union[int, float] = 80,
label_visible: bool = True,
figsize: Tuple[Union[int, float], Union[int, float]] = (8, 8),
legend_params: Dict = {"loc": "center left", "bbox_to_anchor": (1, 1), "frameon": False},
layer: Optional[str] = None,
**kwargs
) -> Gcircle:
"""Plotting cellphonedb results as a chord diagram.
Parameters
----------
adata : AnnData
`AnnData` object with the `.obs` storing the `celltype_key` with or without `splitby_key`.
The `.obs_names` must match the first column of the input `meta.txt` used for `cellphonedb`.
means : pd.DataFrame
Dataframe corresponding to `means.txt` from cellphonedb.
pvals : pd.DataFrame
Dataframe corresponding to `pvalues.txt` or `relevant_interactions.txt` from cellphonedb.
deconvoluted : pd.DataFrame
Dataframe corresponding to `deconvoluted.txt` from cellphonedb.
celltype_key : str
Column name in `adata.obs` storing the celltype annotations.
Values in this column should match the second column of the input `meta.txt` used for `cellphonedb`.
face_col_dict : Optional[Dict[str, str]], optional
dictionary of celltype : face colours.
If not provided, will try and use `.uns` from `adata` if correct slot is present.
edge_col_dict : Optional[Dict[str, str]], optional
Dictionary of interactions : edge colours. Otherwise, will use edge_cmap option.
edge_cmap : LinearSegmentedColormap, optional
a `LinearSegmentedColormap` to generate edge colors.
remove_self : bool, optional
whether to remove self edges.
gap : Union[int, float], optional
relative size of gaps between edges on arc.
scale_lw : Union[int, float], optional
numeric value to scale width of lines.
size : Union[int, float], optional
Width of the arc section. If record is provided, the value is
instead set by the sequence length of the record. In reality
the actual arc section width in the resultant circle is determined
by the ratio of size to the combined sum of the size and interspace
values of the Garc class objects in the Gcircle class object.
interspace : Union[int, float], optional
Distance angle (deg) to the adjacent arc section in clockwise
sequence. The actual interspace size in the circle is determined by
the actual arc section width in the resultant circle is determined
by the ratio of size to the combined sum of the size and interspace
values of the Garc class objects in the Gcircle class object.
raxis_range : Tuple[int, int], optional
Radial axis range where line plot is drawn.
labelposition : Union[int, float], optional
Relative label height from the center of the arc section.
label_visible : bool, optional
Font size of the label. The default is 10.
figsize : Tuple[Union[int, float], Union[int, float]], optional
size of figure.
legend_params : Dict, optional
additional arguments for `plt.legend`.
layer : Optional[str], optional
slot in `AnnData.layers` to access. If `None`, uses `.X`.
**kwargs
passed to `plot_cpdb`.
Returns
-------
Gcircle
a `Gcircle` object from `pycircos`.
"""
# assert splitby = False
splitby_key, return_table = None, True
# run plot_cpdb
lr_interactions = plot_cpdb(
adata=adata,
means=means,
pvals=pvals,
celltype_key=celltype_key,
return_table=return_table,
splitby_key=splitby_key,
**kwargs,
)
# do some name wrangling
subset_clusters = list(set(flatten([x.split("-") for x in lr_interactions.celltype_group])))
adata_subset = adata[adata.obs[celltype_key].isin(subset_clusters)].copy()
interactions = means[
["id_cp_interaction", "interacting_pair", "gene_a", "gene_b", "partner_a", "partner_b", "receptor_a", "receptor_b"]
].copy()
interactions["use_interaction_name"] = [
x + DEFAULT_SEP * 3 + y for x, y in zip(interactions.id_cp_interaction, interactions.interacting_pair)
]
# interactions["converted"] = [re.sub("-", " ", x) for x in interactions.use_interaction_name]
interactions["converted"] = [re.sub("_", "-", x) for x in interactions.use_interaction_name]
lr_interactions["barcode"] = [a + DEFAULT_SEP + b for a, b in zip(lr_interactions.celltype_group, lr_interactions.interaction_group)]
interactions_subset = interactions[interactions["converted"].isin(list(lr_interactions.interaction_group))].copy()
# handle complexes gently
tm0 = {kx: rx.split("_") for kx, rx in interactions_subset.use_interaction_name.items()}
if any([len(x) > 2 for x in tm0.values()]):
complex_id, simple_id = [], []
for i, j in tm0.items():
if len(j) > 2:
complex_id.append(i)
elif len(j) == 2:
simple_id.append(i)
_interactions_subset = interactions_subset.loc[complex_id].copy()
_interactions_subset_simp = interactions_subset.loc[simple_id].copy()
complex_idx1 = [i for i, j in _interactions_subset.partner_b.items() if re.search("complex:", j)]
complex_idx2 = [i for i, j in _interactions_subset.partner_a.items() if re.search("complex:", j)]
# complex_idx
simple_1 = list(_interactions_subset.loc[complex_idx1, "interacting_pair"])
simple_2 = list(_interactions_subset.loc[complex_idx2, "interacting_pair"])
partner_1 = [re.sub("complex:", "", b) for b in _interactions_subset.loc[complex_idx1, "partner_b"]]
partner_2 = [re.sub("complex:", "", a) for a in _interactions_subset.loc[complex_idx2, "partner_a"]]
for i, _ in enumerate(simple_1):
simple_1[i] = re.sub(partner_1[i] + "_|_" + partner_1[i], "", simple_1[i])
for i, _ in enumerate(simple_2):
simple_2[i] = re.sub(partner_2[i] + "_|_" + partner_2[i], "", simple_2[i])
tmpdf = pd.concat([pd.DataFrame(zip(simple_1, partner_1)), pd.DataFrame(zip(partner_2, simple_2))])
tmpdf.index = complex_id
tmpdf.columns = ["id_a", "id_b"]
_interactions_subset = pd.concat([_interactions_subset, tmpdf], axis=1)
simple_tm0 = pd.DataFrame(
[rx.split("_") for rx in _interactions_subset_simp.interacting_pair],
columns=["id_a", "id_b"],
index=_interactions_subset_simp.index,
)
_interactions_subset_simp = pd.concat([_interactions_subset_simp, simple_tm0], axis=1)
interactions_subset = pd.concat([_interactions_subset_simp, _interactions_subset], axis=0)
else:
tm0 = pd.DataFrame(tm0).T
tm0.columns = ["id_a", "id_b"]
tm0.id_a = [x.split(DEFAULT_SEP * 3)[1] for x in tm0.id_a]
interactions_subset = pd.concat([interactions_subset, tm0], axis=1)
# keep only useful genes
geneid = list(set(list(interactions_subset.id_a) + list(interactions_subset.id_b)))
if not all([g in adata_subset.var.index for g in geneid]):
geneid = list(set(list(interactions_subset.gene_a) + list(interactions_subset.gene_b)))
# create a subet anndata
adata_subset_tmp = adata_subset[:, adata_subset.var_names.isin(geneid)].copy()
meta = adata_subset_tmp.obs.copy()
adata_list, adata_list_alt = {}, {}
for x in list(set(meta[celltype_key])):
adata_list[x] = adata_subset_tmp[adata_subset_tmp.obs[celltype_key] == x].copy()
adata_list_alt[x] = adata_subset[adata_subset.obs[celltype_key] == x].copy()
# create expression and fraction dataframes.
adata_list2, adata_list3 = {}, {}
for x in adata_list:
adata_list2[x] = celltype_means(adata_list[x], layer)
adata_list3[x] = celltype_fraction(adata_list[x], layer)
adata_list2 = pd.DataFrame(adata_list2, index=adata_subset_tmp.var_names)
adata_list3 = pd.DataFrame(adata_list3, index=adata_subset_tmp.var_names)
decon_subset = deconvoluted[deconvoluted.complex_name.isin(find_complex(interactions_subset))].copy()
# if any interactions are actually complexes, extract them from the deconvoluted dataframe.
if decon_subset.shape[0] > 0:
decon_subset_expr = decon_subset.groupby("complex_name").apply(lambda r: r[adata_list2.columns].apply(np.mean, axis=0))
cellfrac = defaultdict(dict)
zgenes = list(set(decon_subset_expr.index))
for ct, adat in adata_list_alt.items():
for zg in zgenes:
cellfrac[ct][zg] = np.mean(adat[:, adata.var_names.isin(zg.split("_"))].X > 0)
decon_subset_fraction = pd.DataFrame(cellfrac)
expr_df = pd.concat([adata_list2, decon_subset_expr])
fraction_df = pd.concat([adata_list3, decon_subset_fraction])
else:
expr_df = adata_list2.copy()
fraction_df = adata_list3.copy()
# create edge list
cells_test = list(set(meta[celltype_key]))
cell_comb = []
for c1 in cells_test:
for c2 in cells_test:
if remove_self:
if c1 != c2:
cell_comb.append((c1, c2))
else:
cell_comb.append((c1, c2))
cell_comb = list(set(cell_comb))
cell_type_grid = pd.DataFrame(cell_comb, columns=["source", "target"])
# create the final dataframe for plotting
dfx = generate_df(
interactions_subset=interactions_subset,
cell_type_grid=cell_type_grid,
cell_type_means=expr_df,
cell_type_fractions=fraction_df,
sep=DEFAULT_SEP,
)
# ok form the table for pyCircos
int_value = dict(zip(lr_interactions.barcode, lr_interactions.y_means))
int_value = {k: r for k, r in int_value.items() if pd.notnull(r)}
dfx["interaction_value"] = [int_value[y] if y in int_value else np.nan for y in dfx["barcode"]]
tmpdf = dfx[["producer", "receiver", "converted_pair", "interaction_value"]].copy()
tmpdf["interaction_celltype"] = [
DEFAULT_SEP.join(sorted([a, b, c])) for a, b, c in zip(tmpdf.producer, tmpdf.receiver, tmpdf.converted_pair)
]
celltypes = sorted(list(set(list(tmpdf.producer) + list(tmpdf.receiver))))
celltype_start_dict = {r: k * gap for k, r in enumerate(celltypes)}
celltype_end_dict = {r: k + gap for k, r in enumerate(celltypes)}
interactions = sorted(list(set(tmpdf["interaction_celltype"])))
interaction_start_dict = {r: k * gap for k, r in enumerate(interactions)}
interaction_end_dict = {r: k + gap for k, r in enumerate(interactions)}
tmpdf["from"] = [celltype_start_dict[x] for x in tmpdf.producer]
tmpdf["to"] = [celltype_end_dict[x] for x in tmpdf.receiver]
tmpdf["interaction_value"] = [
j * scale_lw + interaction_start_dict[x] if pd.notnull(j) else np.nan
for j, x in zip(tmpdf.interaction_value, tmpdf.interaction_celltype)
]
tmpdf["start"] = round(tmpdf["interaction_value"] + tmpdf["from"])
tmpdf["end"] = round(tmpdf["interaction_value"] + tmpdf["to"])
if edge_col_dict is None:
uni_interactions = list(set(tmpdf.converted_pair))
col_step = 1 / len(uni_interactions)
start_step = 0
edge_col_dict = {}
for i in uni_interactions:
edge_col_dict[i] = edge_cmap(start_step)
start_step += col_step
circle = Gcircle(figsize=figsize)
if face_col_dict is None:
if celltype_key + "_colors" in adata.uns:
if adata.obs[celltype_key].dtype.name == "category":
face_col_dict = dict(zip(adata.obs[celltype_key].cat.categories, adata.uns[celltype_key + "_colors"]))
else:
face_col_dict = dict(zip(list(set(adata.obs[celltype_key])), adata.uns[celltype_key + "_colors"]))
for i, j in tmpdf.iterrows():
name = j["producer"]
if face_col_dict is None:
col = None
else:
# col = face_col_dict[name] if name in face_col_dict else next(DEFAULT_PAL) # cycle through the default palette
col = face_col_dict[name] if name in face_col_dict else "#e7e7e7" # or just make them grey?
arc = Garc(
arc_id=name,
size=size,
interspace=interspace,
raxis_range=raxis_range,
labelposition=labelposition,
label_visible=label_visible,
facecolor=col,
)
circle.add_garc(arc)
circle.set_garcs(-180, 180)
for i, j in tmpdf.iterrows():
if pd.notnull(j["interaction_value"]):
lr = j["converted_pair"]
start_size = j["start"] + j["interaction_value"] / scale_lw
end_size = j["end"] + j["interaction_value"] / scale_lw
start_size = 1 if start_size < 1 else start_size
end_size = 1 if end_size < 1 else end_size
source = (j["producer"], j["start"] - 1, start_size, raxis_range[0] - size)
destination = (j["receiver"], j["end"] - 1, end_size, raxis_range[0] - size)
circle.chord_plot(source, destination, edge_col_dict[lr] if lr in edge_col_dict else "#f7f7f700")
custom_lines = [Line2D([0], [0], color=val, lw=4) for val in edge_col_dict.values()]
circle.figure.legend(custom_lines, edge_col_dict.keys(), **legend_params)
return circle