Source code for ktplotspy.plot.plot_cpdb_chord

#!/usr/bin/env python
import re

import numpy as np
import pandas as pd

from collections import defaultdict
from itertools import product
from matplotlib.lines import Line2D
from pycirclize import Circos

from ktplotspy.plot import plot_cpdb
from ktplotspy.utils.settings import DEFAULT_PAL, DEFAULT_PAL_CYCLER, DEFAULT_SEP
from ktplotspy.utils.support import celltype_fraction, celltype_means, find_complex, flatten, generate_df


[docs] def plot_cpdb_chord( adata: "AnnData", means: pd.DataFrame, pvals: pd.DataFrame, deconvoluted: pd.DataFrame, celltype_key: str, interaction: str | list[str] | None = None, cell_type1: str | None = None, cell_type2: str | None = None, keep_celltypes: list[str] | None = None, remove_self: bool = True, layer: str | None = None, sector_colors: dict[str, str] | None = None, sector_text_kwargs: dict = {"color": "black", "size": 8, "r": 105, "adjust_rotation": True}, sector_radius_limit: tuple[float, float] = (95, 100), equal_sector_size: bool = False, same_producer_colors: bool = False, link_colors: str | dict[str, str] | None = None, link_offset: float = 0, link_kwargs: dict = {"direction": 1, "allow_twist": True, "r1": 95, "r2": 90}, legend_width: float = 2, legend_kwargs: dict = {"loc": "center", "bbox_to_anchor": (1, 1), "fontsize": 8}, **plot_cpdb_kwargs, ): """Plotting cellphonedb results as a chord diagram. Parameters ---------- adata : AnnData `AnnData` object with the `.obs` storing the `celltype_key` with or without `splitby_key`. The `.obs_names` must match the first column of the input `meta.txt` used for `cellphonedb`. means : pd.DataFrame Dataframe corresponding to `means.txt` from cellphonedb. pvals : pd.DataFrame Dataframe corresponding to `pvalues.txt` or `relevant_interactions.txt` from cellphonedb. deconvoluted : pd.DataFrame Dataframe corresponding to `deconvoluted.txt` from cellphonedb. celltype_key : str Column name in `adata.obs` storing the celltype annotations. Values in this column should match the second column of the input `meta.txt` used for `cellphonedb`. interaction : str | list[str] | None, optional Interaction(s) to plot. If None, all interactions will be plotted. If provided as a '-' separated string between L-R pair, only the specific interaction(s) will be plotted (more than 1 pair is accepted). If provided as a list of strings without '-' separator, all relevant interactions in the list will be plotted. cell_type1 : str | None, optional Name of cell type 1. Accepts regex pattern. None defaults to "." i.e. all cell types. If both cell_type1 and cell_type2 are provided, only interactions contained in these cell types will be plotted. cell_type2 : str | None, optional Name of cell type 2. Accepts regex pattern. None defaults to "." i.e. all cell types. If both cell_type1 and cell_type2 are provided, only interactions contained in these cell types will be plotted. keep_celltypes : list[str] | None, optional If provided, only interactions to/from these celltypes will be plotted. remove_self : bool, optional whether to remove self edges. layer : str | None, optional slot in `AnnData.layers` to access. If `None`, uses `.X`. sector_colors : dict[str, str] | None, optional dictionary of celltype (sector) colours. If not provided, will try and use `.uns` from `adata` if correct slot is present. sector_text_kwargs : dict, optional additional arguments for `track.text`. See https://moshi4.github.io/pyCirclize/api-docs/sector/#pycirclize.sector.Sector.text sector_radius_limit : tuple[float, float], optional radius limits for the sectors. equal_sector_size : bool, optional whether to use equal sector sizes. same_producer_colors : bool, optional whether to use the same colours for sector and links for outgoing interactions. link_colors : dict[str, str] | None, optional String or dictionary of L-R interaction colours. If not provided, will use a default colour, will use a random colour for each unique interaction. link_offset : float, optional offset to add to each interaction value. link_kwargs : dict, optional additional arguments for `circos.link`. See https://moshi4.github.io/pyCirclize/api-docs/circos/#pycirclize.circos.Circos.link legend_width : float, optional width of the legend line. legend_kwargs : dict, optional additional arguments for `Axes.Legend`. See https://matplotlib.org/stable/tutorials/intermediate/legend_guide.html **plot_cpdb_kwargs passed to `plot_cpdb`. Returns ------- Gcircle a `Gcircle` object from `pycircos`. """ # assert splitby = False splitby_key, return_table = None, True if interaction is not None: if isinstance(interaction, str): # check if there's '-' in the interaction if "-" in interaction: lr_intx = interaction.split("-") else: lr_intx = [interaction] elif isinstance(interaction, list): lr_intx = [] for intx in interaction: if "-" in intx: tmp_lr_intx = intx.split("-") for tmp_intx in tmp_lr_intx: lr_intx.append(tmp_intx) else: lr_intx.append(intx) else: lr_intx = None cell_type1 = cell_type1 if cell_type1 is not None else "." cell_type2 = cell_type2 if cell_type2 is not None else "." # run plot_cpdb lr_interactions = plot_cpdb( adata=adata, means=means, pvals=pvals, genes=lr_intx, celltype_key=celltype_key, return_table=return_table, splitby_key=splitby_key, cell_type1=cell_type1, cell_type2=cell_type2, **plot_cpdb_kwargs, ) # do some name wrangling subset_clusters = list(set(flatten([x.split("-") for x in lr_interactions.celltype_group]))) adata_subset = adata[adata.obs[celltype_key].isin(subset_clusters)].copy() interactions = means[ ["id_cp_interaction", "interacting_pair", "gene_a", "gene_b", "partner_a", "partner_b", "receptor_a", "receptor_b"] ].copy() interactions["use_interaction_name"] = [ x + DEFAULT_SEP * 3 + y for x, y in zip(interactions.id_cp_interaction, interactions.interacting_pair) ] interactions["converted"] = [re.sub("_", "-", x) for x in interactions.use_interaction_name] lr_interactions["barcode"] = [a + DEFAULT_SEP + b for a, b in zip(lr_interactions.celltype_group, lr_interactions.interaction_group)] interactions_subset = interactions[interactions["converted"].isin(list(lr_interactions.interaction_group))].copy() # handle complexes gently tm0 = {kx: rx.split("_") for kx, rx in interactions_subset.use_interaction_name.items()} if any([len(x) > 2 for x in tm0.values()]): complex_id, simple_id = [], [] for i, j in tm0.items(): if len(j) > 2: complex_id.append(i) elif len(j) == 2: simple_id.append(i) _interactions_subset = interactions_subset.loc[complex_id].copy() _interactions_subset_simp = interactions_subset.loc[simple_id].copy() complex_idx1 = [i for i, j in _interactions_subset.partner_b.items() if re.search("complex:", j)] complex_idx2 = [i for i, j in _interactions_subset.partner_a.items() if re.search("complex:", j) and i not in complex_idx1] # complex_idx simple_1 = list(_interactions_subset.loc[complex_idx1, "interacting_pair"]) simple_2 = list(_interactions_subset.loc[complex_idx2, "interacting_pair"]) partner_1 = [re.sub("complex:", "", b) for b in _interactions_subset.loc[complex_idx1, "partner_b"]] partner_2 = [re.sub("complex:", "", a) for a in _interactions_subset.loc[complex_idx2, "partner_a"]] for i, _ in enumerate(simple_1): simple_1[i] = re.sub(partner_1[i] + "_|_" + partner_1[i], "", simple_1[i]) for i, _ in enumerate(simple_2): simple_2[i] = re.sub(partner_2[i] + "_|_" + partner_2[i], "", simple_2[i]) tmpdf = pd.concat([pd.DataFrame(zip(simple_1, partner_1)), pd.DataFrame(zip(partner_2, simple_2))]) tmpdf.index = complex_idx1 + complex_idx2 tmpdf = tmpdf.sort_index() tmpdf.columns = ["id_a", "id_b"] _interactions_subset = pd.concat([_interactions_subset, tmpdf], axis=1) simple_tm0 = pd.DataFrame( [rx.split("_") for rx in _interactions_subset_simp.interacting_pair], columns=["id_a", "id_b"], index=_interactions_subset_simp.index, ) _interactions_subset_simp = pd.concat([_interactions_subset_simp, simple_tm0], axis=1) interactions_subset = pd.concat([_interactions_subset_simp, _interactions_subset], axis=0) else: tm0 = pd.DataFrame(tm0).T tm0.columns = ["id_a", "id_b"] tm0.id_a = [x.split(DEFAULT_SEP * 3)[1] for x in tm0.id_a] interactions_subset = pd.concat([interactions_subset, tm0], axis=1) # keep only useful genes geneid = list(set(list(interactions_subset.id_a) + list(interactions_subset.id_b))) if not all([g in adata_subset.var.index for g in geneid]): geneid = list(set(list(interactions_subset.gene_a) + list(interactions_subset.gene_b))) # create a subet anndata adata_subset_tmp = adata_subset[:, adata_subset.var_names.isin(geneid)].copy() meta = adata_subset_tmp.obs.copy() adata_list, adata_list_alt = {}, {} for x in list(set(meta[celltype_key])): adata_list[x] = adata_subset_tmp[adata_subset_tmp.obs[celltype_key] == x].copy() adata_list_alt[x] = adata_subset[adata_subset.obs[celltype_key] == x].copy() # create expression and fraction dataframes. adata_list2, adata_list3 = {}, {} for x in adata_list: adata_list2[x] = celltype_means(adata_list[x], layer) adata_list3[x] = celltype_fraction(adata_list[x], layer) adata_list2 = pd.DataFrame(adata_list2, index=adata_subset_tmp.var_names) adata_list3 = pd.DataFrame(adata_list3, index=adata_subset_tmp.var_names) decon_subset = deconvoluted[deconvoluted.complex_name.isin(find_complex(interactions_subset))].copy() # if any interactions are actually complexes, extract them from the deconvoluted dataframe. if decon_subset.shape[0] > 0: decon_subset_expr = decon_subset.groupby("complex_name").apply(lambda r: r[adata_list2.columns].apply(np.mean, axis=0)) cellfrac = defaultdict(dict) zgenes = list(set(decon_subset_expr.index)) for ct, adat in adata_list_alt.items(): for zg in zgenes: zg_mask = adata.var_names.isin(zg.split("_")) cellfrac[ct][zg] = np.mean(adat[:, zg_mask].X > 0) if np.sum(zg_mask) > 0 else 0 decon_subset_fraction = pd.DataFrame(cellfrac) expr_df = pd.concat([adata_list2, decon_subset_expr]) fraction_df = pd.concat([adata_list3, decon_subset_fraction]) else: expr_df = adata_list2.copy() fraction_df = adata_list3.copy() # create edge list cells_test = list(set(meta[celltype_key])) cell_comb = [] for c1 in cells_test: for c2 in cells_test: if remove_self: if c1 != c2: cell_comb.append((c1, c2)) else: cell_comb.append((c1, c2)) cell_comb = list(set(cell_comb)) cell_type_grid = pd.DataFrame(cell_comb, columns=["source", "target"]) # create the final dataframe for plotting dfx = generate_df( interactions_subset=interactions_subset, cell_type_grid=cell_type_grid, cell_type_means=expr_df, cell_type_fractions=fraction_df, sep=DEFAULT_SEP, ) # ok form the table for pyCircos int_value = dict(zip(lr_interactions.barcode, lr_interactions.y_means)) int_value = {k: r for k, r in int_value.items() if pd.notnull(r)} dfx["interaction_value"] = [int_value[y] if y in int_value else np.nan for y in dfx["barcode"]] tmpdf = dfx[["producer", "receiver", "converted_pair", "interaction_value"]].copy() tmpdf["interaction_celltype"] = [ DEFAULT_SEP.join(sorted([a, b, c])) for a, b, c in zip(tmpdf.producer, tmpdf.receiver, tmpdf.converted_pair) ] celltypes = ( adata.obs[celltype_key].cat.categories if adata.obs[celltype_key].dtype.name == "category" else list(set(adata.obs[celltype_key])) ) interactions = sorted(list(set(tmpdf["interaction_celltype"]))) interaction_start_dict = {r: k for k, r in enumerate(interactions)} tmpdf["interaction_value"] = [ j + interaction_start_dict[x] if pd.notnull(j) else np.nan for j, x in zip(tmpdf.interaction_value, tmpdf.interaction_celltype) ] # filter to specific interactions if interaction is provided with '-' if interaction is not None: if isinstance(interaction, str): if "-" in interaction: # first check if the interaction is exactly in the converted_pair int_check = interaction in list(tmpdf.converted_pair.unique()) if not int_check: # flip the so that the first becomes last interaction_ = "-".join(interaction.split("-")[::-1]) int_check = interaction_ in list(tmpdf.converted_pair.unique()) if not int_check: raise ValueError(f"Neither {interaction} nor {interaction_} are found in the data.") else: intx = interaction_ else: intx = interaction tmpdf = tmpdf[tmpdf.converted_pair.isin([intx])] else: tmpdf = tmpdf[tmpdf.converted_pair.str.contains(interaction)] elif isinstance(interaction, list): tmpint = [] for intx in interaction: if "-" in intx: # first check if the interaction is exactly in the converted_pair int_check = intx in list(tmpdf.converted_pair.unique()) if not int_check: # flip the so that the first becomes last intx_ = "-".join(intx.split("-")[::-1]) int_check = intx_ in list(tmpdf.converted_pair.unique()) if not int_check: raise ValueError(f"Neither {intx} nor {intx_} are found in the data.") else: intx = intx_ else: pass # intx is already assigned tmpint.append(intx) if len(tmpint) > 0: tmpdf = tmpdf[tmpdf.converted_pair.isin(tmpint)] else: tmpdf = tmpdf[tmpdf.converted_pair.str.contains("|".join(lr_intx))] if link_colors is None: uni_interactions = list(set(tmpdf.converted_pair)) if len(uni_interactions) > len(DEFAULT_PAL): link_colors = dict(zip(uni_interactions, [next(DEFAULT_PAL_CYCLER) for i in range(len(uni_interactions))])) else: link_colors = dict(zip(uni_interactions, [DEFAULT_PAL[i] for i in range(len(uni_interactions))])) celltypes = sorted(list(set(list(tmpdf.producer) + list(tmpdf.receiver)))) # Sort the 'converted_pair' values alphabetically tmpdf = tmpdf.sort_values("converted_pair") # add optional offset to each interaction_value and rounding to the nearest integer tmpdf["interaction_value"] = round(tmpdf["interaction_value"].fillna(0) + link_offset) # For each celltype ('producer'), create a cumulative interaction value tracker cumulative_interaction_dict = defaultdict(int) start_values, end_values = [], [] for _, row in tmpdf.iterrows(): start_values.append(cumulative_interaction_dict[row["producer"]]) if row["interaction_value"] > 0: cumulative_interaction_dict[row["producer"]] += row["interaction_value"] end_values.append(cumulative_interaction_dict[row["producer"]]) # Add the 'start' and 'end' columns tmpdf["start"] = start_values tmpdf["end"] = end_values # Set the starting value for the receiver rec_start = {} for celltype in celltypes: rec_start[celltype] = max(tmpdf[tmpdf.producer == celltype].end) # For each row in the dataframe, calculate the 'start' and 'end' values for the receiver start_values, end_values = [], [] for _, r in tmpdf.iterrows(): start_values.append(rec_start[r.receiver]) rec_start[r.receiver] += r.interaction_value end_values.append(rec_start[r.receiver]) # This is the position of the end of the links (i.e. the target) tmpdf["start_2"] = start_values tmpdf["end_2"] = end_values sectors = {} # the maximum size of each sector is the maximum value of 'end_2' for each celltype for celltype in celltypes: tmp = tmpdf[tmpdf.receiver == celltype] if tmp.shape[0] > 0: sectors[celltype] = max(tmp.end_2) else: tmp = tmpdf[tmpdf.producer == celltype] if tmp.shape[0] > 0: sectors[celltype] = max(tmp.end_2) if equal_sector_size: max_sector = max(sectors.values()) sectors = {k: max_sector for k in sectors} else: # offset the sector by 10 if any of the values are 0 for sector in sectors: if sectors[sector] == 0: sectors[sector] = 10 tmpdf = tmpdf[tmpdf.interaction_value > link_offset] tmpdf = tmpdf.reset_index(drop=True) if keep_celltypes is not None: tmpdf["celltype_test1"] = tmpdf["producer"] + DEFAULT_SEP + tmpdf["receiver"] tmpdf["celltype_test2"] = tmpdf["receiver"] + DEFAULT_SEP + tmpdf["producer"] if len(keep_celltypes) == 2: test_celltype = keep_celltypes[0] + DEFAULT_SEP + keep_celltypes[1] tmpdf = tmpdf[tmpdf["celltype_test1"].isin([test_celltype]) | tmpdf["celltype_test2"].isin([test_celltype])] tmpdf = tmpdf.drop(["celltype_test1", "celltype_test2"], axis=1) elif len(keep_celltypes) > 2: tests = list(product(keep_celltypes, keep_celltypes)) test2 = [] for test in tests: test2.append(test[0] + DEFAULT_SEP + test[1]) tmpdf = tmpdf[tmpdf["celltype_test1"].isin(test2) | tmpdf["celltype_test2"].isin(test2)] tmpdf = tmpdf.drop(["celltype_test1", "celltype_test2"], axis=1) else: keep_celltypes = [keep_celltypes] if isinstance(keep_celltypes, str) else keep_celltypes tmpdf = tmpdf[tmpdf.producer.isin(keep_celltypes) | tmpdf.receiver.isin(keep_celltypes)] if sector_colors is None: if celltype_key + "_colors" in adata.uns: if adata.obs[celltype_key].dtype.name == "category": sector_colors = dict(zip(adata.obs[celltype_key].cat.categories, adata.uns[celltype_key + "_colors"])) else: sector_colors = dict(zip(list(set(adata.obs[celltype_key])), adata.uns[celltype_key + "_colors"])) else: if len(celltypes) > len(DEFAULT_PAL): sector_colors = dict(zip(celltypes, [next(DEFAULT_PAL_CYCLER) for i in range(len(celltypes))])) else: sector_colors = dict(zip(celltypes, [DEFAULT_PAL[i] for i in range(len(celltypes))])) for celltype in celltypes: if celltype not in sector_colors: sector_colors[celltype] = "#f7f7f7" # just make it off-white if not present circos = Circos(sectors, space=5) for sector in circos.sectors: track = sector.add_track(r_lim=sector_radius_limit) track.axis(fc=sector_colors[sector.name]) track.text(text=sector.name, **sector_text_kwargs) for _, r in tmpdf.iterrows(): if not same_producer_colors: if link_colors is not None: if isinstance(link_colors, dict): link_color = link_colors[r.converted_pair] if r.converted_pair in link_colors else "#f7f7f700" elif isinstance(link_colors, str): link_color = link_colors else: link_color = sector_colors[r.producer] circos.link( sector_region1=(r.producer, r.start, r.end), sector_region2=(r.receiver, r.start_2, r.end_2), color=link_color, **link_kwargs ) circos.plotfig() # plot legend if isinstance(link_colors, dict): line_handles = [Line2D([], [], color=value, label=key, linewidth=legend_width) for key, value in link_colors.items()] line_legend = circos.ax.legend(handles=line_handles, **legend_kwargs) circos.ax.add_artist(line_legend) return circos