#!/usr/bin/env python
from itertools import product
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import ListedColormap
from ktplotspy.utils.settings import DEFAULT_CLASS_COL, DEFAULT_COL_START, DEFAULT_CPDB_SEP, DEFAULT_V5_COL_START
from ktplotspy.utils.support import diverging_palette
[docs]
def plot_cpdb_heatmap(
pvals: pd.DataFrame,
cell_types: list[str] | None = None,
degs_analysis: bool = False,
log1p_transform: bool = False,
alpha: float = 0.05,
linewidths: float = 0.5,
row_cluster: bool = True,
col_cluster: bool = True,
low_col: str = "#104e8b",
mid_col: str = "#ffdab9",
high_col: str = "#8b0a50",
cmap: str | ListedColormap | None = None,
title: str = "",
return_tables: bool = False,
symmetrical: bool = True,
default_sep: str = DEFAULT_CPDB_SEP,
**kwargs,
) -> sns.matrix.ClusterGrid | dict:
"""Plot cellphonedb results as total counts of interactions.
Parameters
----------
adata : AnnData
`AnnData` object with the `.obs` storing the `celltype_key`.
The `.obs_names` must match the first column of the input `meta.txt` used for `cellphonedb`.
cell_types : list[str] | None, optional
List of cell types to include in the heatmap. If `None`, all cell types are included.
pvals : pd.DataFrame
Dataframe corresponding to `pvalues.txt` or `relevant_interactions.txt` from cellphonedb.
degs_analysis : bool, optional
Whether `cellphonedb` was run in `deg_analysis` mode.
log1p_transform : bool, optional
Whether to log1p transform the output.
alpha : float, optional
P value threshold value for significance.
linewidths : float, optional
Width of lines between each cell.
row_cluster : bool, optional
Whether to cluster rows.
col_cluster : bool, optional
Whether to cluster columns.
low_col : str, optional
Low colour in gradient.
mid_col : str, optional
Middle colour in gradient.
high_col : str, optional
High colour in gradient.
cmap : Optional[Union[ListedColormap, str]], optional
Built-in matplotlib colormap names or custom `ListedColormap`
title : str, optional
Plot title.
return_tables : bool, optional
Whether to return the dataframes storing the interaction network.
symmetrical : bool, optional
Whether to return the sum of interactions as symmetrical heatmap.
default_sep : str, optional
The default separator used when CellPhoneDB was run.
**kwargs
Passed to seaborn.clustermap.
Returns
-------
sns.matrix.ClusterGrid | dict
Either heatmap of cellphonedb interactions or dataframe containing the interaction network.
"""
all_intr = pvals.copy()
intr_pairs = all_intr.interacting_pair
col_start = (
DEFAULT_V5_COL_START if all_intr.columns[DEFAULT_CLASS_COL] == "classification" else DEFAULT_COL_START
) # in v5, there are 12 columns before the values
all_int = all_intr.iloc[:, col_start : all_intr.shape[1]].T
all_int.columns = intr_pairs
if cell_types is None:
cell_types = sorted(list(set([y for z in [x.split(default_sep) for x in all_intr.columns[col_start:]] for y in z])))
cell_types_comb = [f"{default_sep}".join(list(x)) for x in list(product(cell_types, cell_types))]
cell_types_keep = [ct for ct in all_int.index if ct in cell_types_comb]
empty_celltypes = list(set(cell_types_comb) ^ set(cell_types_keep))
all_int = all_int.loc[cell_types_keep]
if len(empty_celltypes) > 0:
tmp_ = np.zeros((len(empty_celltypes), all_int.shape[1]))
if not degs_analysis:
tmp_ += 1
tmp_ = pd.DataFrame(tmp_, index=empty_celltypes, columns=all_int.columns)
all_int = pd.concat([all_int, tmp_], axis=0)
all_count = all_int.melt(ignore_index=False).reset_index()
if degs_analysis:
all_count["significant"] = all_count.value == 1
else:
all_count["significant"] = all_count.value < alpha
count1x = all_count[["index", "significant"]].groupby("index").agg({"significant": "sum"})
tmp = pd.DataFrame([x.split(f"{default_sep}") for x in count1x.index])
count_final = pd.concat([tmp, count1x.reset_index(drop=True)], axis=1)
count_final.columns = ["SOURCE", "TARGET", "COUNT"]
if any(count_final.COUNT > 0):
count_mat = count_final.pivot_table(index="SOURCE", columns="TARGET", values="COUNT")
count_mat.columns.name, count_mat.index.name = None, None
count_mat[pd.isnull(count_mat)] = 0
if symmetrical:
count_matx = np.triu(count_mat) + np.tril(count_mat.T) + np.tril(count_mat) + np.triu(count_mat.T)
count_matx[np.diag_indices_from(count_matx)] = np.diag(count_mat)
count_matx = pd.DataFrame(count_matx)
count_matx.columns = count_mat.columns
count_matx.index = count_mat.index
count_mat = count_matx.copy()
if log1p_transform:
count_mat = np.log1p(count_mat)
if cmap is None:
colmap = diverging_palette(low=low_col, medium=mid_col, high=high_col)
else:
colmap = cmap
if not return_tables:
g = sns.clustermap(
count_mat,
row_cluster=row_cluster,
col_cluster=col_cluster,
linewidths=linewidths,
tree_kws={"linewidths": 0},
cmap=colmap,
**kwargs,
)
if title != "":
g.fig.suptitle(title)
return g
else:
if symmetrical:
all_sum = pd.DataFrame(count_mat.apply(sum, axis=0), columns=["total_interactions"])
else:
count_mat = count_mat.T # so that the table output is the same layout as the plot
row_sums = pd.DataFrame(count_mat.apply(sum, axis=0), columns=["total_interactions_row"])
col_sums = pd.DataFrame(count_mat.apply(sum, axis=1), columns=["total_interactions_col"])
all_sum = pd.concat([row_sums, col_sums], axis=1)
out = {"count_network": count_mat, "interaction_count": all_sum, "interaction_edges": count_final}
return out