Efficient Attention
Introduction
Transformer 구조의 핵심인 Self-Attention은 원래 Attention is All You Need 논문에서 처음 등장한 이후, 다양한 자연어 처리(NLP)와 컴퓨터 비전(CV) 모델에서 널리 사용되고 있습니다.
특히, Self-Attention은 시퀀스 전체의 임베딩을 병렬적으로 계산할 수 있다는 장점 덕분에, 더 큰 컨텍스트를 한 번에 처리할 수 있는 대규모 언어 모델(LLM, Large Language Model)의 기반이 되었습니다.
하지만 LLM의 규모가 커질수록 Self-Attention이 모델 전체 실행 비용에서 차지하는 비중도 커지게 됩니다. 다시 말해, Self-Attention 연산을 얼마나 효율적으로 구현하느냐에 따라 모델의 속도, 메모리 사용량, 비용이 크게 달라질 수 있다는 뜻입니다.
이에 따라 최근에는 기존의 Self-Attention과 수학적으로 동일한 결과를 내지만, 훨씬 더 메모리 효율적인(Memory Efficient) 방식으로 구현하는 방법이 주목받고 있습니다.
이러한 접근법은 연산량을 줄이거나, 중간 결과를 저장하는 메모리 사용을 최소화하여, 대규모 모델을 보다 빠르고 효율적으로 실행할 수 있도록 합니다.
Self-Attention의 효율적인 구현은 단순한 최적화를 넘어, LLM의 실용성을 좌우하는 핵심 기술 중 하나가 되었습니다. 앞으로도 다양한 방식의 Memory Efficient Attention 기법들이 등장하며, 더욱 빠르고 강력한 모델들이 개발될 것으로 기대됩니다.
Background
딥러닝 기반의 시퀀스 모델링에서 오랫동안 RNN, GRU, LSTM 등의 순환 신경망이 사용되어 왔습니다.
이러한 구조들은 시퀀스 데이터를 시간 순서에 따라 처리할 수 있는 장점을 가지고 있었지만, 장기 의존성 문제(long-term dependency), 병렬화의 어려움, 고정된 크기의 컨텍스트 표현 등 여러 한계점을 안고 있었습니다.
특히, 긴 문장을 처리할 때 앞쪽 정보가 뒷쪽으로 잘 전달되지 않는 문제가 있었고, 이로 인해 전체 문맥을 고려한 정교한 표현 학습이 어렵다는 비판이 제기되었습니다.
또한, 순차적으로 연산이 진행되기 때문에 GPU를 활용한 병렬 연산이 어려워 학습 속도와 효율성 측면에서도 한계를 드러냈습니다.
이러한 문제들을 해결하고자 제안된 것이 바로 Attention 메커니즘입니다.
Attention은 입력 시퀀스의 모든 위치를 동시에 고려하고, 각 위치의 중요도를 동적으로 계산할 수 있는 방법으로, 기존 RNN 기반 구조의 한계를 극복하며 시퀀스 모델링의 새로운 가능성을 열었습니다.
Attention
기존의 RNN, GRU, LSTM 기반의 encoder-decoder 구조는 시퀀스 정보를 고정된 크기의 벡터에 압축하면서 정보 손실이 발생하는 한계가 있었습니다. 이를 보완하기 위해 Attention 메커니즘이 등장했습니다.
Attention은 Neural machine translation by jointly learning to align and translate [link]이 처음 제안했으며, 해당 연구는 Dzmitry Bahdanau 와 뉴욕대학교의 조경현 교수님이 주도했습니다.
조 교수님은 강의[link]를 통해 “고정된 크기의 벡터로는 시퀀스의 의미를 완전히 표현할 수 없다”고 결론을 내렸었다고 설명합니다.
즉, 입력 시퀀스를 단일 벡터로 압축하는 것이 아니라, 시퀀스의 길이에 비례한 벡터 정보를 decoder의 입력으로 활용해야 한다는 것입니다.
입력 시퀀스는 토크나이징되어 임베딩 벡터의 시퀀스로 변환됩니다. 이 벡터들은 각 토큰의 의미를 담고 있으며, 다음 그림처럼 각 벡터의 차이와 방향성으로 관계를 나타낼 수 있습니다.
이러한 벡터 표현의 유용성은 Tomas Mikolov의 논문 link1, link2 에서 입증되었습니다.
벡터의 방향과 크기를 통해 단어 간 관계를 파악할 수 있으며, 임베딩 벡터들을 더하거나 평균내는 방식으로 의미 있는 벡터를 생성할 수도 있습니다.
즉, tokenize된 시퀀스는 임베딩 벡터들의 시퀀스로 표현되며, 이 벡터들을 조합하면 문맥 정보를 유지한 의미 벡터를 만들 수 있습니다.
하지만 중요한 점은, decoder가 생성하는 각 시점의 출력은 입력 임베딩 벡터 전체를 동등하게 활용하지 않는다는 것입니다.
각 출력마다 어떤 입력 벡터가 더 중요한지를 판단하기 위해, decoder는 입력 벡터들과의 유사도를 계산하고, 이를 기반으로 가중합(weighted sum)을 수행합니다.
이때 유사도는 보통 dot-product 기반으로 계산되며, 이를 확률 분포로 변환하기 위해 softmax 함수를 적용합니다.
그 결과, decoder는 입력 시퀀스의 의미 정보를 손실 없이 동적으로 선택하여 활용할 수 있게 됩니다. 이는 고정된 컨텍스트 벡터의 한계를 극복하며, 자연어 처리 성능을 크게 향상시켰습니다.
Self-Attention
Self-Attention은 기존의 Attention 구조에서 RNN을 Fully-Connected Network로 대체한 방식입니다.
이러한 구조 변경을 통해 RNN의 순차적인 계산을 피하고, 행렬 곱 연산(GEMM, General Matrix Multiplication)을 사용할 수 있게 되었습니다.
GEMM은 병렬 처리가 가능하며, 하드웨어 수준에서 다양한 최적화를 적용할 수 있기 때문에 RNN 기반 Attention보다 훨씬 빠른 연산이 가능합니다.
이로 인해 Self-Attention은 빠르게 주류로 자리잡았고, 이후 등장한 GPT-2, GPT-3와 같은 대규모 언어 모델(LLM)의 핵심 구조로 채택되었습니다.
또한, Self-Attention은 자연어 처리에 국한되지 않고 Vision Transformer (ViT) 와 같은 비전 모달리티(vision modality) 에도 성공적으로 적용되며, 시각 정보를 처리하는 방식에도 큰 변화를 일으켰습니다.
최근에는 Language, Vision 등 다양한 모달리티를 통합하는 Large Multi-Modal Model에서, Self-Attention이 그 중심 모듈로 사용되고 있습니다.
즉, Self-Attention은 다양한 데이터를 이해하고 연결하는 데 있어 범용적이고 확장 가능한 연산 구조로 진화해가고 있습니다.
Efficient Attention
Efficient Attention의 주요 목적은 연산 복잡도(Computation Complexity)의 Upper Bound를 낮추는 것입니다.
이를 위해 대표적으로 사용되는 방법 중 하나가 Window Attention link으로, 전체 입력 시퀀스를 대상으로 연산하는 것이 아니라 일부 입력(context)을 제한된 범위 내에서만 처리하는 방식입니다.
또 다른 접근 방식인 Linear Attention link은 일반적인 Self-Attention에서 사용하는 $QK^T$ 연산을 재구성합니다.
기존 방식에서는 $(n, d)$와 $(d, n)$ 형태의 행렬을 곱해 복잡도가 $O(n^2)$에 해당했지만, Linear Attention에서는 $K^T V$를 먼저 계산한 뒤, 이를 이용해 $Q (K^T V)$를 구합니다.
즉, $(d, n)$와 $(n, d)$의 곱셈으로 복잡도를 $O(n)$까지 낮추는 것이 가능합니다.
하지만 이러한 Efficient Attention 기법들은 연산 속도나 메모리 측면에서 이점을 얻는 대신, 정확도(Accuracy) 저하라는 단점을 동반하기도 합니다.
이로 인해 최근에는 연산 복잡도를 낮추기보다는, 동일한 수학적 연산 구조를 유지하면서도 더 적은 메모리와 연산량으로 수행할 수 있는 방식, 즉 Memory Efficient Attention에 대한 관심이 높아지고 있습니다.
이러한 방식은 정확도를 유지하면서도 GPU 프로그래밍을 통해 속도 및 메모리 사용을 크게 최적화할 수 있기 때문에, 특히 대규모 모델을 운영하는 데 있어 매우 유용한 접근법으로 주목받고 있습니다.
Softmax
Attention 메커니즘에서는 여러 개의 embedding vector를 결합하지만, 이들을 동일한 비중으로 더하지 않습니다.
각 embedding vector의 중요도를 고려해 가중치를 계산한 뒤, 이를 기반으로 weighted sum을 수행하게 됩니다.
이때 사용되는 함수가 바로 Softmax입니다.
Softmax는 입력된 벡터의 각 요소들을 0과 1 사이의 값으로 변환하며, 전체 요소의 합이 1이 되도록 정규화합니다.
즉, 각 embedding vector가 얼마나 중요한지를 나타내는 확률 분포 형태로 가중치를 표현하게 되는 것이죠.
이러한 특성 덕분에 Softmax는 Attention 구조에서 어떤 입력이 더 중요한지를 반영하는 핵심적인 역할을 하게 됩니다.
Safe-Softmax
Softmax는 벡터의 각 원소에 exponential 함수를 적용한 후, 이를 정규화하여 확률 분포를 만들어주는 함수입니다.
기본적인 형태는 다음과 같습니다:
하지만 이 방식은 입력 값 $x_i$의 크기에 따라 overflow 또는 underflow 문제가 발생할 수 있습니다.
특히 $x_i$가 매우 크거나 작을 경우, $e^{x_i}$의 값이 수치적으로 안정되지 않아 계산 오류를 유발할 수 있습니다.
이러한 문제를 방지하기 위해 사용되는 방식이 바로 Safe-Softmax입니다.
입력 벡터 전체에서 가장 큰 값을 빼줌으로써, 지수 함수 내부의 값을 수치적으로 안정적인 범위로 조정합니다:
이렇게 하면 softmax 결과는 동일하지만, overflow/underflow 없이 보다 안정적으로 연산할 수 있게 됩니다.
Online-Softmax
기존 Softmax는 계산 시 전체 row의 값을 모두 확인해야 합니다.
하지만, 이 방식은 매번 전체 row를 메모리에 올려야 하기 때문에 메모리 사용량이 크고 비효율적입니다.
이를 해결하기 위해 제안된 방식이 Online-Softmax입니다.
Online-Softmax는 입력 전체를 한 번에 처리하는 대신, 입력의 일부만을 순차적으로 확인하며,
분모와 분자를 단계적으로 업데이트하는 방식으로 Softmax를 계산합니다.
이 방식은 메모리 사용을 줄이면서도 동일한 결과를 얻을 수 있어,
특히 메모리 제약이 있는 환경에서 효과적으로 활용됩니다.
Proof
Base case: (V = 1)
-
\(m_1 \leftarrow x_1\)
\(= \max_{k=1}^1 x_k\) -
\(d_1 \leftarrow e^{x_1 - m_1}\)
\(= \sum_{j=1}^1 e^{x_j - m_1}\)
Induction: (V = S)
-
\(Assume \quad m_{S-1} = \max_{k=1}^{S-1} x_k\)
\(m_S \leftarrow \max(m_{S-1}, x_S)\)
\(= \max\bigl(\max_{k=1}^{S-1} x_k,\; x_S\bigr)\)
\(= \max_{k=1}^{S} x_k\) -
\(Assume \quad d_{S-1} = \sum_{j=1}^{S-1} e^{\,x_j - m_S}\)
\(d_S \leftarrow d_{S-1} \, e^{\,m_{S-1} - m_S} + e^{\,x_S - m_S}\)
\(= \left(\sum_{j=1}^{S-1} e^{\,x_j - m_{S-1}}\right) e^{\,m_{S-1} - m_S} + e^{\,x_S - m_S}\)
\(= \sum_{j=1}^{S} e^{\,x_j - m_S}\)
SafeNess
-
\(m_j \in \Bigl[\min_{k=1}^V m_k,\; \max_{k=1}^V m_k\Bigr], \quad \forall j \in \{1, \dots, V\}\)
따라서,, \(m_j\) 는 overflow나 underflow가 될 수 없습니다. -
\(1 \;\le\; d_j \;\le\; j, \quad \forall j \in \{1, \dots, V\}\)
\(d_j\)의 식은 다음과 같습니다.
\(d_S \leftarrow d_{j-1} \, e^{\,m_{j-1} - m_j} + e^{\,x_j - m_j}\)
매번 Softmax 계산에서, 값의 안정성을 유지하기 위해 $d_j$에 1보다 작거나 같은 값을 곱하고, 그 결과를 다시 $d_j$에 더해주는 방식이 사용됩니다.
이러한 누적 계산이 반복되기 때문에, $d_j$가 매우 큰 값을 저장할 수 있어야 합니다.다행히도 32-bit floating point(부동소수점)는 최대 약 $3.4 \times 10^{38}$까지 표현할 수 있으며,
이로 인해 $d_j$는 최대 약 $1.7 \times 10^{37}$까지 저장 가능한 수준으로 안전성이 보장됩니다.이 수치는 32-bit 부동소수점의 최대 표현 범위보다는 작지만, 이는 수치적 안정성과 오차 방지를 위한 보수적인 안전 수치로 판단됩니다.
하지만 처리해야 할 벡터 수가 훨씬 많아져서 $d_j$가 이 범위를 초과할 수 있다면,
64-bit floating point를 사용해 더 넓은 표현 범위를 확보하는 것이 필요합니다.
Block-Update Proof
앞에서 설명한 내용은 Online-Softmax를 하나의 데이터 샘플에 대해 순차적으로 처리하며 업데이트하는 과정을 다룬 것이었습니다.
하지만 실제 환경에서는 데이터를 하나씩 처리하기보다는, 여러 개의 데이터를 동시에 받아서 $d_j$, $m_j$ 등을 계산하게 됩니다.
따라서, 이렇게 새롭게 계산된 값들을 기존에 구해 놓은 $d_i$, $m_i$와 결합하거나 반영할 수 있는 방식이 필요합니다.
즉, Online-Softmax가 단일 입력뿐만 아니라 배치 단위 또는 스트리밍 형태의 연속적인 입력에서도 유효하게 동작하도록 확장되어야 합니다.
\(d_i\) , \(m_i\) 를 미리 구해놓고 다음 input에 대하여 \(d_j\) , \(m_j\) 를 계산하면 위의 공식에 대입하여 update를 할 수 있습니다.
max에 대한 증명은 생략하도록 하겠습니다.
- \(Assume \quad d_{j} = \sum_{k=0}^{j} e^{\,x_k - m_i} \quad d_{i} = \sum_{k=i+1}^{j+i} e^{\,x_k - m_j} \quad m_{i+j} = \max(m_i, m_j)\)
\(d_i \oplus d_{j} = d_i \, e^{\,m_i - \max(m_i, m_j)} \;+\; d_j \, e^{\,m_j - \max(m_i, m_j)}\)
\(= \left(\sum_{k=0}^{i} e^{\,x_k - m_{i}}\right) e^{\,m_{i} - \max(m_i, m_j)} + \left(\sum_{k=i+1}^{i+j} e^{\,x_k - m_{i}}\right) e^{\,m_{j} - \max(m_i, m_j)}\)
\(= \left(\sum_{k=0}^{i} e^{\,x_k - m_{i}}\right) e^{\,m_{i} - m_{i+j}} + \left(\sum_{k=i+1}^{i+j} e^{\,x_k - m_{i}}\right) e^{\,m_{j} - m_{i+j}}\)
\(= \sum_{k=0}^{i+j} e^{\,x_k - m_{i+j}}\)
Online-Self Attention
attention에는 softmax 때문에 input 전체를 봐야한다.
하지만, softmax는 Online-softmax로 전체를 안봐도 계산이 가능하다
self-attention에도 이를 적용한다.[link] \(v^* \in \mathbb{R}^d\) ,\(s^* \in \mathbb{R}\) 를 0으로 초기화하고 , \(m^*\)를 -inf로 초기화합니다. query \(q\), keys \(k_1, \dots, k_n\) 와 values \(v_1, \dots, v_n\) 가 주어질 때 , keys와 values들을 순서대로 사용합니다.\(k_i\) , \(v_i\) 가 주어지면 \(s_i = \mathrm{dot}(q, k_i)\) 를 계산합니다.그리고 나서,\(m^* = \max(m^*,s_i)\), \(v^* \leftarrow v^* e^{m^* - m_i} + v_i e^{s_i - m_i}\) , \(s^* \leftarrow s^* e^{m^* - m_i} + e^{s_i - m_i}\) 를 update합니다.그리고,\(\frac{v^*}{s^*}\)를 계산해줍니다.
Proof
\(m^* = \max(m^*,s_i)\), ,\(s^* \leftarrow s^* e^{m^* - m_i} + e^{s_i - m_i}\) 들은 위에서 증명이 되었습니다.
따라서, 증명을 생략하도록 하겠습니다. \(v^* \leftarrow v^* e^{m^* - m_i} + v_i e^{s_i - m_i}\)에 대해서 증명을 하도록 하겠습니다.
Base case: (V = 1)
- \(v^* \leftarrow v^* e^{m^* - m_1} + v_1 e^{s_1 - m_1}\)
\(= v_1 e^{s_1 - m_1}\)
Induction: (V = i)
- \(Assume \quad v^* = \sum_{j=1}^{i-1} v_je^{\,x_j - m^*}\)
\(m_i = \max(m^*,s_i)\) , \(s_i = \mathrm{dot}(q, k_i)\)
\(v^* \leftarrow v^* e^{m^* - m_i} + v_i e^{s_i - m_i}\)
\(= \left(\sum_{j=1}^{i-1} v_j e^{\,x_j - m^*}\right) e^{\,m^* - m_i} + v_i e^{\,s_i - m_i}\)
\(= \sum_{j=1}^{i} v_j e^{\,s_j - m_i}\)
그리고, \(\frac{v^*}{s^*}\) 를 계산해줍니다.
\(m_i = \max(m^*,s_i)\) 가 현재까지 계산한 \(\max\) 이기 때문에 online self-attention 수식이 정확함이 증명이 됩니다.
FlashAttention
Online-Self Attention을 실제로 구현하여 오픈소스로 제공한 대표적인 사례가 바로 FlashAttention입니다.
FlashAttention은 Tri Dao가 제안한 것으로, 알고리즘 자체에서 완전히 새로운 아이디어가 등장한 것은 아니지만,
Online-Self Attention을 실제 여러 LLM에 적용 가능한 형태로 구현한 첫 번째 사례라는 점에서 큰 의미를 가집니다.
즉, FlashAttention은 기존 이론을 기반으로 하되,
실제 대규모 모델에서 사용할 수 있도록 최적화된 구현체를 제공한 선구적인 작업이라 할 수 있습니다.
Forward
FlashAttention의 forward는 이와 같습니다. algorithm의 line별로 설명을 하도록 하겠습니다.
-
\(B_r\)는 \(Query\)의 block size를 의미하고 , \(B_c\) 는 \(Key,Value\)의 block size를 의미합니다.
-
위의 online-self attention 에서는 \(v^* \in \mathbb{R}^d\) ,\(s^* \in \mathbb{R}\) 를 0으로 초기화 했는데, FlashAttention에서는 이를 \(O , \ell\) 로 각각 표기법을 바꾸었습니다. Batch dimension으로 확장했기에 \(O = \mathbf{0}_{N\times d} \in \mathbb{R}^{N \times d}\) ,\(\ell = \mathbf{0}_N \in \mathbb{R}^N\) ,\(m = (-\infty)_N \in \mathbb{R}^N\) 로 초기화합니다.
-
\(Query ,key , value\) 를 여러 block으로 쪼갭니다.
-
\(O\)는 \(Query\) 의 weighted sum이므로 \(Query\) 와 같은 Block size인 \(B_r\) 으로 쪼개고, \(\ell , m\)는 softmax를 계산할 때 사용되는 중간계산결과 이므로 \(Key,Value\)와 같은 Block size인 \(B_c\) 를 사용합니다.
-
모든 \(Key,Value\) Block에 대해서 반복합니다.
-
$Key,Value$의 block을 먼저 Cache에 load합니다.
-
모든 \(Query\) Block에 대해서 반복합니다.
-
\(Query,O, \ell , m\)의 block을 Cache에 load합니다.
-
\(Query,Key\)를 cache에서 register로 바로 load하여 \(\mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^{T} \in \mathbb{R}^{B_r \times B_c}\) 를 계산합니다.
-
\(\tilde{m}_{ij} = \mathrm{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_r}, \quad \tilde{\mathbf{P}}_{ij} = \exp(\mathbf{S}_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c} , \tilde{\ell}_{ij} = \mathrm{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_r}\) 를 계산합니다. \(\tilde{m}_{ij}\) 는 Online softmax에서의 \(m_i\) ,Online self-attention 에서의 \(m^*\)와 동일한 역할을 합니다. 이번에는 \(\mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}_j^{T} \in \mathbb{R}^{B_r \times B_c}\) 이므로 \(\tilde{m}_{ij} \in \mathbb{R}^{B_r}\)가 됩니다. \(\tilde{\ell}_{ij}\) 는 각 Block의 exponential의 합을 의미합니다. \(\tilde{\mathbf{P}}_{ij} \in \mathbb{R}^{B_r \times B_c}\) 이므로 \(B_c\) 차원을 reduction하는 방향으로 max를 계산하므로 \(\tilde{\ell}_{ij} \in \mathbb{R}^{B_r}\)가 됩니다.\(\tilde{\ell}_{ij}\)는 Online softmax에서의 \(d_S\) 와 같은 역할을 합니다.
-
Online self-attention처럼 max,분모를 새로 update해야 합니다. \(m_i^{new}\) 는 새로운 max 입니다 . \(m_i^{\mathrm{new}} = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_r}\) 는 Online softmax의 \(m^* = \max(m^*,s_i)\) 와 동등한 역할을 합니다. \(\ell_i^{\mathrm{new}}\) 새로운 분모의 역할을 하게 됩니다. \(\ell_i^{\mathrm{new}} = e^{m_i - m_i^{\mathrm{new}}} \, \ell_i + e^{\tilde{m}_{ij} - m_i^{\mathrm{new}}} \, \tilde{\ell}_{ij} \in \mathbb{R}^{B_r}\) 는 Online softmax에서 \(d_S \leftarrow d_{S-1} \, e^{\,m_{S-1} - m_S} + e^{\,x_S - m_S}\) 와 동등한 역할을 합니다.
-
\(\mathrm{diag}(\ell_i^{\mathrm{new}})^{-1}\) 는 기존 \(O\)에 새로운 분모를 곱해주는 과정입니다. \(\mathrm{diag}(\ell_i)\) 는 기존의 곱해진 분모를 없애주기 위한 용도입니다. \(e^{m_i - m_i^{\mathrm{new}}} \, \mathbf{O}_i\) 는 기존의 \(O\)는 이전의 max로 safe softmax를 적용했기에 기존의 max인 \(m_i\) 를 더해서 없애주고 새로운 max인 \(m_i^{\mathrm{new}}\) 를 빼주는 과정입니다. \(\tilde{\mathbf{P}}_{ij}\) 는 block의 max인 \(\tilde{m}_{ij}\) 를 빼주면서 safe softmax를 적용했습니다. 따라서, 새로운 max를 빼주기 위해서 \(e^{\tilde{m}_{ij} - m_i^{\mathrm{new}}}\) 를 곱해줍니다. 이를 통해 , \(O\)를 update 합니다.
-
\(\ell_i\) , \(m_i\) 를 새로운 max,분모로 update를 해줍니다.
Backward
Conclusion
Self-Attention은 Transformer의 핵심 구성 요소로, LLM을 포함한 다양한 딥러닝 모델의 성능을 크게 향상시켜 왔습니다. 그러나 모델의 크기가 커질수록 Self-Attention이 차지하는 계산 및 메모리 비용도 커지게 되었고, 이를 해결하기 위한 다양한 Efficient Attention 기법들이 등장하게 되었습니다.
특히, 정확도 손실 없이 메모리 사용량과 연산 속도를 줄일 수 있는 Memory Efficient Attention에 대한 관심이 커지고 있으며,
그 대표적인 구현 사례가 FlashAttention입니다. FlashAttention은 Online-Self Attention 아이디어를 실제 구현으로 옮겨, 여러 LLM에 적용 가능한 형태로 제공하였고, GPU 메모리 최적화와 병렬 처리 관점에서도 매우 효과적인 성능을 보여주고 있습니다.
본 글에서는 Softmax의 수치적 안정성을 보장하는 Safe-Softmax, 메모리 효율을 개선한 Online-Softmax, 그리고 이를 Self-Attention에 확장한 Online-Self Attention의 이론과 증명 과정을 살펴보았습니다. 또한, FlashAttention 알고리즘의 forward 과정을 구체적으로 분석하여, 어떻게 이러한 이론이 실제 구현에 반영되었는지를 정리하였습니다.
앞으로도 FlashAttention과 같은 커널 최적화 기법들은 점점 더 커지는 모델 규모에 대응하기 위해 필수적인 요소가 될 것이며,
GPU 프로그래밍과 병렬 연산에 대한 이해는 고성능 AI 모델 구현에 있어 점점 더 중요해질 것입니다.
Leave a comment