Skip to content

Commit 4c89a1c

Browse files
Bycobmergify[bot]
authored andcommitted
fix(torch): csvts forecast mode needs sequence of length backcast during predict
1 parent 5a5c8f1 commit 4c89a1c

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/backends/torch/torchinputconns.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -848,20 +848,21 @@ namespace dd
848848
{
849849
vecindex++;
850850
long int tstart = 0;
851-
if (static_cast<long int>(seq.size())
852-
< _backcast_timesteps + _forecast_timesteps)
851+
long int timesteps = _train ? _backcast_timesteps + _forecast_timesteps
852+
: _backcast_timesteps;
853+
if (static_cast<long int>(seq.size()) < timesteps)
853854
{
854855
discard_warn(vecindex, seq.size(), test);
855856
continue;
856857
}
857-
for (; tstart + _backcast_timesteps + _forecast_timesteps
858-
< static_cast<long int>(seq.size());
858+
for (; tstart + timesteps < static_cast<long int>(seq.size());
859859
tstart += _offset)
860-
add_data_instance_forecast(tstart, vecindex, dataset, seq);
860+
{
861+
add_data_instance_forecast(tstart, vecindex, dataset, seq);
862+
}
861863
if (tstart < static_cast<long int>(seq.size()) - 1)
862-
add_data_instance_forecast(seq.size() - _backcast_timesteps
863-
- _forecast_timesteps,
864-
vecindex, dataset, seq);
864+
add_data_instance_forecast(seq.size() - timesteps, vecindex, dataset,
865+
seq);
865866
}
866867
}
867868

0 commit comments

Comments
 (0)