
TL;DR
Class imbalance ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์ ๋ฐ์ดํฐ์ ๊ฐ ํด๋์ค์ ์ ํจ ๋ฐ์ดํฐ ์๋ฅผ ์ ์ํ๊ณ ์ด๋ฅผ ํ์ฉํ re-weighting๊ธฐ๋ฐ Class Balance Loss ๊ธฐ๋ฒ ์ ์.
๋ฌด์จ ๋ฌธ์ ๋ฅผ ํ๊ณ ์๋?
๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต์ ์ฌ์ฉ๋๋ ์ผ๋ฐ์ ์ธ ๋ฐ์ดํฐ ์ (์๋ฅผ๋ค์ด CIFAR-10, 100, ImageNet ๋ฑ)์ด ํด๋์ค ๋ผ๋ฒจ ๋ถํฌ๊ฐ ๊ท ์ผํ ๊ฒ๊ณผ ๋ฌ๋ฆฌ, ์ค์ ์ํฉ์์๋ ๋ชจ๋ ํด๋์ค์ ๋ฐ์ดํฐ ์๊ฐ ๊ท ์ผํ๊ฒ ์์ง๋์ง ์๋, Long Tail ํ์์ด ๋ฐ์ํ๋ค. ์ฌ๊ธฐ์ Long Tail์ด๋ผ๊ณ ํจ์, ๊ฐ ํ์ต ๋ฐ์ดํฐ ์ ํด๋์ค ๋ณ ์ํ ์์ ๋ํ ๋ถํฌ๋ฅผ ๊ทธ๋ ธ์ ๋ ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด, ์์์ ํด๋์ค์ ๋ํด์ ๋ฐ์ดํฐ ์ํ ์๊ฐ ๋ง์ ๋ฐ ๋ฐํด (Head) ๋ค์์ ํด๋์ค์์ ๊ธฐ๋์น ์ดํ์ ์ํ ์๋ฅผ ๊ฐ๋ (Long Tail) ํ์์ ๋งํ๋ค. ๋ง ๊ทธ๋๋ก ํด๋์ค ์ํ ์์ ๋ถํฌ๋ฅผ ๊ทธ๋ ธ์๋ ๊ธด ๊ผฌ๋ฆฌ๋ฅผ ๊ฐ์ง๋ค๋ ๋ป์ด๋ค.
์์ฝํ์๋ฉด, ๋ณธ ๋ ผ๋ฌธ์์๋ ํ์ต ๋ฐ์ดํฐ ์ ๋ด์ ํด๋์ค ๋ถ๊ท ํ ๋ฌธ์ (Class Imbalance Problem) ๋ฅผ ๋ค๋ฃจ๊ณ ์๋ค.

