[PyTorch] 텐서 반복하기 torch.Tensor.repeat
2023. 7. 18. 17:52
torch.Tensor.repeat
Repeats this tensor along the specified dimensions. 특정 dimension에 따라 텐서를 반복함.
Parameters:
sizes
(torch.Size or int...) – The number of times to repeat this tensor along each dimension. 각 dimension에 따라 이 텐서를 반복하는 횟수.
Example:
x = torch.tensor([1, 2, 3]) # torch.Size([3])
x1 = x.repeat(4, 2) # torch.Size([4, 6])
x2 = x.repeat(4, 2, 1) # torch.Size([4, 2, 3])
x3 = x.repeat(3) # torch.Size([9])
print(x3) # tensor([1, 2, 3, 1, 2, 3, 1, 2, 3])
x 텐서의 사이즈가 3일 때, repeat(4, 2)
을 사용하면 사이즈는 (4, 6)= (1, 3) * (4, 2)_이 된다.repeat(4, 2, 1)
을 사용하면 사이즈는 (4, 2, 3)= (1, 1, 3) * (4, 2, 3)이 된다.repeat(3)
을 사용하면 사이즈는 (9)= (3) * (3)_가 된다.
2차원 텐서를 특정 dimension으로 늘려서 3차원 텐서로 만들기
bs = 8
n = 4
y_hat = torch.rand(bs, n, 1) # bs, n, 1
y_batch = torch.rand(bs, 1) # bs, 1
y_batch = y_batch.unsqueeze(1).repeat(1, n, 1) # bs, n, 1
# y_batch1 = y_batch.unsqueeze(1) # bs, 1, 1
# y_batch2 = y_batch1.repeat(1, n, 1) # bs, n, 1
(bs, 1) shape을 가진 y_batch를 (bs, n, 1) shape으로 늘리고 싶을 때, unsqueeze와 repeat을 사용하자.
'* ML | DL > python' 카테고리의 다른 글
[Python] getattr() 특정 문자열의 이름을 가지는 attribute 반환 (0) | 2023.07.20 |
---|---|
[Error / PyTorch] RuntimeError: GET was unable to find an engine to execute this computation (0) | 2023.07.19 |
[Python] 상속(Inheritance) (0) | 2023.07.17 |
pytorch DDP (0) | 2023.07.11 |
pandas apply (0) | 2023.05.11 |