๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

IN DEPTH CAKE/ML-WIKI

<ML๋…ผ๋ฌธ> ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜• ๋ฌธ์ œ Cui et al. "Class-Balanced Loss Based on Effective Number of Samples" (CVPR 2019)

 

TL;DR

Class imbalance ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋ฐ์ดํ„ฐ์…‹ ๊ฐ ํด๋ž˜์Šค์˜ ์œ ํšจ ๋ฐ์ดํ„ฐ ์ˆ˜๋ฅผ ์ •์˜ํ•˜๊ณ  ์ด๋ฅผ ํ™œ์šฉํ•œ re-weighting๊ธฐ๋ฐ˜ Class Balance Loss ๊ธฐ๋ฒ• ์ œ์•ˆ.

 

 

๋ฌด์Šจ ๋ฌธ์ œ๋ฅผ ํ’€๊ณ  ์žˆ๋‚˜?

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต์— ์‚ฌ์šฉ๋˜๋Š” ์ผ๋ฐ˜์ ์ธ ๋ฐ์ดํ„ฐ ์…‹ (์˜ˆ๋ฅผ๋“ค์–ด CIFAR-10, 100, ImageNet ๋“ฑ)์ด ํด๋ž˜์Šค ๋ผ๋ฒจ ๋ถ„ํฌ๊ฐ€ ๊ท ์ผํ•œ ๊ฒƒ๊ณผ ๋‹ฌ๋ฆฌ, ์‹ค์ œ ์ƒํ™ฉ์—์„œ๋Š” ๋ชจ๋“  ํด๋ž˜์Šค์˜ ๋ฐ์ดํ„ฐ ์ˆ˜๊ฐ€ ๊ท ์ผํ•˜๊ฒŒ ์ˆ˜์ง‘๋˜์ง€ ์•Š๋Š”, Long Tail ํ˜„์ƒ์ด ๋ฐœ์ƒํ•œ๋‹ค. ์—ฌ๊ธฐ์„œ Long Tail์ด๋ผ๊ณ ํ•จ์€, ๊ฐ ํ•™์Šต ๋ฐ์ดํ„ฐ ์˜ ํด๋ž˜์Šค ๋ณ„ ์ƒ˜ํ”Œ ์ˆ˜์— ๋Œ€ํ•œ ๋ถ„ํฌ๋ฅผ ๊ทธ๋ ธ์„ ๋•Œ ์•„๋ž˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด, ์†Œ์ˆ˜์˜ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋งŽ์€ ๋ฐ ๋ฐ˜ํ•ด (Head) ๋‹ค์ˆ˜์˜ ํด๋ž˜์Šค์—์„œ ๊ธฐ๋Œ€์น˜ ์ดํ•˜์˜ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ ๊ฐ–๋Š” (Long Tail) ํ˜„์ƒ์„ ๋งํ•œ๋‹ค. ๋ง ๊ทธ๋Œ€๋กœ ํด๋ž˜์Šค ์ƒ˜ํ”Œ ์ˆ˜์˜ ๋ถ„ํฌ๋ฅผ ๊ทธ๋ ธ์„๋•Œ ๊ธด ๊ผฌ๋ฆฌ๋ฅผ ๊ฐ€์ง„๋‹ค๋Š” ๋œป์ด๋‹ค.

 

์š”์•ฝํ•˜์ž๋ฉด, ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ํ•™์Šต ๋ฐ์ดํ„ฐ ์…‹ ๋‚ด์˜ ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜• ๋ฌธ์ œ (Class Imbalance Problem) ๋ฅผ ๋‹ค๋ฃจ๊ณ  ์žˆ๋‹ค.

 

๊ทธ๋ฆผ ์ถœ์ฒ˜: Cui et al. Figure 1

 

๊ธฐ์กด ์—ฐ๊ตฌ๋“ค์€ ์ด ๋ฌธ์ œ๋ฅผ ์–ด๋–ป๊ฒŒ ๋‹ค๋ฃจ์—ˆ๋‚˜?

 

๊ธฐ์กด ์—ฐ๊ตฌ๋“ค์€ ๋ฌธ์ œ ํ•ด๊ฒฐ ๋ฐฉ์‹์— ๋”ฐ๋ผ ํฌ๊ฒŒ Re-sampling๋ฐฉ์‹๊ณผ Cost-sensitive re-weighting ๋ฐฉ์‹์œผ๋กœ ๋‚˜๋‰œ๋‹ค.

