|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# Copyright (c) 2024, Tri Dao. |
| 4 | +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/layernorm_gated.py |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.triton_utils import tl, triton |
| 9 | + |
| 10 | + |
| 11 | +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) |
| 12 | +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) |
| 13 | +@triton.jit |
| 14 | +def _layer_norm_fwd_1pass_kernel( |
| 15 | + X, # pointer to the input |
| 16 | + Y, # pointer to the output |
| 17 | + W, # pointer to the weights |
| 18 | + B, # pointer to the biases |
| 19 | + Z, # pointer to the other branch |
| 20 | + Mean, # pointer to the mean |
| 21 | + Rstd, # pointer to the 1/std |
| 22 | + stride_x_row, # how much to increase the pointer when moving by 1 row |
| 23 | + stride_y_row, |
| 24 | + stride_z_row, |
| 25 | + M, # number of rows in X |
| 26 | + N, # number of columns in X |
| 27 | + eps, # epsilon to avoid division by zero |
| 28 | + BLOCK_N: tl.constexpr, |
| 29 | + HAS_BIAS: tl.constexpr, |
| 30 | + HAS_Z: tl.constexpr, |
| 31 | + NORM_BEFORE_GATE: tl.constexpr, |
| 32 | + IS_RMS_NORM: tl.constexpr, |
| 33 | +): |
| 34 | + # Map the program id to the row of X and Y it should compute. |
| 35 | + row = tl.program_id(0) |
| 36 | + group = tl.program_id(1) |
| 37 | + X += row * stride_x_row + group * N |
| 38 | + Y += row * stride_y_row + group * N |
| 39 | + if HAS_Z: |
| 40 | + Z += row * stride_z_row + group * N |
| 41 | + if not IS_RMS_NORM: |
| 42 | + Mean += group * M |
| 43 | + Rstd += group * M |
| 44 | + W += group * N |
| 45 | + if HAS_BIAS: |
| 46 | + B += group * N |
| 47 | + # Compute mean and variance |
| 48 | + cols = tl.arange(0, BLOCK_N) |
| 49 | + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) |
| 50 | + if HAS_Z and not NORM_BEFORE_GATE: |
| 51 | + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) |
| 52 | + x *= z * tl.sigmoid(z) |
| 53 | + if not IS_RMS_NORM: |
| 54 | + mean = tl.sum(x, axis=0) / N |
| 55 | + tl.store(Mean + row, mean) |
| 56 | + xbar = tl.where(cols < N, x - mean, 0.) |
| 57 | + var = tl.sum(xbar * xbar, axis=0) / N |
| 58 | + else: |
| 59 | + xbar = tl.where(cols < N, x, 0.) |
| 60 | + var = tl.sum(xbar * xbar, axis=0) / N |
| 61 | + rstd = 1 / tl.sqrt(var + eps) |
| 62 | + tl.store(Rstd + row, rstd) |
| 63 | + # Normalize and apply linear transformation |
| 64 | + mask = cols < N |
| 65 | + w = tl.load(W + cols, mask=mask).to(tl.float32) |
| 66 | + if HAS_BIAS: |
| 67 | + b = tl.load(B + cols, mask=mask).to(tl.float32) |
| 68 | + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd |
| 69 | + y = x_hat * w + b if HAS_BIAS else x_hat * w |
| 70 | + if HAS_Z and NORM_BEFORE_GATE: |
| 71 | + z = tl.load(Z + cols, mask=mask).to(tl.float32) |
| 72 | + y *= z * tl.sigmoid(z) |
| 73 | + # Write output |
| 74 | + tl.store(Y + cols, y, mask=mask) |
| 75 | + |
| 76 | + |
| 77 | +def _layer_norm_fwd(x, |
| 78 | + weight, |
| 79 | + bias, |
| 80 | + eps, |
| 81 | + z=None, |
| 82 | + out=None, |
| 83 | + group_size=None, |
| 84 | + norm_before_gate=True, |
| 85 | + is_rms_norm=False): |
| 86 | + M, N = x.shape |
| 87 | + if group_size is None: |
| 88 | + group_size = N |
| 89 | + assert N % group_size == 0 |
| 90 | + ngroups = N // group_size |
| 91 | + assert x.stride(-1) == 1 |
| 92 | + if z is not None: |
| 93 | + assert z.stride(-1) == 1 |
| 94 | + assert z.shape == (M, N) |
| 95 | + assert weight.shape == (N, ) |
| 96 | + assert weight.stride(-1) == 1 |
| 97 | + if bias is not None: |
| 98 | + assert bias.stride(-1) == 1 |
| 99 | + assert bias.shape == (N, ) |
| 100 | + # allocate output |
| 101 | + if out is not None: |
| 102 | + assert out.shape == x.shape |
| 103 | + else: |
| 104 | + out = torch.empty_like(x) |
| 105 | + assert out.stride(-1) == 1 |
| 106 | + mean = torch.empty((ngroups * M, ), dtype=torch.float32, |
| 107 | + device=x.device) if not is_rms_norm else None |
| 108 | + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) |
| 109 | + # Less than 64KB per feature: enqueue fused kernel |
| 110 | + MAX_FUSED_SIZE = 65536 // x.element_size() |
| 111 | + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) |
| 112 | + if group_size > BLOCK_N: |
| 113 | + raise RuntimeError( |
| 114 | + "This layer norm doesn't support feature dim >= 64KB.") |
| 115 | + # heuristics for number of warps |
| 116 | + num_warps = min(max(BLOCK_N // 256, 1), 8) |
| 117 | + grid = (M, ngroups) |
| 118 | + with torch.cuda.device(x.device.index): |
| 119 | + _layer_norm_fwd_1pass_kernel[grid](x, |
| 120 | + out, |
| 121 | + weight, |
| 122 | + bias, |
| 123 | + z, |
| 124 | + mean, |
| 125 | + rstd, |
| 126 | + x.stride(0), |
| 127 | + out.stride(0), |
| 128 | + z.stride(0) if z is not None else 0, |
| 129 | + M, |
| 130 | + group_size, |
| 131 | + eps, |
| 132 | + BLOCK_N=BLOCK_N, |
| 133 | + NORM_BEFORE_GATE=norm_before_gate, |
| 134 | + IS_RMS_NORM=is_rms_norm, |
| 135 | + num_warps=num_warps) |
| 136 | + return out, mean, rstd |
| 137 | + |
| 138 | + |
| 139 | +def rms_norm_gated(x, |
| 140 | + weight, |
| 141 | + bias, |
| 142 | + z=None, |
| 143 | + eps=1e-6, |
| 144 | + group_size=None, |
| 145 | + norm_before_gate=True, |
| 146 | + upcast=True): |
| 147 | + x_shape_og = x.shape |
| 148 | + # reshape input data into 2D tensor |
| 149 | + x = x.reshape(-1, x.shape[-1]) |
| 150 | + if x.stride(-1) != 1: |
| 151 | + x = x.contiguous() |
| 152 | + if z is not None: |
| 153 | + assert z.shape == x_shape_og |
| 154 | + z = z.reshape(-1, z.shape[-1]) |
| 155 | + if z.stride(-1) != 1: |
| 156 | + z = z.contiguous() |
| 157 | + weight = weight.contiguous() |
| 158 | + if bias is not None: |
| 159 | + bias = bias.contiguous() |
| 160 | + y, _, _ = _layer_norm_fwd(x, |
| 161 | + weight, |
| 162 | + bias, |
| 163 | + eps, |
| 164 | + z=z, |
| 165 | + group_size=group_size, |
| 166 | + norm_before_gate=norm_before_gate, |
| 167 | + is_rms_norm=True) |
| 168 | + |
| 169 | + return y.reshape(x_shape_og) |
0 commit comments