Skip to content

Commit ceaae4b

Browse files
authored
Fix torch bloat16 -> numpy float32 conversion for compile max-autotune (#96)
1 parent 97a69cf commit ceaae4b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

segment_anything_fast/predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def predict(
162162
)
163163

164164
masks_np = masks[0].detach().cpu().numpy()
165-
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
166-
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
165+
iou_predictions_np = iou_predictions[0].detach().cpu().float().numpy()
166+
low_res_masks_np = low_res_masks[0].detach().cpu().float().numpy()
167167
return masks_np, iou_predictions_np, low_res_masks_np
168168

169169
@torch.no_grad()

0 commit comments

Comments
 (0)