Source code for ktplotspy.plot.plot_cpdb

#!/usr/bin/env python
import numpy as np
import pandas as pd
import re

from plotnine import (
    aes,
    element_blank,
    element_rect,
    element_text,
    geom_point,
    ggplot,
    ggtitle,
    guide_colourbar,
    guide_legend,
    guides,
    options,
    scale_alpha_continuous,
    scale_colour_continuous,
    scale_colour_manual,
    scale_fill_continuous,
    scale_fill_manual,
    scale_size_continuous,
    theme,
    theme_bw,
)
from typing import List, Literal, Optional, Union, Tuple, Dict

from ktplotspy.utils.settings import (
    DEFAULT_V5_COL_START,
    DEFAULT_COL_START,
    DEFAULT_CLASS_COL,
    DEFAULT_SEP,
    DEFAULT_SPEC_PAT,
    DEFAULT_CELLSIGN_ALPHA,
    DEFAULT_COLUMNS,
)
from ktplotspy.utils.support import (
    ensure_categorical,
    filter_interaction_and_celltype,
    hclust,
    prep_celltype_query,
    prep_query_group,
    prep_table,
    set_x_stroke,
    sub_pattern,
)


[docs]def plot_cpdb( adata: "AnnData", cell_type1: str, cell_type2: str, means: pd.DataFrame, pvals: pd.DataFrame, celltype_key: str, interaction_scores: Optional[pd.DataFrame] = None, cellsign: Optional[pd.DataFrame] = None, degs_analysis: bool = False, splitby_key: Optional[str] = None, alpha: float = 0.05, keep_significant_only: bool = True, genes: Optional[str] = None, gene_family: Optional[Literal["chemokines", "th1", "th2", "th17", "treg", "costimulatory", "coinhibitory"]] = None, custom_gene_family: Optional[Dict[str, List[str]]] = None, standard_scale: bool = True, cluster_rows: bool = True, cmap_name: str = "viridis", max_size: int = 8, max_highlight_size: int = 3, default_style: bool = True, highlight_col: str = "#d62728", highlight_size: Optional[int] = None, special_character_regex_pattern: Optional[str] = None, exclude_interactions: Optional[Union[List, str]] = None, title: str = "", return_table: bool = False, figsize: Tuple[Union[int, float], Union[int, float]] = (6.4, 4.8), min_interaction_score: int = 0, scale_alpha_by_interaction_scores: bool = False, scale_alpha_by_cellsign: bool = False, filter_by_cellsign: bool = False, keep_id_cp_interaction: bool = False, ) -> Union[ggplot, pd.DataFrame]: """Plotting CellPhoneDB results as a dot plot. Parameters ---------- adata : AnnData `AnnData` object with the `.obs` storing the `celltype_key` with or without `splitby_key`. The `.obs_names` must match the first column of the input `meta.txt` used for CellPhoneDB. cell_type1 : str Name of cell type 1. Accepts regex pattern. cell_type2 : str Name of cell type 1. Accepts regex pattern. means : pd.DataFrame Data frame corresponding to `means.txt` from CellPhoneDB. pvals : pd.DataFrame Data frame corresponding to `pvalues.txt` or `relevant_interactions.txt` from CellPhoneDB. celltype_key : str Column name in `adata.obs` storing the celltype annotations. Values in this column should match the second column of the input `meta.txt` used for CellPhoneDB. interaction_scores : Optional[pd.DataFrame], optional Data frame corresponding to `interaction_scores.txt` from CellPhoneDB version 5 onwards. cellsign : Optional[pd.DataFrame], optional Data frame corresponding to `CellSign.txt` from CellPhoneDB version 5 onwards. degs_analysis : bool, optional Whether CellPhoneDB was run in `deg_analysis` mode. splitby_key : Optional[str], optional If provided, will attempt to split the output plot/table by groups. In order for this to work, the second column of the input `meta.txt` used for CellPhoneDB MUST be this format: {splitby}_{celltype}. alpha : float, optional P value threshold value for significance. keep_significant_only : bool, optional Whether or not to trim to significant (p<0.05) hits. genes : Optional[str], optional If provided, will attempt to plot only interactions containing the specified gene(s). gene_family : Optional[Literal["chemokines", "th1", "th2", "th17", "treg", "costimulatory", "coinhibitory"]], optional If provided, will attempt to plot a predetermined set of chemokines or genes associated with Th1, Th2, Th17, Treg, costimulatory or coinhibitory molecules. custom_gene_family : Optional[Dict[str, List[str]]], optional If provided, will update the gene_family dictionary with this custom dictionary. Both `gene_family` (name of the custom family) and `custom_gene_family` (dictionary holding this new family) must be specified for this to work. standard_scale : bool, optional Whether or not to scale the mean interaction values from 0 to 1 per receptor-ligand variable. cluster_rows : bool, optional Whether or not to cluster the rows (interactions). cmap_name : str, optional Matplotlib built-in colormap names. max_size : int, optional Maximum size of points in plot. max_highlight_size : int, optional Maximum highlight size of points in plot. default_style : bool, optional Whether or not to plot in default style or inspired from `squidpy`'s plotting style. highlight_col : str, optional Colour of highlights marking significant hits. highlight_size : Optional[int], optional Size of highlights marking significant hits. special_character_regex_pattern : Optional[str], optional Regex string pattern to perform substitution. This option should not realy be used unless there is really REALLY special characters that you really REALLY want to keep. Rather than using this option, the easiest way is to not your celltypes with weird characters. Just use alpha numeric characters and underscores if necessary. exclude_interactions : Optional[Union[List, str]], optional If provided, the interactions will be removed from the output. title : str, optional Plot title. return_table : bool, optional Whether or not to return the results as a dataframe. figsize : Tuple[Union[int, float], Union[int, float]], optional Figure size. min_interaction_score: int, optional Filtering the interactions shown by including only those above the given interaction score. scale_alpha_by_interaction_scores: bool, optional Whether or not to filter values by the interaction score. scale_alpha_by_cellsign: bool, optional Whether or not to filter the transparency of interactions by the cellsign. filter_by_cellsign: bool, optional Filter out interactions with a 0 value cellsign. keep_id_cp_interaction: bool, optional Whether to keep the original `id_cp_interaction` value when plotting. Returns ------- Union[ggplot, pd.DataFrame] Either a plotnine `ggplot` plot or a pandas `Data frame` holding the results. Raises ------ KeyError If genes and gene_family are both provided, wrong key for gene family provided, or if interaction_score and cellsign are both provided the error will occur. """ if special_character_regex_pattern is None: special_character_regex_pattern = DEFAULT_SPEC_PAT # prepare data metadata = adata.obs.copy() means_mat = prep_table(data=means) pvals_mat = prep_table(data=pvals) col_start = ( DEFAULT_V5_COL_START if pvals_mat.columns[DEFAULT_CLASS_COL] == "classification" else DEFAULT_COL_START ) # in v5, there are 12 columns before the values if pvals_mat.shape != means_mat.shape: tmp_pvals_mat = pd.DataFrame(index=means_mat.index, columns=means_mat.columns) # Copy the values from means_mat to new_df tmp_pvals_mat.iloc[:, :col_start] = means_mat.iloc[:, :col_start] tmp_pvals_mat.update(pvals_mat) if degs_analysis: tmp_pvals_mat.fillna(0, inplace=True) else: tmp_pvals_mat.fillna(1, inplace=True) pvals_mat = tmp_pvals_mat.copy() if (interaction_scores is not None) & (cellsign is not None): raise KeyError("Please specify either interaction scores or cellsign, not both.") if interaction_scores is not None: interaction_scores_mat = prep_table(data=interaction_scores) elif cellsign is not None: cellsign_mat = prep_table(data=cellsign) if degs_analysis: pvals_mat.iloc[:, col_start : pvals_mat.shape[1]] = 1 - pvals_mat.iloc[:, col_start : pvals_mat.shape[1]] # front load the dictionary construction here if col_start == DEFAULT_V5_COL_START: tmp = means_mat.melt(id_vars=means_mat.columns[:col_start]) direc, classif, is_int = {}, {}, {} for _, r in tmp.iterrows(): key = r.id_cp_interaction + DEFAULT_SEP * 3 + r.interacting_pair.replace("_", "-") + DEFAULT_SEP * 3 + r.variable direc[key] = r.directionality classif[key] = r.classification is_int[key] = r.is_integrin # ensure celltypes are ok cell_type1 = sub_pattern(cell_type=cell_type1, pattern=special_character_regex_pattern) cell_type2 = sub_pattern(cell_type=cell_type2, pattern=special_character_regex_pattern) # check for query if genes is None: if gene_family is not None: query_group = prep_query_group(means_mat, custom_gene_family) if isinstance(gene_family, list): query = [] for gf in gene_family: if gf.lower() in query_group: for gfg in query_group[gf.lower()]: query.append(gfg) else: raise KeyError("gene_family needs to be one of the following: {}".format(query_group.keys())) query = list(set(query)) else: if gene_family.lower() in query_group: query = query_group[gene_family.lower()] else: raise KeyError("gene_family needs to be one of the following: {}".format(query_group.keys())) else: query = [i for i in means_mat.interacting_pair if re.search("", i)] elif genes is not None: if gene_family is not None: raise KeyError("Please specify either genes or gene_family, not both.") else: query = [i for i in means_mat.interacting_pair if re.search("|".join(genes), i)] metadata = ensure_categorical(meta=metadata, key=celltype_key) # prepare regex query for celltypes if splitby_key is not None: metadata = ensure_categorical(meta=metadata, key=splitby_key) groups = list(metadata[splitby_key].cat.categories) metadata["_labels"] = [s + "_" + c for s, c in zip(metadata[splitby_key], metadata[celltype_key])] metadata["_labels"] = metadata["_labels"].astype("category") cat_orders = [] for s in metadata[splitby_key].cat.categories: for c in metadata[celltype_key].cat.categories: cat_orders.append(s + "_" + c) cat_orders = [x for x in cat_orders if x in list(metadata._labels)] metadata["_labels"] = metadata["_labels"].cat.reorder_categories(cat_orders) celltype = prep_celltype_query( meta=metadata, cell_type1=cell_type1, cell_type2=cell_type2, pattern=special_character_regex_pattern, split_by=splitby_key, ) else: metadata["_labels"] = metadata[celltype_key] celltype = prep_celltype_query( meta=metadata, cell_type1=cell_type1, cell_type2=cell_type2, pattern=special_character_regex_pattern, split_by=splitby_key, ) cell_type = "|".join(celltype) # keep cell types ct_columns = [ct for ct in means_mat.columns if re.search(cell_type, ct)] # filter means_matx = filter_interaction_and_celltype(data=means_mat, genes=query, celltype_pairs=ct_columns) pvals_matx = filter_interaction_and_celltype(data=pvals_mat, genes=query, celltype_pairs=ct_columns) if interaction_scores is not None: interaction_scores_matx = filter_interaction_and_celltype(data=interaction_scores_mat, genes=query, celltype_pairs=ct_columns) elif cellsign is not None: cellsign_matx = filter_interaction_and_celltype(data=cellsign_mat, genes=query, celltype_pairs=ct_columns) # reorder the columns col_order = [] if splitby_key is not None: for g in groups: for c in means_matx.columns: if re.search(g, c): col_order.append(c) else: col_order = means_matx.columns means_matx = means_matx[col_order] pvals_matx = pvals_matx[col_order] if interaction_scores is not None: interaction_scores_matx = interaction_scores_matx[col_order] elif cellsign is not None: cellsign_matx = cellsign_matx[col_order] # whether or not to filter to only significant hits if keep_significant_only: keep_rows = pvals_matx.apply(lambda r: any(r < alpha), axis=1) keep_rows = [r for r, k in keep_rows.items() if k] if len(keep_rows) > 0: pvals_matx = pvals_matx.loc[keep_rows] means_matx = means_matx.loc[keep_rows] if interaction_scores is not None: interaction_scores_matx = interaction_scores_matx.loc[keep_rows] if cellsign is not None: # cellsign data is actually a subset so let's do keep_rows = [r for r in keep_rows if r in cellsign_matx.index] if len(keep_rows) > 0: cellsign_matx = cellsign_matx.loc[keep_rows] else: raise ValueError("Your cellsign data may not contain significant hits.") # run hierarchical clustering on the rows based on interaction value. if cluster_rows: if means_matx.shape[0] > 2: h_order = hclust(means_matx, axis=0) means_matx = means_matx.loc[h_order] pvals_matx = pvals_matx.loc[h_order] if interaction_scores is not None: interaction_scores_matx = interaction_scores_matx.loc[h_order] elif cellsign is not None: h_order = [h for h in h_order if h in cellsign_matx.index] if len(h_order) > 0: cellsign_matx = cellsign_matx.loc[h_order] else: raise ValueError("Your cellsign data may not contain significant hits.") if standard_scale: means_matx = means_matx.apply(lambda r: (r - np.min(r)) / (np.max(r) - np.min(r)), axis=1) means_matx.fillna(0, inplace=True) # prepare final table colm = "scaled_means" if standard_scale else "means" df = means_matx.melt(ignore_index=False).reset_index() df.index = df["index"] + DEFAULT_SEP * 3 + df["variable"] df.columns = DEFAULT_COLUMNS + [colm] df_pvals = pvals_matx.melt(ignore_index=False).reset_index() df_pvals.index = df_pvals["index"] + DEFAULT_SEP * 3 + df_pvals["variable"] df_pvals.columns = DEFAULT_COLUMNS + ["pvals"] df.celltype_group = [re.sub(DEFAULT_SEP, "-", c) for c in df.celltype_group] df["pvals"] = df_pvals["pvals"] if interaction_scores is not None: df_interaction_scores = interaction_scores_matx.melt(ignore_index=False).reset_index() df_interaction_scores.index = df_interaction_scores["index"] + DEFAULT_SEP * 3 + df_interaction_scores["variable"] df_interaction_scores.columns = DEFAULT_COLUMNS + ["interaction_scores"] df["interaction_scores"] = df_interaction_scores["interaction_scores"] elif cellsign is not None: df_cellsign = cellsign_matx.melt(ignore_index=False).reset_index() df_cellsign.index = df_cellsign["index"] + DEFAULT_SEP * 3 + df_cellsign["variable"] df_cellsign.columns = DEFAULT_COLUMNS + ["cellsign"] # same as above. df["cellsign"] = df_cellsign["cellsign"] # set factors df.celltype_group = df.celltype_group.astype("category") # prepare for non-default style plotting for i in df.index: if df.at[i, colm] == 0: df.at[i, colm] = np.nan df["x_means"] = df[colm] df["y_means"] = df[colm] for i in df.index: if df.at[i, "pvals"] < alpha: df.at[i, "x_means"] = np.nan if df.at[i, "pvals"] == 0: df.at[i, "pvals"] = 0.001 if df.at[i, "pvals"] >= alpha: if keep_significant_only: df.at[i, "y_means"] = np.nan if interaction_scores is not None: if df.at[i, "interaction_scores"] < 1: df.at[i, "x_means"] = np.nan elif cellsign is not None: if df.at[i, "cellsign"] < 1: df.at[i, "cellsign"] = DEFAULT_CELLSIGN_ALPHA df["x_stroke"] = df["x_means"] set_x_stroke(df=df, isnull=False, stroke=0) set_x_stroke(df=df, isnull=True, stroke=highlight_size) if exclude_interactions is not None: if not isinstance(exclude_interactions, list): exclude_interactions = [exclude_interactions] df = df[~df.interaction_group.isin(exclude_interactions)] df["neglog10p"] = abs(-1 * np.log10(df.pvals)) df["neglog10p"] = [0 if x >= 0.05 else j for x, j in zip(df["pvals"], df["neglog10p"])] df["significant"] = ["yes" if x < alpha else np.nan for x in df.pvals] if all(pd.isnull(df["significant"])): df["significant"] = "no" highlight_col = "#FFFFFF" # append the initial data if col_start == DEFAULT_V5_COL_START: df["is_integrin"] = [is_int[i] for i in df.index] df["directionality"] = [direc[i] for i in df.index] df["classification"] = [classif[i] for i in df.index] if df.shape[0] == 0: raise ValueError("No significant results found.") if return_table: return df else: # change the labelling of interaction_group if keep_id_cp_interaction: df.interaction_group = [re.sub(DEFAULT_SEP * 3, "_", c) for c in df.interaction_group] else: df.interaction_group = [c.split(DEFAULT_SEP * 3)[1] for c in df.interaction_group] # set global figure size options.figure_size = figsize if highlight_size is not None: max_highlight_size = highlight_size stroke = "x_stroke" else: stroke = "neglog10p" # plotting if interaction_scores is not None: df = df[df.interaction_scores >= min_interaction_score] if scale_alpha_by_interaction_scores: if default_style: g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour="significant", fill=colm, size=colm, stroke=stroke, alpha="interaction_scores", ), ) else: if all(df["significant"] == "no"): g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour="significant", fill=colm, size=colm, stroke=stroke, alpha="interaction_scores", ), ) default_style = True else: highlight_col = "#FFFFFF" # enforce this g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour=colm, fill="significant", size=colm, stroke=stroke, alpha="interaction_scores", ), ) else: g = None else: if cellsign is not None: if filter_by_cellsign: df = df[df.cellsign >= DEFAULT_CELLSIGN_ALPHA] if scale_alpha_by_cellsign: if default_style: g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour="significant", fill=colm, size=colm, stroke=stroke, alpha="cellsign", ), ) else: if all(df["significant"] == "no"): g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour="significant", fill=colm, size=colm, stroke=stroke, alpha="cellsign", ), ) default_style = True else: highlight_col = "#FFFFFF" # enforce this g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour=colm, fill="significant", size=colm, stroke=stroke, alpha="cellsign", ), ) else: g = None else: g = None if g is None: if default_style: g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour="significant", fill=colm, size=colm, stroke=stroke, ), ) else: if all(df["significant"] == "no"): g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour="significant", fill=colm, size=colm, stroke=stroke, ), ) default_style = True else: highlight_col = "#FFFFFF" # enforce this g = ggplot( df, aes( x="celltype_group", y="interaction_group", colour=colm, fill="significant", size=colm, stroke=stroke, ), ) g = ( g + geom_point( na_rm=True, ) + theme_bw() + theme( axis_text_x=element_text(angle=90, hjust=0, colour="#000000"), axis_text_y=element_text(colour="#000000"), axis_ticks=element_blank(), axis_title_x=element_blank(), axis_title_y=element_blank(), legend_key=element_rect(alpha=0, width=0, height=0), legend_direction="vertical", legend_box="horizontal", ) + scale_size_continuous(range=(0, max_size), aesthetics=["size"]) + scale_size_continuous(range=(0, max_highlight_size), aesthetics=["stroke"]) ) if default_style: g = ( g + scale_colour_manual(values=highlight_col, na_translate=False) + guides( fill=guide_colourbar(barwidth=4, label=True, ticks=True, draw_ulim=True, draw_llim=True, order=1), size=guide_legend( reverse=True, order=2, ), stroke=guide_legend( reverse=True, order=3, ), ) + scale_fill_continuous(cmap_name=cmap_name) ) else: g = ( g + scale_fill_manual(values=highlight_col, na_translate=False) + guides( colour=guide_colourbar(barwidth=4, label=True, ticks=True, draw_ulim=True, draw_llim=True, order=1), size=guide_legend( reverse=True, order=2, ), stroke=guide_legend( reverse=True, order=3, ), ) ) df2 = df.copy() for i in df2.index: if df2.at[i, "pvals"] < alpha: df2.at[i, colm] = np.nan g = ( g + geom_point(aes(x="celltype_group", y="interaction_group", colour=colm, size=colm), df2, inherit_aes=False, na_rm=True) + scale_colour_continuous(cmap_name=cmap_name) ) if highlight_size is not None: g = g + guides(stroke=None) if (interaction_scores is not None) and scale_alpha_by_interaction_scores: g = g + scale_alpha_continuous(breaks=(0, 25, 50, 75, 100)) if (cellsign is not None) and scale_alpha_by_cellsign: g = g + scale_alpha_continuous(breaks=(0, 1)) if title != "": g = g + ggtitle(title) elif gene_family is not None: if isinstance(gene_family, list): gene_family = ", ".join(gene_family) g = g + ggtitle(gene_family) return g