๊ธฐ์กด ์ฐ๊ตฌ๋ค์ ์ด ๋ฌธ์ ๋ฅผ ์ด๋ป๊ฒ ๋ค๋ฃจ์๋?
๊ธฐ์กด ์ฐ๊ตฌ๋ค์ ๋ฌธ์ ํด๊ฒฐ ๋ฐฉ์์ ๋ฐ๋ผ ํฌ๊ฒ Re-sampling๋ฐฉ์๊ณผ Cost-sensitive re-weighting ๋ฐฉ์์ผ๋ก ๋๋๋ค.
- Re-Sampling
Re-sampling๋ฐฉ์์ ๋ง ๊ทธ๋๋ก ์ค์ ์๋ ๋ฐ์ดํฐ ์๋ฅผ ์ค๋ณตํด์ samplingํ๋ ๋ฐฉ์์ด๋ค. ์๋ฅผ ๋ค์ด ๋ฐ์ดํฐ ์๊ฐ ์ ์ ํด๋์ค์ ๋ํด์ ์ค๋ณตํด์ ๋ ์ํ๋ง์ ํ๊ฑฐ๋ (over-sampling) ์ํ ์๊ฐ ๋ง์ ํด๋์ค์ ๋ํด์ ๋ ์ํ๋ง ํ๊ฑฐ๋ (under-sampling) ํ๋ ํํ์ด๋ค. ์ด ๊ฒฝ์ฐ ์ํ ์ค๋ณต์ผ๋ก ์ธํ ๊ณผ์ ํฉ (overfitting)์ด ์ผ์ด๋ ์ ์๊ณ , over-sampling์ ํ๋ ๊ฒฝ์ฐ์ ์๋์ ์ผ๋ก ํ์ต ์๊ฐ์ด ์ฆ๊ฐํ๋ ์ด์ ์ญ์ ๊ณ ๋ ค๋์ด์ผํ๋ค. ์ด๋ฌํ ํ๊ณ์ ์ผ๋ก ์ธํด Re-sampling ๋ฐฉ์๋ณด๋ค๋ Re-weighting๋ฐฉ์์ด ์ข ๋ ์ ํธ๋๋ค.
- Re-weighting
Re-weighting ๋ฐฉ์์ ํด๋์ค ๋ณ ์ํ ์์ ๋ฐ๋ผ์ ํ์ต ์ ์์ค ํจ์์ ๊ฐ์ค์น๋ฅผ ์ฃผ๋ ๋ฐฉ์์ ๋งํ๋ค. ์๋ฅผ ๋ค์ด, ์๋์ ์ผ๋ก ์์์ ํด๋์ค์ ๋ํ loss ๊ฐ์ ๋ ๋์ ๊ฐ์ค์น๋ฅผ ์ฃผ๋ ํํ์ด๋ค. ์๋์ ์ผ๋ก ๊ฐ๋จํ๋ค๋ ์ฅ์ ์ผ๋ก ์ธํด ์ด ๋ฐฉ์์ด ๋ง์ด ์ฌ์ฉ๋์์ผ๋, large-scale ๋ฐ์ดํฐ ์ ์ ๋ํด์๋ ์ด๋ฌํ ๋จ์ํ ๋ฐฉ์์ด ์ ๋์ํ์ง ์๋ ๋ค๋ ์ฌ์ค์ด ํ์ธ๋์๋ค. ๋์ "smoothed weight" ์์ค ํจ์์ ๋ํ ์ฐ๊ตฌ๊ฐ ์งํ๋์๋ค. "smoothed" ๋ฒ์ ์ ๊ฒฝ์ฐ์๋ ํด๋์ค ๋ถํฌ์ ์ ๊ณฑ๊ทผ์ ๋ฐ๋น๋กํ๋๋ก weight๋ฅผ ์ค๋ค.
์ด ๋ ผ๋ฌธ์์๋ ์ด ๋ฌธ์ ๋ฅผ ์ด๋ป๊ฒ ํ๊ณ ์๋?
Conceptually,
๋ณธ ๋ ผ๋ฌธ์์๋ Re-weighting๊ธฐ๋ฒ์์ ์ฐฉ์ํ ๋ฐฉ๋ฒ์ ์ ์ํ๋ค. ์๋์ 2๊ฐ์ ํด๋์ค๋ฅผ ์๋ก ๋ค์ด ์ค๋ช ํด๋ณด๋ฉด, ์ผ์ชฝ์ Head์ ํด๋นํ๋ ๋ฐ์ดํฐ ์ํ์ด๊ณ ์ค๋ฅธ์ชฝ์ Long Tail์ ํด๋นํ๋ ๋ฐ์ดํฐ ์ํ์ด๋ผ๊ณ ํ์. ๊ธฐ์กด์ re-weighting ๋ฐฉ๋ฒ์ ์ผ์ชฝ ํด๋์ค ์ํ์ ์์ ์ค๋ฅธ์ชฝ ํด๋์ค ์ํ์ ์๋ฅผ ๊ฐ์ง๊ณ weighting์ ์ํํ์ง๋ง, ์ด ๊ฒฝ์ฐ ๊ฒ์์ ์ค์ ์ผ๋ก ๊ทธ๋ ค์ง classifier ์ ์ด ๋นจ๊ฐ์์ผ๋ก ์น์ฐ์น๊ฒ๋๋ค. ํ์ง๋ง ํด๋์ค ๋ด์ ๋ชจ๋ ์ํ๋ค์ด ์ ํจํ ์๋ฅผ ๊ตฌ์ฑํ๋ ๊ฒ์ ์๋๋ผ๋ ์ ์ ์๊ธฐํด์ผํ๋ค. ์๋ฅผ ๋ค์ด, ์ํ๋ค ๊ฐ์ ์ค๋ณต๋๋ ์ ๋ณด๋ฅผ ๊ฐ์ง๊ณ ์๋ ๋ฐ์ดํฐ๊ฐ ์์ ์๋ ์๊ธฐ ๋๋ฌธ์ ๋จ์ํ '์ํ์ ์'๋ฅผ loss์ weight๋ก ์ ํ๋ ๊ฒ์ ๋ฌธ์ ๊ฐ ๋ ์ ์๋ค. ๋ฐ๋ผ์ ๋ณธ ๋ ผ๋ฌธ์์๋ ๊ฐ ํด๋์ค ๋ณ '์ํ์ ์ ํจ ์ซ์ (effective number)'๋ฅผ ๊ตฌํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๊ณ ์์ผ๋ฉฐ, ์ด๋ ๊ฒ ๊ตฌํด์ง ์ ํจ ๊ฐ์์ ๋ฐ๋น๋กํ๋๋ก loss๋ฅผ re-weighting์ ํ๋ ๊ฒฝ์ฐ ๊ฒ์์ ์ ์ผ๋ก ์น์ฐ์น classifier๋ฅผ ์ฐ๋ฆฌ๊ฐ ์ํ๋ ํ๋์ ์ ๊ณผ ๊ฐ์ด ์กฐ์ ํ ์ ์๋ค.

