Align Your Prompts: Test-Time Prompting with Distribution Alignment for Zero-Shot Generalization

1Mohamed Bin Zayed University of Artificial Intelligence
  2LinkΓΆping University   3Australian National University
NeurIPS 2023

"Prompt Align"

PromptAlign Concept.

PromptAlign matches the distribution statistics μ𝑙(Ο„ ; 𝐩) , σ²𝑙(Ο„ ; 𝐩) , obtained from multiple augmented views of a single test sample, with the source data distribution statistics μ̂𝑙, σ̂²𝑙. This effectively brings the test sample closer to the distribution of the source data, where the domain shift is denoted by β–²1 β†’ β–²2. Ο„ denotes the distribution of the test sample, 𝐩 represents the prompts that are updated, and 𝑙 refers to the vision-backbone layers. (b) Owing to the distribution matching via prompts, PromptAlign surpasses the existing state-of-the-art prompt learning approaches on 8 out of 10 datasets in cross-dataset generalization benchmarks.

Abstract

The promising zero-shot generalization of vision-language models such as CLIP has led to their adoption using prompt learning for numerous downstream tasks. Previous works have shown test-time prompt tuning using entropy minimization to adapt text prompts for unseen domains. While effective, this overlooks the key cause for performance degradation to unseen domains – distribution shift. In this work, we explicitly handle this problem by aligning the out-of-distribution (OOD) test sample statistics to those of the source data using prompt tuning. We use a single test sample to adapt multi-modal prompts at test time by minimizing the feature distribution shift to bridge the gap in the test domain. Evaluating against the domain generalization benchmark, our method improves zero-shot top-1 accuracy beyond existing prompt-learning techniques, with a 3.08% improvement over the baseline MaPLe. In cross-dataset generalization with unseen categories across 10 datasets, our method improves by 1.82% compared to the existing state-of-the-art. Our source code and models will be publicly released.

PromptAlign Design

PromptAlign design
Architecture and design of PromptAlign.

PromptAlign explicitly aligns the token disctribution statistics for each test sample with that of the source data statistics. The source data statistics are computed using ImageNet as a proxy dataset, given by:

\begin{align} {\hat{\mu}}_{l} = {\mu}_{l}(\mathcal{D}, {\theta}_v) \quad \text{and} \quad {\hat{\sigma}^2}_{l} = {\sigma}^2_{l}(\mathcal{D}, {\theta}_v) \quad \label{eq:source-stats} \end{align}

At test time, multiple augmented views of the sample are passed through the CLIP model and the token distribution statistics -- mean and variance -- are computed as in Eq.\ref{eq:test-stats} and Eq.\ref{eq:test-stats2}, where \(\bigg({{X~}^p}_{l, \mathrm{x}}\) is the prompt token embeddings at layer 𝑙.

\begin{align} {\mu}_{l}(\mathcal{T} ; {p}) = \frac{1}{N_k} \sum_{\mathrm{x} \in \mathcal{H}(X)} {{X ~}^p}_{l, \mathrm{x}} \label{eq:test-stats} \end{align}

\begin{align} {\sigma^2}_{l}(\mathcal{T} ; {p}) = \frac{1}{N_k} \sum_{\mathrm{x} \in \mathcal{H}(X)} \bigg({{X ~}^p}_{l, \mathrm{x}} - {\mu}_{l}(\mathcal{T} ; {p})\bigg)^2 , \label{eq:test-stats2} \end{align}

The distribution alignment loss is computed between the offline computed source data statistics and the test sample statistics across the transformer layers for all tokens as in Eq.\ref{eq:align-loss}. The resulting alignment loss from the distribution shift is combined with the entropy loss to update the multi-modal prompts. For each sample, a single update of the prompts are done, and it is reset to the original prompts for the next sample.

\begin{align} \mathcal{L}_{\text{align}} = \frac{1}{L}\sum_{l=1}^{L} \bigg( \| {\mu}_{l}(\mathcal{T} ; {p}) - {\hat{\mu}}_{l} \|_1 + \| {\sigma^2}_{l}(\mathcal{T} ; {p}) - {\hat{\sigma}^2}_{l}\|_1 \bigg). \label{eq:align-loss} \end{align}

Effect of Distibution Alignment

We evaluate our distribution aloignment strategy PromptAlign on domnain generalization settings and cross-dataset settings. For domain generalization, we evaluate on the ImageNet variants and also on the recently released PUG dataset Imagenet variants.

Effect of token distribution alignment strategy for domain generalization.
Imagenet V2 Imagenet Sketch Imagenet A Imagenet R OOD Avg.
MaPLe 64.07 49.15 50.90 76.98 60.28
MaPLe+TPT 64.87 48.16 58.08 78.12 62.31
PromptAlign 65.29 50.23 59.37 79.33 63.55

Comparison of PromptAlign in cross-dataset evaluation.
Caltech Pets Cars Flowers Food101 Aircraft SUN397 DTD EuroSAT UCF101 Average
CLIP 93.35 88.25 65.48 67.44 83.65 23.67 62.59 44.27 42.01 65.13 63.58
CLIP+TPT 94.16 87.79 66.87 68.98 84.67 24.78 65.50 47.75 42.44 68.04 65.10
CoOp 93.70 89.14 64.51 68.71 85.30 18.47 64.15 41.92 46.39 66.55 63.88
CoCoOp 93.79 90.46 64.90 70.85 83.97 22.29 66.89 45.45 39.23 68.44 64.63
MaPLe 93.53 90.49 65.57 72.23 86.20 24.74 67.01 46.49 48.06 68.69 66.30
MaPLe+TPT 93.59 90.72 66.50 72.37 86.64 24.70 67.54 45.87 47.80 69.19 66.50
PromptAlign 94.01 90.76 68.50 72.39 86.65 24.80 67.54 47.24 47.86 69.47 66.92

BibTeX

@InProceedings{samadh2023align,
    author    = {Samadh, Jameel Hassan Abdul and Gani, Hanan and Hussein, Noor Hazim and Khattak, Muhammad Uzair and Naseer, Muzammal and Khan, Fahad and Khan, Salman},
    title     = {Align Your Prompts: Test-Time Prompting with Distribution Alignment for Zero-Shot Generalization},
    booktitle = {Thirty-seventh Conference on Neural Information Processing Systems},
    year      = {2023}
}