- Re-Sampling

Re-sampling๋ฐฉ์‹์€ ๋ง ๊ทธ๋Œ€๋กœ ์‹ค์ œ ์žˆ๋Š” ๋ฐ์ดํ„ฐ ์ˆ˜๋ฅผ ์ค‘๋ณตํ•ด์„œ samplingํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋ฐ์ดํ„ฐ ์ˆ˜๊ฐ€ ์ ์€ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ ์ค‘๋ณตํ•ด์„œ ๋” ์ƒ˜ํ”Œ๋ง์„ ํ•˜๊ฑฐ๋‚˜ (over-sampling) ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋งŽ์€ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ ๋œ ์ƒ˜ํ”Œ๋ง ํ•˜๊ฑฐ๋‚˜ (under-sampling) ํ•˜๋Š” ํ˜•ํƒœ์ด๋‹ค. ์ด ๊ฒฝ์šฐ ์ƒ˜ํ”Œ ์ค‘๋ณต์œผ๋กœ ์ธํ•œ ๊ณผ์ ํ•ฉ (overfitting)์ด ์ผ์–ด๋‚  ์ˆ˜ ์žˆ๊ณ , over-sampling์„ ํ•˜๋Š” ๊ฒฝ์šฐ์— ์ƒ๋Œ€์ ์œผ๋กœ ํ•™์Šต ์‹œ๊ฐ„์ด ์ฆ๊ฐ€ํ•˜๋Š” ์ด์Šˆ ์—ญ์‹œ ๊ณ ๋ ค๋˜์–ด์•ผํ•œ๋‹ค. ์ด๋Ÿฌํ•œ ํ•œ๊ณ„์ ์œผ๋กœ ์ธํ•ด Re-sampling ๋ฐฉ์‹๋ณด๋‹ค๋Š” Re-weighting๋ฐฉ์‹์ด ์ข€ ๋” ์„ ํ˜ธ๋œ๋‹ค.

 

- Re-weighting

Re-weighting ๋ฐฉ์‹์€ ํด๋ž˜์Šค ๋ณ„ ์ƒ˜ํ”Œ ์ˆ˜์— ๋”ฐ๋ผ์„œ ํ•™์Šต ์‹œ ์†์‹ค ํ•จ์ˆ˜์— ๊ฐ€์ค‘์น˜๋ฅผ ์ฃผ๋Š” ๋ฐฉ์‹์„ ๋งํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ƒ๋Œ€์ ์œผ๋กœ ์†Œ์ˆ˜์˜ ํด๋ž˜์Šค์— ๋Œ€ํ•œ loss ๊ฐ’์— ๋” ๋†’์€ ๊ฐ€์ค‘์น˜๋ฅผ ์ฃผ๋Š” ํ˜•ํƒœ์ด๋‹ค. ์ƒ๋Œ€์ ์œผ๋กœ ๊ฐ„๋‹จํ•˜๋‹ค๋Š” ์žฅ์ ์œผ๋กœ ์ธํ•ด ์ด ๋ฐฉ์‹์ด ๋งŽ์ด ์‚ฌ์šฉ๋˜์—ˆ์œผ๋‚˜, large-scale ๋ฐ์ดํ„ฐ ์…‹์— ๋Œ€ํ•ด์„œ๋Š” ์ด๋Ÿฌํ•œ ๋‹จ์ˆœํ•œ ๋ฐฉ์‹์ด ์ž˜ ๋™์ž‘ํ•˜์ง€ ์•Š๋Š” ๋‹ค๋Š” ์‚ฌ์‹ค์ด ํ™•์ธ๋˜์—ˆ๋‹ค. ๋Œ€์‹  "smoothed weight" ์†์‹ค ํ•จ์ˆ˜์— ๋Œ€ํ•œ ์—ฐ๊ตฌ๊ฐ€ ์ง„ํ–‰๋˜์—ˆ๋‹ค. "smoothed" ๋ฒ„์ „์˜ ๊ฒฝ์šฐ์—๋Š” ํด๋ž˜์Šค ๋ถ„ํฌ์˜ ์ œ๊ณฑ๊ทผ์— ๋ฐ˜๋น„๋ก€ํ•˜๋„๋ก weight๋ฅผ ์ค€๋‹ค.

 

 

