๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ

IN DEPTH CAKE/ML-WIKI

<ML๋…ผ๋ฌธ> ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜• ๋ฌธ์ œ Cui et al. "Class-Balanced Loss Based on Effective Number of Samples" (CVPR 2019)

 

TL;DR

Class imbalance ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋ฐ์ดํ„ฐ์…‹ ๊ฐ ํด๋ž˜์Šค์˜ ์œ ํšจ ๋ฐ์ดํ„ฐ ์ˆ˜๋ฅผ ์ •์˜ํ•˜๊ณ  ์ด๋ฅผ ํ™œ์šฉํ•œ re-weighting๊ธฐ๋ฐ˜ Class Balance Loss ๊ธฐ๋ฒ• ์ œ์•ˆ.

 

 

๋ฌด์Šจ ๋ฌธ์ œ๋ฅผ ํ’€๊ณ  ์žˆ๋‚˜?

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต์— ์‚ฌ์šฉ๋˜๋Š” ์ผ๋ฐ˜์ ์ธ ๋ฐ์ดํ„ฐ ์…‹ (์˜ˆ๋ฅผ๋“ค์–ด CIFAR-10, 100, ImageNet ๋“ฑ)์ด ํด๋ž˜์Šค ๋ผ๋ฒจ ๋ถ„ํฌ๊ฐ€ ๊ท ์ผํ•œ ๊ฒƒ๊ณผ ๋‹ฌ๋ฆฌ, ์‹ค์ œ ์ƒํ™ฉ์—์„œ๋Š” ๋ชจ๋“  ํด๋ž˜์Šค์˜ ๋ฐ์ดํ„ฐ ์ˆ˜๊ฐ€ ๊ท ์ผํ•˜๊ฒŒ ์ˆ˜์ง‘๋˜์ง€ ์•Š๋Š”, Long Tail ํ˜„์ƒ์ด ๋ฐœ์ƒํ•œ๋‹ค. ์—ฌ๊ธฐ์„œ Long Tail์ด๋ผ๊ณ ํ•จ์€, ๊ฐ ํ•™์Šต ๋ฐ์ดํ„ฐ ์˜ ํด๋ž˜์Šค ๋ณ„ ์ƒ˜ํ”Œ ์ˆ˜์— ๋Œ€ํ•œ ๋ถ„ํฌ๋ฅผ ๊ทธ๋ ธ์„ ๋•Œ ์•„๋ž˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด, ์†Œ์ˆ˜์˜ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋งŽ์€ ๋ฐ ๋ฐ˜ํ•ด (Head) ๋‹ค์ˆ˜์˜ ํด๋ž˜์Šค์—์„œ ๊ธฐ๋Œ€์น˜ ์ดํ•˜์˜ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ ๊ฐ–๋Š” (Long Tail) ํ˜„์ƒ์„ ๋งํ•œ๋‹ค. ๋ง ๊ทธ๋Œ€๋กœ ํด๋ž˜์Šค ์ƒ˜ํ”Œ ์ˆ˜์˜ ๋ถ„ํฌ๋ฅผ ๊ทธ๋ ธ์„๋•Œ ๊ธด ๊ผฌ๋ฆฌ๋ฅผ ๊ฐ€์ง„๋‹ค๋Š” ๋œป์ด๋‹ค.

 

์š”์•ฝํ•˜์ž๋ฉด, ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ํ•™์Šต ๋ฐ์ดํ„ฐ ์…‹ ๋‚ด์˜ ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜• ๋ฌธ์ œ (Class Imbalance Problem) ๋ฅผ ๋‹ค๋ฃจ๊ณ  ์žˆ๋‹ค.

 

๊ทธ๋ฆผ ์ถœ์ฒ˜: Cui et al. Figure 1

 

๊ธฐ์กด ์—ฐ๊ตฌ๋“ค์€ ์ด ๋ฌธ์ œ๋ฅผ ์–ด๋–ป๊ฒŒ ๋‹ค๋ฃจ์—ˆ๋‚˜?

 

