Skip to content

Commit 06029dd

Browse files
Merge pull request #1081 from akx/ruff-format
Reformat Python code with Ruff
2 parents fd723b7 + 5a4263f commit 06029dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2661
-1777
lines changed

.github/scripts/set_platform_tag.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@ def get_platform_tag(architecture):
77
system = platform.system()
88

99
if system == "Linux":
10-
tag = (
11-
"manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
12-
)
10+
tag = "manylinux_2_24_x86_64" if architecture == "x86_64" else "manylinux_2_24_aarch64"
1311
elif system == "Darwin":
1412
tag = "macosx_13_1_x86_64" if architecture == "x86_64" else "macosx_13_1_arm64"
1513
elif system == "Windows":

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.2.0
3+
rev: v0.3.2
44
hooks:
55
- id: ruff
66
args:
77
- --fix
8-
# - id: ruff-format # TODO: enable when the time is right
8+
- id: ruff-format
99
- repo: https://github.com/pre-commit/pre-commit-hooks
1010
rev: v4.5.0
1111
hooks:
Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
21
import matplotlib.gridspec as gridspec
32
import matplotlib.pyplot as plt
43
import pandas as pd
54

6-
cmap=plt.get_cmap('cool')
7-
8-
if __name__ == '__main__':
5+
cmap = plt.get_cmap("cool")
96

10-
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
7+
if __name__ == "__main__":
8+
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
119
gs = gridspec.GridSpec(1, 2)
1210

1311
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
@@ -19,25 +17,28 @@
1917
ax = fig.add_subplot(gs[0, 0])
2018

2119
# TODO: change this to what you want.
22-
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
20+
rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True)
2321
df = rdf[rdf.batch_size == batch_size_for_plot1]
2422

2523
# first plot the time occupied by different operations
2624
for k, marker, ls, color, name in [
27-
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
28-
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
29-
30-
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
31-
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
32-
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
33-
34-
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
35-
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
36-
37-
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
38-
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
39-
('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'),
40-
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'),
25+
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"),
26+
(
27+
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
28+
"o",
29+
"-",
30+
"C4",
31+
"SwitchBack int8 (sum of parts)",
32+
),
33+
("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"),
34+
("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"),
35+
("standard_gx", "^", ":", "gray", "Matmul GX (both)"),
36+
("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"),
37+
("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"),
38+
("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"),
39+
("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"),
40+
("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"),
41+
("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"),
4142
]:
4243
xs = []
4344
ys = []
@@ -47,89 +48,104 @@
4748
df_ = df_[df_.dim_out == embed_dim * 4]
4849
xs.append(embed_dim)
4950
y_ = 0
50-
for k_ in k.split('+'):
51+
for k_ in k.split("+"):
5152
y_ += df_[k_].values[0]
5253
df_ = df[df.dim_in == embed_dim * 4]
5354
df_ = df_[df_.dim_out == embed_dim]
54-
for k_ in k.split('+'):
55+
for k_ in k.split("+"):
5556
y_ += df_[k_].values[0]
5657
ys.append(y_ * 0.5)
5758

59+
ax.plot(
60+
xs,
61+
ys,
62+
color=color,
63+
label=name,
64+
marker=marker,
65+
markersize=5 if marker == "s" else 5,
66+
linestyle=ls,
67+
linewidth=2 if "+" in k else 1.0,
68+
)
5869

59-
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
60-
61-
62-
ax.set_xlabel('dim', fontsize=13)
63-
ax.set_ylabel('time (ms)', fontsize=13)
70+
ax.set_xlabel("dim", fontsize=13)
71+
ax.set_ylabel("time (ms)", fontsize=13)
6472

6573
ax.grid()
6674

67-
ax.set_xscale('log')
75+
ax.set_xscale("log")
6876
if logscale_plot1:
69-
ax.set_yscale('log')
77+
ax.set_yscale("log")
7078

71-
ax.tick_params(axis='x', labelsize=11)
72-
ax.tick_params(axis='y', labelsize=11)
79+
ax.tick_params(axis="x", labelsize=11)
80+
ax.tick_params(axis="y", labelsize=11)
7381