์ด ๋…ผ๋ฌธ์—์„œ๋Š” ์ด ๋ฌธ์ œ๋ฅผ ์–ด๋–ป๊ฒŒ ํ’€๊ณ  ์žˆ๋‚˜?

Conceptually,

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” Re-weighting๊ธฐ๋ฒ•์—์„œ ์ฐฉ์•ˆํ•œ ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•œ๋‹ค. ์•„๋ž˜์˜ 2๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ์˜ˆ๋กœ ๋“ค์–ด ์„ค๋ช…ํ•ด๋ณด๋ฉด, ์™ผ์ชฝ์€ Head์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์ด๊ณ  ์˜ค๋ฅธ์ชฝ์€ Long Tail์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์ด๋ผ๊ณ  ํ•˜์ž. ๊ธฐ์กด์˜ re-weighting ๋ฐฉ๋ฒ•์€ ์™ผ์ชฝ ํด๋ž˜์Šค ์ƒ˜ํ”Œ์˜ ์ˆ˜์™€ ์˜ค๋ฅธ์ชฝ ํด๋ž˜์Šค ์ƒ˜ํ”Œ์˜ ์ˆ˜๋ฅผ ๊ฐ€์ง€๊ณ  weighting์„ ์ˆ˜ํ–‰ํ–ˆ์ง€๋งŒ, ์ด ๊ฒฝ์šฐ ๊ฒ€์€์ƒ‰ ์‹ค์„ ์œผ๋กœ ๊ทธ๋ ค์ง„ classifier ์„ ์ด ๋นจ๊ฐ„์ƒ‰์œผ๋กœ ์น˜์šฐ์น˜๊ฒŒ๋œ๋‹ค. ํ•˜์ง€๋งŒ ํด๋ž˜์Šค ๋‚ด์˜ ๋ชจ๋“  ์ƒ˜ํ”Œ๋“ค์ด ์œ ํšจํ•œ ์ˆ˜๋ฅผ ๊ตฌ์„ฑํ•˜๋Š” ๊ฒƒ์€ ์•„๋‹ˆ๋ผ๋Š” ์ ์„ ์ƒ๊ธฐํ•ด์•ผํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ƒ˜ํ”Œ๋“ค ๊ฐ„์— ์ค‘๋ณต๋˜๋Š” ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ์„ ์ˆ˜๋„ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋‹จ์ˆœํžˆ '์ƒ˜ํ”Œ์˜ ์ˆ˜'๋ฅผ loss์˜ weight๋กœ ์ •ํ•˜๋Š” ๊ฒƒ์€ ๋ฌธ์ œ๊ฐ€ ๋  ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ๊ฐ ํด๋ž˜์Šค ๋ณ„ '์ƒ˜ํ”Œ์˜ ์œ ํšจ ์ˆซ์ž (effective number)'๋ฅผ ๊ตฌํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ์ด๋ ‡๊ฒŒ ๊ตฌํ•ด์ง„ ์œ ํšจ ๊ฐœ์ˆ˜์— ๋ฐ˜๋น„๋ก€ํ•˜๋„๋ก loss๋ฅผ re-weighting์„ ํ•˜๋Š” ๊ฒฝ์šฐ ๊ฒ€์€์ƒ‰ ์„ ์œผ๋กœ ์น˜์šฐ์นœ classifier๋ฅผ ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ํŒŒ๋ž€์ƒ‰ ์„ ๊ณผ ๊ฐ™์ด ์กฐ์ •ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

๊ทธ๋ฆผ ์ถœ์ฒ˜: Cui et al. Figure 1

์œ ํšจ ์ƒ˜ํ”Œ ์ˆ˜ (Effective Number of Samples)

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ๊ฐ ๋ฐ์ดํ„ฐ ์…‹์˜ ํด๋ž˜์Šค ๋ณ„ ์œ ํšจ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ ์–ด๋–ป๊ฒŒ ๊ตฌํ•˜๋Š”์ง€๋ฅผ ์ œ์•ˆํ•˜๊ณ ์žˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๋…ผ๋ฌธ์—์„œ๋Š” ์œ ํšจ ๊ฐœ์ˆ˜ (Effective Number) ๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ€์ด ์ •์˜ํ•˜๊ณ  ์žˆ๋‹ค.

 


Definition 1 (Effective Number).
The effective number of samples is the expecterd volume of samples.


 

์ƒ˜ํ”Œ ํฌ๊ธฐ์˜ ๊ธฐ๋Œ“๊ฐ’์„ ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜๋กœ ์ •์˜ํ•˜๊ณ ์žˆ๋‹ค.

 

