Skip to content

Commit 72263c8

Browse files
committed
[*] improve log message for storage view content
1 parent 5eb5d5a commit 72263c8

File tree

2 files changed

+68
-32
lines changed

2 files changed

+68
-32
lines changed

include/ctranslate2/storage_view.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ namespace ctranslate2 {
238238

239239
friend std::ostream& operator<<(std::ostream& os, const StorageView& storage);
240240

241+
template <typename T>
242+
void print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const;
243+
241244
protected:
242245
DataType _dtype = DataType::FLOAT32;
243246
Device _device = Device::CPU;

src/storage_view.cc

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -440,44 +440,77 @@ namespace ctranslate2 {
440440
return os;
441441
}
442442

443-
std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
444-
StorageView printable(storage.dtype());
445-
printable.copy_from(storage);
446-
TYPE_DISPATCH(
447-
printable.dtype(),
448-
const auto* values = printable.data<T>();
449-
if (printable.size() <= PRINT_MAX_VALUES) {
450-
for (dim_t i = 0; i < printable.size(); ++i) {
451-
os << ' ';
452-
print_value(os, values[i]);
443+
std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
444+
// Create a printable copy of the storage
445+
StorageView printable(storage.dtype());
446+
printable.copy_from(storage);
447+
448+
// Check the data type and print accordingly
449+
TYPE_DISPATCH(
450+
printable.dtype(),
451+
const auto* values = printable.data<T>();
452+
const auto& shape = printable.shape();
453+
454+
// Print tensor contents based on dimensionality
455+
if (shape.empty()) { // Scalar case
456+
os << "Data (Scalar): " << values[0] << std::endl;
457+
} else {
458+
os << "Data (" << shape.size() << "D ";
459+
if (shape.size() == 1)
460+
os << "Vector";
461+
else if (shape.size() == 2)
462+
os << "Matrix";
463+
else
464+
os << "Tensor";
465+
os << "):" << std::endl;
466+
printable.print_tensor(os, values, shape, 0, 0, 0);
467+
os << std::endl;
468+
}
469+
);
470+
471+
// Print metadata
472+
os << "[device:" << device_to_str(storage.device(), storage.device_index())
473+
<< ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as ";
474+
if (storage.is_scalar())
475+
os << "scalar";
476+
else {
477+
for (dim_t i = 0; i < storage.rank(); ++i) {
478+
if (i > 0)
479+
os << 'x';
480+
os << storage.dim(i);
453481
}
454482
}
455-
else {
456-
for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) {
457-
os << ' ';
458-
print_value(os, values[i]);
483+
os << ']';
484+
return os;
485+
}
486+
487+
template <typename T>
488+
void StorageView::print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const {
489+
std::string indentation(indent, ' ');
490+
491+
os << indentation << "[";
492+
bool is_last_dim = (dim == shape.size() - 1);
493+
for (dim_t i = 0; i < shape[dim]; ++i) {
494+
if (i > 0) {
495+
os << ", ";
496+
if (!is_last_dim) {
497+
os << "\n" << std::string(indent, ' ');
498+
}
459499
}
460-
os << " ...";
461-
for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) {
462-
os << ' ';
463-
print_value(os, values[i]);
500+
501+
if (i == PRINT_MAX_VALUES / 2) {
502+
os << "...";
503+
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
504+
} else {
505+
if (is_last_dim) {
506+
os << data[offset + i];
507+
} else {
508+
print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent);
509+
}
464510
}
465511
}
466-
os << std::endl);
467-
os << "[" << device_to_str(storage.device(), storage.device_index())
468-
<< " " << dtype_name(storage.dtype()) << " storage viewed as ";
469-
if (storage.is_scalar())
470-
os << "scalar";
471-
else {
472-
for (dim_t i = 0; i < storage.rank(); ++i) {
473-
if (i > 0)
474-
os << 'x';
475-
os << storage.dim(i);
476-
}
512+
os << "]";
477513
}
478-
os << ']';
479-
return os;
480-
}
481514

482515
#define DECLARE_IMPL(T) \
483516
template \

0 commit comments

Comments
 (0)