์ƒˆ์†Œ์‹

IN DEPTH CAKE/Coding-WIKI

[pytorch] RuntimeError: one_hot is only applicable to index tensor

  • -
๋ฐ˜์‘ํ˜•

 

 

import torch.nn.functional as F
bs = 10
F.one_hot(torch.ones(bs)*2, 5)

2์— ๋Œ€์‘๋˜๋Š” ์ „์ฒด ํฌ๊ธฐ๋Š” 5์ธ one hot encoding vector๋ฅผ batch size๊ฐœ์ˆ˜๋งŒํผ ๋งŒ๋“ค๊ณ  ์‹ถ์–ด์„œ ์œ„์™€ ๊ฐ™์ด ์ฝ”๋“œ๋ฅผ ์ž‘์„ฑํ•˜๋ฉด ๋Ÿฐํƒ€์ž„ ์—๋Ÿฌ๊ฐ€ ๋œฌ๋‹ค.

 


RuntimeError: one_hot is only applicable to index tensor.

one_hot ํ•จ์ˆ˜๋Š” index tensor (ํƒ€์ž… ๋ช…์ด torch.int64)๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›๊ธฐ ๋•Œ๋ฌธ์— ๋ฐœ์ƒํ•˜๋Š” ์—๋Ÿฌ์ด๋‹ค. ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ˆ˜์ •ํ•ด์ฃผ๋ฉด ๋œ๋‹ค.

 

>>> torch.nn.functional.one_hot((torch.ones(bs)*2).to(torch.int64), 5)
tensor([[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0]])

 

 

 

๋ โ—ผ๏ธŽ

๋ฐ˜์‘ํ˜•
Contents

ํฌ์ŠคํŒ… ์ฃผ์†Œ๋ฅผ ๋ณต์‚ฌํ–ˆ์Šต๋‹ˆ๋‹ค

์ด ๊ธ€์ด ๋„์›€์ด ๋˜์—ˆ๋‹ค๋ฉด ๊ณต๊ฐ ๋ถ€ํƒ๋“œ๋ฆฝ๋‹ˆ๋‹ค.

# ๋กœ๋”ฉ ํ™”๋ฉด ๋™์ž‘ ์ฝ”๋“œ(Code) ์„ค์ •ํ•˜๊ธฐ