@@ -57,6 +57,49 @@ def run_diffusers_mmdit(
57
57
58
58
return noise_pred .numpy ()
59
59
60
+ def run_attn_turbine (q , k , v , args ):
61
+ attn_runner = vmfbRunner (
62
+ args .device ,
63
+ args .vmfb_path ,
64
+ None ,
65
+ )
66
+ iree_inputs = [
67
+ ireert .asdevicearray (attn_runner .config .device , q ),
68
+ ireert .asdevicearray (attn_runner .config .device , k ),
69
+ ireert .asdevicearray (attn_runner .config .device , v ),
70
+ ]
71
+ attn_output = attn_runner .ctx .modules .compiled_attn ["run_forward" ](
72
+ * iree_inputs
73
+ ).to_host ()
74
+ return attn_output
75
+
76
+ @torch .no_grad ()
77
+ def run_attn_torch (q , k , v , args ):
78
+ from turbine_models .custom_models .sd3_inference .sd3_mmdit import MMDiTAttention
79
+
80
+ mmdit_attn = MMDiTAttention ()
81
+ attn_output = mmdit_attn .forward (
82
+ torch .tensor (q , dtype = torch .float32 ),
83
+ torch .tensor (k , dtype = torch .float32 ),
84
+ torch .tensor (v , dtype = torch .float32 ),
85
+ )
86
+
87
+ return attn_output .numpy ()
88
+
89
+ def find_errs (turbine_output , torch_output , dim = [], failed_dims = [], errs = []):
90
+ if not np .allclose (turbine_output , torch_output , rtol = 4e-2 , atol = 4e-2 ):
91
+ if turbine_output .ndim > 0 :
92
+ orig_dim = dim
93
+ for idx , i in enumerate (torch_output ):
94
+ dim = [* orig_dim , idx ]
95
+ try :
96
+ np .testing .assert_allclose (turbine_output [idx ], torch_output [idx ], rtol = 4e-2 , atol = 4e-2 )
97
+ except Exception as e :
98
+ err = np .abs (turbine_output [idx ] - torch_output [idx ])
99
+ failed_dims .append (dim )
100
+ errs .append ([err , turbine_output [idx ], torch_output [idx ]])
101
+ failed_dims , errs = find_errs (turbine_output [idx ], torch_output [idx ], dim , failed_dims , errs )
102
+ return (failed_dims , errs )
60
103
61
104
if __name__ == "__main__" :
62
105
from turbine_models .custom_models .sd3_inference .sd3_cmd_opts import args
@@ -69,6 +112,29 @@ def run_diffusers_mmdit(
69
112
dtype = torch .float16
70
113
else :
71
114
dtype = torch .float32
115
+
116
+ if args .attn_repro :
117
+ qkv_shape = (2 , 24 , 4250 , 64 )
118
+ example_qkv = [
119
+ np .load ("q.npy" ).astype (np .float16 ),
120
+ np .load ("k.npy" ).astype (np .float16 ),
121
+ np .load ("v.npy" ).astype (np .float16 ),
122
+ ]
123
+ turbine_output = run_attn_turbine (
124
+ * example_qkv ,
125
+ args ,
126
+ )
127
+ torch_output = run_attn_torch (* example_qkv , args ).astype (np .float16 )
128
+ np .save ("turbine_attn_output.npy" , turbine_output )
129
+ np .save ("torch_attn_output.npy" , torch_output )
130
+ failed_dims , errs = find_errs (turbine_output , torch_output )
131
+ for idx , dim in enumerate (failed_dims ):
132
+ if len (dim ) == len (torch_output .shape ):
133
+ print ("Failed dimension: " , dim , " with error: " , errs [idx ][0 ])
134
+ print ("Turbine output: " , errs [idx ][1 ])
135
+ print ("Torch output: " , errs [idx ][2 ])
136
+ print (torch_output .shape )
137
+ exit ()
72
138
73
139
batch_size = args .batch_size * 2 #do classifier free guidance
74
140
hidden_states = torch .randn (
0 commit comments