from typing import Union, List, Tuple, Callable, Set
import holoviews as hv
import hvplot.pandas
import numpy as np
import pandas as pd
import scipy.sparse
import scipy.stats
from anndata import AnnData
from holoviews import dim
from holoviews.plotting.links import Link
from holoviews.plotting.bokeh.callbacks import LinkCallback
# def sort_by_values(summarized_df):
# # sort rows by expression
# sorted_df = summarized_df.sort_values(axis=0, by=list(summarized_df.columns.values), ascending=False)
# indices = [summarized_df.index.get_loc(c) for c in sorted_df.index]
# return indices
def __auto_bin(df, nbins, width, height):
if nbins == -1 and df.shape[0] >= 500000:
nbins = int(max(200, min(width, height) / 2))
return nbins
def __create_hover_tool(df, keywords: dict, exclude: List, current: str = None, whitelist: List = None):
"""
Generate hover tool.
Args:
keywords: Keyword dict
exclude: List of columns in df to exclude.
current: Key in df that is plotted to show 1st in tooltip
"""
try:
import bokeh.models
import holoviews.core.util
hover_cols = []
for column in df.columns:
if column not in exclude and column != current and column not in hover_cols and (
whitelist is None or column in whitelist):
hover_cols.append(column)
keywords['hover_cols'] = hover_cols
tooltips = []
if current is not None:
tooltips.append((current, '@{' + holoviews.core.util.dimension_sanitizer(current) + '}'))
for hover_col in hover_cols:
tooltips.append((hover_col, '@{' + holoviews.core.util.dimension_sanitizer(hover_col) + '}'))
tools = keywords.get('tools', [])
keywords['tools'] = tools + [bokeh.models.HoverTool(tooltips=tooltips)]
except ModuleNotFoundError:
pass
def __create_bounds_stream(source):
stream = hv.streams.BoundsXY(source=source)
return stream
def get_bounds(plot):
if isinstance(plot, hv.Layout):
if plot.shape == (1, 1):
plot = plot[0, 0]
else:
raise ValueError('Please select the plot in the layout')
if hasattr(plot, 'bounds_stream'):
return plot.bounds_stream.bounds
def __to_list(vals):
if isinstance(vals, np.ndarray):
vals = vals.tolist()
elif isinstance(vals, tuple):
vals = list(vals)
elif not isinstance(vals, list):
vals = [vals]
return vals
def __size_legend(size_min, size_max, dot_min, dot_max, size_tick_labels_format, size_ticks):
# TODO improve
size_ticks_pixels = np.interp(size_ticks, (size_min, size_max), (dot_min, dot_max))
size_tick_labels = [size_tick_labels_format.format(x) for x in size_ticks]
points = hv.Points(
{'x': np.repeat(0.1, len(size_ticks)), 'y': np.arange(len(size_ticks), 0, -1),
'size': size_ticks_pixels},
vdims='size').opts(xaxis=None, color='black', yaxis=None, size=dim('size'))
labels = hv.Labels(
{'x': np.repeat(0.2, len(size_ticks)), 'y': np.arange(len(size_ticks), 0, -1),
'text': size_tick_labels},
['x', 'y'], 'text').opts(text_align='left', text_font_size='9pt')
overlay = (points * labels)
overlay.opts(width=125, height=int(len(size_ticks) * (dot_max + 12)), xlim=(0, 1),
ylim=(0, len(size_ticks) + 1),
invert_yaxis=True, shared_axes=False, show_frame=False)
return overlay
def __fix_color_by_data_type(df, by):
if by is not None and (pd.api.types.is_categorical_dtype(df[by]) or pd.api.types.is_bool_dtype(df[by])):
df[by] = df[by].astype(str) # hvplot does not currently handle categorical or boolean type for colors
def __get_raw(adata, use_raw):
adata_raw = adata
if use_raw or (use_raw is None and adata.raw is not None):
if adata.raw is None:
raise ValueError('Raw data not found')
adata_raw = adata.raw
return adata_raw
def __get_df(adata, adata_raw, keys, df=None, is_obs=None):
if df is not None and is_obs is None:
raise ValueError('Please provide is_obs when df is provided.')
for key in keys:
if df is None:
is_obs = key not in adata.var
df = pd.DataFrame(data=dict(id=(adata.obs.index.values if is_obs else adata.var.index.values)))
if key in adata_raw.var_names and is_obs:
X = adata_raw.obs_vector(key)
if scipy.sparse.issparse(X):
X = X.toarray()
df[key] = X
elif key in adata.obs and is_obs:
df[key] = adata.obs[key].values
elif key in adata.var and not is_obs:
df[key] = adata.var[key].values
else:
raise ValueError('{} not found'.format(key))
return df
def __bin(df, nbins, coordinate_columns, reduce_function, coordinate_column_to_range=None):
# replace coordinates with bin
for view_column_name in coordinate_columns: # add view column _bin
values = df[view_column_name].values
view_column_range = coordinate_column_to_range.get(view_column_name,
None) if coordinate_column_to_range is not None else None
column_min = values.min() if view_column_range is None else view_column_range[0]
column_max = values.max() if view_column_range is None else view_column_range[1]
df[view_column_name] = np.floor(
np.interp(values, [column_min, column_max], [0, nbins - 1])).astype(int)
agg_func = {}
for column in df:
if column == 'count':
agg_func[column] = 'sum'
elif pd.api.types.is_categorical_dtype(df[column]):
agg_func[column] = lambda x: x.mode()[0]
elif column not in coordinate_columns and pd.api.types.is_numeric_dtype(df[column]):
agg_func[column] = reduce_function
return df.groupby(coordinate_columns, as_index=False).agg(agg_func), df[coordinate_columns]
[docs]def violin(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str = None,
width: int = 300, cmap: Union[str, List[str], Tuple[str]] = 'Category20', cols: int = None,
use_raw: bool = None, **kwds) -> hv.core.element.Element:
"""
Generate a violin plot.
Args:
adata: Annotated data matrix.
keys: Keys for accessing variables of adata.var_names, field of adata.var, or field of adata.obs
by: Group plot by specified observation.
width: Plot width.
cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information.
cols: Number of columns for laying out multiple plots
use_raw: Use `raw` attribute of `adata` if present.
"""
if cols is None:
cols = 3
adata_raw = __get_raw(adata, use_raw)
plots = []
keywords = dict(padding=0.02, cmap=cmap, rot=90)
keywords.update(kwds)
keys = __to_list(keys)
df = __get_df(adata, adata_raw, keys + ([] if by is None else [by]))
__fix_color_by_data_type(df, by)
for key in keys:
p = df.hvplot.violin(key, width=width, by=by, violin_color=by, **keywords)
plots.append(p)
return hv.Layout(plots).cols(cols)
[docs]def heatmap(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str,
reduce_function: Callable[[np.ndarray], float] = np.mean,
use_raw: bool = None, cmap: Union[str, List[str], Tuple[str]] = 'Reds', **kwds) -> hv.core.element.Element:
"""
Generate a heatmap.
Args:
adata: Annotated data matrix.
keys: Keys for accessing variables of adata.var_names
by: Group plot by specified observation.
cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information.
reduce_function: Function to summarize an element in the heatmap
use_raw: Use `raw` attribute of `adata` if present.
"""
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
df = None
keywords = dict(colorbar=True, xlabel='', cmap=cmap, ylabel=str(by), rot=90)
keywords.update(kwds)
for key in keys:
X = adata_raw.obs_vector(key)
if scipy.sparse.issparse(X):
X = X.toarray()
_df = pd.DataFrame(X, columns=['value'])
_df['feature'] = key
_df[by] = adata.obs[by].values
df = _df if df is None else pd.concat((df, _df))
return df.hvplot.heatmap(x='feature', y=by, C='value', reduce_function=reduce_function, **keywords)
[docs]def scatter(adata: AnnData, x: str, y: str, color=None, size: Union[int, str] = None,
dot_min=2, dot_max=14, use_raw: bool = None, sort: bool = True, width: int = 400, height: int = 400,
nbins: int = -1, reduce_function: Callable[[np.array], float] = np.mean,
cmap: Union[str, List[str], Tuple[str]] = 'viridis', **kwds) -> hv.core.element.Element:
"""
Generate a scatter plot.
Args:
adata: Annotated data matrix.
x: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs
y: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs
cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information.
color: Field in .var_names, adata.var, or adata.obs to color the points by.
sort: Plot higher color by values on top of lower values.
size: Field in .var_names, adata.var, or adata.obs to size the points by or a pixel size.
dot_min: Minimum dot size when sizing points by a field.
dot_max: Maximum dot size when sizing points by a field.
use_raw: Use `raw` attribute of `adata` if present.
nbins: Number of bins used to summarize plot on a grid. Useful for large datasets. Negative one means automatically bin the plot.
reduce_function: Function used to summarize overlapping cells if nbins is specified
"""
return __scatter(adata=adata, x=x, y=y, color=color, size=size, dot_min=dot_min, dot_max=dot_max, use_raw=use_raw,
sort=sort, width=width, height=height, nbins=nbins, reduce_function=reduce_function, cmap=cmap, is_scatter=True,
**kwds)
def line(adata: AnnData, x: str, y: str,
use_raw: bool = None, width: int = 400, height: int = 400,
nbins: int = None, reduce_function: Callable[[np.array], float] = np.mean, **kwds) -> hv.core.element.Element:
"""
Generate a scatter plot.
Args:
adata: Annotated data matrix.
x: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs
y: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs
use_raw: Use `raw` attribute of `adata` if present.
nbins: Number of bins used to summarize plot on a grid. Useful for large datasets.
reduce_function: Function used to summarize overlapping cells if nbins is specified
"""
return __scatter(adata=adata, x=x, y=y, use_raw=use_raw, sort=False, width=width, height=height, nbins=nbins,
reduce_function=reduce_function, is_scatter=False,
**kwds)
def __scatter(adata: AnnData, x: str, y: str, color=None, size: Union[int, str] = None,
dot_min=2, dot_max=14, use_raw: bool = None, sort: bool = True, width: int = 400, height: int = 400,
nbins: int = None, reduce_function: Callable[[np.array], float] = np.mean,
cmap: Union[str, List[str], Tuple[str]] = 'viridis', is_scatter=True, **kwds) -> hv.core.element.Element:
"""
Generate a scatter plot.
Args:
adata: Annotated data matrix.
x: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs
y: Key for accessing variables of adata.var_names, field of adata.var, or field of adata.obs
cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information.
color: Field in .var_names, adata.var, or adata.obs to color the points by.
sort: Plot higher color by values on top of lower values.
size: Field in .var_names, adata.var, or adata.obs to size the points by or a pixel size.
dot_min: Minimum dot size when sizing points by a field.
dot_max: Maximum dot size when sizing points by a field.
use_raw: Use `raw` attribute of `adata` if present.
nbins: Number of bins used to summarize plot on a grid. Useful for large datasets.
reduce_function: Function used to summarize overlapping cells if nbins is specified
"""
adata_raw = __get_raw(adata, use_raw)
keywords = dict(fontsize=dict(title=9), nonselection_alpha=0.1, padding=0.02, xaxis=True, yaxis=True, width=width,
height=height, alpha=1, tools=['box_select'], cmap=cmap)
keywords.update(kwds)
keys = [x, y]
if color is not None and is_scatter:
keys.append(color)
is_size_by = isinstance(size, str)
if is_size_by and is_scatter:
keys.append(size)
df = __get_df(adata, adata_raw, keys)
nbins = __auto_bin(df, nbins, width, height)
df_with_coords = df
hover_cols = keywords.get('hover_cols', [])
if nbins is not None and nbins > 0:
df['count'] = 1.0
hover_cols.append('count')
df, df_with_coords = __bin(df, nbins=nbins, coordinate_columns=[x, y], reduce_function=reduce_function)
else:
hover_cols.append('id')
keywords['hover_cols'] = hover_cols
if color is not None and is_scatter:
__fix_color_by_data_type(df, color)
is_color_by_numeric = pd.api.types.is_numeric_dtype(df[color])
if is_color_by_numeric:
keywords.update(dict(colorbar=True, c=color))
if sort:
df = df.sort_values(by=color)
else:
keywords.update(dict(by=color))
if is_size_by:
size_min = df[size].min()
size_max = df[size].max()
size_pixels = np.interp(df[size], (size_min, size_max), (dot_min, dot_max))
df['pixels'] = size_pixels
keywords['s'] = 'pixels'
hover_cols = keywords.get('hover_cols', [])
hover_cols.append(size)
keywords['hover_cols'] = hover_cols
elif size is not None:
keywords['size'] = size
if is_scatter:
p = df.hvplot.scatter(x=x, y=y, **keywords)
else:
df = df.sort_values(by=x)
p = df.hvplot.line(x=x, y=y, **keywords)
if is_size_by:
return_value = (p + __size_legend(size_min=size_min, size_max=size_max, dot_min=dot_min, dot_max=dot_max,
size_tick_labels_format='{0:.1f}',
size_ticks=np.array([size_min, (size_min + size_max) / 2, size_max])))
else:
return_value = p
return_value.df = df_with_coords
return return_value
[docs]def dotplot(adata: AnnData, keys: Union[str, List[str], Tuple[str]], by: str,
reduce_function: Callable[[np.ndarray], float] = np.mean,
fraction_min: float = 0, fraction_max: float = None, dot_min: int = 0, dot_max: int = 14,
use_raw: bool = None, cmap: Union[str, List[str], Tuple[str]] = 'Reds',
sort_function: Callable[[pd.DataFrame], List[str]] = None, **kwds) -> hv.core.element.Element:
"""
Generate a dot plot.
Args:
adata: Annotated data matrix.
keys: Keys for accessing variables of adata.var_names
by: Group plot by specified observation.
cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information.
reduce_function: Function to summarize an element in the heatmap
fraction_min: Minimum fraction expressed value.
fraction_max: Maximum fraction expressed value.
dot_min: Minimum pixel dot size.
dot_max: Maximum pixel dot size.
use_raw: Use `raw` attribute of `adata` if present.
sort_function: Optional function that accepts summarized data frame and returns a list of row indices in the order to render in the heatmap.
"""
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
keywords = dict(colorbar=True, ylabel=str(by), xlabel='', padding=0, rot=90, cmap=cmap)
keywords.update(kwds)
X = adata_raw[:, keys].X
if scipy.sparse.issparse(X):
X = X.toarray()
df = pd.DataFrame(data=X, columns=keys)
df[by] = adata.obs[by].values
def non_zero(g):
return np.count_nonzero(g) / g.shape[0]
summarized_df = df.groupby(by).aggregate([reduce_function, non_zero])
if sort_function is not None:
row_indices = sort_function(summarized_df)
summarized_df = summarized_df.iloc[row_indices]
mean_columns = []
frac_columns = []
for i in range(len(summarized_df.columns)):
if i % 2 == 0:
mean_columns.append(summarized_df.columns[i])
else:
frac_columns.append(summarized_df.columns[i])
# features on columns, by on rows
fraction_df = summarized_df[frac_columns]
mean_df = summarized_df[mean_columns]
y, x = np.indices(mean_df.shape)
y = y.flatten()
x = x.flatten()
fraction = fraction_df.values.flatten()
if fraction_max is None:
fraction_max = fraction.max()
size = np.interp(fraction, (fraction_min, fraction_max), (dot_min, dot_max))
summary_values = mean_df.values.flatten()
xlabel = [keys[i] for i in range(len(keys))]
ylabel = [str(summarized_df.index[i]) for i in range(len(summarized_df.index))]
dotplot_df = pd.DataFrame(
data=dict(x=x, y=y, value=summary_values, pixels=size, fraction=fraction, xlabel=np.array(xlabel)[x],
ylabel=np.array(ylabel)[y]))
xticks = [(i, keys[i]) for i in range(len(keys))]
yticks = [(i, str(summarized_df.index[i])) for i in range(len(summarized_df.index))]
keywords['width'] = int(np.ceil(dot_max * len(xticks) + 150))
keywords['height'] = int(np.ceil(dot_max * len(yticks) + 100))
try:
import bokeh.models
keywords['hover_cols'] = ['fraction', 'xlabel', 'ylabel']
keywords['tools'] = [bokeh.models.HoverTool(tooltips=[
('fraction', '@fraction'),
('value', '@value'),
('x', '@xlabel'),
('y', '@ylabel')
])]
except ModuleNotFoundError:
pass
p = dotplot_df.hvplot.scatter(x='x', y='y', xlim=(-0.5, len(xticks) + 0.5), ylim=(-0.5, len(yticks) + 0.5),
c='value', s='pixels', xticks=xticks, yticks=yticks, **keywords)
size_range = fraction_max - fraction_min
if 0.3 < size_range <= 0.6:
size_legend_step = 0.1
elif size_range <= 0.3:
size_legend_step = 0.05
else:
size_legend_step = 0.2
size_ticks = np.arange(fraction_min if fraction_min > 0 or fraction_min > 0 else fraction_min + size_legend_step,
fraction_max + size_legend_step, size_legend_step)
result = p + __size_legend(size_min=fraction_min, size_max=fraction_max, dot_min=dot_min, dot_max=dot_max,
size_tick_labels_format='{:.0%}', size_ticks=size_ticks)
result.df = dotplot_df
return result
[docs]def scatter_matrix(adata: AnnData, keys: Union[str, List[str], Tuple[str]], color=None, use_raw: bool = None,
**kwds) -> hv.core.element.Element:
"""
Generate a scatter plot matrix.
Args:
adata: Annotated data matrix.
keys: Key for accessing variables of adata.var_names or a field of adata.obs
color: Key in adata.obs to color points by.
use_raw: Use `raw` attribute of `adata` if present.
"""
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
if color is not None:
keys.append(color)
df = __get_df(adata, adata_raw, keys)
__fix_color_by_data_type(df, color)
p = hvplot.scatter_matrix(df, c=color, **kwds)
p.df = df
return p
[docs]def embedding(adata: AnnData, basis: str, keys: Union[None, str, List[str], Tuple[str]] = None,
cmap: Union[str, List[str], Tuple[str]] = 'viridis',
alpha: float = 1, size: int = 12,
width: int = 400, height: int = 400,
sort: bool = True, cols: int = 2,
use_raw: bool = None, nbins: int = -1, reduce_function: Callable[[np.array], float] = np.mean,
labels_on_data: bool = False, tooltips: Union[str, List[str], Tuple[str]] = None,
**kwds) -> hv.core.element.Element:
"""
Generate an embedding plot.
Args:
adata: Annotated data matrix.
keys: Key for accessing variables of adata.var_names or a field of adata.obs used to color the plot. Can also use `count` to plot cell count when binning.
basis: String in adata.obsm containing coordinates.
alpha: Points alpha value.
size: Point pixel size.
sort: Plot higher values on top of lower values.
cmap: Color map name (hv.plotting.list_cmaps()) or a list of hex colors. See http://holoviews.org/user_guide/Styling_Plots.html for more information.
nbins: Number of bins used to summarize plot on a grid. Useful for large datasets. Negative one means automatically bin the plot.
reduce_function: Function used to summarize overlapping cells if nbins is specified.
cols: Number of columns for laying out multiple plots
width: Plot width.
height: Plot height.
tooltips: List of additional fields to show on hover.
labels_on_data: Whether to draw labels for categorical features on the plot.
use_raw: Use `raw` attribute of `adata` if present.
"""
if keys is None:
keys = []
adata_raw = __get_raw(adata, use_raw)
keys = __to_list(keys)
if tooltips is None:
tooltips = []
tooltips = __to_list(tooltips)
keywords = dict(fontsize=dict(title=9), padding=0.02, xaxis=False, yaxis=False, nonselection_alpha=0.1,
tools=['box_select'], cmap=cmap, legend=not labels_on_data)
keywords.update(kwds)
coordinate_columns = ['X_' + basis + c for c in ['1', '2']]
df = __get_df(adata, adata_raw, keys + tooltips,
pd.DataFrame(adata.obsm['X_' + basis][:, 0:2], columns=coordinate_columns),
is_obs=True)
nbins = __auto_bin(df, nbins, width, height)
df_with_coords = df
density = len(keys) == 0
if density:
keys = ['count']
bin_data = nbins is not None and nbins > 0
plots = []
if bin_data or density:
df['count'] = 1.0
if bin_data:
df, df_with_coords = __bin(df, nbins=nbins, coordinate_columns=coordinate_columns,
reduce_function=reduce_function)
for key in keys:
__fix_color_by_data_type(df, key)
is_color_by_numeric = pd.api.types.is_numeric_dtype(df[key])
df_to_plot = df
if sort and is_color_by_numeric:
df_to_plot = df.sort_values(by=key)
__create_hover_tool(df, keywords, exclude=coordinate_columns, current=key)
p = df_to_plot.hvplot.scatter(
x=coordinate_columns[0],
y=coordinate_columns[1],
title=str(key),
c=key if is_color_by_numeric else None,
by=key if not is_color_by_numeric else None,
size=size,
alpha=alpha,
colorbar=is_color_by_numeric,
width=width, height=height, **keywords)
bounds_stream = __create_bounds_stream(p)
if not is_color_by_numeric and labels_on_data:
labels_df = df_to_plot[[coordinate_columns[0], coordinate_columns[1], key]].groupby(key).aggregate(
np.median)
labels = hv.Labels({('x', 'y'): labels_df, 'text': labels_df.index.values}, ['x', 'y'], 'text')
p = p * labels
p.bounds_stream = bounds_stream
plots.append(p)
# note that we can't link brushing because points are plotted in different order for each plot
layout = hv.Layout(plots).cols(cols)
layout.df = df_with_coords
return layout
[docs]def variable_feature_plot(adata: AnnData, **kwds) -> hv.core.element.Element:
"""
Generate a variable feature plot.
Args:
adata: Annotated data matrix.
"""
if 'hvf_loess' in adata.var:
keywords = dict(x='mean', y='var', y_fit='hvf_loess', color='highly_variable_features',
xlabel='Mean log expression', ylabel='Variance of log expression')
else:
keywords = dict(x='means', y='dispersions_norm', y_fit=None, color='highly_variable',
xlabel='Mean log expression', ylabel='Normalized dispersion')
keywords.update(kwds)
x = keywords.pop('x')
y = keywords.pop('y')
color = keywords.pop('color')
xlabel = keywords.pop('xlabel')
ylabel = keywords.pop('ylabel')
y_fit = keywords.pop('y_fit')
line_color = keywords.pop('line_color', 'black')
if y_fit is not None and y_fit in adata.var:
return scatter(adata, x=x, y=y, xlabel=xlabel, color=color,
ylabel=ylabel, **keywords) * line(adata, x=x, y=y_fit, line_color=line_color)
else:
return scatter(adata, x=x, y=y, color=color, xlabel=xlabel, ylabel=ylabel)
class __BrushLink(Link):
_requires_target = True
class __BrushLinkCallback(LinkCallback):
source_model = 'selected'
source_handles = ['cds']
on_source_changes = ['indices']
target_model = 'selected'
source_code = """
target_selected.indices = source_selected.indices;
"""
__BrushLink.register_callback('bokeh', __BrushLinkCallback)
def volcano(adata: AnnData, basis: str = 'de_res', x: str = 'log_fold_change', y: str = 't_qval',
x_cutoff: float = 1, y_cutoff: float = 0.05, cluster_ids: Union[List, Tuple, Set] = None,
**kwds) -> hv.core.element.Element:
"""
Generate a volcano plot.
Args:
adata: Annotated data matrix.
basis: String in adata.varm containing statistics to plot.
x: Field in basis to plot on x-axis. Field is assumed to end with :cluster_id (e.g. log_fold_change:1).
y: Field in basis to plot on y-axis. Field is assumed to end with :cluster_id (e.g. t_qval:1)..
x_cutoff: Highlight items >= x_cutoff or <=-x_cutoff
y_cutoff: Highlight items >= y_cutoff
cluster_ids: Optional list of cluster ids to include. If unspecified, plots are shown for all clusters.
"""
de_results = adata.varm[basis]
names = de_results.dtype.names # stat:cluster e.g. 'mwu_pval:13'
cluster_to_xy = {}
keywords = dict(fontsize=dict(title=9), nonselection_line_color=None, line_color='black',
selection_line_color='black', line_width=0.3, nonselection_alpha=0.05,
padding=0.02, xaxis=True, yaxis=True, alpha=0.9,
tools=['box_select'], hover_cols=['id'],
cmap={'Up': '#e41a1c', 'Down': '#377eb8', 'Not significant': '#bdbdbd'})
keywords.update(kwds)
for name in names:
xy_index = -1
if name.startswith(x):
xy_index = 0
elif name.startswith(y):
xy_index = 1
if xy_index != -1:
cluster_id = name[name.rindex(':') + 1:]
if cluster_ids is None or (cluster_ids is not None and cluster_id in cluster_ids):
xy = cluster_to_xy.get(cluster_id, None)
if xy is None:
xy = [None, None]
cluster_to_xy[cluster_id] = xy
xy[xy_index] = name
plots = []
cluster_ids = cluster_to_xy.keys()
df = pd.DataFrame(dict(id=adata.var.index.values))
filtered_cluster_ids = []
for cluster_id in cluster_ids:
xy = cluster_to_xy[cluster_id]
if xy[0] is not None and xy[1] is not None:
filtered_cluster_ids.append(cluster_id)
x_column = '{}_{}'.format(x, cluster_id)
y_column = '{}_{}'.format(y, cluster_id)
y_log_column = '{}_{}_log'.format(y, cluster_id)
status_column = '{}_status'.format(cluster_id)
df[x_column] = de_results[xy[0]]
df[y_column] = de_results[xy[1]]
df[status_column] = 'Not significant'
df.loc[(df[y_column] <= y_cutoff) & (df[x_column] >= x_cutoff), status_column] = 'Up'
df.loc[(df[y_column] <= y_cutoff) & (df[x_column] < -x_cutoff), status_column] = 'Down'
df[y_log_column] = -np.log10(df[y_column] + 1e-12)
for cluster_id in filtered_cluster_ids:
x_column = '{}_{}'.format(x, cluster_id)
y_column = '{}_{}'.format(y, cluster_id)
y_log_column = '{}_{}_log'.format(y, cluster_id)
status_column = '{}_status'.format(cluster_id)
__create_hover_tool(df, keywords, exclude=[], whitelist=['id', x_column, y_column])
p = df.hvplot.scatter(x=x_column, y=y_log_column, title=str(
cluster_id), c=status_column, xlabel=str(x), ylabel='-log10 ' + str(y), **keywords)
plots.append(p)
# shared_datasource for linked brushing colors points incorrectly
for i in range(len(plots)):
for j in range(i):
__BrushLink(plots[i], plots[j])
__BrushLink(plots[j], plots[i])
result = hv.Layout(plots).cols(1)
result.df = df
return result
[docs]def composition_plot(adata: AnnData, by: str, condition: str, stacked: bool = True, normalize: bool = True,
stats: bool = True, **kwds) -> hv.core.element.Element:
"""
Generate a composition plot, which shows the percentage of observations from every condition within each cluster (by).
Args:
adata: Annotated data matrix.
by: Key for accessing variables of adata.var_names or a field of adata.obs used to group the data.
condition: Key for accessing variables of adata.var_names or a field of adata.obs used to compute counts within a group.
reduce_function: Function used to summarize condition groups
stacked: Whether bars are stacked.
normalize: Normalize counts within each group to sum to one.
stats: Compute statistics for each group using the fisher exact test when condition has two groups and the chi square test otherwise.
"""
adata_raw = __get_raw(adata, False)
keys = [by, condition]
df = __get_df(adata, adata_raw, keys)
keywords = dict(stacked=stacked, group_label=condition)
keywords.update(kwds)
invert = keywords.get('invert', False)
if not invert and 'rot' not in keywords:
keywords['rot'] = 90
dummy_df = pd.get_dummies(df[condition])
df = pd.concat([df, dummy_df], axis=1)
df = df.groupby(by).agg(np.sum)
cluster_p_values = None
obs = None
if stats:
# condition_in, condition_out
# by_in
# by_out
cluster_p_values = np.ones(shape=df.shape[0])
scores = np.ones(shape=df.shape[0])
obs = []
p_value_func = scipy.stats.fisher_exact if df.shape[1] == 2 else scipy.stats.chi2_contingency
group_clusters_by_name = 'a_b'
counter = 1
while group_clusters_by_name in df.columns:
group_clusters_by_name = 'a_b-' + str(counter)
for i in range(df.shape[0]): # each cluster
obs_df = df.copy()
cluster_in_out = ['a'] * df.shape[0]
cluster_in_out[i] = 'b'
obs_df[group_clusters_by_name] = cluster_in_out
obs_df = obs_df.groupby(group_clusters_by_name).agg(np.sum)
p_value_result = p_value_func(obs_df.values)
obs.append(obs_df.values)
cluster_p_values[i] = p_value_result[1]
scores[i] = p_value_result[0]
from statsmodels.stats.multitest import multipletests
_, fdr, _, _ = multipletests(cluster_p_values, alpha=0.05, method='fdr_bh')
bonferroni = np.minimum(cluster_p_values * len(cluster_p_values), 1.0)
cluster_p_values = pd.DataFrame(
data=dict(cluster=df.index, fdr=fdr, bonferroni=bonferroni, p_value=cluster_p_values))
cluster_p_values['fisher_exact_odds_ratio' if df.shape[1] == 2 else 'chi2'] = scores
cluster_p_values.sort_index(inplace=True, ascending=False) # match order of bar plot
if normalize:
df = df.T.div(df.sum(axis=1)).T
p = df.hvplot.bar(by, list(dummy_df.columns.values), **keywords)
if cluster_p_values is not None:
p = p + hv.Table(cluster_p_values)
p.cols(1)
p.df = df
p.obs = obs
p.stats = cluster_p_values
return p