Invariant-Feature Subspace Recovery: A New Class of Provable Domain Generalization Algorithms.
CoRR(2023)
摘要
Domain generalization asks for models trained over a set of training
environments to generalize well in unseen test environments. Recently, a series
of algorithms such as Invariant Risk Minimization (IRM) have been proposed for
domain generalization. However, Rosenfeld et al. (2021) shows that in a simple
linear data model, even if non-convexity issues are ignored, IRM and its
extensions cannot generalize to unseen environments with less than $d_s+1$
training environments, where $d_s$ is the dimension of the spurious-feature
subspace. In this work, we propose Invariant-feature Subspace Recovery (ISR): a
new class of algorithms to achieve provable domain generalization across the
settings of classification and regression problems. First, in the binary
classification setup of Rosenfeld et al. (2021), we show that our first
algorithm, ISR-Mean, can identify the subspace spanned by invariant features
from the first-order moments of the class-conditional distributions, and
achieve provable domain generalization with $d_s+1$ training environments. Our
second algorithm, ISR-Cov, further reduces the required number of training
environments to $O(1)$ using the information of second-order moments. Notably,
unlike IRM, our algorithms bypass non-convexity issues and enjoy global
convergence guarantees. Next, we extend ISR-Mean to the more general setting of
multi-class classification and propose ISR-Multiclass, which leverages class
information and provably recovers the invariant-feature subspace with $\lceil
d_s/k\rceil+1$ training environments for $k$-class classification. Finally, for
regression problems, we propose ISR-Regression that can identify the
invariant-feature subspace with $d_s+1$ training environments. Empirically, we
demonstrate the superior performance of our ISRs on synthetic benchmarks.
Further, ISR can be used as post-processing methods for feature extractors such
as neural nets.
更多查看译文
关键词
generalization,provable domain
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要