Skip to content

Commit 9f4381b

Browse files
authored
[cherry-pick] Refine table pipes for layout (#3649)
1 parent 6f7bf27 commit 9f4381b

File tree

3 files changed

+85
-38
lines changed

3 files changed

+85
-38
lines changed

paddlex/configs/pipelines/PP-StructureV3.yaml

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ SubModules:
1313
model_name: PP-DocLayout-L
1414
model_dir: null
1515
threshold:
16-
7: 0.3
16+
0: 0.3 # paragraph_title
17+
7: 0.3 # formula
18+
16: 0.3 # seal
1719
layout_nms: True
1820
layout_unclip_ratio: 1.0
1921
layout_merge_bboxes_mode:
@@ -94,6 +96,33 @@ SubPipelines:
9496
module_name: table_cells_detection
9597
model_name: RT-DETR-L_wireless_table_cell_det
9698
model_dir: null
99+
SubPipelines:
100+
GeneralOCR:
101+
pipeline_name: OCR
102+
text_type: general
103+
use_doc_preprocessor: False
104+
use_textline_orientation: True
105+
SubModules:
106+
TextDetection:
107+
module_name: text_detection
108+
model_name: PP-OCRv4_server_det
109+
model_dir: null
110+
limit_side_len: 1200
111+
limit_type: max
112+
thresh: 0.3
113+
box_thresh: 0.4
114+
unclip_ratio: 2.0
115+
TextLineOrientation:
116+
module_name: textline_orientation
117+
model_name: PP-LCNet_x0_25_textline_ori
118+
model_dir: null
119+
batch_size: 1
120+
TextRecognition:
121+
module_name: text_recognition
122+
model_name: PP-OCRv4_server_rec_doc
123+
model_dir: null
124+
batch_size: 6
125+
score_thresh: 0.0
97126

98127
SealRecognition:
99128
pipeline_name: seal_recognition

paddlex/inference/pipelines/table_recognition/pipeline.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ def __init__(
8888
{"pipeline_config_error": "config error for general_ocr_pipeline!"},
8989
)
9090
self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
91+
else:
92+
self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
93+
"GeneralOCR",
94+
None
95+
)
9196

9297
self._crop_by_boxes = CropByBoxes()
9398

