Inside a Transformer
Posted: Mon Jul 15, 2024 4:10 pm
Inside an Language Model Transformer:
Q,K,V
V, Value is the relevence of the word at that position in the sentence
Q, Query Is used to refernce a word and compute a score in its relation to another word referenced by K, Key
The output of a single head is
Q = Query
KT = Transpose of Key
DimK = Dimensions of Key = Dimension of Embedded Sentence
Output(Head) = Softmax( (Q * KT) / sqrt(DimK) ) * V
There are many heads and many layers of heads.
It learns a set of connections between words referenced by Q, and K. Also it learns the relevence V of that word at that position in the sentence.
Transform Embedded Sentence X into Q,K,V values
Calculate AttentionScore:
Q,K,V
V, Value is the relevence of the word at that position in the sentence
Q, Query Is used to refernce a word and compute a score in its relation to another word referenced by K, Key
The output of a single head is
Q = Query
KT = Transpose of Key
DimK = Dimensions of Key = Dimension of Embedded Sentence
Output(Head) = Softmax( (Q * KT) / sqrt(DimK) ) * V
There are many heads and many layers of heads.
It learns a set of connections between words referenced by Q, and K. Also it learns the relevence V of that word at that position in the sentence.
Transform Embedded Sentence X into Q,K,V values
- Code: Select all
__global__ void transform_embeddings_kernel(
const int N, // batch size
const int T, // sequence length
const int H, // number of heads
const int D, // embedding dimension
const float * __restrict__ input, // input embeddings. shape = (N,T,D)
const float * __restrict__ W_q, // query weights. shape = (D,H*D)
const float * __restrict__ W_k, // key weights. shape = (D,H*D)
const float * __restrict__ W_v, // value weights. shape = (D,H*D)
float * __restrict__ Q, // output queries. shape = (N,T,H,D)
float * __restrict__ K, // output keys. shape = (N,T,H,D)
float * __restrict__ V // output values. shape = (N,T,H,D)
) {
int n = blockIdx.z;
int t = blockIdx.y;
int h = blockIdx.x;
int d = threadIdx.x;
if (n >= N || t >= T || h >= H || d >= D) return;
float q_sum = 0.0f, k_sum = 0.0f, v_sum = 0.0f;
for (int i = 0; i < D; ++i) {
float input_val = input[n*T*D + t*D + i];
q_sum += input_val * W_q[i*H*D + h*D + d];
k_sum += input_val * W_k[i*H*D + h*D + d];
v_sum += input_val * W_v[i*H*D + h*D + d];
}
Q[n*T*H*D + t*H*D + h*D + d] = q_sum;
K[n*T*H*D + t*H*D + h*D + d] = k_sum;
V[n*T*H*D + t*H*D + h*D + d] = v_sum;
}
Calculate AttentionScore:
- Code: Select all
__global__ void attention_kernel(
const int N, // batch size
const int T, // sequence length
const int H, // number of heads
const int D, // embedding dimension
const float * __restrict__ Q, // query. shape = (N,T,H,D)
const float * __restrict__ K, // key. shape = (N,T,H,D)
const float * __restrict__ V, // value. shape = (N,T,H,D)
float * __restrict__ output // output. shape = (N,T,H,D)
) {
int n = blockIdx.z;
int h = blockIdx.y;
int t_q = blockIdx.x;
int d = threadIdx.x;
if (n >= N || h >= H || t_q >= T || d >= D) return;
float sum = 0.0f;
for (int t_k = 0; t_k < T; ++t_k) {
float dot_product = 0.0f;
for (int i = 0; i < D; ++i) {
dot_product += Q[n*T*H*D + t_q*H*D + h*D + i] * K[n*T*H*D + t_k*H*D + h*D + i];
}
float attention_score = expf(dot_product / sqrtf(D));
sum += attention_score * V[n*T*H*D + t_k*H*D + h*D + d];
}
output[n*T*H*D + t_q*H*D + h*D + d] = sum;
}