A generative scheme for a single-cell count matrix with multiplicative batch effects

We encounter single-cell expression data consisting of multiple batches. One of the primary goals is to identify cell types (clusters/factors) and cell-type-specific gene expression patterns. However, distinguishing batch-specific and cell-type-specific genes only by a factorization method is challenging and often not identifiable from data alone. For each gene gg and cell jj, the gene expression YgjY_{gj} were sampled from Poisson distribution with the rate parameter:

Ξ»gj=Ξ»gjπ—Žπ—‡π–»π—‚π–Ίπ—Œπ–Ύπ–½Γ—βˆkΞ΄gkXkj,\lambda_{gj} = \lambda_{gj}^{\textsf{unbiased}} \times \prod_{k} \delta_{gk}^{X_{kj}},

affected by the batch effects Ξ΄gk\delta_{gk}. More formally, letting XkjX_{kj} be a batch membership matrix, assigning a cell jj to a batch kk if and only if Xkj=1X_{kj}=1, we assume the average gene expression rates are linearly affected by in the log-transformed space:

𝔼[lnYgj]=ln(βˆ‘tΞ²gtΞΈjt)+βˆ‘klnΞ΄gkXkj.\mathbb{E}\!\left[\ln Y_{gj}\right] = \ln \left( \sum_{t} \beta_{gt} \theta_{jt} \right) + \sum_{k} \ln\delta_{gk} X_{kj}.


m <- 500 # genes
n <- 1000 # cells
nb <- 2 # batches

