Skip to content

Commit d8ce81d

Browse files
authored
Rc copy ts input (#3540)
* cherry-pick layout numclass * add ts show (#3532) * [cp] deepcopy fix ts show
1 parent d174bca commit d8ce81d

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

paddlex/inference/models/ts_classification/predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import pandas as pd
1818
import os
19+
import copy
1920

2021
from ....modules.ts_classification.model_list import MODELS
2122
from ...common.batch_sampler import TSBatchSampler
@@ -109,6 +110,7 @@ def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
109110
Dict[str, Any]: A dictionary containing the paths to the input data, the raw input time series, and the classification results.
110111
"""
111112
batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
113+
batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
112114

113115
if "TSNormalize" in self.preprocessors:
114116
batch_ts = self.preprocessors["TSNormalize"](ts_list=batch_raw_ts)
@@ -127,7 +129,7 @@ def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
127129
return {
128130
"input_path": batch_data,
129131
"input_ts": batch_raw_ts,
130-
"input_ts_data": batch_raw_ts,
132+
"input_ts_data": batch_raw_ts_ori,
131133
"classification": batch_ts_preds,
132134
"target_cols": [self.config["info_params"]["target_cols"]]
133135
}

paddlex/inference/models/ts_forecasting/predictor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import pandas as pd
1818
import os
19+
import copy
1920

2021
from ....modules.ts_forecast.model_list import MODELS
2122
from ...common.batch_sampler import TSBatchSampler
@@ -122,6 +123,7 @@ def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
122123
"""
123124

124125
batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
126+
batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
125127
batch_cutoff_ts = self.preprocessors["TSCutOff"](ts_list=batch_raw_ts)
126128

127129
if "TSNormalize" in self.preprocessors:
@@ -152,6 +154,6 @@ def process(self, batch_data: List[Union[str, pd.DataFrame]]) -> Dict[str, Any]:
152154
return {
153155
"input_path": batch_data,
154156
"input_ts": batch_raw_ts,
155-
"cutoff_ts": batch_raw_ts,
157+
"cutoff_ts": batch_raw_ts_ori,
156158
"forecast": batch_ts_preds,
157159
}

0 commit comments

Comments
 (0)