Skip to content

Commit 8584b6d

Browse files
committed
add layers
1 parent 37d2931 commit 8584b6d

File tree

6 files changed

+214
-0
lines changed

6 files changed

+214
-0
lines changed

backends/python/server/text_embeddings_server/layers/__init__.py

Whitespace-only changes.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from text_embeddings_server.utils.import_utils import SYSTEM
2+
import os
3+
4+
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
5+
raise ImportError("`USE_FLASH_ATTENTION` is false.")
6+
if SYSTEM == "cuda":
7+
from .cuda import attention
8+
elif SYSTEM == "rocm":
9+
from .rocm import attention
10+
else:
11+
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import os
2+
import torch
3+
4+
from loguru import logger
5+
6+
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
7+
raise ImportError("`USE_FLASH_ATTENTION` is false.")
8+
9+
if not torch.cuda.is_available():
10+
raise ImportError("CUDA is not available")
11+
12+
major, minor = torch.cuda.get_device_capability()
13+
is_sm75 = major == 7 and minor == 5
14+
is_sm8x = major == 8 and minor >= 0
15+
is_sm90 = major == 9 and minor == 0
16+
17+
HAS_FLASH_ATTN = False
18+
HAS_FLASH_ATTN_V2 = False
19+
try:
20+
try:
21+
import flash_attn_2_cuda
22+
except ImportError:
23+
raise ImportError(
24+
"Flash Attention V2 is not installed.\n"
25+
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
26+
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
27+
)
28+
if not (is_sm8x or is_sm90):
29+
raise ImportError(
30+
f"GPU with CUDA capability {major} {minor} is not supported for "
31+
"Flash Attention V2"
32+
)
33+
HAS_FLASH_ATTN_V2 = True
34+
except ImportError as e:
35+
try:
36+
import flash_attn_cuda
37+
except ImportError:
38+
raise ImportError(
39+
"Flash Attention is not installed.\n"
40+
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
41+
"or install flash attention with `cd server && make install install-flash-attention`"
42+
) from e
43+
44+
if not (is_sm75 or is_sm8x or is_sm90):
45+
raise ImportError(
46+
f"GPU with CUDA capability {major} {minor} is not supported"
47+
) from e
48+
logger.warning(f"Unable to use Flash Attention V2: {e}")
49+
HAS_FLASH_ATTN = True
50+
51+
52+
def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
53+
if HAS_FLASH_ATTN_V2:
54+
return flash_attn_2_cuda.varlen_fwd(
55+
q,
56+
k,
57+
v,
58+
out,
59+
cu_seqlens,
60+
cu_seqlens,
61+
max_s,
62+
max_s,
63+
0.0,
64+
softmax_scale,
65+
False,
66+
is_causal,
67+
-1,
68+
-1,
69+
False,
70+
None,
71+
)
72+
73+
if HAS_FLASH_ATTN:
74+
return flash_attn_cuda.fwd(
75+
q,
76+
k,
77+
v,
78+
out,
79+
cu_seqlens,
80+
cu_seqlens,
81+
max_s,
82+
max_s,
83+
0.0,
84+
softmax_scale,
85+
False,
86+
is_causal,
87+
False,
88+
0,
89+
None,
90+
)
91+
92+
raise NotImplementedError("flash attention is not installed")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
import torch
3+
from text_embeddings_server.utils.import_utils import SYSTEM
4+
from loguru import logger
5+
6+
major, minor = torch.cuda.get_device_capability()
7+
is_sm75 = major == 7 and minor == 5
8+
9+
if SYSTEM == "rocm":
10+
try:
11+
import flash_attn_2_cuda
12+
13+
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
14+
except ImportError as e:
15+
if major >= 8 or is_sm75:
16+
architecture_suffix = f"-{SYSTEM}"
17+
raise ImportError(f"Flash Attention V2 is not installed. {e}")
18+
else:
19+
for idx in range(torch.cuda.device_count()):
20+
name = torch.cuda.get_device_name(idx)
21+
if "MI210" not in name and "MI250" not in name and "MI300" not in name:
22+
raise ImportError(
23+
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
24+
)
25+
raise ImportError(
26+
f"AMD GPU with ROCm capability {major} {minor} is not supported"
27+
) from e
28+
29+
def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
30+
return flash_attn_2_cuda.varlen_fwd(
31+
q,
32+
k,
33+
v,
34+
out,
35+
cu_seqlens,
36+
cu_seqlens,
37+
max_s,
38+
max_s,
39+
0.0,
40+
softmax_scale,
41+
False,
42+
is_causal,
43+
False,
44+
None,
45+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
from text_embeddings_server.utils.import_utils import SYSTEM
3+
4+
from transformers.models.bert import BertConfig
5+
6+
if SYSTEM == "cuda":
7+
import dropout_layer_norm
8+
9+
class FastLayerNorm:
10+
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
11+
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
12+
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
13+
self.variance_epsilon = config.layer_norm_eps
14+
15+
def forward(self, hidden_states, residual=None):
16+
normed_hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
17+
hidden_states,
18+
residual,
19+
self.weight,
20+
self.bias,
21+
None,
22+
None,
23+
None,
24+
None,
25+
0.0,
26+
self.variance_epsilon,
27+
1.0,
28+
0,
29+
None,
30+
False,
31+
False,
32+
)
33+
if residual is None:
34+
residual = hidden_states
35+
36+
return normed_hidden_states, residual
37+
38+
elif SYSTEM == "rocm":
39+
class FastLayerNorm:
40+
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
41+
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
42+
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
43+
self.variance_epsilon = config.layer_norm_eps
44+
45+
def forward(self, hidden_states, residual=None):
46+
if residual is not None:
47+
hidden_states += residual
48+
residual = hidden_states
49+
50+
hidden_states = torch.nn.functional.layer_norm(hidden_states, self.weight.shape, self.weight, self.bias, eps=self.variance_epsilon)
51+
52+
return hidden_states, residual
53+
else:
54+
raise ValueError("System not recognized")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
from loguru import logger
3+
4+
SYSTEM = None
5+
if torch.version.hip is not None:
6+
SYSTEM = "rocm"
7+
elif torch.version.cuda is not None and torch.cuda.is_available():
8+
SYSTEM = "cuda"
9+
else:
10+
SYSTEM = "cpu"
11+
12+
logger.info(f"Python backend: detected system {SYSTEM}")

0 commit comments

Comments
 (0)