PromptAlign matches the distribution statistics ΞΌπ(Ο ; π©) , ΟΒ²π(Ο ; π©) , obtained from multiple augmented views of a single test sample, with the source data distribution statistics ΞΌΜπ, ΟΜΒ²π. This effectively brings the test sample closer to the distribution of the source data, where the domain shift is denoted by β²1 β β²2. Ο denotes the distribution of the test sample, π© represents the prompts that are updated, and π refers to the vision-backbone layers. (b) Owing to the distribution matching via prompts, PromptAlign surpasses the existing state-of-the-art prompt learning approaches on 8 out of 10 datasets in cross-dataset generalization benchmarks.
The promising zero-shot generalization of vision-language models such as CLIP has led to their adoption using prompt learning for numerous downstream tasks. Previous works have shown test-time prompt tuning using entropy minimization to adapt text prompts for unseen domains. While effective, this overlooks the key cause for performance degradation to unseen domains β distribution shift. In this work, we explicitly handle this problem by aligning the out-of-distribution (OOD) test sample statistics to those of the source data using prompt tuning. We use a single test sample to adapt multi-modal prompts at test time by minimizing the feature distribution shift to bridge the gap in the test domain. Evaluating against the domain generalization benchmark, our method improves zero-shot top-1 accuracy beyond existing prompt-learning techniques, with a 3.08% improvement over the baseline MaPLe. In cross-dataset generalization with unseen categories across 10 datasets, our method improves by 1.82% compared to the existing state-of-the-art. Our source code and models will be publicly released.
PromptAlign explicitly aligns the token disctribution statistics for each test sample with that of the source data statistics. The source data statistics are computed using ImageNet as a proxy dataset, given by:
\begin{align} {\hat{\mu}}_{l} = {\mu}_{l}(\mathcal{D}, {\theta}_v) \quad \text{and} \quad {\hat{\sigma}^2}_{l} = {\sigma}^2_{l}(\mathcal{D}, {\theta}_v) \quad \label{eq:source-stats} \end{align}
At test time, multiple augmented views of the sample are passed through the CLIP model and the token distribution statistics -- mean and variance -- are computed as in Eq.\ref{eq:test-stats} and Eq.\ref{eq:test-stats2}, where \(\bigg({{X~}^p}_{l, \mathrm{x}}\) is the prompt token embeddings at layer π.
\begin{align} {\mu}_{l}(\mathcal{T} ; {p}) = \frac{1}{N_k} \sum_{\mathrm{x} \in \mathcal{H}(X)} {{X ~}^p}_{l, \mathrm{x}} \label{eq:test-stats} \end{align}
\begin{align} {\sigma^2}_{l}(\mathcal{T} ; {p}) = \frac{1}{N_k} \sum_{\mathrm{x} \in \mathcal{H}(X)} \bigg({{X ~}^p}_{l, \mathrm{x}} - {\mu}_{l}(\mathcal{T} ; {p})\bigg)^2 , \label{eq:test-stats2} \end{align}
The distribution alignment loss is computed between the offline computed source data statistics and the test sample statistics across the transformer layers for all tokens as in Eq.\ref{eq:align-loss}. The resulting alignment loss from the distribution shift is combined with the entropy loss to update the multi-modal prompts. For each sample, a single update of the prompts are done, and it is reset to the original prompts for the next sample.
\begin{align} \mathcal{L}_{\text{align}} = \frac{1}{L}\sum_{l=1}^{L} \bigg( \| {\mu}_{l}(\mathcal{T} ; {p}) - {\hat{\mu}}_{l} \|_1 + \| {\sigma^2}_{l}(\mathcal{T} ; {p}) - {\hat{\sigma}^2}_{l}\|_1 \bigg). \label{eq:align-loss} \end{align}
We evaluate our distribution aloignment strategy PromptAlign on domnain generalization settings and cross-dataset settings. For domain generalization, we evaluate on the ImageNet variants and also on the recently released PUG dataset Imagenet variants.
Imagenet V2 | Imagenet Sketch | Imagenet A | Imagenet R | OOD Avg. | |
---|---|---|---|---|---|
MaPLe | 64.07 | 49.15 | 50.90 | 76.98 | 60.28 |
MaPLe+TPT | 64.87 | 48.16 | 58.08 | 78.12 | 62.31 |
PromptAlign | 65.29 | 50.23 | 59.37 | 79.33 | 63.55 |
Caltech | Pets | Cars | Flowers | Food101 | Aircraft | SUN397 | DTD | EuroSAT | UCF101 | Average | |
---|---|---|---|---|---|---|---|---|---|---|---|
CLIP | 93.35 | 88.25 | 65.48 | 67.44 | 83.65 | 23.67 | 62.59 | 44.27 | 42.01 | 65.13 | 63.58 |
CLIP+TPT | 94.16 | 87.79 | 66.87 | 68.98 | 84.67 | 24.78 | 65.50 | 47.75 | 42.44 | 68.04 | 65.10 |
CoOp | 93.70 | 89.14 | 64.51 | 68.71 | 85.30 | 18.47 | 64.15 | 41.92 | 46.39 | 66.55 | 63.88 |
CoCoOp | 93.79 | 90.46 | 64.90 | 70.85 | 83.97 | 22.29 | 66.89 | 45.45 | 39.23 | 68.44 | 64.63 |
MaPLe | 93.53 | 90.49 | 65.57 | 72.23 | 86.20 | 24.74 | 67.01 | 46.49 | 48.06 | 68.69 | 66.30 |
MaPLe+TPT | 93.59 | 90.72 | 66.50 | 72.37 | 86.64 | 24.70 | 67.54 | 45.87 | 47.80 | 69.19 | 66.50 |
PromptAlign | 94.01 | 90.76 | 68.50 | 72.39 | 86.65 | 24.80 | 67.54 | 47.24 | 47.86 | 69.47 | 66.92 |
@InProceedings{samadh2023align,
author = {Samadh, Jameel Hassan Abdul and Gani, Hanan and Hussein, Noor Hazim and Khattak, Muhammad Uzair and Naseer, Muzammal and Khan, Fahad and Khan, Salman},
title = {Align Your Prompts: Test-Time Prompting with Distribution Alignment for Zero-Shot Generalization},
booktitle = {Thirty-seventh Conference on Neural Information Processing Systems},
year = {2023}
}