7482
ax.set_xticks(dims_to_xtick)
7583
ax.set_xticklabels(dims_to_xtick)
7684
ax.set_xticks([], minor=True)
7785

78-
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
79-
leg.get_texts()[0].set_fontweight('bold')
80-
leg.get_texts()[1].set_fontweight('bold')
86+
leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10)
87+
leg.get_texts()[0].set_fontweight("bold")
88+
leg.get_texts()[1].set_fontweight("bold")
8189
plt.subplots_adjust(left=0.1)
82-
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
83-
90+
ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20)
8491

8592
ax = fig.add_subplot(gs[0, 1])
8693

8794
# now plot the % speedup for different batch sizes
8895
for j, batch_size in enumerate(batch_sizes_for_plot2):
8996
all_xs, all_ys = [], []
9097
for k, marker, ls, color, name in [
91-
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
92-
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
98+
("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"),
99+
(
100+
"x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd",
101+
"o",
102+
"-",
103+
"C4",
104+
"SwitchBack int8 (total time)",
105+
),
93106
]:
94-
95107
xs, ys = [], []
96108
df = rdf[rdf.batch_size == batch_size]
97109
for embed_dim in dims_to_consider:
98110
df_ = df[df.dim_in == embed_dim]
99111
df_ = df_[df_.dim_out == embed_dim * 4]
100112
xs.append(embed_dim)
101113
y_ = 0
102-
for k_ in k.split('+'):
114+
for k_ in k.split("+"):
103115
y_ += df_[k_].values[0]
104116
df_ = df[df.dim_in == embed_dim * 4]
105117
df_ = df_[df_.dim_out == embed_dim]
106-
for k_ in k.split('+'):
118+
for k_ in k.split("+"):
107119
y_ += df_[k_].values[0]
108120
ys.append(y_ * 0.5)
109121
all_xs.append(xs)
110122
all_ys.append(ys)
111123

112124
color = cmap(j * 0.25)
113125
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
114-
markers = ['^', 'v', 'P', 'o']
115-
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
126+
markers = ["^", "v", "P", "o"]
127+
ax.plot(
128+
all_xs[0],
129+
real_ys,
130+
color=color,
131+
label=f"batch * sequence length = {batch_size}",
132+
marker=markers[j],
133+
markersize=5 if marker == "s" else 5,
134+
)
116135

117136
ax.legend()
118-
ax.set_xlabel('dim', fontsize=13)
119-
ax.set_xscale('log')
137+
ax.set_xlabel("dim", fontsize=13)
138+
ax.set_xscale("log")
120139
ax.grid()
121-
ax.set_ylabel(r'% speedup', fontsize=13)
140+
ax.set_ylabel(r"% speedup", fontsize=13)
122141

123-
124-
ax.tick_params(axis='x', labelsize=11)
125-
ax.tick_params(axis='y', labelsize=11)
142+
ax.tick_params(axis="x", labelsize=11)
143+
ax.tick_params(axis="y", labelsize=11)
126144

127145
ax.set_xticks(dims_to_xtick)
128146
ax.set_xticklabels(dims_to_xtick)
129147
ax.set_xticks([], minor=True)
130148

131-
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
132-
133-
149+
ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20)
134150

135-
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
151+
plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight")

benchmarking/switchback/speed_benchmark.py

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,31 @@
2020

2121
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
2222

23-
def get_time(k, fn, info_dict):
2423