## 1. batch membership
X <- matrix(0, n, nb)
batch <- sample(nb, n, replace = TRUE)
for(b in 1:nb){
    X[batch == b, b] <- 1

## 2. batch effects
W.true <- matrix(rnorm(m*nb), m, nb) <- apply(W.true %*% t(X), 2, scale)

## 3. true effects
K <- 5
.beta <- matrix(rgamma(m * K, 1), m, K)
.theta <- matrix(rgamma(n * K, 1), n, K)
lambda.true <- .beta %*% t(.theta)
kk <- apply(.theta, 1, which.max)

lambda <- lambda.true * exp(
yy <- apply(lambda, 2, function(l) sapply(l, rpois, n=1))
oo <- order(apply(t(.theta), 2, which.max))

If we can accurately estimate a true batch effect matrix, say Ξ΄gk\delta_{gk}, it is straightforward to adjust the difference between batches. How can we identify the true batch effect Ξ΄gk\delta_{gk} for all the genes gg specifically expressed in the batch kk? If we match cells ii and jj sampled from the batches aa and bb, respectively, we expect the batch-specific difference Ξ΄gaβ‰ Ξ΄gk\delta_{ga} \neq \delta_{gk} will persist and even amplify, but the difference originated from cell types will vanish. This problem is equivalent to estimating the potential outcome of gene expressions in each batch kk, 𝔼[Ygj(k)]\mathbb{E}\!\left[Y_{gj}^{(k)}\right].

A causal inference approach to identify batch effects

To dissect batch-specific effect in a causal inference (potential outcome) framework, we assume our confounding variables QQ are well-distributed across different batches:

  • Overlap: 0<p(Xkj=1|Q)<10 < p(X_{kj}=1|Q) < 1 for all kk.

Moreover, we assume these covariates are sufficient enough to induce conditional dependence between potential (imputed) gene expression and batch assignment mechanisms:

  • Strong ignorability: (Y(k),Y(kβ€²))βŠ₯βŠ₯X|Q(Y(k), Y(k')) \perp\!\!\!\perp X | Q for all k,kβ€²k,k' pairs.

Estimation of the batch effects by matching

Suppose we can counterfactually estimate gene expressions of a certain cell jj if the cell was measured in different batches other than the observed batch kk.

Zgj=βˆ‘i(1βˆ’Xik)wjiYgiβˆ‘i(1βˆ’Xik)wjiZ_{gj} = \frac{ \sum_{i} (1 - X_{ik}) w_{ji} Y_{gi} }{ \sum_{i} (1 - X_{ik}) w_{ji} }

Like many other batch correction methods invented for single-cell RNA-seq analysis, we will assume ZgjZ_{gj} reliably contain biologically-relevant cell state information while excluding the batch-specific effects to which the cell jj belong.

Observed log-likelihood: ∏jp(Ygj|ΞΌgs,Ξ΄gk,Xjk)=∏jPoisson(Ygj|ΞΌgsβˆ‘kΞ΄gkXjk)\prod_{j} p(Y_{gj}|\mu_{gs},\delta_{gk},X_{jk}) =\prod_{j} \operatorname{Poisson}(Y_{gj}|\mu_{gs} \sum_{k} \delta_{gk} X_{jk})

Counterfactual log-likelihood: ∏jp(Zgj|μgs,γgs)=∏jPoisson(Zgj|μgsγgs)\prod_{j} p(Z_{gj}|\mu_{gs}, \gamma_{gs}) = \prod_{j} \operatorname{Poisson}(Z_{gj}|\mu_{gs} \gamma_{gs})

Local update: Maximize batch ss-specific parameters

𝔼[ΞΌgs]β‰ˆβˆ‘jβˆˆπ’žsYgj+βˆ‘jβˆˆπ’žsZgjβˆ‘kΞ΄gknsk+nsΞ³gs\mathbb{E}\!\left[\mu_{gs}\right] \approx \frac{ \sum_{j \in \mathcal{C}_{s}} Y_{gj} + \sum_{j \in \mathcal{C}_{s}} Z_{gj} } {\sum_{k} \delta_{gk} n_{sk} + n_{s} \gamma_{gs}}

Letting psk=nsk/nsp_{sk} = n_{sk} / n_{s}, ΞΌgs←Yβ€Ύgs+Zβ€Ύgsβˆ‘kΞ΄gkpsk+Ξ³gs\mu_{gs} \gets \frac{ \bar{Y}_{gs} + \bar{Z}_{gs}} {\sum_{k} \delta_{gk} p_{sk} + \gamma_{gs}}

If δgk→0\delta_{gk} \to 0 and psk=1p_{sk}=1, meaning that this sample ss is just sampled from the batch kk only, μgs→Y‾gs+Z‾gs\mu_{gs} \to \bar{Y}_{gs} + \bar{Z}_{gs} and Y‾gs→Y‾gsk=0\bar{Y}_{gs} \to \bar{Y}_{gsk} = 0. Therefore, μgs→Z‾gs\mu_{gs} \to \bar{Z}_{gs}.

Global update

𝔼[Ξ΄gk]β‰ˆβˆ‘sβˆ‘jβˆˆπ’žsXkjYgjβˆ‘sΞΌgsβˆ‘jβˆˆπ’žsXkj\mathbb{E}\!\left[\delta_{gk}\right] \approx \frac{\sum_{s} \sum_{j \in \mathcal{C}_{s}} X_{kj} Y_{gj}}{\sum_{s} \mu_{gs} \sum_{j \in \mathcal{C}_{s}} X_{kj}}

Ξ΄gkβ†βˆ‘sYβ€Ύgsknskβˆ‘sΞΌgsnsk\delta_{gk} \gets \frac{\sum_{s} \bar{Y}_{gsk} n_{sk}} {\sum_{s} \mu_{gs} n_{sk}}

If Y‾gsk→μgs\bar{Y}_{gsk} \to \mu_{gs} for all ss, δgk→1\delta_{gk} \to 1. If Y‾gsk<μgs\bar{Y}_{gsk} < \mu_{gs} in all ss, δgk<1\delta_{gk} < 1. If Y‾gsk→0\bar{Y}_{gsk} \to 0 for all ss, δgk→0\delta_{gk} \to 0.


  1. Initialize batch effect Ξ΄gk←1\delta_{gk} \gets 1 for each gene gg and batch kk

  2. Initialize Ξ³gs←1\gamma_{gs} \gets 1 for each sample ss

  3. Static global stat: Sgk←0S_{gk} \gets 0

  4. For each pseudo-bulk sample ss with cells π’žs\mathcal{C}_{s},

    • nskβ†βˆ‘jβˆˆπ’žsXkjn_{sk} \gets \sum_{j \in \mathcal{C}_{s}} X_{kj}, nsβ†βˆ‘knskn_{s} \gets \sum_{k} n_{sk}, psk←nsk/nsp_{sk} \gets n_{sk}/n_{s}

    • Yβ€Ύgsβ†βˆ‘jβˆˆπ’žsYgj/ns\bar{Y}_{gs} \gets \sum_{j \in \mathcal{C}_{s}} Y_{gj} / n_{s}

    • Yβ€Ύgskβ†βˆ‘jβˆˆπ’žsYgjXkj/ns\bar{Y}_{gsk} \gets \sum_{j \in \mathcal{C}_{s}} Y_{gj} X_{kj} / n_{s}

    • Zβ€Ύgsβ†βˆ‘jβˆˆπ’žsZgj/ns\bar{Z}_{gs} \gets \sum_{j \in \mathcal{C}_{s}} Z_{gj} / n_{s} after matching and imputation

    • Sgk←Sgk+Yβ€ΎgsknskS_{gk} \gets S_{gk} + \bar{Y}_{gsk} n_{sk}

  5. Iterative-updated global stat: Tgk←0T_{gk} \gets 0

  6. (Local step) For each PB sample ss:

    • Ξ΄β€Ύgsβ†βˆ‘kΞ΄gkpsk\bar{\delta}_{gs} \gets \sum_{k} \delta_{gk} p_{sk}

    • ΞΌgs←(Yβ€Ύgs+Zβ€Ύgs)/(Ξ³gs+Ξ΄β€Ύgs)\mu_{gs} \gets (\bar{Y}_{gs} + \bar{Z}_{gs}) / (\gamma_{gs} + \bar{\delta}_{gs})

    • Ξ³gs←(Yβ€Ύgs)/(ΞΌgs)\gamma_{gs} \gets (\bar{Y}_{gs})/(\mu_{gs})

    • For each kk: Tgk←Tgk+ΞΌgsnskT_{gk} \gets T_{gk} + \mu_{gs} n_{sk}

  7. (Global step) For each batch kk:

    • Ξ΄gk←Sgk/Tgk\delta_{gk} \gets S_{gk} / T_{gk}
  8. Repeat the previous three steps (5-7) until convergence

A toy example

## 1. project
K <- 5
R <- matrix(rnorm(m * K), K, m)
Q.raw <- R %*% yy # K x n

Before we adjust batch membership in the random projection matrix:

cor(t(Q.raw), X)
       [,1]       [,2]

[1,] 0.7617260 -0.7617260 [2,] 0.8283630 -0.8283630 [3,] 0.8099248 -0.8099248 [4,] -0.7250199 0.7250199 [5,] 0.6651915 -0.6651915

## 2. regress out
##  X theta = X inv(X'X) X' Y
##          = U D V' V inv(D^2) V' (U D V')' Y
##          = U inv(D) V' V D U' Y
##          = U U' Y

x.svd <- svd(X)
U <- x.svd$u
U.t <- t(x.svd$u)

Q.t <- t(Q.raw)
Q.t <- Q.t - U %*% U.t %*% Q.t
Q <- t(Q.t)

After we adjust the batch effects:

cor(Q.t, X)
          [,1]          [,2]

[1,] -2.276319e-16 2.276319e-16 [2,] -5.524025e-16 5.524025e-16 [3,] -2.934286e-16 2.934286e-16 [4,] 3.075415e-16 -3.075415e-16 [5,] -3.636223e-16 3.636223e-16

q.svd <- svd(Q)

## 3. sorting
B <- (sign(q.svd$v) + 1)/2
ss <- apply(sweep(B, 2, 2^(seq(0,K-1)), `*`), 1, sum) + 1
feat.dn <- apply(Q, 2, function(x) x / sqrt(sum(x^2)))
knn <- 3
d <- nrow(feat.dn)

## a. construct dictionary for each batch
dict.list <- lapply(sort(unique(batch)),
                    function(b) { new(AnnoyAngular, d) })

for(j in 1:length(batch)){
    b <- batch[j]
    dict.list[[b]]$addItem(j, feat.dn[,j])

for(dd in dict.list){

## b. a simplified routine to retrieve and estimate counterfactual y
.counterfactual <- function(j){
    v <- feat.dn[,j]

    nn <- c()
    dd <- c()

    for(k in 1:nb){
        if(k == batch[j]) next
        .nn <- dict.list[[k]]$getNNsByVector(v, knn)
        .dd <- apply(feat.dn[, .nn], 2, function(u) sum((u - v)^2))
        nn <- c(nn, .nn)
        dd <- c(dd, .dd)

    w <- exp(-(dd - max(dd)))
    w <- w/sum(w)

    yy[, nn, drop = FALSE] %*% matrix(w, ncol=1)
ngene <- nrow(yy)
nbatch <- ncol(X)
nsample <- max(ss)

.delta.db <- matrix(1, ngene, nbatch)       # gene x batch effects
.delta.num.db <- matrix(0, ngene, nbatch)    # gene x batch numerators
.delta.denom.db <- matrix(0, ngene, nbatch)  # gene x batch denominators <- matrix(0, nbatch, nsample)      # batch x sample probabilities <- matrix(0, nbatch, nsample)      # batch x sample freq
.ybar.ds <- matrix(0, ngene, nsample)       # gene x sample observed average
.zbar.ds <- matrix(0, ngene, nsample)       # gene x sample imputed average
.mu.ds <- matrix(1, ngene, nsample)         # gene x sample adjusted average

## Precalculate some statistics
for(s in 1:nsample){
    if(sum(ss == s) < 1) next

    .yy <- yy[, ss == s, drop = FALSE]
    .zz <-, lapply(which(ss == s), .counterfactual))

    .ybar.ds[,s] <- apply(.yy, 1, mean)
    .zbar.ds[,s] <- apply(.zz, 1, mean)[,s] <- colMeans(X[ss == s, ])[,s] <- colSums(X[ss == s, ])

    .y.dsb <- yy[, ss == s, drop = FALSE] %*% X[ss == s, , drop = FALSE]
    .delta.num.db <- .delta.num.db + .y.dsb

.gamma.ds <- matrix(1, ngene, nsample)

for(iter in 1:100){
    .mu.ds <- (.ybar.ds + .zbar.ds) / (.delta.db %*% + .gamma.ds + 1e-8)
    .gamma.ds <- .zbar.ds / (.mu.ds + 1e-8)
    .delta.db <- .delta.num.db / (.mu.ds %*% t( + 1e-8)

Can we recover the original batch effects?

plot(.delta.db[,1], W.true[,1], pch=19, xlab="estimated delta", ylab="true delta effect", main="batch1")
plot(.delta.db[,2], W.true[,2], pch=19, xlab="estimated delta", ylab="true delta effect", main="batch2")

Are they independent of the cell type effects?

y.true <- sweep(lambda.true %*% X, 2, colSums(X), `/`)
plot(.delta.db[,1], y.true[,1], pch=19, xlab="estimated delta", ylab="true y mean", main="batch1")
plot(.delta.db[,2], y.true[,2], pch=19, xlab="estimated delta", ylab="true y mean", main="batch2")

While adjusting the estimated batch effects, can we recover the unbiased cell type effects? The following is before adjustment:

ybar <- sweep(yy %*% X, 2, colSums(X), `/`)
plot(ybar[,1], y.true[,1], pch=19, xlab="sample mean", ylab="true y mean", log="x", main="batch1")
plot(ybar[,2], y.true[,2], pch=19, xlab="sample mean", ylab="true y mean", log="x", main="batch2")

Here, we adjusted the batch effects:

ybar.adj <- sweep((yy / .delta.db[, batch]) %*% X, 2, colSums(X), `/`)
plot(ybar.adj[,1], y.true[,1], pch=19, xlab="adjusted sample mean", ylab="true y mean", log="x", main="batch1")
plot(ybar.adj[,2], y.true[,2], pch=19, xlab="adjusted sample mean", ylab="true y mean", log="x", main="batch2")

.tsne <- Rtsne::Rtsne(log(1 + t(yy)), num_threads=4)$Y
plot(.tsne[,1], .tsne[,2], col=batch, pch=19, cex=.5, main="before batch adj",
     xlab = "tsne1", ylab = "tsne2")
legend("topleft", c("batch #1", "batch #2"), col=1:2, pch=19)
.tsne <- Rtsne::Rtsne(log(1 + t(yy/.delta.db[,batch])), num_threads=4)$Y
plot(.tsne[,1], .tsne[,2], col=batch, pch=19, cex=.5, main="after batch adj",
     xlab = "tsne1", ylab = "tsne2")