์ ํจ ์ํ ์ (Effective Number of Samples)
๋ณธ ๋ ผ๋ฌธ์์๋ ๊ฐ ๋ฐ์ดํฐ ์ ์ ํด๋์ค ๋ณ ์ ํจ ์ํ ์๋ฅผ ์ด๋ป๊ฒ ๊ตฌํ๋์ง๋ฅผ ์ ์ํ๊ณ ์๋ค. ์ด๋ฅผ ์ํด ๋ ผ๋ฌธ์์๋ ์ ํจ ๊ฐ์ (Effective Number) ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํ๊ณ ์๋ค.
Definition 1 (Effective Number).
The effective number of samples is the expecterd volume of samples.
์ํ ํฌ๊ธฐ์ ๊ธฐ๋๊ฐ์ ์ํ์ ์ ํจ ๊ฐ์๋ก ์ ์ํ๊ณ ์๋ค.
๊ธฐ๋๊ฐ์ ์ ์ํ๊ธฐ ์ํด์ ๋ณธ ๋
ผ๋ฌธ์์๋ ์๋ก ์ํ๋งํ ๋ฐ์ดํฐ๊ฐ ์ง๊ธ๊น์ง ๊ด์ฐฐ๋ ์ํ๊ณผ ๊ฒน์น ํ๋ฅ ์