๊ธฐ์กด ์—ฐ๊ตฌ๋“ค์€ ๋ฌธ์ œ ํ•ด๊ฒฐ ๋ฐฉ์‹์— ๋”ฐ๋ผ ํฌ๊ฒŒ Re-sampling๋ฐฉ์‹๊ณผ Cost-sensitive re-weighting ๋ฐฉ์‹์œผ๋กœ ๋‚˜๋‰œ๋‹ค.

- Re-Sampling

Re-sampling๋ฐฉ์‹์€ ๋ง ๊ทธ๋Œ€๋กœ ์‹ค์ œ ์žˆ๋Š” ๋ฐ์ดํ„ฐ ์ˆ˜๋ฅผ ์ค‘๋ณตํ•ด์„œ samplingํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋ฐ์ดํ„ฐ ์ˆ˜๊ฐ€ ์ ์€ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ ์ค‘๋ณตํ•ด์„œ ๋” ์ƒ˜ํ”Œ๋ง์„ ํ•˜๊ฑฐ๋‚˜ (over-sampling) ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ๋งŽ์€ ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ ๋œ ์ƒ˜ํ”Œ๋ง ํ•˜๊ฑฐ๋‚˜ (under-sampling) ํ•˜๋Š” ํ˜•ํƒœ์ด๋‹ค. ์ด ๊ฒฝ์šฐ ์ƒ˜ํ”Œ ์ค‘๋ณต์œผ๋กœ ์ธํ•œ ๊ณผ์ ํ•ฉ (overfitting)์ด ์ผ์–ด๋‚  ์ˆ˜ ์žˆ๊ณ , over-sampling์„ ํ•˜๋Š” ๊ฒฝ์šฐ์— ์ƒ๋Œ€์ ์œผ๋กœ ํ•™์Šต ์‹œ๊ฐ„์ด ์ฆ๊ฐ€ํ•˜๋Š” ์ด์Šˆ ์—ญ์‹œ ๊ณ ๋ ค๋˜์–ด์•ผํ•œ๋‹ค. ์ด๋Ÿฌํ•œ ํ•œ๊ณ„์ ์œผ๋กœ ์ธํ•ด Re-sampling ๋ฐฉ์‹๋ณด๋‹ค๋Š” Re-weighting๋ฐฉ์‹์ด ์ข€ ๋” ์„ ํ˜ธ๋œ๋‹ค.

 

- Re-weighting

Re-weighting ๋ฐฉ์‹์€ ํด๋ž˜์Šค ๋ณ„ ์ƒ˜ํ”Œ ์ˆ˜์— ๋”ฐ๋ผ์„œ ํ•™์Šต ์‹œ ์†์‹ค ํ•จ์ˆ˜์— ๊ฐ€์ค‘์น˜๋ฅผ ์ฃผ๋Š” ๋ฐฉ์‹์„ ๋งํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ƒ๋Œ€์ ์œผ๋กœ ์†Œ์ˆ˜์˜ ํด๋ž˜์Šค์— ๋Œ€ํ•œ loss ๊ฐ’์— ๋” ๋†’์€ ๊ฐ€์ค‘์น˜๋ฅผ ์ฃผ๋Š” ํ˜•ํƒœ์ด๋‹ค. ์ƒ๋Œ€์ ์œผ๋กœ ๊ฐ„๋‹จํ•˜๋‹ค๋Š” ์žฅ์ ์œผ๋กœ ์ธํ•ด ์ด ๋ฐฉ์‹์ด ๋งŽ์ด ์‚ฌ์šฉ๋˜์—ˆ์œผ๋‚˜, large-scale ๋ฐ์ดํ„ฐ ์…‹์— ๋Œ€ํ•ด์„œ๋Š” ์ด๋Ÿฌํ•œ ๋‹จ์ˆœํ•œ ๋ฐฉ์‹์ด ์ž˜ ๋™์ž‘ํ•˜์ง€ ์•Š๋Š” ๋‹ค๋Š” ์‚ฌ์‹ค์ด ํ™•์ธ๋˜์—ˆ๋‹ค. ๋Œ€์‹  "smoothed weight" ์†์‹ค ํ•จ์ˆ˜์— ๋Œ€ํ•œ ์—ฐ๊ตฌ๊ฐ€ ์ง„ํ–‰๋˜์—ˆ๋‹ค. "smoothed" ๋ฒ„์ „์˜ ๊ฒฝ์šฐ์—๋Š” ํด๋ž˜์Šค ๋ถ„ํฌ์˜ ์ œ๊ณฑ๊ทผ์— ๋ฐ˜๋น„๋ก€ํ•˜๋„๋ก weight๋ฅผ ์ค€๋‹ค.

 

 

