#include <math_constants.h>
#include <math.h>
#define M_LOG2E 1.4426950408889634;
__device__ float __log2f(float x){return log1pf(logf(x))*M_LOG2E;}
__global__ void LLM_forward_kernel(
const int N, //batch size
const int T, //max length of text
const int H, //beam search hyperparameter
const float * __restrict__ x, //source embeddings. shape = (B,T,D)
const float * __restrict__ y, //target embeddings. shape = (B,T,D)
const float * __restrict__ w, //input weights. shape=(B,T,D,H,D) Batch x Sequence Len x Depth x Number of Heads x Depth
//Depth = size of the Word Dimension = Size of the Key,Query,Value Dim usually 64
float * __restrict__ output //output logits. shape=(B, T, H, D)
){
int bid = blockIdx.y;
int tid = threadIdx.x + threadIdx.y*blockDim.x;
if (bid >= N || tid > (T-1)*H) return;
int bidx = bid%N, tidx = tid%(T-1), hidx = tid/(T-1); //each thread handle a specific position in source text
const float *wx, *yx;
if (bid==0){
wx=w; yx=&y[bid*(T-1)*H*3+tidx*H];
}else{
wx=&w[(bid-1)*(T-1)*H*3+tidx*H]; yx=&y[bid*(T-1)*H+tidx*H];
}
//the first element in logits
float logit = __log2f(__expf(wx[0]));
if (tid == 0) output[(bidx*T + tidx)*H] = logit;
wx+=H; yx+=3*H;
//the following elements in logits
for (int c=1;c<H;++c){
float v1,v2;
if(hidx==0) {
v1 = __log2f(__expf(wx[0]));
wx+=3*H; yx += 3*H;
v2 = *yx++;
} else if (hidx == c-1){
v1=output[(bidx*T + tidx)*H+c-1];
wx+=3*H;
v2 = __log2f(__expf(wx[0]));
yx += 3*H;
} else if (hidx == c){
v1=output[(bidx*T + tidx)*H+c-1]; wx+=3*H;
v2 = *yx++;
}else{
v1=output[(bidx*T + tidx)*H+c-1]; wx+=3*H;
v2 = output[(bidx*T + tidx)*H+hidx-1]; yx += 3*H;
}
logit = __log2f(expf(v1) * (expf(v2) - 1));
output[(bidx*T + tidx)*H+c]=logit;
}
}
/*
dim_y = number of batches (batch size)
dim_x = T
dim_z = H
*/
__global__ void rnnt_loss(const double *output, const int* labels, double *costs, int T, int N){
int bidx = blockIdx.y; // batch idx
int tidx = threadIdx.x; // sequence step idx
int label = -1;
if (tidx < T){
label = labels[bidx*T+tidx];
}
double loss = -2 * output[(bidx*T + tidx) * N + label];
costs[bidx] += loss;
}
/*
dim_y = number of batches (batch size)
dim_x = T
dim_z = H
*/
__global__ void rnnt_backward(const double* output, const int* labels, const double *gradient, double *costs, int T, int N){
// cost gradient = -2 * (output - y) * delta^L + -2 * (delta^L + delta^{L+1} + ... delta^{T-1})
// ------------ -------------------------------
// dC/dy dC/dy * dy/ddelta^{L}
int bidx = blockIdx.y; // batch idx
int tidx = threadIdx.x; // sequence step idx
if (tidx < T){
costs[bidx] -= gradient[(bidx*T+tidx) * N + labels[bidx*T+tidx]];
costs[bidx] += gradient[(bidx*T+tidx) * N + labels[bidx*T+tidx]] -
(output[(bidx*T+tidx)*N+labels[bidx*T+tidx]] - output[(bidx*T+tidx)*N+labels[(bidx*T+tidx)+1]]) *
gradient[(bidx*T+tidx)*N + labels[(bidx*T+tidx) + 1]];
} else if (tidx == T){
costs[bidx] -= gradient[(bidx*T+tidx-1) * N + labels[bidx*T+tidx - 1]];
}
}