๋ถ์ฐ ์ค๋ช
๋ณธ ๋ ผ๋ฌธ์ ์์ด๋์ด๋ ํด๋์ค์ ๋ฐ์ดํฐ๋ฅผ ๋ ๋ง์ด ์ฌ์ฉํ ๋ marginal benefit์ด ์ค์ด๋๋ ์ ๋๋ฅผ ์ธก์ ํ๋ ๊ฒ์ด๋ค. ์๋ฅผ ๋ค์ด, ์ํ์ ์๊ฐ ์ฆ๊ฐํ๋ฉด ์๋ก ์ถ๊ฐ๋๋ ์ํ์ด ํ์ฌ ์กด์ฌํ๋ ์ํ๊ณผ ๊ฒน์น ํ๋ฅ ์ด ๋์์ง๋ค. (์ฌํํ ๋ฐ์ดํฐ ์ฆ๊ฐ - ํ์ , ํฌ๋กํ ๋ฑ - ์ผ๋ก ํ๋ณด๋ ๋ฐ์ดํฐ ์ญ์ ์ค๋ณต๋๋ ๋ฐ์ดํฐ๋ก ๋ณผ ์ ์์ต๋๋ค.)
์ด์ ์ํ์ ์ ํจ ๊ฐ์ (expected number or expected volume)๋ฅผ ์ํ์ ์ผ๋ก ํํํด๋ณผ ์ ์๋ค. ๋จผ์ ํด๋น ํด๋์ค์ ํผ์ณ ๊ณต๊ฐ์์ ๋ชจ๋ ๊ฐ๋ฅํ ๋ฐ์ดํฐ์ ์งํฉ์
๊ทธ๋ฌ๋ฉด ์ํ์ ์ ํจ๊ฐ์๋ ๋ค์๊ณผ ๊ฐ์ด ํ๊ธฐํ ์ ์๋ค.
Proposition 1 (Effective Number)., where .
์ฆ๋ช ์ค๋ช
๊ท๋ฉ๋ฒ์ผ๋ก ์ด๋ฅผ ์ค๋ช
ํด๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค. ๋จผ๋
Effective Number์ ๋ํ ์ถ๊ฐ์ ์ธ ์ดํด
๊ทธ๋ฆฌ๊ณ ์์ proposition์ผ๋ก๋ถํฐ ์ฐ๋ฆฌ๋ ์ํ์ ์ ํจ ๊ฐ์๊ฐ
Class-Balanced Loss
์์ ์ค๋ช ํ๋ฏ์ด, ๋ณธ ๋ ผ๋ฌธ์ ๊ธฐ์กด์ ์ฐ๊ตฌ๋ค ์ค loss์ class๋ณ re-weight๋ฅผ ์ฃผ๋ ๋ฐฉ๋ฒ๋ก ๋ค์ ํ๊ณ์ ์์ด๋์ด๋ฅผ Effective Number๋ผ๋ ๊ฐ๋ ์ ๋์ ํจ์ผ๋ก์จ ๊ฐ์ ํ๊ณ ์๋ค. ์ด ๋ ผ๋ฌธ์ ์ฅ์ ์ loss function์ agnostic ํ๋ค๋ ์ ์ด๋ค. ์ค์ ๋ก ๋ณธ ๋ ผ๋ฌธ์์๋ Cross-Entropy Loss (Softmax, Sigmoid)์ Focal Loss์ ๋ํ ์์ ๋ฅผ ํจ๊ป ์ ์ํ๊ณ ์๋ค.
์
๋ ฅ
๊ทธ๋ฌ๋ฉด ํด๋์ค
์ด๋ ๊ฒ ๊ตฌํ effective number๋ฅผ ์ฌ์ฉํ์ฌ ์์คํจ์๋ฅผ balancing ํ๊ธฐ ์ํด์ ๋
ผ๋ฌธ์์๋ "weighting factor
์ด๋ฅผ ์ข ํฉํ class-balanced (CB) loss๋ ๋ค์๊ณผ ๊ฐ๋ค.
- class
์ ์ํ ์๋ฅผ ๋ผ๊ณ ํ์ ์ผ ๋, class ์ weighting factor๋ ์ด๋ค.- ์ด ๋์ CB loss๋ ๋ค์๊ณผ ๊ฐ๋ค:
์ฐธ๊ณ ๋ก,
EX) Class-Balanced Softmax Cross-Entropy Loss
Softmax Cross-Entropy loss์ Class Balance (CB)๋ฅผ ์ ์ฉํด๋ณด์. ๋จผ์ , ๋ชจ๋ธ๋ก๋ถํฐ ์ฃผ์ด์ง ๊ฐ ํด๋์ค๋ณ ์์ธก ๊ฒฐ๊ณผ๋ฅผ
์ฌ๊ธฐ์ Class Balance๋ฅผ ์ ์ฉ์ํค๋ฉด ๋ค์๊ณผ ๊ฐ๋ค. ๋ผ๋ฒจ
Code ๋ค์ฌ๋ค๋ณด๊ธฐ
(์ฝ๋ ์ถ์ฒ: https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py)
def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
"""Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
where Loss is one of the standard losses used for Neural Networks.
Args:
labels: A int tensor of size [batch].
logits: A float tensor of size [batch, no_of_classes].
samples_per_cls: A python list of size [no_of_classes].
no_of_classes: total number of classes. int
loss_type: string. One of "sigmoid", "focal", "softmax".
beta: float. Hyperparameter for Class balanced loss.
gamma: float. Hyperparameter for Focal loss.
Returns:
cb_loss: A float tensor representing class balanced loss
"""
effective_num = 1.0 - np.power(beta, samples_per_cls)
weights = (1.0 - beta) / np.array(effective_num)
# normalization
weights = weights / np.sum(weights) * no_of_classes
labels_one_hot = F.one_hot(labels, no_of_classes).float()
weights = torch.tensor(weights).float()
weights = weights.unsqueeze(0)
weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
weights = weights.sum(1)
weights = weights.unsqueeze(1)
weights = weights.repeat(1,no_of_classes)
if loss_type == "focal":
cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
elif loss_type == "sigmoid":
cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
elif loss_type == "softmax":
pred = logits.softmax(dim = 1)
cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
return cb_loss
Reference
(Cui et al. 2019) Cui et al. "Class-Balanced Loss Based on Effective Number of Samples", CVPR 2019

๋ด์ฉ์ ๋ํ ์ฝ๋ฉํธ๋
์ธ์ ๋ ์ง ํ์์ ๋๋ค
๋ โผ๏ธ