import time
import numpy as np
from scipy.sparse import issparse
from anndata import AnnData
import logging
logger = logging.getLogger("sccloud")
from sccloud.tools import estimate_feature_statistics, select_features
[docs]def set_group_attribute(data: AnnData, attribute_string: str) -> None:
"""Set group attributes used in batch correction.
Batch correction assumes the differences in gene expression between channels are due to batch effects.
However, in many cases, we know that channels can be partitioned into several groups and each group is
biologically different from others. In this case, *sccloud* will only perform batch correction for channels within each group.
Parameters
----------
data: ``anndata.AnnData``
Annotated data matrix with rows for cells and columns for genes.
attribute_string: ``str``
Attributes used to construct groups:
* If ``None``, assume all channels are from one group.
* ``attr``, where ``attr`` is a keyword in ``data.obs``.
So the groups are defined by this sample attribute.
*``att1+att2+...+attrn``, where ``attr1`` to ``attrn`` are keywords in ``data.obs``.
So the groups are defined by the Cartesian product of these *n* attributes.
* ``attr=value_11,...value_1n_1;value_21,...value_2n_2;...;value_m1,...,value_mn_m``, where ``attr`` is a keyword in ``data.obs``.
In this form, there will be *(m+1)* groups. A cell belongs to group *i* (*i > 1*) if and only if
its sample attribute ``attr`` has a value among ``value_i1``, ... ``value_in_i``.
A cell belongs to group 0 if it does not belong to any other groups.
Returns
-------
``None``
Update ``data.obs``:
* ``data.obs["Group"]``: Group ID for each cell.
Examples
--------
>>> scc.set_group_attribute(adata, attr_string = "Individual")
>>> scc.set_group_attribute(adata, attr_string = "Individual+assignment")
>>> scc.set_group_attribute(adata, attr_string = "Channel=1,3,5;2,4,6,8")
"""
if attribute_string.find("=") >= 0:
attr, value_str = attribute_string.split("=")
assert attr in data.obs.columns
values = value_str.split(";")
data.obs["Group"] = "0"
for group_id, value in enumerate(values):
vals = value.split(",")
idx = np.isin(data.obs[attr], vals)
data.obs.loc[idx, "Group"] = str(group_id + 1)
elif attribute_string.find("+") >= 0:
attrs = attribute_string.split("+")
assert np.isin(attrs, data.obs.columns).sum() == len(attrs)
data.obs["Group"] = data.obs[attrs].apply(lambda x: "+".join(x), axis=1)
else:
assert attribute_string in data.obs.columns
data.obs["Group"] = data.obs[attribute_string]
def estimate_adjustment_matrices(data: AnnData) -> bool:
""" Estimate adjustment matrices
"""
if ("gmeans" not in data.varm) or ("gstds" not in data.varm):
estimate_feature_statistics(data, True)
if data.uns["Channels"].size == 1:
logger.warning(
"Warning: data only contains 1 channel. Batch correction disabled!"
)
return False
nchannel = data.uns["Channels"].size
plus = np.zeros((data.shape[1], nchannel))
muls = np.zeros((data.shape[1], nchannel))
ncells = data.uns["ncells"]
means = data.varm["means"]
partial_sum = data.varm["partial_sum"]
gmeans = data.varm["gmeans"]
gstds = data.varm["gstds"]
c2gid = data.uns["c2gid"]
for i in range(data.uns["Channels"].size):
if ncells[i] > 1:
muls[:, i] = (partial_sum[:, i] / (ncells[i] - 1.0)) ** 0.5
outliers = muls[:, i] < 1e-6
normals = np.logical_not(outliers)
muls[outliers, i] = 1.0
muls[normals, i] = gstds[normals, c2gid[i]] / muls[normals, i]
plus[:, i] = gmeans[:, c2gid[i]] - muls[:, i] * means[:, i]
data.varm["plus"] = plus
data.varm["muls"] = muls
return True
def correct_batch_effects(data: AnnData, keyword: str, features: str = None) -> None:
""" Apply calculated plus and muls to correct batch effects for a dense matrix
"""
X = data.uns[keyword]
m = X.shape[1]
if features is not None:
selected = data.var[features].values
plus = data.varm["plus"][selected, :]
muls = data.varm["muls"][selected, :]
else:
selected = np.ones(data.shape[1], dtype=bool)
plus = data.varm["plus"]
muls = data.varm["muls"]
for i, channel in enumerate(data.uns["Channels"]):
idx = np.isin(data.obs["Channel"], channel)
if idx.sum() == 0:
continue
X[idx] = X[idx] * np.reshape(muls[:, i], newshape=(1, m)) + np.reshape(
plus[:, i], newshape=(1, m)
)
# X[X < 0.0] = 0.0
[docs]def correct_batch(data: AnnData, features: str = None) -> None:
"""Batch correction on data.
Parameters
----------
data: ``anndata.AnnData``
Annotated data matrix with rows for cells and columns for genes.
features: `str`, optional, default: ``None``
Features to be included in batch correction computation. If ``None``, simply consider all features.
Returns
-------
``None``
Update ``data.X`` by the corrected count matrix.
Examples
--------
>>> scc.correct_batch(adata, features = "highly_variable_features")
"""
tot_seconds = 0.0
# estimate adjustment parameters
start = time.time()
can_correct = estimate_adjustment_matrices(data)
end = time.time()
tot_seconds += end - start
logger.info("Adjustment parameters are estimated.")
# select dense matrix
keyword = select_features(data, features)
logger.info("Features are selected.")
if can_correct:
start = time.time()
correct_batch_effects(data, keyword, features)
end = time.time()
tot_seconds += end - start
logger.info(
"Batch correction is finished. Time spent = {:.2f}s.".format(tot_seconds)
)