자동 미분(Automatic Differentiation, AD)은 함수의 기울기(도함수)를 연산 그래프를 기반으로 자동으로 계산하는 기술이다. 심볼릭 미분(symbolic differentiation)이나 수치 미분(numerical differentiation)과 달리, 자동 미분은 기계 오차를 최소화하면서 정확한 기울기를 효율적으로 계산할 수 있어 딥러닝과 최적화 알고리즘의 핵심 기술로 사용된다.
개요
자동 미분은 함수의 구성 요소(덧셈, 곱셈, 행렬 연산 등)에 대해 미분 규칙을 알고 있기 때문에, 프로그래머는 단지 함수를 정의하기만 하면 프레임워크가 역전파(backpropagation)를 통해 기울기를 자동 계산한다.
예: PyTorch, TensorFlow, JAX는 모두 자동 미분 엔진을 핵심으로 갖는다.
자동 미분의 종류
자동 미분은 크게 다음 두 가지 방식으로 나뉜다.
Forward Mode AD
입력 변수에서 출발하여 출력까지 미분 값을 전달하는 방식.
특징:
- 입력 개수가 적고 출력 개수가 많은 함수에 적합
- 계산 흐름이 단순함
- JAX의 `jax.jvp`, `jax.forward_ad` 등이 사용
예: f: ℝ → ℝⁿ 형태에서 효율적
Reverse Mode AD (역전파, Backpropagation)
출력에서부터 역방향으로 기울기를 계산하는 방식.
특징:
- 출력이 스칼라(예: 손실 값)이고 입력 차원이 큰 함수에 최고 효율
- 신경망 학습의 표준 기법
- PyTorch, TensorFlow, JAX 등 딥러닝 프레임워크 대부분이 사용
예: 신경망과 같이 파라미터 수가 매우 많을 때 가장 효율적
→ 딥러닝에서 사용하는 backpropagation은 사실상 Reverse Mode AD의 특수한 형태이다.
자동 미분의 원리
자동 미분은 보통 다음처럼 동작한다:
- 연산을 **그래프(Computation Graph)** 로 기록
- 각 노드(연산)의 미분 규칙을 알고 있으므로
- Chain rule(연쇄 법칙)을 이용해 전체 미분을 계산
예시 연산:
y = x*2 + x*3
dy/dx = (2 + 3) = 5
프레임워크는 사용자가 직접 미분 규칙을 쓸 필요 없이 내부적으로 규칙을 조합해 기울기를 계산한다.
수치 미분 및 심볼릭 미분과의 비교
| 방식 | 설명 | 장점 | 단점 |
|---|---|---|---|
| 수치 미분 (Finite Differences) | h를 이용한 근사 차분 | 구현 쉬움 | 오차 큼, 비용 비쌈 |
| 심볼릭 미분 | 함수식을 기호적 조작 | 정확한 표현식 | 표현식 폭발, 복잡함 |
| 자동 미분 | 그래프 기반 미분 | 정확 + 빠름 | 심볼릭보다 엄밀하진 않음 (추상적) |
주요 프레임워크의 자동 미분
PyTorch (Autograd)
- 동적 그래프 기반
- 연산 수행 시 그래프를 기록
- `.backward()` 호출 시 기울기 계산
- Eager execution이 직관적
TensorFlow (GradientTape)
- 2.x에서 eager + 자동 그래프 변환
- `tf.GradientTape()`로 연산 추적
JAX (Autograd + XLA)
- 순수 함수 기반
- `jax.grad` 로 함수의 미분 생성
- JIT + XLA 최적화로 매우 빠른 성능
- Forward/Reverse/Vector-Jacobian 등 다양한 모드 지원
JAX는 AD가 프레임워크 핵심이자 철학이며, 성능 최적화가 매우 뛰어나다.
딥러닝에서의 자동 미분
딥러닝 학습 과정은 다음과 같이 자동 미분을 필수로 사용한다:
- 손실 함수 L(θ) 계산
- 파라미터 θ에 대한 ∂L/∂θ 계산 (역전파)
- 옵티마이저(Adam, SGD 등)로 θ 업데이트
이 기울기 계산 과정 전체가 자동 미분 엔진에 의해 수행된다.
Jacobian, Hessian 등 고차 미분
자동 미분은 1차 미분뿐 아니라 Jacobian, Hessian 같은 고차 미분도 계산할 수 있다.
대표적 예:
- Hessian-vector product (HVP)
- Jacobian-vector product (JVP)
- JAX는 특히 고차 미분에 강함 (`jax.jacobian`, `jax.hessian`)
장점
- 사용자는 미분 공식을 직접 작성할 필요 없음
- 복잡한 수식·모델에서도 안정적으로 동작
- 딥러닝 모델 학습의 필수 기술
- Forward/Reverse 구조로 고성능을 보장
- XLA/Triton 등과 결합하면 매우 빠름
한계
- 메모리 사용량이 증가할 수 있음
- 함수형 제약(JAX) 등 사용 난이도 증가 가능
- 비미분 가능 연산(discrete ops)은 별도 처리 필요
- 계산 그래프가 너무 크면 성능 저하 발생