์ด ๋…ผ๋ฌธ์—์„œ๋Š” ์ด ๋ฌธ์ œ๋ฅผ ์–ด๋–ป๊ฒŒ ํ’€๊ณ  ์žˆ๋‚˜?

Conceptually,

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” Re-weighting๊ธฐ๋ฒ•์—์„œ ์ฐฉ์•ˆํ•œ ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•œ๋‹ค. ์•„๋ž˜์˜ 2๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ์˜ˆ๋กœ ๋“ค์–ด ์„ค๋ช…ํ•ด๋ณด๋ฉด, ์™ผ์ชฝ์€ Head์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์ด๊ณ  ์˜ค๋ฅธ์ชฝ์€ Long Tail์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์ด๋ผ๊ณ  ํ•˜์ž. ๊ธฐ์กด์˜ re-weighting ๋ฐฉ๋ฒ•์€ ์™ผ์ชฝ ํด๋ž˜์Šค ์ƒ˜ํ”Œ์˜ ์ˆ˜์™€ ์˜ค๋ฅธ์ชฝ ํด๋ž˜์Šค ์ƒ˜ํ”Œ์˜ ์ˆ˜๋ฅผ ๊ฐ€์ง€๊ณ  weighting์„ ์ˆ˜ํ–‰ํ–ˆ์ง€๋งŒ, ์ด ๊ฒฝ์šฐ ๊ฒ€์€์ƒ‰ ์‹ค์„ ์œผ๋กœ ๊ทธ๋ ค์ง„ classifier ์„ ์ด ๋นจ๊ฐ„์ƒ‰์œผ๋กœ ์น˜์šฐ์น˜๊ฒŒ๋œ๋‹ค. ํ•˜์ง€๋งŒ ํด๋ž˜์Šค ๋‚ด์˜ ๋ชจ๋“  ์ƒ˜ํ”Œ๋“ค์ด ์œ ํšจํ•œ ์ˆ˜๋ฅผ ๊ตฌ์„ฑํ•˜๋Š” ๊ฒƒ์€ ์•„๋‹ˆ๋ผ๋Š” ์ ์„ ์ƒ๊ธฐํ•ด์•ผํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ƒ˜ํ”Œ๋“ค ๊ฐ„์— ์ค‘๋ณต๋˜๋Š” ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๋ฐ์ดํ„ฐ๊ฐ€ ์žˆ์„ ์ˆ˜๋„ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋‹จ์ˆœํžˆ '์ƒ˜ํ”Œ์˜ ์ˆ˜'๋ฅผ loss์˜ weight๋กœ ์ •ํ•˜๋Š” ๊ฒƒ์€ ๋ฌธ์ œ๊ฐ€ ๋  ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ๊ฐ ํด๋ž˜์Šค ๋ณ„ '์ƒ˜ํ”Œ์˜ ์œ ํšจ ์ˆซ์ž (effective number)'๋ฅผ ๊ตฌํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ์ด๋ ‡๊ฒŒ ๊ตฌํ•ด์ง„ ์œ ํšจ ๊ฐœ์ˆ˜์— ๋ฐ˜๋น„๋ก€ํ•˜๋„๋ก loss๋ฅผ re-weighting์„ ํ•˜๋Š” ๊ฒฝ์šฐ ๊ฒ€์€์ƒ‰ ์„ ์œผ๋กœ ์น˜์šฐ์นœ classifier๋ฅผ ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ํŒŒ๋ž€์ƒ‰ ์„ ๊ณผ ๊ฐ™์ด ์กฐ์ •ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

