Skip to content

Commit 2418edc

Browse files
[cherry-pick] check batchsize 1 for latex_rec (#3518)
1 parent 3466ebd commit 2418edc

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

paddlex/inference/models/formula_recognition/predictor.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
from ....utils import logging
1617
from ....utils.func_register import FuncRegister
1718
from ....modules.formula_recognition.model_list import MODELS
@@ -38,14 +39,31 @@
3839

3940

4041
class FormulaRecPredictor(BasicPredictor):
42+
"""FormulaRecPredictor that inherits from BasicPredictor."""
4143

4244
entities = MODELS
4345

4446
_FUNC_MAP = {}
4547
register = FuncRegister(_FUNC_MAP)
4648

4749
def __init__(self, *args, **kwargs):
50+
"""Initializes FormulaRecPredictor.
51+
Args:
52+
*args: Arbitrary positional arguments passed to the superclass.
53+
**kwargs: Arbitrary keyword arguments passed to the superclass.
54+
"""
4855
super().__init__(*args, **kwargs)
56+
57+
self.model_names_only_supports_batchsize_of_one = {
58+
"LaTeX_OCR_rec",
59+
}
60+
if self.model_name in self.model_names_only_supports_batchsize_of_one:
61+
logging.warning(
62+
f"Formula Recognition Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, "
63+
"if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, "
64+
f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}."
65+
)
66+
4967
self.pre_tfs, self.infer, self.post_op = self._build()
5068

5169
def _build_batch_sampler(self):
@@ -100,9 +118,25 @@ def process(self, batch_data):
100118
batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
101119
batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
102120

103-
x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
104-
batch_preds = self.infer(x=x)
105-
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
121+
if self.model_name in self.model_names_only_supports_batchsize_of_one:
122+
batch_preds = []
123+
max_length = 0
124+
for batch_img in batch_imgs:
125+
batch_pred_ = self.infer([batch_img])[0].reshape([-1])
126+
max_length = max(max_length, batch_pred_.shape[0])
127+
batch_preds.append(batch_pred_)
128+
for i in range(len(batch_preds)):
129+
batch_preds[i] = np.pad(
130+
batch_preds[i],
131+
(0, max_length - batch_preds[i].shape[0]),
132+
mode="constant",
133+
constant_values=0,
134+
)
135+
else:
136+
x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
137+
batch_preds = self.infer(x=x)
138+
batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
139+
106140
rec_formula = self.post_op(batch_preds)
107141
return {
108142
"input_path": batch_data.input_paths,

0 commit comments

Comments
 (0)