๊ธฐ๋Œ“๊ฐ’์„ ์ •์˜ํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ƒˆ๋กœ ์ƒ˜ํ”Œ๋งํ•œ ๋ฐ์ดํ„ฐ๊ฐ€ ์ง€๊ธˆ๊นŒ์ง€ ๊ด€์ฐฐ๋œ ์ƒ˜ํ”Œ๊ณผ ๊ฒน์น  ํ™•๋ฅ ์„ $p$, ๊ฒน์น˜์ง€ ์•Š์„ ํ™•๋ฅ ์„ $1-p$๋กœ ์ƒ์ •ํ•œ๋‹ค (์ด ๋•Œ, ๋ฌธ์ œ๋ฅผ ๋‹จ์ˆœํ™”ํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ผ๋ถ€๋งŒ ๊ฒน์น˜๋Š” ๊ฒฝ์šฐ์— ๋Œ€ํ•ด์„œ๋Š” ๊ณ ๋ คํ•˜๊ณ  ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค) ์ด๋ฅผ ๊ทธ๋ฆผ์œผ๋กœ ๊ทธ๋ ค๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

๊ทธ๋ฆผ ์ถœ์ฒ˜: Cui et al. Figure 2

 

 

๋ถ€์—ฐ ์„ค๋ช…

๋”๋ณด๊ธฐ

๋ณธ ๋…ผ๋ฌธ์˜ ์•„์ด๋””์–ด๋Š” ํด๋ž˜์Šค์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋” ๋งŽ์ด ์‚ฌ์šฉํ•  ๋•Œ marginal benefit์ด ์ค„์–ด๋“œ๋Š” ์ •๋„๋ฅผ ์ธก์ •ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ƒ˜ํ”Œ์˜ ์ˆ˜๊ฐ€ ์ฆ๊ฐ€ํ•˜๋ฉด ์ƒˆ๋กœ ์ถ”๊ฐ€๋˜๋Š” ์ƒ˜ํ”Œ์ด ํ˜„์žฌ ์กด์žฌํ•˜๋Š” ์ƒ˜ํ”Œ๊ณผ ๊ฒน์น  ํ™•๋ฅ ์ด ๋†’์•„์ง„๋‹ค. (์‹ฌํ”Œํ•œ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• - ํšŒ์ „, ํฌ๋กœํ•‘ ๋“ฑ - ์œผ๋กœ ํ™•๋ณด๋œ ๋ฐ์ดํ„ฐ ์—ญ์‹œ ์ค‘๋ณต๋˜๋Š” ๋ฐ์ดํ„ฐ๋กœ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)

 

์ด์ œ ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜ (expected number or expected volume)๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ํ‘œํ˜„ํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋จผ์ € ํ•ด๋‹น ํด๋ž˜์Šค์˜ ํ”ผ์ณ ๊ณต๊ฐ„์—์„œ ๋ชจ๋“  ๊ฐ€๋Šฅํ•œ ๋ฐ์ดํ„ฐ์˜ ์ง‘ํ•ฉ์„ $\mathcal{S}$๋ผ๊ณ  ํ•˜๊ณ , ์ด๋Ÿฌํ•œ ์ง‘ํ•ฉ $S$์˜ ํฌ๊ธฐ๋ฅผ $N$์ด๋ผ๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์šฐ๋ฆฌ๊ฐ€ ํ‘œํ˜„ํ•˜๊ณ ์žํ•˜๋Š” ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜๋ฅผ $E_n$์œผ๋กœ ํ‘œ๊ธฐํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ $n$์€ ์ƒ˜ํ”Œ์˜ ์ˆ˜๋ฅผ ์˜๋ฏธํ•œ๋‹ค.

 

๊ทธ๋Ÿฌ๋ฉด ์ƒ˜ํ”Œ์˜ ์œ ํšจ๊ฐœ์ˆ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œ๊ธฐํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

Proposition 1 (Effective Number).
$E_n = (1-\beta^n)/(1-\beta)$, where $\beta = (N-1)/N$.

 

 

์ฆ๋ช… ์„ค๋ช…

๋”๋ณด๊ธฐ

