IT용어위키


PyTorch topk

함수 정보
이름 topk
라이브러리 PyTorch
도입 시기 PyTorch 0.x
시그니처 torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
설명 주어진 차원을 따라 입력 텐서에서 상위 또는 하위 k개의 원소와 그 인덱스를 반환한다.


topk은 PyTorch에서 텐서의 주어진 차원(dim)에서 상위 k개의 값을 추출하고 그 인덱스를 함께 반환하는 함수이다.[1]

함수 개요

  • 함수명: torch.topk 또는 텐서 객체에서 Tensor.topk 메서드
  • 기능: 입력 텐서에서 특정 차원을 따라 k개의 최댓값 또는 최솟값과 그 인덱스를 반환
  • 반환형: namedtuple(values, indices)
  • 기본 동작: largest=True이면 큰 값부터, sorted=True이면 정렬된 결과를 반환

매개변수

  • input (Tensor): 입력 텐서
  • k (int): 추출할 요소 개수
  • dim (int, optional): 연산할 차원. 기본은 마지막 차원
  • largest (bool, optional): True면 큰 값, False면 작은 값 반환
  • sorted (bool, optional): True면 정렬된 값 반환
  • out (tuple, optional): 출력 텐서 (values, indices 형식)

반환

  • values (Tensor): 상위 k개의 값
  • indices (Tensor): 원본 텐서에서 값의 위치 인덱스

사용 예시

import torch
x = torch.tensor([1., 2., 3., 4., 5.])
values, indices = torch.topk(x, 3)
print(values)   # tensor([5., 4., 3.])
print(indices)  # tensor([4, 3, 2])

2차원 예시:

x = torch.tensor([[1., 3., 2.],
                  [4., 0., 5.]])
values, indices = torch.topk(x, 2, dim=1)
print(values)
# tensor([[3., 2.],
#         [5., 4.]])
print(indices)
# tensor([[1, 2],
#         [2, 0]])

활용

  • 분류 문제의 top-k 정확도 계산 시 유용
  • 이상치 탐지 또는 작은 값 필터링 시 largest=False 설정
  • 다차원 전체에서 추출하려면 flatten() 후 적용 필요
# 전체 텐서에서 상위 5개 값 추출
x = torch.randn(3, 4, 5)
flat = x.flatten()
vals, inds = torch.topk(flat, 5)

유의사항

  • 동일한 값이 있을 경우 indices 순서는 보장되지 않음
  • sorted=False로 설정하면 성능이 더 좋을 수 있음
  • k가 해당 차원 크기보다 클 경우 오류 발생

관련 함수

  • torch.kthvalue: 정확히 k번째 값
  • torch.sort: 전체 정렬
  • torch.argsort: 정렬된 인덱스 반환

같이 보기

참고 문헌

각주


  출처: IT위키 (IT위키에서 최신 문서 보기)

  * 본 페이지는 IT Wiki에서 미러링된 페이지입니다. 일부 오류나 표현의 누락이 있을 수 있습니다. 원본 문서는 IT Wiki에서 확인하세요!