๊ทธ๋ฆผ ์ถœ์ฒ˜: Cui et al. Figure 1

์œ ํšจ ์ƒ˜ํ”Œ ์ˆ˜ (Effective Number of Samples)

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ๊ฐ ๋ฐ์ดํ„ฐ ์…‹์˜ ํด๋ž˜์Šค ๋ณ„ ์œ ํšจ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ ์–ด๋–ป๊ฒŒ ๊ตฌํ•˜๋Š”์ง€๋ฅผ ์ œ์•ˆํ•˜๊ณ ์žˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด ๋…ผ๋ฌธ์—์„œ๋Š” ์œ ํšจ ๊ฐœ์ˆ˜ (Effective Number) ๋ฅผ ๋‹ค์Œ๊ณผ ๊ฐ€์ด ์ •์˜ํ•˜๊ณ  ์žˆ๋‹ค.

 


Definition 1 (Effective Number).
The effective number of samples is the expecterd volume of samples.


 

์ƒ˜ํ”Œ ํฌ๊ธฐ์˜ ๊ธฐ๋Œ“๊ฐ’์„ ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜๋กœ ์ •์˜ํ•˜๊ณ ์žˆ๋‹ค.

 

๊ธฐ๋Œ“๊ฐ’์„ ์ •์˜ํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ƒˆ๋กœ ์ƒ˜ํ”Œ๋งํ•œ ๋ฐ์ดํ„ฐ๊ฐ€ ์ง€๊ธˆ๊นŒ์ง€ ๊ด€์ฐฐ๋œ ์ƒ˜ํ”Œ๊ณผ ๊ฒน์น  ํ™•๋ฅ ์„ p, ๊ฒน์น˜์ง€ ์•Š์„ ํ™•๋ฅ ์„ 1โˆ’p๋กœ ์ƒ์ •ํ•œ๋‹ค (์ด ๋•Œ, ๋ฌธ์ œ๋ฅผ ๋‹จ์ˆœํ™”ํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ผ๋ถ€๋งŒ ๊ฒน์น˜๋Š” ๊ฒฝ์šฐ์— ๋Œ€ํ•ด์„œ๋Š” ๊ณ ๋ คํ•˜๊ณ  ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค) ์ด๋ฅผ ๊ทธ๋ฆผ์œผ๋กœ ๊ทธ๋ ค๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

๊ทธ๋ฆผ ์ถœ์ฒ˜: Cui et al. Figure 2

 

 

๋ถ€์—ฐ ์„ค๋ช…

๋”๋ณด๊ธฐ

๋ณธ ๋…ผ๋ฌธ์˜ ์•„์ด๋””์–ด๋Š” ํด๋ž˜์Šค์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋” ๋งŽ์ด ์‚ฌ์šฉํ•  ๋•Œ marginal benefit์ด ์ค„์–ด๋“œ๋Š” ์ •๋„๋ฅผ ์ธก์ •ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์ƒ˜ํ”Œ์˜ ์ˆ˜๊ฐ€ ์ฆ๊ฐ€ํ•˜๋ฉด ์ƒˆ๋กœ ์ถ”๊ฐ€๋˜๋Š” ์ƒ˜ํ”Œ์ด ํ˜„์žฌ ์กด์žฌํ•˜๋Š” ์ƒ˜ํ”Œ๊ณผ ๊ฒน์น  ํ™•๋ฅ ์ด ๋†’์•„์ง„๋‹ค. (์‹ฌํ”Œํ•œ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• - ํšŒ์ „, ํฌ๋กœํ•‘ ๋“ฑ - ์œผ๋กœ ํ™•๋ณด๋œ ๋ฐ์ดํ„ฐ ์—ญ์‹œ ์ค‘๋ณต๋˜๋Š” ๋ฐ์ดํ„ฐ๋กœ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)

 

