Skip to content

Commit 097e2df

Browse files
committed
fix: fix preprocess for yolo obj det
1 parent 71e14c4 commit 097e2df

File tree

3 files changed

+28
-29
lines changed

3 files changed

+28
-29
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
🔍 使用在线体验找到适合你场景的模型组合
4141

4242
### 在线体验
43-
43+
[modelscope](https://www.modelscope.cn/studios/jockerK/RapidTableDetDemo)
4444
### 效果展示
4545

4646
![res_show.jpg](readme_resource/res_show.jpg)![res_show2.jpg](readme_resource/res_show2.jpg)
@@ -102,6 +102,7 @@ print(
102102
# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
103103
#
104104
# img = img_loader(img_path)
105+
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
105106
# file_name_with_ext = os.path.basename(img_path)
106107
# file_name, file_ext = os.path.splitext(file_name_with_ext)
107108
# out_dir = "rapid_table_det/outputs"

demo_onnx.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from rapid_table_det.inference import TableDetector
22

3-
img_path = f"images/weixin.png"
3+
img_path = f"images/WechatIMG149.jpeg"
44
table_det = TableDetector(
5-
obj_model_type="paddle_obj_det_s", edge_model_type="paddle_edge_det_s"
5+
edge_model_type="yolo_edge_det", obj_model_type="yolo_obj_det"
66
)
77

88
result, elapse = table_det(img_path)
@@ -11,23 +11,24 @@
1111
f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
1212
)
1313
# 输出可视化
14-
# import os
15-
# import cv2
16-
# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
17-
#
18-
# img = img_loader(img_path)
19-
# file_name_with_ext = os.path.basename(img_path)
20-
# file_name, file_ext = os.path.splitext(file_name_with_ext)
21-
# out_dir = "rapid_table_det/outputs"
22-
# if not os.path.exists(out_dir):
23-
# os.makedirs(out_dir)
24-
# extract_img = img.copy()
25-
# for i, res in enumerate(result):
26-
# box = res["box"]
27-
# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
28-
# # 带识别框和左上角方向位置
29-
# img = visuallize(img, box, lt, rt, rb, lb)
30-
# # 透视变换提取表格图片
31-
# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
32-
# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
33-
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
14+
import os
15+
import cv2
16+
from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
17+
18+
img = img_loader(img_path)
19+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
20+
file_name_with_ext = os.path.basename(img_path)
21+
file_name, file_ext = os.path.splitext(file_name_with_ext)
22+
out_dir = "rapid_table_det/outputs"
23+
if not os.path.exists(out_dir):
24+
os.makedirs(out_dir)
25+
extract_img = img.copy()
26+
for i, res in enumerate(result):
27+
box = res["box"]
28+
lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
29+
# 带识别框和左上角方向位置
30+
img = visuallize(img, box, lt, rt, rb, lb)
31+
# 透视变换提取表格图片
32+
wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
33+
cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
34+
cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)

rapid_table_det/predictor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,7 @@ def __call__(self, img, **kwargs):
9191
return result, time.time() - start
9292

9393
def img_preprocess(self, img, resize_shape=[928, 928]):
94-
# im, new_w, new_h, left, top = ResizePad(img, resize_shape[0])
95-
new_w, new_h = resize_shape
96-
left, top = 0, 0
97-
im = cv2.resize(img, resize_shape, cv2.INTER_LINEAR)
94+
im, new_w, new_h, left, top = ResizePad(img, resize_shape[0])
9895
im = im / 255.0
9996
im = im.transpose((2, 0, 1)).copy()
10097
im = im[None, :].astype("float32")
@@ -118,8 +115,8 @@ def img_postprocess(self, predict_maps, x_factor, y_factor, left, top, score):
118115
# 从当前行提取边界框坐标
119116
x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
120117
# 计算边界框的缩放坐标
121-
xmin = max(int((x - w / 2) * x_factor) - left, 0)
122-
ymin = max(int((y - h / 2) * y_factor) - top, 0)
118+
xmin = max(int((x - w / 2 - left) * x_factor), 0)
119+
ymin = max(int((y - h / 2 - top) * y_factor), 0)
123120
xmax = xmin + int(w * x_factor)
124121
ymax = ymin + int(h * y_factor)
125122
# 将类别ID、得分和框坐标添加到各自的列表中

0 commit comments

Comments
 (0)