import torch.nn.functional as F
bs = 10
F.one_hot(torch.ones(bs)*2, 5)
2์ ๋์๋๋ ์ ์ฒด ํฌ๊ธฐ๋ 5์ธ one hot encoding vector๋ฅผ batch size๊ฐ์๋งํผ ๋ง๋ค๊ณ ์ถ์ด์ ์์ ๊ฐ์ด ์ฝ๋๋ฅผ ์์ฑํ๋ฉด ๋ฐํ์ ์๋ฌ๊ฐ ๋ฌ๋ค.
RuntimeError: one_hot is only applicable to index tensor.
one_hot ํจ์๋ index tensor (ํ์
๋ช
์ด torch.int64)๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ๊ธฐ ๋๋ฌธ์ ๋ฐ์ํ๋ ์๋ฌ์ด๋ค. ๋ค์๊ณผ ๊ฐ์ด ์์ ํด์ฃผ๋ฉด ๋๋ค.
>>> torch.nn.functional.one_hot((torch.ones(bs)*2).to(torch.int64), 5)
tensor([[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0]])
๋ โผ๏ธ