์ด์ œ ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜ (expected number or expected volume)๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ํ‘œํ˜„ํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋จผ์ € ํ•ด๋‹น ํด๋ž˜์Šค์˜ ํ”ผ์ณ ๊ณต๊ฐ„์—์„œ ๋ชจ๋“  ๊ฐ€๋Šฅํ•œ ๋ฐ์ดํ„ฐ์˜ ์ง‘ํ•ฉ์„ S๋ผ๊ณ  ํ•˜๊ณ , ์ด๋Ÿฌํ•œ ์ง‘ํ•ฉ S์˜ ํฌ๊ธฐ๋ฅผ N์ด๋ผ๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์šฐ๋ฆฌ๊ฐ€ ํ‘œํ˜„ํ•˜๊ณ ์žํ•˜๋Š” ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜๋ฅผ En์œผ๋กœ ํ‘œ๊ธฐํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ n์€ ์ƒ˜ํ”Œ์˜ ์ˆ˜๋ฅผ ์˜๋ฏธํ•œ๋‹ค.

 

๊ทธ๋Ÿฌ๋ฉด ์ƒ˜ํ”Œ์˜ ์œ ํšจ๊ฐœ์ˆ˜๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œ๊ธฐํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

Proposition 1 (Effective Number).
En=(1โˆ’ฮฒn)/(1โˆ’ฮฒ), where ฮฒ=(Nโˆ’1)/N.

 

 

์ฆ๋ช… ์„ค๋ช…

๋”๋ณด๊ธฐ

๊ท€๋‚ฉ๋ฒ•์œผ๋กœ ์ด๋ฅผ ์„ค๋ช…ํ•ด๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. ๋จผ๋” E1=1์„ ๋งŒ์กฑํ•˜๊ณ  (overlap์ด ์—†๊ธฐ ๋•Œ๋ฌธ์—) ์ด๋Š” E1=(1โˆ’ฮฒ1)/(1โˆ’ฮฒ)=1์„ ๋งŒ์กฑํ•ฉ๋‹ˆ๋‹ค. ์ด์ œ ๊ณผ๊ฑฐ์— ์ƒ˜ํ”Œ๋œ nโˆ’1๊ฐœ์˜ ์˜ˆ์ œ๋กœ๋ถ€ํ„ฐ n๋ฒˆ์งธ ์ƒ˜ํ”Œ์„ ์ƒ˜ํ”Œ๋งํ–ˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์ƒˆ๋กœ ์ƒ˜ํ”Œ๋œ ๋ฐ์ดํ„ฐ๊ฐ€ ์ด์ „์˜ ์ƒ˜ํ”Œ๋“ค๊ณผ overlap๋  ํ™•๋ฅ ์€ ์ „์ฒด ์ง‘ํ•ฉS์˜ ๊ฐœ์ˆ˜ N์ค‘ Enโˆ’1์ผ ํ™•๋ฅ , ์ฆ‰ p=Enโˆ’1/N์„ ๋งŒ์กฑํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ n๊ฐœ์˜ ์ƒ˜ํ”Œ์— ๋Œ€ํ•œ ์œ ํšจ ๊ฐœ์ˆ˜์˜ ๊ธฐ๋Œ“๊ฐ’์€ En=pโ‹…Enโˆ’1+(1โˆ’p)(Enโˆ’1+1)=1+Nโˆ’1NEnโˆ’1์„ ๋งŒ์กฑํ•ฉ๋‹ˆ๋‹ค. ๊ท€๋‚ฉ๋ฒ•์— ์˜ํ•ด์„œ Enโˆ’1=(1โˆ’ฮฒnโˆ’1)/(1โˆ’ฮฒ)๋ฅผ ๋งŒ์กฑํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ–ˆ์„ ๋•Œ, En=1+ฮฒ1โˆ’ฮฒnโˆ’11โˆ’ฮฒ=1โˆ’ฮฒ+ฮฒโˆ’ฮฒn1โˆ’ฮฒ=1โˆ’ฮฒn1โˆ’ฮฒ ๋ฅผ ๋งŒ์กฑํ•จ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