@@ -217,6 +222,33 @@ def predict_doc_preprocessor_res(
217222
doc_preprocessor_res = {}
218223
doc_preprocessor_image = image_array
219224
return doc_preprocessor_res, doc_preprocessor_image
225+
226+
def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
227+
"""
228+
Splits OCR bounding boxes by table cells and retrieves text.
229+
230+
Args:
231+
ori_img (ndarray): The original image from which text regions will be extracted.
232+
cells_bboxes (list or ndarray): Detected cell bounding boxes to extract text from.
233+
234+
Returns:
235+
list: A list containing the recognized texts from each cell.
236+
"""
237+
238+
# Check if cells_bboxes is a list and convert it if not.
239+
if not isinstance(cells_bboxes, list):
240+
cells_bboxes = cells_bboxes.tolist()
241+
texts_list = [] # Initialize a list to store the recognized texts.
242+
# Process each bounding box provided in cells_bboxes.
243+
for i in range(len(cells_bboxes)):
244+
# Extract and round up the coordinates of the bounding box.
245+
x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
246+
# Perform OCR on the defined region of the image and get the recognized text.
247+
rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
248+
# Concatenate the texts and append them to the texts_list.
249+
texts_list.append(''.join(rec_te["rec_texts"]))
250+
# Return the list of recognized texts from each cell.
251+
return texts_list
220252

221253
def split_ocr_bboxes_by_table_cells(self, ori_img, cells_bboxes):
222254
"""
@@ -270,15 +302,9 @@ def predict_single_table_recognition_res(
270302
"""
271303
table_structure_pred = next(self.table_structure_model(image_array))
272304
if use_table_cells_ocr_results == True:
273-
table_cells_result = list(
274-
map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
275-
)
276-
table_cells_result = [
277-
[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result
278-
]
279-
cells_texts_list = self.split_ocr_bboxes_by_table_cells(
280-
image_array, table_cells_result
281-
)
305+
table_cells_result = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"]))
306+
table_cells_result = [[rect[0], rect[1], rect[4], rect[5]] for rect in table_cells_result]
307+
cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
282308
else:
283309
cells_texts_list = []
284310
single_table_recognition_res = get_table_recognition_res(
@@ -381,6 +407,9 @@ def predict(
381407
text_rec_score_thresh=text_rec_score_thresh,
382408
)
383409
)
410+
elif use_table_cells_ocr_results == True:
411+
assert self.general_ocr_config_bak != None
412+
self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)
384413

385414
table_res_list = []
386415
table_region_id = 1

paddlex/inference/pipelines/table_recognition/pipeline_v2.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def __init__(
128128
{"pipeline_config_error": "config error for general_ocr_pipeline!"},
129129
)
130130
self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
131+
else:
132+
self.general_ocr_config_bak = config.get("SubPipelines", {}).get(
133+
"GeneralOCR",
134+
None
135+
)
131136

132137
self._crop_by_boxes = CropByBoxes()
133138

@@ -595,35 +600,23 @@ def predict_single_table_recognition_res(
595600
use_e2e_model = True
596601
else:
597602
table_cells_pred = next(
598-
self.wireless_table_cells_detection_model(
599-
image_array, threshold=0.3
600-
)
603+
self.wireless_table_cells_detection_model(image_array, threshold=0.3)
601604
) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
602605
# If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
603606

604607
if use_e2e_model == False:
605-
table_structure_result = self.extract_results(
606-
table_structure_pred, "table_stru"
607-
)
608-
table_cells_result, table_cells_score = self.extract_results(
609-
table_cells_pred, "det"
610-
)
611-
table_cells_result, table_cells_score = self.cells_det_results_nms(
612-
table_cells_result, table_cells_score
613-
)
614-
ocr_det_boxes = self.get_region_ocr_det_boxes(
615-
overall_ocr_res["rec_boxes"].tolist(), table_box
616-
)
608+
table_structure_result = self.extract_results(table_structure_pred, "table_stru")
609+
table_cells_result, table_cells_score = self.extract_results(table_cells_pred, "det")
610+
table_cells_result, table_cells_score = self.cells_det_results_nms(table_cells_result, table_cells_score)
611+
ocr_det_boxes = self.get_region_ocr_det_boxes(overall_ocr_res["rec_boxes"].tolist(), table_box)
617612
table_cells_result = self.cells_det_results_reprocessing(
618613
table_cells_result,
619614
table_cells_score,
620615
ocr_det_boxes,
621616
len(table_structure_pred["bbox"]),
622617
)
623618
if use_table_cells_ocr_results == True:
624-
cells_texts_list = self.split_ocr_bboxes_by_table_cells(
625-
image_array, table_cells_result
626-
)
619+
cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result)
627620
else:
628621
cells_texts_list = []
629622
single_table_recognition_res = get_table_recognition_res(
@@ -636,16 +629,9 @@ def predict_single_table_recognition_res(
636629
)
637630
else:
638631
if use_table_cells_ocr_results == True:
639-
table_cells_result_e2e = list(
640-
map(lambda arr: arr.tolist(), table_structure_pred["bbox"])
641-
)
642-
table_cells_result_e2e = [
643-
[rect[0], rect[1], rect[4], rect[5]]
644-
for rect in table_cells_result_e2e
645-
]
646-
cells_texts_list = self.split_ocr_bboxes_by_table_cells(
647-
image_array, table_cells_result_e2e
648-
)
632+
table_cells_result_e2e = list(map(lambda arr: arr.tolist(), table_structure_pred["bbox"]))
633+
table_cells_result_e2e = [[rect[0], rect[1], rect[4], rect[5]]for rect in table_cells_result_e2e]
634+
cells_texts_list = self.split_ocr_bboxes_by_table_cells(image_array, table_cells_result_e2e)
649635
else:
650636
cells_texts_list = []
651637
single_table_recognition_res = get_table_recognition_res_e2e(
@@ -749,6 +735,9 @@ def predict(
749735
text_rec_score_thresh=text_rec_score_thresh,
750736
)
751737
)
738+
elif use_table_cells_ocr_results == True:
739+
assert self.general_ocr_config_bak != None
740+
self.general_ocr_pipeline = self.create_pipeline(self.general_ocr_config_bak)
752741

753742
table_res_list = []
754743
table_region_id = 1

0 commit comments

Comments
 (0)