๊ท€๋‚ฉ๋ฒ•์œผ๋กœ ์ด๋ฅผ ์„ค๋ช…ํ•ด๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. ๋จผ๋” $E_1 = 1$์„ ๋งŒ์กฑํ•˜๊ณ  (overlap์ด ์—†๊ธฐ ๋•Œ๋ฌธ์—) ์ด๋Š” $E_1 = (1-\beta^1)/(1-\beta)=1$์„ ๋งŒ์กฑํ•ฉ๋‹ˆ๋‹ค. ์ด์ œ ๊ณผ๊ฑฐ์— ์ƒ˜ํ”Œ๋œ $n-1$๊ฐœ์˜ ์˜ˆ์ œ๋กœ๋ถ€ํ„ฐ $n$๋ฒˆ์งธ ์ƒ˜ํ”Œ์„ ์ƒ˜ํ”Œ๋งํ–ˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์ƒˆ๋กœ ์ƒ˜ํ”Œ๋œ ๋ฐ์ดํ„ฐ๊ฐ€ ์ด์ „์˜ ์ƒ˜ํ”Œ๋“ค๊ณผ overlap๋  ํ™•๋ฅ ์€ ์ „์ฒด ์ง‘ํ•ฉ$S$์˜ ๊ฐœ์ˆ˜ $N$์ค‘ $E_{n-1}$์ผ ํ™•๋ฅ , ์ฆ‰ $p=E_{n-1}/N$์„ ๋งŒ์กฑํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ $n$๊ฐœ์˜ ์ƒ˜ํ”Œ์— ๋Œ€ํ•œ ์œ ํšจ ๊ฐœ์ˆ˜์˜ ๊ธฐ๋Œ“๊ฐ’์€ $E_n = p \cdot E_{n-1} + (1-p)(E_{n-1} + 1) = 1 + \frac{N-1}{N}E_{n-1}$์„ ๋งŒ์กฑํ•ฉ๋‹ˆ๋‹ค. ๊ท€๋‚ฉ๋ฒ•์— ์˜ํ•ด์„œ $E_{n-1}  = (1-\beta^{n-1})/(1-\beta)$๋ฅผ ๋งŒ์กฑํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ–ˆ์„ ๋•Œ, $E_n = 1+\beta \frac{1-\beta^{n-1}}{1-\beta} = \frac{1-\beta+\beta-\beta^n}{1-\beta} = \frac{1-\beta^n}{1-\beta}$ ๋ฅผ ๋งŒ์กฑํ•จ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

Effective Number์— ๋Œ€ํ•œ ์ถ”๊ฐ€์ ์ธ ์ดํ•ด

๋”๋ณด๊ธฐ

๊ทธ๋ฆฌ๊ณ  ์œ„์˜ proposition์œผ๋กœ๋ถ€ํ„ฐ ์šฐ๋ฆฌ๋Š” ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜๊ฐ€ $n$์— ๋Œ€ํ•œ ์ง€์ˆ˜ํ•จ์ˆ˜์ž„์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. (์ฐธ๊ณ ๋กœ, $\beta$๋Š” 0์—์„œ 1 ์‚ฌ์ด $[0,1)$ ํ™•๋ฅ ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค) ๊ทธ๋ฆฌ๊ณ  ์ฐธ๊ณ ๋กœ $\beta$๋Š” $n$์ด ์ฆ๊ฐ€ํ•จ์—๋”ฐ๋ผ $E_n$์ด ์–ผ๋งˆ๋‚˜ ๋นจ๋ฆฌ ์ฆ๊ฐ€ํ•˜๋Š”์ง€๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด class๋ฅผ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” ์ง‘ํ•ฉ $S$ ์˜ ์ „์ฒด ํฌ๊ธฐ $N$์€ ์–ด๋–ป๊ฒŒ ๊ตฌํ•  ์ˆ˜ ์žˆ์„๊นŒ์š”? $E_n = (1-\beta^n)/(1-\beta)=\sum_{j=1}^{n} \beta^{j-1}$๋กœ๋ถ€ํ„ฐ $N$์€ ๋‹ค์Œ์˜ ๊ทนํ•œ์œผ๋กœ ๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. $N = \lim_{n\rightarrow \infty} \sum_{j=1}^{n} \beta^{j-1} = 1/(1-\beta)$. ์ด๋กœ๋ถ€ํ„ฐ $\beta = 0$์ด๋ฉด $E_n=1$ ์„ ๋งŒ์กฑํ•˜๊ณ  $\beta$๊ฐ€ 1์— ๊ฐ€๊นŒ์›Œ์งˆ์ˆ˜๋ก $E_n \rightarrow n$๋ฅผ ๋งŒ์กฑํ•จ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

 

