|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import numpy as np |
15 | 16 | from ....utils import logging
|
16 | 17 | from ....utils.func_register import FuncRegister
|
17 | 18 | from ....modules.formula_recognition.model_list import MODELS
|
|
38 | 39 |
|
39 | 40 |
|
40 | 41 | class FormulaRecPredictor(BasicPredictor):
|
| 42 | + """FormulaRecPredictor that inherits from BasicPredictor.""" |
41 | 43 |
|
42 | 44 | entities = MODELS
|
43 | 45 |
|
44 | 46 | _FUNC_MAP = {}
|
45 | 47 | register = FuncRegister(_FUNC_MAP)
|
46 | 48 |
|
47 | 49 | def __init__(self, *args, **kwargs):
|
| 50 | + """Initializes FormulaRecPredictor. |
| 51 | + Args: |
| 52 | + *args: Arbitrary positional arguments passed to the superclass. |
| 53 | + **kwargs: Arbitrary keyword arguments passed to the superclass. |
| 54 | + """ |
48 | 55 | super().__init__(*args, **kwargs)
|
| 56 | + |
| 57 | + self.model_names_only_supports_batchsize_of_one = { |
| 58 | + "LaTeX_OCR_rec", |
| 59 | + } |
| 60 | + if self.model_name in self.model_names_only_supports_batchsize_of_one: |
| 61 | + logging.warning( |
| 62 | + f"Formula Recognition Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, " |
| 63 | + "if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, " |
| 64 | + f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}." |
| 65 | + ) |
| 66 | + |
49 | 67 | self.pre_tfs, self.infer, self.post_op = self._build()
|
50 | 68 |
|
51 | 69 | def _build_batch_sampler(self):
|
@@ -100,9 +118,25 @@ def process(self, batch_data):
|
100 | 118 | batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
|
101 | 119 | batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
|
102 | 120 |
|
103 |
| - x = self.pre_tfs["ToBatch"](imgs=batch_imgs) |
104 |
| - batch_preds = self.infer(x=x) |
105 |
| - batch_preds = [p.reshape([-1]) for p in batch_preds[0]] |
| 121 | + if self.model_name in self.model_names_only_supports_batchsize_of_one: |
| 122 | + batch_preds = [] |
| 123 | + max_length = 0 |
| 124 | + for batch_img in batch_imgs: |
| 125 | + batch_pred_ = self.infer([batch_img])[0].reshape([-1]) |
| 126 | + max_length = max(max_length, batch_pred_.shape[0]) |
| 127 | + batch_preds.append(batch_pred_) |
| 128 | + for i in range(len(batch_preds)): |
| 129 | + batch_preds[i] = np.pad( |
| 130 | + batch_preds[i], |
| 131 | + (0, max_length - batch_preds[i].shape[0]), |
| 132 | + mode="constant", |
| 133 | + constant_values=0, |
| 134 | + ) |
| 135 | + else: |
| 136 | + x = self.pre_tfs["ToBatch"](imgs=batch_imgs) |
| 137 | + batch_preds = self.infer(x=x) |
| 138 | + batch_preds = [p.reshape([-1]) for p in batch_preds[0]] |
| 139 | + |
106 | 140 | rec_formula = self.post_op(batch_preds)
|
107 | 141 | return {
|
108 | 142 | "input_path": batch_data.input_paths,
|
|
0 commit comments