Effective Number์— ๋Œ€ํ•œ ์ถ”๊ฐ€์ ์ธ ์ดํ•ด

๋”๋ณด๊ธฐ

๊ทธ๋ฆฌ๊ณ  ์œ„์˜ proposition์œผ๋กœ๋ถ€ํ„ฐ ์šฐ๋ฆฌ๋Š” ์ƒ˜ํ”Œ์˜ ์œ ํšจ ๊ฐœ์ˆ˜๊ฐ€ n์— ๋Œ€ํ•œ ์ง€์ˆ˜ํ•จ์ˆ˜์ž„์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. (์ฐธ๊ณ ๋กœ, ฮฒ๋Š” 0์—์„œ 1 ์‚ฌ์ด [0,1) ํ™•๋ฅ ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค) ๊ทธ๋ฆฌ๊ณ  ์ฐธ๊ณ ๋กœ ฮฒ๋Š” n์ด ์ฆ๊ฐ€ํ•จ์—๋”ฐ๋ผ En์ด ์–ผ๋งˆ๋‚˜ ๋นจ๋ฆฌ ์ฆ๊ฐ€ํ•˜๋Š”์ง€๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด class๋ฅผ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋Š” ์ง‘ํ•ฉ S ์˜ ์ „์ฒด ํฌ๊ธฐ N์€ ์–ด๋–ป๊ฒŒ ๊ตฌํ•  ์ˆ˜ ์žˆ์„๊นŒ์š”? En=(1โˆ’ฮฒn)/(1โˆ’ฮฒ)=โˆ‘j=1nฮฒjโˆ’1๋กœ๋ถ€ํ„ฐ N์€ ๋‹ค์Œ์˜ ๊ทนํ•œ์œผ๋กœ ๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. N=limnโ†’โˆžโˆ‘j=1nฮฒjโˆ’1=1/(1โˆ’ฮฒ). ์ด๋กœ๋ถ€ํ„ฐ ฮฒ=0์ด๋ฉด En=1 ์„ ๋งŒ์กฑํ•˜๊ณ  ฮฒ๊ฐ€ 1์— ๊ฐ€๊นŒ์›Œ์งˆ์ˆ˜๋ก Enโ†’n๋ฅผ ๋งŒ์กฑํ•จ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

 

Class-Balanced Loss

 

์•ž์„œ ์„ค๋ช…ํ–ˆ๋“ฏ์ด, ๋ณธ ๋…ผ๋ฌธ์€ ๊ธฐ์กด์˜ ์—ฐ๊ตฌ๋“ค ์ค‘ loss์— class๋ณ„ re-weight๋ฅผ ์ฃผ๋Š” ๋ฐฉ๋ฒ•๋ก ๋“ค์˜ ํ•œ๊ณ„์™€ ์•„์ด๋””์–ด๋ฅผ Effective Number๋ผ๋Š” ๊ฐœ๋…์„ ๋„์ž…ํ•จ์œผ๋กœ์จ ๊ฐœ์„ ํ•˜๊ณ  ์žˆ๋‹ค. ์ด ๋…ผ๋ฌธ์˜ ์žฅ์ ์€ loss function์— agnostic ํ•˜๋‹ค๋Š” ์ ์ด๋‹ค. ์‹ค์ œ๋กœ ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” Cross-Entropy Loss (Softmax, Sigmoid)์™€ Focal Loss์— ๋Œ€ํ•œ ์˜ˆ์ œ๋ฅผ ํ•จ๊ป˜ ์ œ์‹œํ•˜๊ณ ์žˆ๋‹ค.

 

