-
Notifications
You must be signed in to change notification settings - Fork 229
Description
There is a bug in Line 296-312
void DataHandler::LoadChunk(DataIterator& it, Matrix& mat, vector& random_rows) {
float* data_ptr = mat.GetHostData();
int num_dims = it.GetDims();
int num_rand = (chunk_size_ + random_access_chunk_size_ - 1) / random_access_chunk_size_;
int row, end;
for (int i = 0; i < num_rand; i++) {
row = random_rows[i];
end = (row + random_access_chunk_size_) % dataset_size_;
if (end < row) {
it.Get(data_ptr, row, dataset_size_);
it.Get(data_ptr + num_dims * (dataset_size_ - row), 0, end);
} else {
it.Get(data_ptr, row, end);
}
data_ptr += num_dims * random_access_chunk_size_;
}
}
One possible way to fix it could be :
void DataHandler::LoadChunk(DataIterator& it, Matrix& mat, vector& random_rows) {
float* data_ptr = mat.GetHostData();
int num_dims = it.GetDims();
int num_rand = (chunk_size_ + random_access_chunk_size_ - 1) / random_access_chunk_size_;
int row, end;
for (int i = 0; i < num_rand; i++) {
row = random_rows[i];
end = (row + random_access_chunk_size_) % dataset_size_;
if (end < row) {
it.Get(data_ptr, row, dataset_size_);
int remain_size = random_access_chunk_size_ - (dataset_size_ - row);
if (remain_size > 0)
{
it.Get(data_ptr + num_dims * (dataset_size_ - row), 0, remain_size);
}
} else {
it.Get(data_ptr, row, end);
}
data_ptr += num_dims * random_access_chunk_size_;
}
}