Class-Balanced Loss

 

์•ž์„œ ์„ค๋ช…ํ–ˆ๋“ฏ์ด, ๋ณธ ๋…ผ๋ฌธ์€ ๊ธฐ์กด์˜ ์—ฐ๊ตฌ๋“ค ์ค‘ loss์— class๋ณ„ re-weight๋ฅผ ์ฃผ๋Š” ๋ฐฉ๋ฒ•๋ก ๋“ค์˜ ํ•œ๊ณ„์™€ ์•„์ด๋””์–ด๋ฅผ Effective Number๋ผ๋Š” ๊ฐœ๋…์„ ๋„์ž…ํ•จ์œผ๋กœ์จ ๊ฐœ์„ ํ•˜๊ณ  ์žˆ๋‹ค. ์ด ๋…ผ๋ฌธ์˜ ์žฅ์ ์€ loss function์— agnostic ํ•˜๋‹ค๋Š” ์ ์ด๋‹ค. ์‹ค์ œ๋กœ ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” Cross-Entropy Loss (Softmax, Sigmoid)์™€ Focal Loss์— ๋Œ€ํ•œ ์˜ˆ์ œ๋ฅผ ํ•จ๊ป˜ ์ œ์‹œํ•˜๊ณ ์žˆ๋‹ค.

 

์ž…๋ ฅ $\boldsymbol{x}$์™€ ๋ผ๋ฒจ $y \in \{1, 2, \cdots, C$๊ฐ€ ์ฃผ์–ด์ง€๊ณ  $C$๊ฐœ์˜ ํด๋ž˜์Šค๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ–ˆ์„ ๋•Œ, ๋ชจ๋ธ์˜ ์ถ”์ • class probability๋ฅผ $\boldsymbol{p} = [p_1, p_2, \cdots , p_C]^T$ ๋กœ ์ฃผ์–ด์ง„๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ์ด ๋•Œ์˜ ์†์‹ ํ•จ์ˆ˜๋ฅผ $\mathcal{L}(\boldsymbol{p}, y)$๋ฅผ ํ‘œ๊ธฐํ•œ๋‹ค.

 

๊ทธ๋Ÿฌ๋ฉด ํด๋ž˜์Šค $i$์˜ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ $n_i$๋ผ๊ณ  ํ–ˆ์„ ๋•Œ, ์•ž์˜ Proposition์— ์˜ํ•ด์„œ ํ•ด๋‹น ํด๋ž˜์Šค์˜ effective number๋Š” $E_{n_i} = (1-\beta_i^{n_i})(1-\beta_i)$๋ฅผ ๋งŒ์กฑํ•˜๊ณ  ์ด ๋•Œ $\beta_i = (N_i - 1)/N_i$๋ฅผ ๋งŒ์กฑํ•œ๋‹ค. ๊ทผ๋ฐ ๋ฌธ์ œ๋Š” $N_i$๋Š” ๊ทธ ์ •์˜์— ๋”ฐ๋ฅด๋ฉด ํด๋ž˜์Šค์˜ ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ์ƒ˜ํ”Œ ์ง‘ํ•ฉ $S_i$์˜ ํฌ๊ธฐ์ด๊ธฐ ๋•Œ๋ฌธ์— ๊ตฌํ•  ์ˆ˜ ์—†๋‹ค. ๋”ฐ๋ผ์„œ, ์‹ค์ œ ํ•™์Šต์—์„œ๋Š” $N_i$๊ฐ€ ์˜ค์ง ๋ฐ์ดํ„ฐ์…‹์— ์˜์กด์ ์ด๊ณ , ๋ชจ๋“  ํด๋ž˜์Šค $i$์˜ $N_i$๊ฐ€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋™์ผํ•˜๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค $N_i = N, \beta_i = \beta = (N-1)/N$.

 

์ด๋ ‡๊ฒŒ ๊ตฌํ•œ effective number๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์†์‹คํ•จ์ˆ˜๋ฅผ balancing ํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋…ผ๋ฌธ์—์„œ๋Š” "weighting factor $\alpha_i$"๋ฅผ ๋„์ž…ํ•ฉ๋‹ˆ๋‹ค. weigthing factor $\alpha_i$๋Š” class $i$์˜ ์œ ํšจ ์ƒ˜ํ”Œ ์ˆ˜์— ๋ฐ˜๋น„๋ก€ํ•˜๋Š” ํ•ญ์ด๋‹ค: $\alpha_i \propto 1/E_{n_i}$. (์•ž์˜ Re-weighting ๊ธฐ์กด ์—ฐ๊ตฌ ์ฐธ๊ณ ) ์ถ”๊ฐ€์ ์œผ๋กœ, Weigting factor๋ฅผ ์ ์šฉํ–ˆ์„ ๋•Œ ์Šค์ผ€์ผ์„ ์กฐ์ ˆํ•ด์ฃผ๊ธฐ ์œ„ํ•ด์„œ ์ „์ฒด ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ weighting factor๋“ค์„  ์ •๊ทœํ™”ํ•ด์ค€๋‹ค. ($\sum_{i=1}^{C} \alpha_i = C$)

 

์ด๋ฅผ ์ข…ํ•ฉํ•œ class-balanced (CB) loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

  • class $i$์˜ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ $n_i$๋ผ๊ณ  ํ•˜์ž
  • $\beta \in [0,1)$์ผ ๋•Œ, class $i$์˜ weighting factor๋Š” $(1-\beta)/(1-\beta^{n_i})$์ด๋‹ค.
  • ์ด ๋•Œ์˜ CB loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค: $CB(\boldsymbol{p}, y) = \frac{1}{E_{n_y}}\mathcal{L}(\boldsymbol{p},y) = \frac{1-\beta}{1-\beta^{n_y}}\mathcal{L}(\boldsymbol{p},y)$

์ฐธ๊ณ ๋กœ, $\beta = 0$์ด๋ฉด weighting์„ ์ฃผ์ง€ ์•Š๋Š”๊ฒƒ๊ณผ ๊ฐ™๊ณ , $\beta$ ๊ฐ’์ด 1์— ๊ฐ€๊นŒ์›Œ์งˆ์ˆ˜๋ก class frequency๋งŒ์œผ๋กœ re-weightingํ•˜๋Š” ๊ฒƒ๊ณผ ์œ ์‚ฌํ•œ ํšจ๊ณผ๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค.

 

 

EX) Class-Balanced Softmax Cross-Entropy Loss

 

