|
4 | 4 |
|
5 | 5 | #include "dispatch.h"
|
6 | 6 |
|
7 |
| -#define PRINT_MAX_VALUES 6 |
8 | 7 |
|
9 | 8 | namespace ctranslate2 {
|
10 | 9 |
|
@@ -440,44 +439,79 @@ namespace ctranslate2 {
|
440 | 439 | return os;
|
441 | 440 | }
|
442 | 441 |
|
443 |
| - std::ostream& operator<<(std::ostream& os, const StorageView& storage) { |
444 |
| - StorageView printable(storage.dtype()); |
445 |
| - printable.copy_from(storage); |
446 |
| - TYPE_DISPATCH( |
447 |
| - printable.dtype(), |
448 |
| - const auto* values = printable.data<T>(); |
449 |
| - if (printable.size() <= PRINT_MAX_VALUES) { |
450 |
| - for (dim_t i = 0; i < printable.size(); ++i) { |
451 |
| - os << ' '; |
452 |
| - print_value(os, values[i]); |
453 |
| - } |
454 |
| - } |
| 442 | + std::ostream& operator<<(std::ostream& os, const StorageView& storage) { |
| 443 | + // Create a printable copy of the storage |
| 444 | + StorageView printable(storage.dtype()); |
| 445 | + printable.copy_from(storage); |
| 446 | + |
| 447 | + // Check the data type and print accordingly |
| 448 | + TYPE_DISPATCH( |
| 449 | + printable.dtype(), |
| 450 | + const auto* values = printable.data<T>(); |
| 451 | + const auto& shape = printable.shape(); |
| 452 | + |
| 453 | + // Print tensor contents based on dimensionality |
| 454 | + if (shape.empty()) { // Scalar case |
| 455 | + os << "Data (Scalar): " << values[0] << std::endl; |
| 456 | + } else if (shape.size() == 1) { // Vector case |
| 457 | + os << "Data (1D Vector):" << std::endl; |
| 458 | + os << "["; |
| 459 | + for (dim_t i = 0; i < printable.size(); ++i) { |
| 460 | + if (i > 0) os << ", "; |
| 461 | + if (i < PRINT_MAX_VALUES / 2 || i >= printable.size() - PRINT_MAX_VALUES / 2) { |
| 462 | + os << values[i]; |
| 463 | + } else if (i == PRINT_MAX_VALUES / 2) { |
| 464 | + os << "..."; |
| 465 | + i = printable.size() - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part |
| 466 | + } |
| 467 | + } |
| 468 | + os << "]\n"; |
| 469 | + } else if (shape.size() == 2) { // 2D Matrix case |
| 470 | + os << "Data (2D Matrix):" << std::endl; |
| 471 | + os << "[\n"; |
| 472 | + for (dim_t i = 0; i < shape[0]; ++i) { |
| 473 | + if (i > 0) os << ",\n"; |
| 474 | + if (i < PRINT_MAX_VALUES / 2 || i >= shape[0] - PRINT_MAX_VALUES / 2) { |
| 475 | + os << " ["; |
| 476 | + for (dim_t j = 0; j < shape[1]; ++j) { |
| 477 | + if (j > 0) os << ", "; |
| 478 | + if (j < PRINT_MAX_VALUES / 2 || j >= shape[1] - PRINT_MAX_VALUES / 2) { |
| 479 | + os << values[i * shape[1] + j]; |
| 480 | + } else if (j == PRINT_MAX_VALUES / 2) { |
| 481 | + os << "..."; |
| 482 | + j = shape[1] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part |
| 483 | + } |
| 484 | + } |
| 485 | + os << "]"; |
| 486 | + } else if (i == PRINT_MAX_VALUES / 2) { |
| 487 | + os << " ..."; |
| 488 | + i = shape[0] - PRINT_MAX_VALUES / 2 - 1; // Skip to the last part |
| 489 | + } |
| 490 | + } |
| 491 | + os << "\n]\n"; |
| 492 | + } else { // Higher-dimensional tensors |
| 493 | + os << "Data (" << shape.size() << "D Tensor):" << std::endl; |
| 494 | + storage.print_tensor(os, values, shape, 0, 0, 0); |
| 495 | + os << std::endl; |
| 496 | + } |
| 497 | + ); |
| 498 | + |
| 499 | + // Print metadata |
| 500 | + os << "[device:" << device_to_str(storage.device(), storage.device_index()) |
| 501 | + << ", dtype:" << dtype_name(storage.dtype()) << ", storage viewed as "; |
| 502 | + if (storage.is_scalar()) |
| 503 | + os << "scalar"; |
455 | 504 | else {
|
456 |
| - for (dim_t i = 0; i < PRINT_MAX_VALUES / 2; ++i) { |
457 |
| - os << ' '; |
458 |
| - print_value(os, values[i]); |
| 505 | + for (dim_t i = 0; i < storage.rank(); ++i) { |
| 506 | + if (i > 0) |
| 507 | + os << 'x'; |
| 508 | + os << storage.dim(i); |
459 | 509 | }
|
460 |
| - os << " ..."; |
461 |
| - for (dim_t i = printable.size() - (PRINT_MAX_VALUES / 2); i < printable.size(); ++i) { |
462 |
| - os << ' '; |
463 |
| - print_value(os, values[i]); |
464 |
| - } |
465 |
| - } |
466 |
| - os << std::endl); |
467 |
| - os << "[" << device_to_str(storage.device(), storage.device_index()) |
468 |
| - << " " << dtype_name(storage.dtype()) << " storage viewed as "; |
469 |
| - if (storage.is_scalar()) |
470 |
| - os << "scalar"; |
471 |
| - else { |
472 |
| - for (dim_t i = 0; i < storage.rank(); ++i) { |
473 |
| - if (i > 0) |
474 |
| - os << 'x'; |
475 |
| - os << storage.dim(i); |
476 | 510 | }
|
| 511 | + os << ']'; |
| 512 | + return os; |
477 | 513 | }
|
478 |
| - os << ']'; |
479 |
| - return os; |
480 |
| - } |
| 514 | + |
481 | 515 |
|
482 | 516 | #define DECLARE_IMPL(T) \
|
483 | 517 | template \
|
|
0 commit comments