| 이름 | topk |
|---|---|
| 라이브러리 | PyTorch |
| 도입 시기 | PyTorch 0.x |
| 시그니처 | torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) |
| 설명 | 주어진 차원을 따라 입력 텐서에서 상위 또는 하위 k개의 원소와 그 인덱스를 반환한다. |
topk은 PyTorch에서 텐서의 주어진 차원(dim)에서 상위 k개의 값을 추출하고 그 인덱스를 함께 반환하는 함수이다.[1]
함수 개요
- 함수명:
torch.topk또는 텐서 객체에서Tensor.topk메서드 - 기능: 입력 텐서에서 특정 차원을 따라 k개의 최댓값 또는 최솟값과 그 인덱스를 반환
- 반환형:
namedtuple(values, indices) - 기본 동작:
largest=True이면 큰 값부터,sorted=True이면 정렬된 결과를 반환
매개변수
- input (Tensor): 입력 텐서
- k (int): 추출할 요소 개수
- dim (int, optional): 연산할 차원. 기본은 마지막 차원
- largest (bool, optional):
True면 큰 값,False면 작은 값 반환 - sorted (bool, optional):
True면 정렬된 값 반환 - out (tuple, optional): 출력 텐서 (
values, indices형식)
반환
values(Tensor): 상위 k개의 값indices(Tensor): 원본 텐서에서 값의 위치 인덱스
사용 예시
import torch
x = torch.tensor([1., 2., 3., 4., 5.])
values, indices = torch.topk(x, 3)
print(values) # tensor([5., 4., 3.])
print(indices) # tensor([4, 3, 2])
2차원 예시:
x = torch.tensor([[1., 3., 2.],
[4., 0., 5.]])
values, indices = torch.topk(x, 2, dim=1)
print(values)
# tensor([[3., 2.],
# [5., 4.]])
print(indices)
# tensor([[1, 2],
# [2, 0]])
활용
- 분류 문제의 top-k 정확도 계산 시 유용
- 이상치 탐지 또는 작은 값 필터링 시
largest=False설정 - 다차원 전체에서 추출하려면
flatten()후 적용 필요
# 전체 텐서에서 상위 5개 값 추출
x = torch.randn(3, 4, 5)
flat = x.flatten()
vals, inds = torch.topk(flat, 5)
유의사항
- 동일한 값이 있을 경우
indices순서는 보장되지 않음 sorted=False로 설정하면 성능이 더 좋을 수 있음k가 해당 차원 크기보다 클 경우 오류 발생
관련 함수
torch.kthvalue: 정확히 k번째 값torch.sort: 전체 정렬torch.argsort: 정렬된 인덱스 반환
같이 보기
- PyTorch 텐서(Tensor)
- PyTorch 모델 분류(Classification)
- Top‑k 정확도(Top‑K Accuracy)
- PyTorch 정렬(Sort) 연산
- torch.kthvalue
참고 문헌
- PyTorch 공식 문서: https://pytorch.org/docs/stable/generated/torch.topk.html
- GeeksforGeeks. “How to find the k‑th and the top 'k' elements of a tensor in PyTorch?” https://www.geeksforgeeks.org/how-to-find-the-k-th-and-the-top-k-elements-of-a-tensor-in-pytorch/
- PyTorch GitHub Issue #88184. https://github.com/pytorch/pytorch/issues/88184