Skip to content

Commit b3c01d5

Browse files
committed
[*] improve log message for storage view content
1 parent 5eb5d5a commit b3c01d5

File tree

2 files changed

+101
-35
lines changed

2 files changed

+101
-35
lines changed

include/ctranslate2/storage_view.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "types.h"
88
#include "utils.h"
99

10+
#define PRINT_MAX_VALUES 6
11+
1012
namespace ctranslate2 {
1113

1214
#define ASSERT_DTYPE(DTYPE) \
@@ -238,6 +240,36 @@ namespace ctranslate2 {
238240

239241
friend std::ostream& operator<<(std::ostream& os, const StorageView& storage);
240242

243+
template <typename T>
244+
void print_tensor(std::ostream& os, const T* data, const std::vector<dim_t>& shape, size_t dim, size_t offset, int indent) const {
245+
std::string indentation(indent, ' ');
246+
if (dim == shape.size() - 1) {
247+
os << indentation << "[";
248+
for (dim_t i = 0; i < shape[dim]; ++i) {
249+
if (i > 0) os << ", ";
250+
if (i < PRINT_MAX_VALUES / 2 || i >= shape[dim] - PRINT_MAX_VALUES / 2) {
251+
os << data[offset + i];
252+
} else if (i == PRINT_MAX_VALUES / 2) {
253+
os << "...";
254+
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
255+
}
256+
}
257+
os << "]";
258+
} else {
259+
os << indentation << "[\n";
260+
for (dim_t i = 0; i < shape[dim]; ++i) {
261+
if (i > 0) os << ",\n";
262+
if (i < PRINT_MAX_VALUES / 2 || i >= shape[dim] - PRINT_MAX_VALUES / 2) {
263+
print_tensor(os, data, shape, dim + 1, offset + i * shape[dim + 1], indent + 2);
264+
} else if (i == PRINT_MAX_VALUES / 2) {
265+
os << indentation << " ...";
266+
i = shape[dim] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
267+
}
268+
}
269+
os << "\n" << indentation << "]";
270+
}
271+
}
272+
241273
protected:
242274
DataType _dtype = DataType::FLOAT32;
243275
Device _device = Device::CPU;

src/storage_view.cc

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include "dispatch.h"
66

7-
#define PRINT_MAX_VALUES 6
87

98
namespace ctranslate2 {
109

@@ -440,44 +439,79 @@ namespace ctranslate2 {
440439
return os;
441440
}
442441

443-
std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
444-
StorageView printable(storage.dtype());
445-
printable.copy_from(storage);
446-
TYPE_DISPATCH(
447-
printable.dtype(),
448-
const auto* values = printable.data<T>();
449-
if (printable.size() <= PRINT_MAX_VALUES) {
450-
for (dim_t i = 0; i < printable.size(); ++i) {
451-
os << ' ';
452-
print_value(os, values[i]);
453-
}
454-
}
442+
std::ostream& operator<<(std::ostream& os, const StorageView& storage) {
443+
// Create a printable copy of the storage
444+
StorageView printable(storage.dtype());
445+
printable.copy_from(storage);
446+
447+
// Check the data type and print accordingly
448+
TYPE_DISPATCH(
449+
printable.dtype(),
450+
const auto* values = printable.data<T>();
451+
const auto& shape = printable.shape();
452+
453+
// Print tensor contents based on dimensionality
454+
if (shape.empty()) { // Scalar case
455+
os << "Data (Scalar): " << values[0] << std::endl;
456+
} else if (shape.size() == 1) { // Vector case
457+
os << "Data (1D Vector):" << std::endl;
458+
os << "[";
459+
for (dim_t i = 0; i < printable.size(); ++i) {
460+
if (i > 0) os << ", ";
461+
if (i < PRINT_MAX_VALUES / 2 || i >= printable.size() - PRINT_MAX_VALUES / 2) {
462+
os << values[i];
463+
} else if (i == PRINT_MAX_VALUES / 2) {
464+
os << "...";
465+
i = printable.size() - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
466+
}
467+
}
468+
os << "]\n";
469+
} else if (shape.size() == 2) { // 2D Matrix case
470+
os << "Data (2D Matrix):" << std::endl;
471+
os << "[\n";
472+
for (dim_t i = 0; i < shape[0]; ++i) {
473+
if (i > 0) os << ",\n";
474+
if (i < PRINT_MAX_VALUES / 2 || i >= shape[0] - PRINT_MAX_VALUES / 2) {
475+
os << " [";
476+
for (dim_t j = 0; j < shape[1]; ++j) {
477+
if (j > 0) os << ", ";
478+
if (j < PRINT_MAX_VALUES / 2 || j >= shape[1] - PRINT_MAX_VALUES / 2) {
479+
os << values[i * shape[1] + j];
480+
} else if (j == PRINT_MAX_VALUES / 2) {
481+
os << "...";
482+
j = shape[1] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
483+
}
484+
}
485+
os << "]";
486+
} else if (i == PRINT_MAX_VALUES / 2) {
487+
os << " ...";
488+
i = shape[0] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part
489+
}
490+
}
491+
os << "\n]\n";
492+
} else { // Higher-dimensional tensors
493+
os << "Data (" << shape.size() << "D Tensor):" << std::endl;
494+
storage.print_tensor(os, values, shape, 0, 0, 0);
495+
os << std::endl;
496+
}
497+
);
498+
499+
// Print metadata
500+
os << "[device:" << device_to_str(storage.device(), storage.device_index())
501+
<< ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as ";
502+
if (storage.is_scalar())
503+
os << "scalar";
455504
else {
456-
for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) {
457-
os << ' ';
458-
print_value(os, values[i]);
505+
for (dim_t i = 0; i < storage.rank(); ++i) {
506+
if (i > 0)
507+
os << 'x';
508+
os << storage.dim(i);
459509
}
460-
os << " ...";
461-
for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) {
462-
os << ' ';
463-
print_value(os, values[i]);
464-
}
465-
}
466-
os << std::endl);
467-
os << "[" << device_to_str(storage.device(), storage.device_index())
468-
<< " " << dtype_name(storage.dtype()) << " storage viewed as ";
469-
if (storage.is_scalar())
470-
os << "scalar";
471-
else {
472-
for (dim_t i = 0; i < storage.rank(); ++i) {
473-
if (i > 0)
474-
os << 'x';
475-
os << storage.dim(i);
476510
}
511+
os << ']';
512+
return os;
477513
}
478-
os << ']';
479-
return os;
480-
}
514+
481515

482516
#define DECLARE_IMPL(T) \
483517
template \

0 commit comments

Comments
 (0)