simons blog

Gated Delta Net Decoding

In this blogpost we will briefly explain Gated Delta Net decoding. This workload is one of the workloads to be optimised in the Flashinfer Competition and understanding its mathematical properties can help in optimising the corresponding kernel. I hope this helps to onboard others quickly with the necessary understanding of Gated Delta Net decoding.

Reference Code and its interpretation

import math
import torch
import torch.nn.functional as F


def matmul(a: torch.Tensor, b: torch.Tensor):
    """Float32 matmul for numerical stability."""
    return a.float() @ b.float()


@torch.no_grad()
def run(q, k, v, state, A_log, a, dt_bias, b, scale):
    """
    Gated Delta Net decode reference implementation (k-last layout).

    State layout: [B, H, V, K] (k-last, K dimension at the end)

    Gate computation:
    g = exp(-exp(A_log) * softplus(a + dt_bias))
    beta = sigmoid(b)

    Delta rule update:
    state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)
    output = scale * q @ state_new
    """
    B, T, num_q_heads, K = q.shape
    _, _, num_k_heads, _ = k.shape
    _, _, num_v_heads, V = v.shape
    num_heads = num_v_heads
    device = q.device

    assert K == 128 and V == 128
    assert T == 1
    assert num_v_heads % num_q_heads == 0
    assert num_v_heads % num_k_heads == 0

    if scale is None or scale == 0.0:
        scale = 1.0 / math.sqrt(K)

    # Compute g and beta from raw parameters
    x = a.float() + dt_bias.float()  # [B, 1, HV]
    g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x))  # [B, 1, HV]
    beta = torch.sigmoid(b.float())  # [B, 1, HV]

    q_f32 = q.squeeze(1).float()  # [B, Hq, K]
    k_f32 = k.squeeze(1).float()  # [B, Hk, K]
    v_f32 = v.squeeze(1).float()  # [B, Hv, V]
    g_f32 = g.squeeze(1).float()  # [B, Hv]
    beta_f32 = beta.squeeze(1).float()  # [B, Hv]

    if state is not None:
        state_f32 = state.float()  # [B, Hv, V, K]
    else:
        state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)

    # Expand Q/K heads to V-head granularity for GVA-style setups.
    q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)
    k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)

    new_state = torch.zeros_like(state_f32)
    output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)

    for b_idx in range(B):
        for h_idx in range(num_heads):
            q_h = q_exp[b_idx, h_idx]  # [K]
            k_h = k_exp[b_idx, h_idx]  # [K]
            v_h = v_f32[b_idx, h_idx]  # [V]
            h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2)  # [V,K] -> [K,V]
            g_val = g_f32[b_idx, h_idx]
            beta_val = beta_f32[b_idx, h_idx]

            old_state = g_val * h_state
            old_v = matmul(k_h, old_state)  # [V]
            new_v = beta_val * v_h + (1.0 - beta_val) * old_v
            state_remove = matmul(k_h.unsqueeze(1), old_v.unsqueeze(0))
            state_update = matmul(k_h.unsqueeze(1), new_v.unsqueeze(0))
            h_state = old_state - state_remove + state_update

            output[b_idx, h_idx] = scale * matmul(q_h, h_state)
            new_state[b_idx, h_idx] = h_state.transpose(-1, -2)  # [K,V] -> [V,K]

    output = output.unsqueeze(1).to(torch.bfloat16)  # [B,1,Hv,V]
    return output, new_state

Note we first compute β and g:

x = a.float() + dt_bias.float()  # [B, 1, HV]
g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x))  # [B, 1, HV]
beta = torch.sigmoid(b.float())  # [B, 1, HV]

Note that both gating factors and beta factors are within the range from 0 to 1.

Let us now focus on operation per batch and head (as it's structurally the same for each pair of these). For a fixed batch and head we denote the corresponding scalars as gt and βt.

S~t1=gt·St1, i.e. we just scale the state from the previous timestep (or "gate" it if you want to call it that). The smaller the gating factor, the more we scale St1 down.

We'll then perform vo=ktTS~t11×V. This can be seen as the estimate of current vt given information up to timestep t1 if gt1.

We'll then compute a mixture of the estimate above and vt for the current timestep as
vn=βt·vt+(1βt)·vo1×V.
Note that βt can be thought of as an interpolation coefficient between the estimate and the current value. If the confidence in the estimate is high we decrease βt, if it's low we increase it.

Observe that

vnvo=βt(vtvo).

We'll then perform unsqueeze on the vectors in PyTorch to interpret them such that we can compute inner product to obtain the new state:

St=S~t1ktTvo+ktTvn

which can be written more compactly as

St=S~t1+ktT(vnvo)

and using the relation above

St=S~t1+βt·ktT(vtvo).

Note that we can see from this formula that the state will not be updated (except for gating) iff βt=0, since in that case

St=S~t1.

On the other hand it will take the form

St=S~t1+ktT(vtvo)

iff βt=1. Note this is equivalent to the Linear Attention update rule except the gating of the state.

The last step is the update of the output which is simply

ot=η·qtTSt.

which uses a scale factor η.

Conclusion

I hope this short note shed some light on the Gated Delta Net Decoding. Feel free to contact me on Linkedin or X to exchange ideas.