24+
def get_time(k, fn, info_dict):
2525
for _ in range(repeat // 2):
26-
fn()
26+
fn()
2727

2828
torch.cuda.synchronize()
2929
start = time.time()
3030
for _ in range(repeat):
31-
fn()
31+
fn()
3232

3333
torch.cuda.synchronize()
3434
end = time.time()
3535
ms = (end - start) / repeat * 1000
3636
print(f"time {k}: {ms:.3f} ms")
3737
info_dict[k] = ms
3838

39-
if __name__ == '__main__':
39+
40+
if __name__ == "__main__":
4041
torch.manual_seed(0)
4142
wm = 4
4243
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
4344
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
44-
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
45-
45+
for batch_size in [256 * 32, 256 * 64, 256 * 128, 256 * 256, 256 * 512]:
4646
# switch switches dim_in and dim_out
4747
for switch in [False, True]:
48-
4948
# hparams
5049
repeat = 64
5150
batch_size = batch_size
@@ -73,35 +72,86 @@ def get_time(k, fn, info_dict):
7372
state_w_rowwise = w.max(dim=1)[0]
7473
state_w_global = w.max()
7574

76-
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
77-
78-
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
79-
get_time('standard_gw', lambda : g.t().matmul(x), info)
80-
get_time('standard_gx', lambda : g.matmul(w), info)
81-
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
82-
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
83-
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
84-
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
85-
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
86-
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
87-
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
88-
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
89-
get_time('w_quantize_global', lambda : quantize_global(w), info)
90-
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
91-
92-
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
93-
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
94-
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
95-
96-
print('TOTAL STANDARD', time_standard)
97-
print('TOTAL ROWWISE', time_rowwise)
98-
print('TOTAL GLOBAL', time_global)
99-
100-
print('speedup', -100*(time_global - time_standard)/time_standard)
101-
102-
info['time_standard'] = time_standard
103-
info['time_rowwise'] = time_rowwise
104-
info['time_global'] = time_global
75+
info = {
76+
"repeat": repeat,
77+
"batch_size": batch_size,
78+
"dim_out": dim_out,
79+
"dim_in": dim_in,
80+
"wm": wm,
81+
"switch": switch,
82+
}
83+
84+
get_time("standard_fwd", lambda: x.matmul(w.t()), info)
85+
get_time("standard_gw", lambda: g.t().matmul(x), info)
86+
get_time("standard_gx", lambda: g.matmul(w), info)
87+
get_time(
88+
"rowwise_fwd",
89+
lambda: int8_matmul_rowwise_dequantize(
90+
x_int8,
91+
w_int8.t(),
92+
state_x_rowwise,
93+
state_w_columnwise,
94+
None,
95+
),
96+
info,
97+
)
98+
get_time(
99+
"rowwise_bwd",
100+
lambda: int8_matmul_rowwise_dequantize(
101+
g_int8,
102+
wt_int8.t(),
103+
state_x_rowwise,
104+
state_w_rowwise,
105+
None,
106+
),
107+
info,
108+
)
109+
get_time(
110+
"global_fwd",
111+
lambda: int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None),
112+
info,
113+
)
114+
get_time(
115+
"global_bwd",
116+
lambda: int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None),
117+
info,
118+
)
119+
get_time("x_quantize_rowwise", lambda: quantize_rowwise(x), info)
120+
get_time("g_quantize_rowwise", lambda: quantize_rowwise(g), info)
121+
get_time("w_quantize_rowwise", lambda: quantize_rowwise(w), info)
122+
get_time("w_quantize_colwise_transpose", lambda: quantize_columnwise_and_transpose(w), info)
123+
get_time("w_quantize_global", lambda: quantize_global(w), info)
124+
get_time("w_quantize_global_transpose", lambda: quantize_global_transpose(w), info)
125+
126+
time_standard = info["standard_fwd"] + info["standard_gx"] + info["standard_gw"]
127+
time_rowwise = (
128+
info["x_quantize_rowwise"]
129+
+ info["g_quantize_rowwise"]
130+
+ info["w_quantize_colwise_transpose"]
131+
+ info["w_quantize_rowwise"]
132+
+ info["standard_gw"]
133+
+ info["rowwise_fwd"]
134+
+ info["rowwise_bwd"]
135+
)
136+
time_global = (
137+
info["x_quantize_rowwise"]
138+
+ info["g_quantize_rowwise"]
139+
+ info["w_quantize_global"]
140+
+ info["w_quantize_global_transpose"]
141+
+ info["standard_gw"]
142+
+ info["global_fwd"]
143+
+ info["global_bwd"]
144+
)
145+
146+
print("TOTAL STANDARD", time_standard)
147+
print("TOTAL ROWWISE", time_rowwise)
148+
print("TOTAL GLOBAL", time_global)
149+
150+
print("speedup", -100 * (time_global - time_standard) / time_standard)
151+
152+
info["time_standard"] = time_standard
153+
info["time_rowwise"] = time_rowwise
154+
info["time_global"] = time_global
105155

106156
info_json = json.dumps(info)
107157

0 commit comments

Comments
 (0)