์ž…๋ ฅ x์™€ ๋ผ๋ฒจ yโˆˆ{1,2,โ‹ฏ,C๊ฐ€ ์ฃผ์–ด์ง€๊ณ  C๊ฐœ์˜ ํด๋ž˜์Šค๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ–ˆ์„ ๋•Œ, ๋ชจ๋ธ์˜ ์ถ”์ • class probability๋ฅผ p=[p1,p2,โ‹ฏ,pC]T ๋กœ ์ฃผ์–ด์ง„๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ์ด ๋•Œ์˜ ์†์‹ ํ•จ์ˆ˜๋ฅผ L(p,y)๋ฅผ ํ‘œ๊ธฐํ•œ๋‹ค.

 

๊ทธ๋Ÿฌ๋ฉด ํด๋ž˜์Šค i์˜ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ ni๋ผ๊ณ  ํ–ˆ์„ ๋•Œ, ์•ž์˜ Proposition์— ์˜ํ•ด์„œ ํ•ด๋‹น ํด๋ž˜์Šค์˜ effective number๋Š” Eni=(1โˆ’ฮฒini)(1โˆ’ฮฒi)๋ฅผ ๋งŒ์กฑํ•˜๊ณ  ์ด ๋•Œ ฮฒi=(Niโˆ’1)/Ni๋ฅผ ๋งŒ์กฑํ•œ๋‹ค. ๊ทผ๋ฐ ๋ฌธ์ œ๋Š” Ni๋Š” ๊ทธ ์ •์˜์— ๋”ฐ๋ฅด๋ฉด ํด๋ž˜์Šค์˜ ๊ฐ€๋Šฅํ•œ ๋ชจ๋“  ์ƒ˜ํ”Œ ์ง‘ํ•ฉ Si์˜ ํฌ๊ธฐ์ด๊ธฐ ๋•Œ๋ฌธ์— ๊ตฌํ•  ์ˆ˜ ์—†๋‹ค. ๋”ฐ๋ผ์„œ, ์‹ค์ œ ํ•™์Šต์—์„œ๋Š” Ni๊ฐ€ ์˜ค์ง ๋ฐ์ดํ„ฐ์…‹์— ์˜์กด์ ์ด๊ณ , ๋ชจ๋“  ํด๋ž˜์Šค i์˜ Ni๊ฐ€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋™์ผํ•˜๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค Ni=N,ฮฒi=ฮฒ=(Nโˆ’1)/N.

 

์ด๋ ‡๊ฒŒ ๊ตฌํ•œ effective number๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์†์‹คํ•จ์ˆ˜๋ฅผ balancing ํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋…ผ๋ฌธ์—์„œ๋Š” "weighting factor ฮฑi"๋ฅผ ๋„์ž…ํ•ฉ๋‹ˆ๋‹ค. weigthing factor ฮฑi๋Š” class i์˜ ์œ ํšจ ์ƒ˜ํ”Œ ์ˆ˜์— ๋ฐ˜๋น„๋ก€ํ•˜๋Š” ํ•ญ์ด๋‹ค: ฮฑiโˆ1/Eni. (์•ž์˜ Re-weighting ๊ธฐ์กด ์—ฐ๊ตฌ ์ฐธ๊ณ ) ์ถ”๊ฐ€์ ์œผ๋กœ, Weigting factor๋ฅผ ์ ์šฉํ–ˆ์„ ๋•Œ ์Šค์ผ€์ผ์„ ์กฐ์ ˆํ•ด์ฃผ๊ธฐ ์œ„ํ•ด์„œ ์ „์ฒด ํด๋ž˜์Šค์— ๋Œ€ํ•ด์„œ weighting factor๋“ค์„  ์ •๊ทœํ™”ํ•ด์ค€๋‹ค. (โˆ‘i=1Cฮฑi=C)

 

์ด๋ฅผ ์ข…ํ•ฉํ•œ class-balanced (CB) loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

  • class i์˜ ์ƒ˜ํ”Œ ์ˆ˜๋ฅผ ni๋ผ๊ณ  ํ•˜์ž
  • ฮฒโˆˆ[0,1)์ผ ๋•Œ, class i์˜ weighting factor๋Š” (1โˆ’ฮฒ)/(1โˆ’ฮฒni)์ด๋‹ค.
  • ์ด ๋•Œ์˜ CB loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค: CB(p,y)=1EnyL(p,y)=1โˆ’ฮฒ1โˆ’ฮฒnyL(p,y)

์ฐธ๊ณ ๋กœ, ฮฒ=0์ด๋ฉด weighting์„ ์ฃผ์ง€ ์•Š๋Š”๊ฒƒ๊ณผ ๊ฐ™๊ณ , ฮฒ ๊ฐ’์ด 1์— ๊ฐ€๊นŒ์›Œ์งˆ์ˆ˜๋ก class frequency๋งŒ์œผ๋กœ re-weightingํ•˜๋Š” ๊ฒƒ๊ณผ ์œ ์‚ฌํ•œ ํšจ๊ณผ๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค.

 

 

EX) Class-Balanced Softmax Cross-Entropy Loss

 

