TL;DR
Class imbalance ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์ ๋ฐ์ดํฐ์ ๊ฐ ํด๋์ค์ ์ ํจ ๋ฐ์ดํฐ ์๋ฅผ ์ ์ํ๊ณ ์ด๋ฅผ ํ์ฉํ re-weighting๊ธฐ๋ฐ Class Balance Loss ๊ธฐ๋ฒ ์ ์.
๋ฌด์จ ๋ฌธ์ ๋ฅผ ํ๊ณ ์๋?
๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต์ ์ฌ์ฉ๋๋ ์ผ๋ฐ์ ์ธ ๋ฐ์ดํฐ ์ (์๋ฅผ๋ค์ด CIFAR-10, 100, ImageNet ๋ฑ)์ด ํด๋์ค ๋ผ๋ฒจ ๋ถํฌ๊ฐ ๊ท ์ผํ ๊ฒ๊ณผ ๋ฌ๋ฆฌ, ์ค์ ์ํฉ์์๋ ๋ชจ๋ ํด๋์ค์ ๋ฐ์ดํฐ ์๊ฐ ๊ท ์ผํ๊ฒ ์์ง๋์ง ์๋, Long Tail ํ์์ด ๋ฐ์ํ๋ค. ์ฌ๊ธฐ์ Long Tail์ด๋ผ๊ณ ํจ์, ๊ฐ ํ์ต ๋ฐ์ดํฐ ์ ํด๋์ค ๋ณ ์ํ ์์ ๋ํ ๋ถํฌ๋ฅผ ๊ทธ๋ ธ์ ๋ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด, ์์์ ํด๋์ค์ ๋ํด์ ๋ฐ์ดํฐ ์ํ ์๊ฐ ๋ง์ ๋ฐ ๋ฐํด (Head) ๋ค์์ ํด๋์ค์์ ๊ธฐ๋์น ์ดํ์ ์ํ ์๋ฅผ ๊ฐ๋ (Long Tail) ํ์์ ๋งํ๋ค. ๋ง ๊ทธ๋๋ก ํด๋์ค ์ํ ์์ ๋ถํฌ๋ฅผ ๊ทธ๋ ธ์๋ ๊ธด ๊ผฌ๋ฆฌ๋ฅผ ๊ฐ์ง๋ค๋ ๋ป์ด๋ค.
์์ฝํ์๋ฉด, ๋ณธ ๋ ผ๋ฌธ์์๋ ํ์ต ๋ฐ์ดํฐ ์ ๋ด์ ํด๋์ค ๋ถ๊ท ํ ๋ฌธ์ (Class Imbalance Problem) ๋ฅผ ๋ค๋ฃจ๊ณ ์๋ค.
๊ธฐ์กด ์ฐ๊ตฌ๋ค์ ์ด ๋ฌธ์ ๋ฅผ ์ด๋ป๊ฒ ๋ค๋ฃจ์๋?
๊ธฐ์กด ์ฐ๊ตฌ๋ค์ ๋ฌธ์ ํด๊ฒฐ ๋ฐฉ์์ ๋ฐ๋ผ ํฌ๊ฒ Re-sampling๋ฐฉ์๊ณผ Cost-sensitive re-weighting ๋ฐฉ์์ผ๋ก ๋๋๋ค.
- Re-Sampling
Re-sampling๋ฐฉ์์ ๋ง ๊ทธ๋๋ก ์ค์ ์๋ ๋ฐ์ดํฐ ์๋ฅผ ์ค๋ณตํด์ samplingํ๋ ๋ฐฉ์์ด๋ค. ์๋ฅผ ๋ค์ด ๋ฐ์ดํฐ ์๊ฐ ์ ์ ํด๋์ค์ ๋ํด์ ์ค๋ณตํด์ ๋ ์ํ๋ง์ ํ๊ฑฐ๋ (over-sampling) ์ํ ์๊ฐ ๋ง์ ํด๋์ค์ ๋ํด์ ๋ ์ํ๋ง ํ๊ฑฐ๋ (under-sampling) ํ๋ ํํ์ด๋ค. ์ด ๊ฒฝ์ฐ ์ํ ์ค๋ณต์ผ๋ก ์ธํ ๊ณผ์ ํฉ (overfitting)์ด ์ผ์ด๋ ์ ์๊ณ , over-sampling์ ํ๋ ๊ฒฝ์ฐ์ ์๋์ ์ผ๋ก ํ์ต ์๊ฐ์ด ์ฆ๊ฐํ๋ ์ด์ ์ญ์ ๊ณ ๋ ค๋์ด์ผํ๋ค. ์ด๋ฌํ ํ๊ณ์ ์ผ๋ก ์ธํด Re-sampling ๋ฐฉ์๋ณด๋ค๋ Re-weighting๋ฐฉ์์ด ์ข ๋ ์ ํธ๋๋ค.
- Re-weighting
Re-weighting ๋ฐฉ์์ ํด๋์ค ๋ณ ์ํ ์์ ๋ฐ๋ผ์ ํ์ต ์ ์์ค ํจ์์ ๊ฐ์ค์น๋ฅผ ์ฃผ๋ ๋ฐฉ์์ ๋งํ๋ค. ์๋ฅผ ๋ค์ด, ์๋์ ์ผ๋ก ์์์ ํด๋์ค์ ๋ํ loss ๊ฐ์ ๋ ๋์ ๊ฐ์ค์น๋ฅผ ์ฃผ๋ ํํ์ด๋ค. ์๋์ ์ผ๋ก ๊ฐ๋จํ๋ค๋ ์ฅ์ ์ผ๋ก ์ธํด ์ด ๋ฐฉ์์ด ๋ง์ด ์ฌ์ฉ๋์์ผ๋, large-scale ๋ฐ์ดํฐ ์ ์ ๋ํด์๋ ์ด๋ฌํ ๋จ์ํ ๋ฐฉ์์ด ์ ๋์ํ์ง ์๋ ๋ค๋ ์ฌ์ค์ด ํ์ธ๋์๋ค. ๋์ "smoothed weight" ์์ค ํจ์์ ๋ํ ์ฐ๊ตฌ๊ฐ ์งํ๋์๋ค. "smoothed" ๋ฒ์ ์ ๊ฒฝ์ฐ์๋ ํด๋์ค ๋ถํฌ์ ์ ๊ณฑ๊ทผ์ ๋ฐ๋น๋กํ๋๋ก weight๋ฅผ ์ค๋ค.
์ด ๋ ผ๋ฌธ์์๋ ์ด ๋ฌธ์ ๋ฅผ ์ด๋ป๊ฒ ํ๊ณ ์๋?
Conceptually,
๋ณธ ๋ ผ๋ฌธ์์๋ Re-weighting๊ธฐ๋ฒ์์ ์ฐฉ์ํ ๋ฐฉ๋ฒ์ ์ ์ํ๋ค. ์๋์ 2๊ฐ์ ํด๋์ค๋ฅผ ์๋ก ๋ค์ด ์ค๋ช ํด๋ณด๋ฉด, ์ผ์ชฝ์ Head์ ํด๋นํ๋ ๋ฐ์ดํฐ ์ํ์ด๊ณ ์ค๋ฅธ์ชฝ์ Long Tail์ ํด๋นํ๋ ๋ฐ์ดํฐ ์ํ์ด๋ผ๊ณ ํ์. ๊ธฐ์กด์ re-weighting ๋ฐฉ๋ฒ์ ์ผ์ชฝ ํด๋์ค ์ํ์ ์์ ์ค๋ฅธ์ชฝ ํด๋์ค ์ํ์ ์๋ฅผ ๊ฐ์ง๊ณ weighting์ ์ํํ์ง๋ง, ์ด ๊ฒฝ์ฐ ๊ฒ์์ ์ค์ ์ผ๋ก ๊ทธ๋ ค์ง classifier ์ ์ด ๋นจ๊ฐ์์ผ๋ก ์น์ฐ์น๊ฒ๋๋ค. ํ์ง๋ง ํด๋์ค ๋ด์ ๋ชจ๋ ์ํ๋ค์ด ์ ํจํ ์๋ฅผ ๊ตฌ์ฑํ๋ ๊ฒ์ ์๋๋ผ๋ ์ ์ ์๊ธฐํด์ผํ๋ค. ์๋ฅผ ๋ค์ด, ์ํ๋ค ๊ฐ์ ์ค๋ณต๋๋ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ ๋ฐ์ดํฐ๊ฐ ์์ ์๋ ์๊ธฐ ๋๋ฌธ์ ๋จ์ํ '์ํ์ ์'๋ฅผ loss์ weight๋ก ์ ํ๋ ๊ฒ์ ๋ฌธ์ ๊ฐ ๋ ์ ์๋ค. ๋ฐ๋ผ์ ๋ณธ ๋ ผ๋ฌธ์์๋ ๊ฐ ํด๋์ค ๋ณ '์ํ์ ์ ํจ ์ซ์ (effective number)'๋ฅผ ๊ตฌํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๊ณ ์์ผ๋ฉฐ, ์ด๋ ๊ฒ ๊ตฌํด์ง ์ ํจ ๊ฐ์์ ๋ฐ๋น๋กํ๋๋ก loss๋ฅผ re-weighting์ ํ๋ ๊ฒฝ์ฐ ๊ฒ์์ ์ ์ผ๋ก ์น์ฐ์น classifier๋ฅผ ์ฐ๋ฆฌ๊ฐ ์ํ๋ ํ๋์ ์ ๊ณผ ๊ฐ์ด ์กฐ์ ํ ์ ์๋ค.
์ ํจ ์ํ ์ (Effective Number of Samples)
๋ณธ ๋ ผ๋ฌธ์์๋ ๊ฐ ๋ฐ์ดํฐ ์ ์ ํด๋์ค ๋ณ ์ ํจ ์ํ ์๋ฅผ ์ด๋ป๊ฒ ๊ตฌํ๋์ง๋ฅผ ์ ์ํ๊ณ ์๋ค. ์ด๋ฅผ ์ํด ๋ ผ๋ฌธ์์๋ ์ ํจ ๊ฐ์ (Effective Number) ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํ๊ณ ์๋ค.
Definition 1 (Effective Number).
The effective number of samples is the expecterd volume of samples.
์ํ ํฌ๊ธฐ์ ๊ธฐ๋๊ฐ์ ์ํ์ ์ ํจ ๊ฐ์๋ก ์ ์ํ๊ณ ์๋ค.
๊ธฐ๋๊ฐ์ ์ ์ํ๊ธฐ ์ํด์ ๋ณธ ๋ ผ๋ฌธ์์๋ ์๋ก ์ํ๋งํ ๋ฐ์ดํฐ๊ฐ ์ง๊ธ๊น์ง ๊ด์ฐฐ๋ ์ํ๊ณผ ๊ฒน์น ํ๋ฅ ์ $p$, ๊ฒน์น์ง ์์ ํ๋ฅ ์ $1-p$๋ก ์์ ํ๋ค (์ด ๋, ๋ฌธ์ ๋ฅผ ๋จ์ํํ๊ธฐ ์ํด์ ์ผ๋ถ๋ง ๊ฒน์น๋ ๊ฒฝ์ฐ์ ๋ํด์๋ ๊ณ ๋ คํ๊ณ ์์ง ์์ต๋๋ค) ์ด๋ฅผ ๊ทธ๋ฆผ์ผ๋ก ๊ทธ๋ ค๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
๋ถ์ฐ ์ค๋ช
๋ณธ ๋ ผ๋ฌธ์ ์์ด๋์ด๋ ํด๋์ค์ ๋ฐ์ดํฐ๋ฅผ ๋ ๋ง์ด ์ฌ์ฉํ ๋ marginal benefit์ด ์ค์ด๋๋ ์ ๋๋ฅผ ์ธก์ ํ๋ ๊ฒ์ด๋ค. ์๋ฅผ ๋ค์ด, ์ํ์ ์๊ฐ ์ฆ๊ฐํ๋ฉด ์๋ก ์ถ๊ฐ๋๋ ์ํ์ด ํ์ฌ ์กด์ฌํ๋ ์ํ๊ณผ ๊ฒน์น ํ๋ฅ ์ด ๋์์ง๋ค. (์ฌํํ ๋ฐ์ดํฐ ์ฆ๊ฐ - ํ์ , ํฌ๋กํ ๋ฑ - ์ผ๋ก ํ๋ณด๋ ๋ฐ์ดํฐ ์ญ์ ์ค๋ณต๋๋ ๋ฐ์ดํฐ๋ก ๋ณผ ์ ์์ต๋๋ค.)
์ด์ ์ํ์ ์ ํจ ๊ฐ์ (expected number or expected volume)๋ฅผ ์ํ์ ์ผ๋ก ํํํด๋ณผ ์ ์๋ค. ๋จผ์ ํด๋น ํด๋์ค์ ํผ์ณ ๊ณต๊ฐ์์ ๋ชจ๋ ๊ฐ๋ฅํ ๋ฐ์ดํฐ์ ์งํฉ์ $\mathcal{S}$๋ผ๊ณ ํ๊ณ , ์ด๋ฌํ ์งํฉ $S$์ ํฌ๊ธฐ๋ฅผ $N$์ด๋ผ๊ณ ๊ฐ์ ํ๋ค. ๊ทธ๋ฆฌ๊ณ ์ฐ๋ฆฌ๊ฐ ํํํ๊ณ ์ํ๋ ์ํ์ ์ ํจ ๊ฐ์๋ฅผ $E_n$์ผ๋ก ํ๊ธฐํ๊ฒ ์ต๋๋ค. ์ฌ๊ธฐ์ $n$์ ์ํ์ ์๋ฅผ ์๋ฏธํ๋ค.
๊ทธ๋ฌ๋ฉด ์ํ์ ์ ํจ๊ฐ์๋ ๋ค์๊ณผ ๊ฐ์ด ํ๊ธฐํ ์ ์๋ค.
Proposition 1 (Effective Number).
$E_n = (1-\beta^n)/(1-\beta)$, where $\beta = (N-1)/N$.
์ฆ๋ช ์ค๋ช
๊ท๋ฉ๋ฒ์ผ๋ก ์ด๋ฅผ ์ค๋ช ํด๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค. ๋จผ๋ $E_1 = 1$์ ๋ง์กฑํ๊ณ (overlap์ด ์๊ธฐ ๋๋ฌธ์) ์ด๋ $E_1 = (1-\beta^1)/(1-\beta)=1$์ ๋ง์กฑํฉ๋๋ค. ์ด์ ๊ณผ๊ฑฐ์ ์ํ๋ $n-1$๊ฐ์ ์์ ๋ก๋ถํฐ $n$๋ฒ์งธ ์ํ์ ์ํ๋งํ๋ค๊ณ ๊ฐ์ ํด๋ณด๊ฒ ์ต๋๋ค. ๊ทธ๋ฌ๋ฉด ์๋ก ์ํ๋ ๋ฐ์ดํฐ๊ฐ ์ด์ ์ ์ํ๋ค๊ณผ overlap๋ ํ๋ฅ ์ ์ ์ฒด ์งํฉ$S$์ ๊ฐ์ $N$์ค $E_{n-1}$์ผ ํ๋ฅ , ์ฆ $p=E_{n-1}/N$์ ๋ง์กฑํฉ๋๋ค. ๋ฐ๋ผ์ $n$๊ฐ์ ์ํ์ ๋ํ ์ ํจ ๊ฐ์์ ๊ธฐ๋๊ฐ์ $E_n = p \cdot E_{n-1} + (1-p)(E_{n-1} + 1) = 1 + \frac{N-1}{N}E_{n-1}$์ ๋ง์กฑํฉ๋๋ค. ๊ท๋ฉ๋ฒ์ ์ํด์ $E_{n-1} = (1-\beta^{n-1})/(1-\beta)$๋ฅผ ๋ง์กฑํ๋ค๊ณ ๊ฐ์ ํ์ ๋, $E_n = 1+\beta \frac{1-\beta^{n-1}}{1-\beta} = \frac{1-\beta+\beta-\beta^n}{1-\beta} = \frac{1-\beta^n}{1-\beta}$ ๋ฅผ ๋ง์กฑํจ์ ํ์ธํ ์ ์์ต๋๋ค.
Effective Number์ ๋ํ ์ถ๊ฐ์ ์ธ ์ดํด
๊ทธ๋ฆฌ๊ณ ์์ proposition์ผ๋ก๋ถํฐ ์ฐ๋ฆฌ๋ ์ํ์ ์ ํจ ๊ฐ์๊ฐ $n$์ ๋ํ ์ง์ํจ์์์ ํ์ธํ ์ ์์ต๋๋ค. (์ฐธ๊ณ ๋ก, $\beta$๋ 0์์ 1 ์ฌ์ด $[0,1)$ ํ๋ฅ ์ ์๋ฏธํฉ๋๋ค) ๊ทธ๋ฆฌ๊ณ ์ฐธ๊ณ ๋ก $\beta$๋ $n$์ด ์ฆ๊ฐํจ์๋ฐ๋ผ $E_n$์ด ์ผ๋ง๋ ๋นจ๋ฆฌ ์ฆ๊ฐํ๋์ง๋ฅผ ๊ฒฐ์ ํ๋ค๊ณ ๋ณผ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ฉด class๋ฅผ ํํํ ์ ์๋ ์งํฉ $S$ ์ ์ ์ฒด ํฌ๊ธฐ $N$์ ์ด๋ป๊ฒ ๊ตฌํ ์ ์์๊น์? $E_n = (1-\beta^n)/(1-\beta)=\sum_{j=1}^{n} \beta^{j-1}$๋ก๋ถํฐ $N$์ ๋ค์์ ๊ทนํ์ผ๋ก ๊ตฌํ ์ ์์ต๋๋ค. $N = \lim_{n\rightarrow \infty} \sum_{j=1}^{n} \beta^{j-1} = 1/(1-\beta)$. ์ด๋ก๋ถํฐ $\beta = 0$์ด๋ฉด $E_n=1$ ์ ๋ง์กฑํ๊ณ $\beta$๊ฐ 1์ ๊ฐ๊น์์ง์๋ก $E_n \rightarrow n$๋ฅผ ๋ง์กฑํจ์ ์ ์ ์์ต๋๋ค.
Class-Balanced Loss
์์ ์ค๋ช ํ๋ฏ์ด, ๋ณธ ๋ ผ๋ฌธ์ ๊ธฐ์กด์ ์ฐ๊ตฌ๋ค ์ค loss์ class๋ณ re-weight๋ฅผ ์ฃผ๋ ๋ฐฉ๋ฒ๋ก ๋ค์ ํ๊ณ์ ์์ด๋์ด๋ฅผ Effective Number๋ผ๋ ๊ฐ๋ ์ ๋์ ํจ์ผ๋ก์จ ๊ฐ์ ํ๊ณ ์๋ค. ์ด ๋ ผ๋ฌธ์ ์ฅ์ ์ loss function์ agnostic ํ๋ค๋ ์ ์ด๋ค. ์ค์ ๋ก ๋ณธ ๋ ผ๋ฌธ์์๋ Cross-Entropy Loss (Softmax, Sigmoid)์ Focal Loss์ ๋ํ ์์ ๋ฅผ ํจ๊ป ์ ์ํ๊ณ ์๋ค.
์ ๋ ฅ $\boldsymbol{x}$์ ๋ผ๋ฒจ $y \in \{1, 2, \cdots, C$๊ฐ ์ฃผ์ด์ง๊ณ $C$๊ฐ์ ํด๋์ค๊ฐ ์๋ค๊ณ ๊ฐ์ ํ์ ๋, ๋ชจ๋ธ์ ์ถ์ class probability๋ฅผ $\boldsymbol{p} = [p_1, p_2, \cdots , p_C]^T$ ๋ก ์ฃผ์ด์ง๋ค๊ณ ๊ฐ์ ํ๋ค. ์ด ๋์ ์์ ํจ์๋ฅผ $\mathcal{L}(\boldsymbol{p}, y)$๋ฅผ ํ๊ธฐํ๋ค.
๊ทธ๋ฌ๋ฉด ํด๋์ค $i$์ ์ํ ์๋ฅผ $n_i$๋ผ๊ณ ํ์ ๋, ์์ Proposition์ ์ํด์ ํด๋น ํด๋์ค์ effective number๋ $E_{n_i} = (1-\beta_i^{n_i})(1-\beta_i)$๋ฅผ ๋ง์กฑํ๊ณ ์ด ๋ $\beta_i = (N_i - 1)/N_i$๋ฅผ ๋ง์กฑํ๋ค. ๊ทผ๋ฐ ๋ฌธ์ ๋ $N_i$๋ ๊ทธ ์ ์์ ๋ฐ๋ฅด๋ฉด ํด๋์ค์ ๊ฐ๋ฅํ ๋ชจ๋ ์ํ ์งํฉ $S_i$์ ํฌ๊ธฐ์ด๊ธฐ ๋๋ฌธ์ ๊ตฌํ ์ ์๋ค. ๋ฐ๋ผ์, ์ค์ ํ์ต์์๋ $N_i$๊ฐ ์ค์ง ๋ฐ์ดํฐ์ ์ ์์กด์ ์ด๊ณ , ๋ชจ๋ ํด๋์ค $i$์ $N_i$๊ฐ ๋ค์๊ณผ ๊ฐ์ด ๋์ผํ๋ค๊ณ ๊ฐ์ ํ๋ค $N_i = N, \beta_i = \beta = (N-1)/N$.
์ด๋ ๊ฒ ๊ตฌํ effective number๋ฅผ ์ฌ์ฉํ์ฌ ์์คํจ์๋ฅผ balancing ํ๊ธฐ ์ํด์ ๋ ผ๋ฌธ์์๋ "weighting factor $\alpha_i$"๋ฅผ ๋์ ํฉ๋๋ค. weigthing factor $\alpha_i$๋ class $i$์ ์ ํจ ์ํ ์์ ๋ฐ๋น๋กํ๋ ํญ์ด๋ค: $\alpha_i \propto 1/E_{n_i}$. (์์ Re-weighting ๊ธฐ์กด ์ฐ๊ตฌ ์ฐธ๊ณ ) ์ถ๊ฐ์ ์ผ๋ก, Weigting factor๋ฅผ ์ ์ฉํ์ ๋ ์ค์ผ์ผ์ ์กฐ์ ํด์ฃผ๊ธฐ ์ํด์ ์ ์ฒด ํด๋์ค์ ๋ํด์ weighting factor๋ค์ ์ ๊ทํํด์ค๋ค. ($\sum_{i=1}^{C} \alpha_i = C$)
์ด๋ฅผ ์ข ํฉํ class-balanced (CB) loss๋ ๋ค์๊ณผ ๊ฐ๋ค.
- class $i$์ ์ํ ์๋ฅผ $n_i$๋ผ๊ณ ํ์
- $\beta \in [0,1)$์ผ ๋, class $i$์ weighting factor๋ $(1-\beta)/(1-\beta^{n_i})$์ด๋ค.
- ์ด ๋์ CB loss๋ ๋ค์๊ณผ ๊ฐ๋ค: $CB(\boldsymbol{p}, y) = \frac{1}{E_{n_y}}\mathcal{L}(\boldsymbol{p},y) = \frac{1-\beta}{1-\beta^{n_y}}\mathcal{L}(\boldsymbol{p},y)$
์ฐธ๊ณ ๋ก, $\beta = 0$์ด๋ฉด weighting์ ์ฃผ์ง ์๋๊ฒ๊ณผ ๊ฐ๊ณ , $\beta$ ๊ฐ์ด 1์ ๊ฐ๊น์์ง์๋ก class frequency๋ง์ผ๋ก re-weightingํ๋ ๊ฒ๊ณผ ์ ์ฌํ ํจ๊ณผ๋ฅผ ๋ํ๋ธ๋ค.
EX) Class-Balanced Softmax Cross-Entropy Loss
Softmax Cross-Entropy loss์ Class Balance (CB)๋ฅผ ์ ์ฉํด๋ณด์. ๋จผ์ , ๋ชจ๋ธ๋ก๋ถํฐ ์ฃผ์ด์ง ๊ฐ ํด๋์ค๋ณ ์์ธก ๊ฒฐ๊ณผ๋ฅผ $\boldsymbol{z} = [z_1, z_2, \cdots, z_C]^T$๋ผ๊ณ ํ์. ๊ทธ๋ฌ๋ฉด softmax cross-entropy loss๋ ๋ค์๊ณผ ๊ฐ์ด ์ฃผ์ด์ง๋ค:
$CE_{softmax}(\boldsymbol{z}, y) = - \log \Big( exp(z_y)/\sum_{j=1}^{C} exp(z_j) \Big)$
์ฌ๊ธฐ์ Class Balance๋ฅผ ์ ์ฉ์ํค๋ฉด ๋ค์๊ณผ ๊ฐ๋ค. ๋ผ๋ฒจ $y$์ ๋์๋๋ ํด๋์ค ํ์ต ์ํ ์๊ฐ $n_y$๋ผ๊ณ ํ๋ฉด, CB loss๋ ๋ค์๊ณผ ๊ฐ์ด ์ฃผ์ด์ง๋ค.
$CB_{softmax}(\boldsymbol{z}, y) = - (1-\beta)/(1-\beta^{n_y}) \cdot \log \Big( exp(z_y)/\sum_{j=1}^{C} exp(z_j) \Big)$
Code ๋ค์ฌ๋ค๋ณด๊ธฐ
(์ฝ๋ ์ถ์ฒ: https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py)
def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
"""Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
where Loss is one of the standard losses used for Neural Networks.
Args:
labels: A int tensor of size [batch].
logits: A float tensor of size [batch, no_of_classes].
samples_per_cls: A python list of size [no_of_classes].
no_of_classes: total number of classes. int
loss_type: string. One of "sigmoid", "focal", "softmax".
beta: float. Hyperparameter for Class balanced loss.
gamma: float. Hyperparameter for Focal loss.
Returns:
cb_loss: A float tensor representing class balanced loss
"""
effective_num = 1.0 - np.power(beta, samples_per_cls)
weights = (1.0 - beta) / np.array(effective_num)
# normalization
weights = weights / np.sum(weights) * no_of_classes
labels_one_hot = F.one_hot(labels, no_of_classes).float()
weights = torch.tensor(weights).float()
weights = weights.unsqueeze(0)
weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
weights = weights.sum(1)
weights = weights.unsqueeze(1)
weights = weights.repeat(1,no_of_classes)
if loss_type == "focal":
cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
elif loss_type == "sigmoid":
cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
elif loss_type == "softmax":
pred = logits.softmax(dim = 1)
cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
return cb_loss
Reference
(Cui et al. 2019) Cui et al. "Class-Balanced Loss Based on Effective Number of Samples", CVPR 2019
๋ด์ฉ์ ๋ํ ์ฝ๋ฉํธ๋
์ธ์ ๋ ์ง ํ์์ ๋๋ค
๋ โผ๏ธ
'IN DEPTH CAKE > ML-WIKI' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
Inductive Bias, ๊ทธ๋ฆฌ๊ณ Vision Transformer (ViT) (16) | 2023.08.22 |
---|---|
<ML๋ ผ๋ฌธ> CVAE์ ๋ํ์ฌ (feat. ๋๊ฐ ์ง์ง CVAE์ธ๊ฐ? ํ๋์ ์ด๋ฆ, ๋ ๊ฐ์ ๊ธฐ๋ฒ) (6) | 2023.03.10 |