[PyTorch] 텐서 반복하기 torch.Tensor.repeat

2023. 7. 18. 17:52

PyTorch Docs

torch.Tensor.repeat

Repeats this tensor along the specified dimensions. 특정 dimension에 따라 텐서를 반복함.

Parameters:

sizes (torch.Size or int...) – The number of times to repeat this tensor along each dimension. 각 dimension에 따라 이 텐서를 반복하는 횟수.

Example:

x = torch.tensor([1, 2, 3]) # torch.Size([3])

x1 = x.repeat(4, 2) # torch.Size([4, 6])

x2 = x.repeat(4, 2, 1) # torch.Size([4, 2, 3])

x3 = x.repeat(3) # torch.Size([9])
print(x3) # tensor([1, 2, 3, 1, 2, 3, 1, 2, 3])

x 텐서의 사이즈가 3일 때, repeat(4, 2)을 사용하면 사이즈는 (4, 6)= (1, 3) * (4, 2)_이 된다.
repeat(4, 2, 1)을 사용하면 사이즈는 (4, 2, 3)
= (1, 1, 3) * (4, 2, 3)이 된다.
repeat(3)을 사용하면 사이즈는 (9)
= (3) * (3)_가 된다.

2차원 텐서를 특정 dimension으로 늘려서 3차원 텐서로 만들기

bs = 8
n = 4

y_hat = torch.rand(bs, n, 1) # bs, n, 1
y_batch = torch.rand(bs, 1) # bs, 1


y_batch = y_batch.unsqueeze(1).repeat(1, n, 1) # bs, n, 1
# y_batch1 = y_batch.unsqueeze(1) # bs, 1, 1
# y_batch2 = y_batch1.repeat(1, n, 1) # bs, n, 1

(bs, 1) shape을 가진 y_batch를 (bs, n, 1) shape으로 늘리고 싶을 때, unsqueeze와 repeat을 사용하자.

BELATED ARTICLES

more