|
1 |
| - |
2 | 1 | import matplotlib.gridspec as gridspec
|
3 | 2 | import matplotlib.pyplot as plt
|
4 | 3 | import pandas as pd
|
5 | 4 |
|
6 |
| -cmap=plt.get_cmap('cool') |
7 |
| - |
8 |
| -if __name__ == '__main__': |
| 5 | +cmap = plt.get_cmap("cool") |
9 | 6 |
|
10 |
| - fig = plt.figure(tight_layout=True, figsize=(12,3.5)) |
| 7 | +if __name__ == "__main__": |
| 8 | + fig = plt.figure(tight_layout=True, figsize=(12, 3.5)) |
11 | 9 | gs = gridspec.GridSpec(1, 2)
|
12 | 10 |
|
13 | 11 | dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
|
|
19 | 17 | ax = fig.add_subplot(gs[0, 0])
|
20 | 18 |
|
21 | 19 | # TODO: change this to what you want.
|
22 |
| - rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) |
| 20 | + rdf = pd.read_json("speed_benchmark/info_a100_py2.jsonl", lines=True) |
23 | 21 | df = rdf[rdf.batch_size == batch_size_for_plot1]
|
24 | 22 |
|
25 | 23 | # first plot the time occupied by different operations
|
26 | 24 | for k, marker, ls, color, name in [
|
27 |
| - ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), |
28 |
| - ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), |
29 |
| - |
30 |
| - ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), |
31 |
| - ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), |
32 |
| - ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), |
33 |
| - |
34 |
| - ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), |
35 |
| - ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), |
36 |
| - |
37 |
| - ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), |
38 |
| - ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), |
39 |
| - ('w_quantize_global', '.', '--', 'C4', 'Quantize global W (switchback)'), |
40 |
| - ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize global and\ntranspose W (switchback)'), |
| 25 | + ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (sum of parts)"), |
| 26 | + ( |
| 27 | + "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", |
| 28 | + "o", |
| 29 | + "-", |
| 30 | + "C4", |
| 31 | + "SwitchBack int8 (sum of parts)", |
| 32 | + ), |
| 33 | + ("standard_fwd", "^", "--", "C2", "Matmul XW (standard)"), |
| 34 | + ("standard_gw", "^", "-.", "C2", "Matmul GW (standard)"), |
| 35 | + ("standard_gx", "^", ":", "gray", "Matmul GX (both)"), |
| 36 | + ("global_fwd", "^", "--", "C4", "Int8 Matmul XW (switchback)"), |
| 37 | + ("global_bwd", "^", "-.", "C4", "Int8 Matmul GW (switchback)"), |
| 38 | + ("x_quantize_rowwise", "P", "--", "C4", "Quantize rowwise X (switchback)"), |
| 39 | + ("g_quantize_rowwise", "P", "-.", "C4", "Quantize rowwise G (switchback)"), |
| 40 | + ("w_quantize_global", ".", "--", "C4", "Quantize global W (switchback)"), |
| 41 | + ("w_quantize_global_transpose", ".", "-.", "C4", "Quantize global and\ntranspose W (switchback)"), |
41 | 42 | ]:
|
42 | 43 | xs = []
|
43 | 44 | ys = []
|
|
47 | 48 | df_ = df_[df_.dim_out == embed_dim * 4]
|
48 | 49 | xs.append(embed_dim)
|
49 | 50 | y_ = 0
|
50 |
| - for k_ in k.split('+'): |
| 51 | + for k_ in k.split("+"): |
51 | 52 | y_ += df_[k_].values[0]
|
52 | 53 | df_ = df[df.dim_in == embed_dim * 4]
|
53 | 54 | df_ = df_[df_.dim_out == embed_dim]
|
54 |
| - for k_ in k.split('+'): |
| 55 | + for k_ in k.split("+"): |
55 | 56 | y_ += df_[k_].values[0]
|
56 | 57 | ys.append(y_ * 0.5)
|
57 | 58 |
|
| 59 | + ax.plot( |
| 60 | + xs, |
| 61 | + ys, |
| 62 | + color=color, |
| 63 | + label=name, |
| 64 | + marker=marker, |
| 65 | + markersize=5 if marker == "s" else 5, |
| 66 | + linestyle=ls, |
| 67 | + linewidth=2 if "+" in k else 1.0, |
| 68 | + ) |
58 | 69 |
|
59 |
| - ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) |
60 |
| - |
61 |
| - |
62 |
| - ax.set_xlabel('dim', fontsize=13) |
63 |
| - ax.set_ylabel('time (ms)', fontsize=13) |
| 70 | + ax.set_xlabel("dim", fontsize=13) |
| 71 | + ax.set_ylabel("time (ms)", fontsize=13) |
64 | 72 |
|
65 | 73 | ax.grid()
|
66 | 74 |
|
67 |
| - ax.set_xscale('log') |
| 75 | + ax.set_xscale("log") |
68 | 76 | if logscale_plot1:
|
69 |
| - ax.set_yscale('log') |
| 77 | + ax.set_yscale("log") |
70 | 78 |
|
71 |
| - ax.tick_params(axis='x', labelsize=11) |
72 |
| - ax.tick_params(axis='y', labelsize=11) |
| 79 | + ax.tick_params(axis="x", labelsize=11) |
| 80 | + ax.tick_params(axis="y", labelsize=11) |
73 | 81 |
|
74 | 82 | ax.set_xticks(dims_to_xtick)
|
75 | 83 | ax.set_xticklabels(dims_to_xtick)
|
76 | 84 | ax.set_xticks([], minor=True)
|
77 | 85 |
|
78 |
| - leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) |
79 |
| - leg.get_texts()[0].set_fontweight('bold') |
80 |
| - leg.get_texts()[1].set_fontweight('bold') |
| 86 | + leg = ax.legend(loc="upper center", bbox_to_anchor=(-0.64, 1.0), ncol=1, fontsize=10) |
| 87 | + leg.get_texts()[0].set_fontweight("bold") |
| 88 | + leg.get_texts()[1].set_fontweight("bold") |
81 | 89 | plt.subplots_adjust(left=0.1)
|
82 |
| - ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) |
83 |
| - |
| 90 | + ax.set_title(" Linear layer, batch * sequence length = 32k", fontsize=10, loc="left", y=1.05, pad=-20) |
84 | 91 |
|
85 | 92 | ax = fig.add_subplot(gs[0, 1])
|
86 | 93 |
|
87 | 94 | # now plot the % speedup for different batch sizes
|
88 | 95 | for j, batch_size in enumerate(batch_sizes_for_plot2):
|
89 | 96 | all_xs, all_ys = [], []
|
90 | 97 | for k, marker, ls, color, name in [
|
91 |
| - ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), |
92 |
| - ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), |
| 98 | + ("standard_gx+standard_gw+standard_fwd", "s", "-", "C2", "Standard fp16 (total time)"), |
| 99 | + ( |
| 100 | + "x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd", |
| 101 | + "o", |
| 102 | + "-", |
| 103 | + "C4", |
| 104 | + "SwitchBack int8 (total time)", |
| 105 | + ), |
93 | 106 | ]:
|
94 |
| - |
95 | 107 | xs, ys = [], []
|
96 | 108 | df = rdf[rdf.batch_size == batch_size]
|
97 | 109 | for embed_dim in dims_to_consider:
|
98 | 110 | df_ = df[df.dim_in == embed_dim]
|
99 | 111 | df_ = df_[df_.dim_out == embed_dim * 4]
|
100 | 112 | xs.append(embed_dim)
|
101 | 113 | y_ = 0
|
102 |
| - for k_ in k.split('+'): |
| 114 | + for k_ in k.split("+"): |
103 | 115 | y_ += df_[k_].values[0]
|
104 | 116 | df_ = df[df.dim_in == embed_dim * 4]
|
105 | 117 | df_ = df_[df_.dim_out == embed_dim]
|
106 |
| - for k_ in k.split('+'): |
| 118 | + for k_ in k.split("+"): |
107 | 119 | y_ += df_[k_].values[0]
|
108 | 120 | ys.append(y_ * 0.5)
|
109 | 121 | all_xs.append(xs)
|
110 | 122 | all_ys.append(ys)
|
111 | 123 |
|
112 | 124 | color = cmap(j * 0.25)
|
113 | 125 | real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
|
114 |
| - markers = ['^', 'v', 'P', 'o'] |
115 |
| - ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) |
| 126 | + markers = ["^", "v", "P", "o"] |
| 127 | + ax.plot( |
| 128 | + all_xs[0], |
| 129 | + real_ys, |
| 130 | + color=color, |
| 131 | + label=f"batch * sequence length = {batch_size}", |
| 132 | + marker=markers[j], |
| 133 | + markersize=5 if marker == "s" else 5, |
| 134 | + ) |
116 | 135 |
|
117 | 136 | ax.legend()
|
118 |
| - ax.set_xlabel('dim', fontsize=13) |
119 |
| - ax.set_xscale('log') |
| 137 | + ax.set_xlabel("dim", fontsize=13) |
| 138 | + ax.set_xscale("log") |
120 | 139 | ax.grid()
|
121 |
| - ax.set_ylabel(r'% speedup', fontsize=13) |
| 140 | + ax.set_ylabel(r"% speedup", fontsize=13) |
122 | 141 |
|
123 |
| - |
124 |
| - ax.tick_params(axis='x', labelsize=11) |
125 |
| - ax.tick_params(axis='y', labelsize=11) |
| 142 | + ax.tick_params(axis="x", labelsize=11) |
| 143 | + ax.tick_params(axis="y", labelsize=11) |
126 | 144 |
|
127 | 145 | ax.set_xticks(dims_to_xtick)
|
128 | 146 | ax.set_xticklabels(dims_to_xtick)
|
129 | 147 | ax.set_xticks([], minor=True)
|
130 | 148 |
|
131 |
| - ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) |
132 |
| - |
133 |
| - |
| 149 | + ax.set_title(" Linear layer summary, varying dimensions", fontsize=10, loc="left", y=1.05, pad=-20) |
134 | 150 |
|
135 |
| - plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') |
| 151 | + plt.savefig("speed_benchmark/plot_with_info.pdf", bbox_inches="tight") |
0 commit comments