Softmax Cross-Entropy loss์— Class Balance (CB)๋ฅผ ์ ์šฉํ•ด๋ณด์ž. ๋จผ์ €, ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ์ฃผ์–ด์ง„ ๊ฐ ํด๋ž˜์Šค๋ณ„ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ $\boldsymbol{z} = [z_1, z_2, \cdots, z_C]^T$๋ผ๊ณ  ํ•˜์ž. ๊ทธ๋Ÿฌ๋ฉด softmax cross-entropy loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ฃผ์–ด์ง„๋‹ค:

 

$CE_{softmax}(\boldsymbol{z}, y) = - \log \Big( exp(z_y)/\sum_{j=1}^{C} exp(z_j) \Big)$

 

์—ฌ๊ธฐ์— Class Balance๋ฅผ ์ ์šฉ์‹œํ‚ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. ๋ผ๋ฒจ $y$์— ๋Œ€์‘๋˜๋Š” ํด๋ž˜์Šค ํ•™์Šต ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ $n_y$๋ผ๊ณ ํ•˜๋ฉด, CB loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ฃผ์–ด์ง„๋‹ค.

 

$CB_{softmax}(\boldsymbol{z}, y) = - (1-\beta)/(1-\beta^{n_y}) \cdot \log \Big( exp(z_y)/\sum_{j=1}^{C} exp(z_j) \Big)$

 

 

 

Code ๋“ค์—ฌ๋‹ค๋ณด๊ธฐ

(์ฝ”๋“œ ์ถœ์ฒ˜: https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py)

def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.

    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.

    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.

    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    # normalization
    weights = weights / np.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()

    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss

 

 

 

Reference

(Cui et al. 2019) Cui et al. "Class-Balanced Loss Based on Effective Number of Samples", CVPR 2019

 

 

 

 

 

 

๋‚ด์šฉ์— ๋Œ€ํ•œ ์ฝ”๋ฉ˜ํŠธ๋Š”

์–ธ์ œ๋“ ์ง€ ํ™˜์˜์ž…๋‹ˆ๋‹ค

๋ โ—ผ๏ธŽ

 

 

 

 

 

๋ฐ˜์‘ํ˜•