Softmax Cross-Entropy loss์— Class Balance (CB)๋ฅผ ์ ์šฉํ•ด๋ณด์ž. ๋จผ์ €, ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ์ฃผ์–ด์ง„ ๊ฐ ํด๋ž˜์Šค๋ณ„ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ z=[z1,z2,โ‹ฏ,zC]T๋ผ๊ณ  ํ•˜์ž. ๊ทธ๋Ÿฌ๋ฉด softmax cross-entropy loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ฃผ์–ด์ง„๋‹ค:

 

CEsoftmax(z,y)=โˆ’logโก(exp(zy)/โˆ‘j=1Cexp(zj))

 

์—ฌ๊ธฐ์— Class Balance๋ฅผ ์ ์šฉ์‹œํ‚ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. ๋ผ๋ฒจ y์— ๋Œ€์‘๋˜๋Š” ํด๋ž˜์Šค ํ•™์Šต ์ƒ˜ํ”Œ ์ˆ˜๊ฐ€ ny๋ผ๊ณ ํ•˜๋ฉด, CB loss๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ฃผ์–ด์ง„๋‹ค.

 

CBsoftmax(z,y)=โˆ’(1โˆ’ฮฒ)/(1โˆ’ฮฒny)โ‹…logโก(exp(zy)/โˆ‘j=1Cexp(zj))

 

 

 

Code ๋“ค์—ฌ๋‹ค๋ณด๊ธฐ

(์ฝ”๋“œ ์ถœ์ฒ˜: https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py)

def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.

    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.

    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.

    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    # normalization
    weights = weights / np.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()

    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss

 

 

 

Reference

(Cui et al. 2019) Cui et al. "Class-Balanced Loss Based on Effective Number of Samples", CVPR 2019

 

 

 

 

 

 

๋‚ด์šฉ์— ๋Œ€ํ•œ ์ฝ”๋ฉ˜ํŠธ๋Š”

์–ธ์ œ๋“ ์ง€ ํ™˜์˜์ž…๋‹ˆ๋‹ค

๋ โ—ผ๏ธŽ

 

 

 

 

 

๋ฐ˜์‘ํ˜•