Skip to content

Commit 94ba46d

Browse files
committed
Attn debugging tools
1 parent d7c709e commit 94ba46d

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,49 @@ def run_diffusers_mmdit(
5757

5858
return noise_pred.numpy()
5959

60+
def run_attn_turbine(q, k, v, args):
61+
attn_runner = vmfbRunner(
62+
args.device,
63+
args.vmfb_path,
64+
None,
65+
)
66+
iree_inputs = [
67+
ireert.asdevicearray(attn_runner.config.device, q),
68+
ireert.asdevicearray(attn_runner.config.device, k),
69+
ireert.asdevicearray(attn_runner.config.device, v),
70+
]
71+
attn_output = attn_runner.ctx.modules.compiled_attn["run_forward"](
72+
*iree_inputs
73+
).to_host()
74+
return attn_output
75+
76+
@torch.no_grad()
77+
def run_attn_torch(q, k, v, args):
78+
from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTAttention
79+
80+
mmdit_attn = MMDiTAttention()
81+
attn_output = mmdit_attn.forward(
82+
torch.tensor(q, dtype=torch.float32),
83+
torch.tensor(k, dtype=torch.float32),
84+
torch.tensor(v, dtype=torch.float32),
85+
)
86+
87+
return attn_output.numpy()
88+
89+
def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
90+
if not np.allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2):
91+
if turbine_output.ndim > 0:
92+
orig_dim = dim
93+
for idx, i in enumerate(torch_output):
94+
dim = [*orig_dim, idx]
95+
try:
96+
np.testing.assert_allclose(turbine_output[idx], torch_output[idx], rtol=4e-2, atol=4e-2)
97+
except Exception as e:
98+
err = np.abs(turbine_output[idx] - torch_output[idx])
99+
failed_dims.append(dim)
100+
errs.append([err, turbine_output[idx], torch_output[idx]])
101+
failed_dims, errs = find_errs(turbine_output[idx], torch_output[idx], dim, failed_dims, errs)
102+
return (failed_dims, errs)
60103

61104
if __name__ == "__main__":
62105
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
@@ -69,6 +112,29 @@ def run_diffusers_mmdit(
69112
dtype = torch.float16
70113
else:
71114
dtype = torch.float32
115+
116+
if args.attn_repro:
117+
qkv_shape = (2, 24, 4250, 64)
118+
example_qkv = [
119+
np.load("q.npy").astype(np.float16),
120+
np.load("k.npy").astype(np.float16),
121+
np.load("v.npy").astype(np.float16),
122+
]
123+
turbine_output = run_attn_turbine(
124+
*example_qkv,
125+
args,
126+
)
127+
torch_output = run_attn_torch(*example_qkv, args).astype(np.float16)
128+
np.save("turbine_attn_output.npy", turbine_output)
129+
np.save("torch_attn_output.npy", torch_output)
130+
failed_dims, errs = find_errs(turbine_output, torch_output)
131+
for idx, dim in enumerate(failed_dims):
132+
if len(dim) == len(torch_output.shape):
133+
print("Failed dimension: ", dim, " with error: ", errs[idx][0])
134+
print("Turbine output: ", errs[idx][1])
135+
print("Torch output: ", errs[idx][2])
136+
print(torch_output.shape)
137+
exit()
72138

73139
batch_size = args.batch_size * 2 #do classifier free guidance
74140
hidden_states = torch.randn(

models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ def step(self, noise_pred, t, sample, guidance_scale, i):
9393
sample = self.model.step(noise_pred, t, sample, return_dict=False)[0]
9494
return sample.type(self.dtype)
9595

96-
97-
class SharkSchedulerCPUWrapper:
96+
# Wraps a diffusers scheduler running on native pytorch+cpu.
97+
# This allows us to use it interchangeably with compiled schedulers in our pipeline(s).
98+
class TorchCPUFlowSchedulerCompat:
9899
@torch.no_grad()
99100
def __init__(
100101
self, scheduler, batch_size, num_inference_steps, dest_device, latents_dtype

0 commit comments

Comments
 (0)