Source code for ktplotspy.plot.plot_cpdb_heatmap

#!/usr/bin/env python
import numpy as np
import pandas as pd
import seaborn as sns

from itertools import product
from matplotlib.colors import ListedColormap
from typing import Optional, Union, Dict, List

from ktplotspy.utils.support import diverging_palette
from ktplotspy.utils.settings import DEFAULT_V5_COL_START, DEFAULT_COL_START, DEFAULT_CLASS_COL, DEFAULT_CPDB_SEP


[docs]def plot_cpdb_heatmap( pvals: pd.DataFrame, cell_types: Optional[List[str]] = None, degs_analysis: bool = False, log1p_transform: bool = False, alpha: float = 0.05, linewidths: float = 0.5, row_cluster: bool = True, col_cluster: bool = True, low_col: str = "#104e8b", mid_col: str = "#ffdab9", high_col: str = "#8b0a50", cmap: Optional[Union[str, ListedColormap]] = None, title: str = "", return_tables: bool = False, symmetrical: bool = True, default_sep: str = DEFAULT_CPDB_SEP, **kwargs ) -> Union[sns.matrix.ClusterGrid, Dict]: """Plot cellphonedb results as total counts of interactions. Parameters ---------- adata : AnnData `AnnData` object with the `.obs` storing the `celltype_key`. The `.obs_names` must match the first column of the input `meta.txt` used for `cellphonedb`. cell_types : Optional[List[str]], optional List of cell types to include in the heatmap. If `None`, all cell types are included. pvals : pd.DataFrame Dataframe corresponding to `pvalues.txt` or `relevant_interactions.txt` from cellphonedb. degs_analysis : bool, optional Whether `cellphonedb` was run in `deg_analysis` mode. log1p_transform : bool, optional Whether to log1p transform the output. alpha : float, optional P value threshold value for significance. linewidths : float, optional Width of lines between each cell. row_cluster : bool, optional Whether to cluster rows. col_cluster : bool, optional Whether to cluster columns. low_col : str, optional Low colour in gradient. mid_col : str, optional Middle colour in gradient. high_col : str, optional High colour in gradient. cmap : Optional[Union[ListedColormap, str]], optional Built-in matplotlib colormap names or custom `ListedColormap` title : str, optional Plot title. return_tables : bool, optional Whether to return the dataframes storing the interaction network. symmetrical : bool, optional Whether to return the sum of interactions as symmetrical heatmap. default_sep : str, optional The default separator used when CellPhoneDB was run. **kwargs Passed to seaborn.clustermap. Returns ------- Union[sns.matrix.ClusterGrid, Dict] Either heatmap of cellphonedb interactions or dataframe containing the interaction network. """ all_intr = pvals.copy() intr_pairs = all_intr.interacting_pair col_start = ( DEFAULT_V5_COL_START if all_intr.columns[DEFAULT_CLASS_COL] == "classification" else DEFAULT_COL_START ) # in v5, there are 12 columns before the values all_int = all_intr.iloc[:, col_start : all_intr.shape[1]].T all_int.columns = intr_pairs if cell_types is None: cell_types = sorted(list(set([y for z in [x.split(default_sep) for x in all_intr.columns[col_start:]] for y in z]))) cell_types_comb = ["|".join(list(x)) for x in list(product(cell_types, cell_types))] cell_types_keep = [ct for ct in all_int.index if ct in cell_types_comb] empty_celltypes = list(set(cell_types_comb) ^ set(cell_types_keep)) all_int = all_int.loc[cell_types_keep] if len(empty_celltypes) > 0: tmp_ = np.zeros((len(empty_celltypes), all_int.shape[1])) if not degs_analysis: tmp_ += 1 tmp_ = pd.DataFrame(tmp_, index=empty_celltypes, columns=all_int.columns) all_int = pd.concat([all_int, tmp_], axis=0) all_count = all_int.melt(ignore_index=False).reset_index() if degs_analysis: all_count["significant"] = all_count.value == 1 else: all_count["significant"] = all_count.value < alpha count1x = all_count[["index", "significant"]].groupby("index").agg({"significant": "sum"}) tmp = pd.DataFrame([x.split("|") for x in count1x.index]) count_final = pd.concat([tmp, count1x.reset_index(drop=True)], axis=1) count_final.columns = ["SOURCE", "TARGET", "COUNT"] if any(count_final.COUNT > 0): count_mat = count_final.pivot_table(index="SOURCE", columns="TARGET", values="COUNT") count_mat.columns.name, count_mat.index.name = None, None count_mat[pd.isnull(count_mat)] = 0 if symmetrical: count_matx = np.triu(count_mat) + np.tril(count_mat.T) + np.tril(count_mat) + np.triu(count_mat.T) count_matx[np.diag_indices_from(count_matx)] = np.diag(count_mat) count_matx = pd.DataFrame(count_matx) count_matx.columns = count_mat.columns count_matx.index = count_mat.index count_mat = count_matx.copy() if log1p_transform: count_mat = np.log1p(count_mat) if cmap is None: colmap = diverging_palette(low=low_col, medium=mid_col, high=high_col) else: colmap = cmap if not return_tables: g = sns.clustermap( count_mat, row_cluster=row_cluster, col_cluster=col_cluster, linewidths=linewidths, tree_kws={"linewidths": 0}, cmap=colmap, **kwargs ) if title != "": g.fig.suptitle(title) return g else: if symmetrical: all_sum = pd.DataFrame(count_mat.apply(sum, axis=0), columns=["total_interactions"]) else: count_mat = count_mat.T # so that the table output is the same layout as the plot row_sums = pd.DataFrame(count_mat.apply(sum, axis=0), columns=["total_interactions_row"]) col_sums = pd.DataFrame(count_mat.apply(sum, axis=1), columns=["total_interactions_col"]) all_sum = pd.concat([row_sums, col_sums], axis=1) out = {"count_network": count_mat, "interaction_count": all_sum, "interaction_edges": count_final} return out