Skip to content

Commit f0dca9a

Browse files
committed
Bit of cleanup
1 parent f2c53ef commit f0dca9a

File tree

14 files changed

+316
-257
lines changed

14 files changed

+316
-257
lines changed

exllamav2/compat.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ def pairwise(iterable):
1818

1919
tested_peer_copy = None
2020

21-
def test_gpu_peer_copy(device_a: torch.Device,
22-
device_b: torch.Device):
21+
def test_gpu_peer_copy(
22+
device_a: torch.Device,
23+
device_b: torch.Device
24+
):
2325
global tested_peer_copy
2426

2527
if tested_peer_copy is None:
@@ -47,9 +49,11 @@ def test_gpu_peer_copy(device_a: torch.Device,
4749
return False
4850

4951

50-
def safe_move_tensor(tensor: torch.Tensor | tuple[torch.Tensor],
51-
device: torch.Device | str | int,
52-
non_blocking = False):
52+
def safe_move_tensor(
53+
tensor: torch.Tensor | tuple[torch.Tensor],
54+
device: torch.Device | str | int,
55+
non_blocking = False
56+
):
5357

5458
# Accept tensor or tuple of tensors
5559

exllamav2/device.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@ class ExLlamaV2DeviceContext:
4040

4141
stream: torch.cuda.Stream
4242

43-
def __init__(self,
44-
model: ExLlamaV2,
45-
device_idx: int,
46-
scratch_bytes: int):
47-
43+
def __init__(
44+
self,
45+
model: ExLlamaV2,
46+
device_idx: int,
47+
scratch_bytes: int
48+
):
4849
self.model = model
4950
self.device_idx = device_idx
5051
self.ready = False

exllamav2/embedding.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ class ExLlamaV2Embedding(ExLlamaV2Module):
2020

2121
is_tp: bool
2222

23-
def __init__(self,
24-
model: ExLlamaV2,
25-
key: str):
23+
def __init__(
24+
self,
25+
model: ExLlamaV2,
26+
key: str
27+
):
2628
super().__init__(model, key)
2729

2830
self.is_tp = False
@@ -93,14 +95,16 @@ def scratch_space(self) -> int:
9395
return 0
9496

9597

96-
def forward(self,
97-
hidden_states: torch.Tensor,
98-
cache = None,
99-
attn_params: ExLlamaV2Attention.Params = None,
100-
past_len = None,
101-
intermediates: bool = False,
102-
loras = None,
103-
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
98+
def forward(
99+
self,
100+
hidden_states: torch.Tensor,
101+
cache = None,
102+
attn_params: ExLlamaV2Attention.Params = None,
103+
past_len = None,
104+
intermediates: bool = False,
105+
loras = None,
106+
**kwargs
107+
) -> torch.Tensor | dict[str: torch.Tensor]:
104108

105109
cfg = self.model.config
106110

exllamav2/ext.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -339,19 +339,21 @@ def make_q_matrix(w: dict,
339339
if "q_group_map" not in w:
340340
w["q_group_map"] = make_group_map(w["q_groups"], w["q_weight"].shape[0])
341341

342-
return ext_c.make_q_matrix(w["q_weight"],
343-
w.get("q_perm", none_tensor),
344-
w.get("q_invperm", none_tensor),
345-
w["q_scale"],
346-
w["q_scale_max"],
347-
w["q_groups"],
348-
w["q_group_map"],
349-
none_tensor,
350-
none_tensor,
351-
none_tensor,
352-
w.get("bias", none_tensor),
353-
temp_dq,
354-
max_dq_rows)
342+
return ext_c.make_q_matrix(
343+
w["q_weight"],
344+
w.get("q_perm", none_tensor),
345+
w.get("q_invperm", none_tensor),
346+
w["q_scale"],
347+
w["q_scale_max"],
348+
w["q_groups"],
349+
w["q_group_map"],
350+
none_tensor,
351+
none_tensor,
352+
none_tensor,
353+
w.get("bias", none_tensor),
354+
temp_dq,
355+
max_dq_rows
356+
)
355357

356358
# GPTQ
357359

@@ -370,36 +372,38 @@ def make_q_matrix(w: dict,
370372
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
371373
w["q_invperm"] = torch.empty_like(w["q_perm"])
372374

373-
return ext_c.make_q_matrix(w["qweight"],
374-
w["q_perm"],
375-
w["q_invperm"],
376-
none_tensor,
377-
none_tensor,
378-
none_tensor,
379-
none_tensor,
380-
w["qzeros"],
381-
w["scales"],
382-
w["g_idx"].cpu(),
383-
w.get("bias", none_tensor),
384-
temp_dq,
385-
max_dq_rows)
375+
return ext_c.make_q_matrix(
376+
w["qweight"],
377+
w["q_perm"],
378+
w["q_invperm"],
379+
none_tensor,
380+
none_tensor,
381+
none_tensor,
382+
none_tensor,
383+
w["qzeros"],
384+
w["scales"],
385+
w["g_idx"].cpu(),
386+
w.get("bias", none_tensor),
387+
temp_dq,
388+
max_dq_rows
389+
)
386390

387391
# GPTQ without g_idx
388392

389393
else:
390394

391-
return ext_c.make_q_matrix(w["qweight"],
392-
none_tensor,
393-
none_tensor,
394-
none_tensor,
395-
none_tensor,
396-
none_tensor,
397-
none_tensor,
398-
w["qzeros"],
399-
w["scales"],
400-
none_tensor,
401-
w.get("bias", none_tensor),
402-
temp_dq,
403-
max_dq_rows)
404-
405-
395+
return ext_c.make_q_matrix(
396+
w["qweight"],
397+
none_tensor,
398+
none_tensor,
399+
none_tensor,
400+
none_tensor,
401+
none_tensor,
402+
none_tensor,
403+
w["qzeros"],
404+
w["scales"],
405+
none_tensor,
406+
w.get("bias", none_tensor),
407+
temp_dq,
408+
max_dq_rows
409+
)

exllamav2/fasttensors.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@ class STFile:
5252
st_context = None
5353
tensor_remap: dict | None
5454

55-
def __init__(self,
56-
filename: str,
57-
fast: bool = True,
58-
keymap: list[tuple[str, str]] = None):
59-
55+
def __init__(
56+
self,
57+
filename: str,
58+
fast: bool = True,
59+
keymap: list[tuple[str, str]] = None
60+
):
6061
global global_stfiles
6162

6263
self.metadata = None
@@ -101,9 +102,11 @@ def __init__(self,
101102

102103

103104
@staticmethod
104-
def open(filename,
105-
fast = True,
106-
keymap: list[tuple[str, str]] = None) -> STFile:
105+
def open(
106+
filename,
107+
fast = True,
108+
keymap: list[tuple[str, str]] = None
109+
) -> STFile:
107110
"""
108111
Open safetensors file, scan header and retain handle.
109112
@@ -181,12 +184,14 @@ def get_cm(self, device):
181184
return f
182185

183186

184-
def get_tensor(self,
185-
key: str,
186-
device,
187-
not_fast: bool = False,
188-
cached: bool = False,
189-
out_dtype = None) -> torch.Tensor:
187+
def get_tensor(
188+
self,
189+
key: str,
190+
device,
191+
not_fast: bool = False,
192+
cached: bool = False,
193+
out_dtype = None
194+
) -> torch.Tensor:
190195
global global_tensorcache
191196

192197
torch.cuda.synchronize()

exllamav2/headnorm.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ class ExLlamaV2HeadNorm(ExLlamaV2Module):
2121
num_heads: int
2222

2323

24-
def __init__(self,
25-
model: ExLlamaV2,
26-
key: str,
27-
num_heads: int,
28-
head_dim: int):
24+
def __init__(
25+
self,
26+
model: ExLlamaV2,
27+
key: str,
28+
num_heads: int,
29+
head_dim: int
30+
):
2931
super().__init__(model, key)
3032

3133
self.layernorm = None
@@ -101,14 +103,16 @@ def scratch_space(self) -> int:
101103
return 0
102104

103105

104-
def forward(self,
105-
hidden_states: torch.Tensor,
106-
cache = None,
107-
attn_params = None,
108-
past_len = None,
109-
intermediates: bool = False,
110-
loras = None,
111-
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
106+
def forward(
107+
self,
108+
hidden_states: torch.Tensor,
109+
cache = None,
110+
attn_params = None,
111+
past_len = None,
112+
intermediates: bool = False,
113+
loras = None,
114+
**kwargs
115+
) -> torch.Tensor | dict[str: torch.Tensor]:
112116

113117
norm = torch.empty_like(hidden_states)
114118
ext_c.head_norm(hidden_states,
@@ -122,14 +126,16 @@ def forward(self,
122126
else:
123127
return hidden_states
124128

125-
def forward_torch(self,
126-
hidden_states: torch.Tensor,
127-
cache = None,
128-
attn_params = None,
129-
past_len = None,
130-
intermediates: bool = False,
131-
loras = None,
132-
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
129+
def forward_torch(
130+
self,
131+
hidden_states: torch.Tensor,
132+
cache = None,
133+
attn_params = None,
134+
past_len = None,
135+
intermediates: bool = False,
136+
loras = None,
137+
**kwargs
138+
) -> torch.Tensor | dict[str: torch.Tensor]:
133139

134140
input_dtype = hidden_states.dtype
135141
hidden_states = hidden_states.to(torch.float32)

exllamav2/layernorm.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ class ExLlamaV2LayerNorm(ExLlamaV2Module):
1818
variance_epsilon: float
1919

2020

21-
def __init__(self,
22-
model: ExLlamaV2,
23-
key: str):
21+
def __init__(
22+
self,
23+
model: ExLlamaV2,
24+
key: str
25+
):
2426
super().__init__(model, key)
2527

2628
self.layernorm = None
@@ -93,15 +95,17 @@ def scratch_space(self) -> int:
9395
return 0
9496

9597

96-
def forward(self,
97-
hidden_states: torch.Tensor,
98-
cache = None,
99-
attn_params = None,
100-
past_len = None,
101-
intermediates: bool = False,
102-
loras = None,
103-
output_fp32 = False, # TODO:
104-
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
98+
def forward(
99+
self,
100+
hidden_states: torch.Tensor,
101+
cache = None,
102+
attn_params = None,
103+
past_len = None,
104+
intermediates: bool = False,
105+
loras = None,
106+
output_fp32 = False, # TODO:
107+
**kwargs
108+
) -> torch.Tensor | dict[str: torch.Tensor]:
105109

106110
output_shape = hidden_states.shape
107111
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
@@ -120,15 +124,17 @@ def forward(self,
120124
return hidden_states
121125

122126

123-
def forward_torch(self,
124-
hidden_states: torch.Tensor,
125-
cache = None,
126-
attn_params = None,
127-
past_len = None,
128-
intermediates: bool = False,
129-
loras = None,
130-
output_fp32 = False, # TODO:
131-
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
127+
def forward_torch(
128+
self,
129+
hidden_states: torch.Tensor,
130+
cache = None,
131+
attn_params = None,
132+
past_len = None,
133+
intermediates: bool = False,
134+
loras = None,
135+
output_fp32 = False, # TODO:
136+
**kwargs
137+
) -> torch.Tensor | dict[str: torch.Tensor]:
132138

133139
hidden_states = self.layernorm(hidden_states)
134140

0 commit comments

Comments
 (0)