Skip to content

Commit 921240e

Browse files
committed
update doc
1 parent 308f19d commit 921240e

File tree

1 file changed

+123
-0
lines changed

1 file changed

+123
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# 多模型组合
2+
3+
PaddleX 提供了十分丰富的模型,以及针对不同任务的模型产线,同时PaddleX也支持用户多模型组合使用,以解决复杂、特定的任务。
4+
5+
## 一、任务分析
6+
7+
下面通过组合使用PaddleX提供的版面分析模型、表格识别模型和OCR产线,解决表格识别任务,具体来说,该任务分为以下几个步骤:
8+
9+
1. 使用版面分析模型,检测出文档图片中的表格区域位置;
10+
2. 使用OpenCV等操作,裁剪出文档图片中表格区域图片;
11+
3. 使用表格识别模型,对表格图片进行识别,得到表格结构的html表示;
12+
4. 使用OCR产线,对表格图片进行识别,得到表格区域的文字;
13+
14+
## 二、示例代码
15+
16+
根据任务分析,即可使用PaddleX进行开发,完整示例代码如下:
17+
18+
```python
19+
import cv2
20+
from paddlex import create_model, create_pipeline
21+
22+
23+
class TableRec:
24+
25+
def __init__(self):
26+
self.layout_model = create_model("PicoDet_layout_1x")
27+
self.table_model = create_model("SLANet_plus")
28+
self.ocr_pipeline = create_pipeline("OCR")
29+
30+
def crop_table(self, layout_res):
31+
img_path = layout_res["input_path"]
32+
img = cv2.imread(img_path)
33+
34+
table_img_list = []
35+
for box in layout_res["boxes"]:
36+
if box["label"] != "Table":
37+
continue
38+
xmin, ymin, xmax, ymax = [int(i) for i in box["coordinate"]]
39+
table_img = img[ymin:ymax, xmin:xmax]
40+
table_img_list.append({"input": table_img})
41+
return table_img_list
42+
43+
def predict(self, data):
44+
for layout_res in self.layout_model(data):
45+
final_res = {}
46+
table_img_list = self.crop_table(layout_res)
47+
table_res = list(self.table_model(table_img_list))
48+
ocr_res = list(self.ocr_pipeline(table_img_list))
49+
final_res["structure"] = table_res["structure"]
50+
final_res["ocr_box"] = ocr_res["dt_polys"]
51+
final_res["rec_text"] = ocr_res["rec_text"]
52+
yield final_res
53+
54+
55+
if __name__ == "__main__":
56+
solution = TableRec()
57+
output = solution.predict("https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg")
58+
for res in output:
59+
print(res)
60+
```
61+
62+
## 三、代码详解
63+
64+
接下来对以上代码进行详解:
65+
66+
1. 实例化模型和产线
67+
68+
```python
69+
self.layout_model = create_model("PicoDet_layout_1x") # 实例化版面分析模型
70+
self.table_model = create_model("SLANet_plus") # 实例化表格识别模型
71+
self.ocr_pipeline = create_pipeline("OCR") # 实例化OCR模型产线
72+
```
73+
74+
PaddleX提供了`create_model(model)`函数用于实例化模型,只需通过参数`model`指定模型名称即可使用PaddleX提供的官方预训练模型,或是通过参数`model`指定本地模型路径,即可使用训练好的本地模型。关于PaddleX支持的全部官方预训练模型请参考[模型列表](../../support_list/models_list.md),关于`create_model`方法的更详细使用方式,请参考文档[模型使用 Python API](./model_python_API.md)
75+
76+
PaddleX提供了`create_pipeline(pipeline)`函数用于实例化产线,只需通过参数`pipeline`指定产线名称即可使用PaddleX提供的模型产线,或是通过参数`pipeline`指定本地产线配置文件(`*.yaml`),即可使用自定义的模型产线。关于PaddleX内置的模型产线可以查看[产线列表](../../support_list/pipelines_list.md),关于`create_pipeline`方法的更详细使用方式,请参考文档[产线使用 Python API](../../pipeline_usage/instructions/pipeline_python_API.md)
77+
78+
2. 调用版面分析模型
79+
80+
```python
81+
for layout_res in self.layout_model(data):
82+
pass
83+
```
84+
85+
PaddleX推理预测功能中的模型类(`Predictor`)和产线类(`Pipeline`)均实现了`__call__(input)``predict(input)`方法,支持通过参数`input`传入待预测数据,同时上述两种方法均基于`yield`实现,因此需要作为`generator`调用。关于模型预测和产线预测的详细说明请参考[模型使用 Python API](./model_python_API.md)[产线使用 Python API](../../pipeline_usage/instructions/pipeline_python_API.md)
86+
87+
对于版面分析模型`self.layout_model`,首先传入待预测数据`data`,并通过`for-in`的方式得到每张图片的版面分析预测结果`layout_res`
88+
89+
3. 处理版面分析预测结果
90+
91+
```python
92+
table_img_list = self.crop_table(layout_res)
93+
94+
95+
def crop_table(self, layout_res):
96+
img_path = layout_res["input_path"]
97+
img = cv2.imread(img_path)
98+
99+
table_img_list = []
100+
for box in layout_res["boxes"]:
101+
if box["label"] != "Table":
102+
continue
103+
xmin, ymin, xmax, ymax = [int(i) for i in box["coordinate"]]
104+
table_img = img[ymin:ymax, xmin:xmax]
105+
table_img_list.append({"input": table_img})
106+
return table_img_list
107+
```
108+
109+
在得到版面分析预测结果`layout_res`后,需要对其进行处理,并构造符合后续模型预测输入格式的数据。
110+
111+
首先读取原始图像`img`,然后依据预测结果中每个目标(`box`)的类别信息(`box["label"]`)和位置坐标信息(`box["coordinate"]`),从原始图像裁剪(`img[ymin:ymax, xmin:xmax]`)得到表格子图(`table_img`),表格子图就是后续表格识别模型和OCR模型的待预测数据,因此需要将待预测数据按要求进行整理,具体来说需要整理为`{"input": table_img}`形式的字典,其中字典的key必须为`input`,因为PaddleX中模型或产线的预测输入函数的参数即为`input`,而对应的value即为待预测数据,如果需要预测一批数据,则应为上述字典组成的list(`table_img_list`)。
112+
113+
4. 调用表格识别模型与OCR模型
114+
115+
```python
116+
table_res = list(self.table_model(table_img_list))
117+
ocr_res = list(self.ocr_pipeline(table_img_list))
118+
final_res["structure"] = table_res["structure"]
119+
final_res["ocr_box"] = ocr_res["dt_polys"]
120+
final_res["rec_text"] = ocr_res["rec_text"]
121+
```
122+
123+
得到待预测的表格图像(`table_img_list`)后,即可进行表格结构识别和OCR识别,得到所需预测结果。

0 commit comments

Comments
 (0)