@@ -440,44 +440,77 @@ namespace ctranslate2 {
440
440
return os;
441
441
}
442
442
443
- std::ostream& operator <<(std::ostream& os, const StorageView& storage) {
444
- StorageView printable (storage.dtype ());
445
- printable.copy_from (storage);
446
- TYPE_DISPATCH (
447
- printable.dtype (),
448
- const auto * values = printable.data <T>();
449
- if (printable.size () <= PRINT_MAX_VALUES) {
450
- for (dim_t i = 0 ; i < printable.size (); ++i) {
451
- os << ' ' ;
452
- print_value (os, values[i]);
443
+ std::ostream& operator <<(std::ostream& os, const StorageView& storage) {
444
+ // Create a printable copy of the storage
445
+ StorageView printable (storage.dtype ());
446
+ printable.copy_from (storage);
447
+
448
+ // Check the data type and print accordingly
449
+ TYPE_DISPATCH (
450
+ printable.dtype (),
451
+ const auto * values = printable.data <T>();
452
+ const auto & shape = printable.shape ();
453
+
454
+ // Print tensor contents based on dimensionality
455
+ if (shape.empty ()) { // Scalar case
456
+ os << " Data (Scalar): " << values[0 ] << std::endl;
457
+ } else {
458
+ os << " Data (" << shape.size () << " D " ;
459
+ if (shape.size () == 1 )
460
+ os << " Vector" ;
461
+ else if (shape.size () == 2 )
462
+ os << " Matrix" ;
463
+ else
464
+ os << " Tensor" ;
465
+ os << " ):" << std::endl;
466
+ printable.print_tensor (os, values, shape, 0 , 0 , 0 );
467
+ os << std::endl;
468
+ }
469
+ );
470
+
471
+ // Print metadata
472
+ os << " [device:" << device_to_str (storage.device (), storage.device_index ())
473
+ << " , dtype:" << dtype_name (storage.dtype ()) << " , storage viewed as " ;
474
+ if (storage.is_scalar ())
475
+ os << " scalar" ;
476
+ else {
477
+ for (dim_t i = 0 ; i < storage.rank (); ++i) {
478
+ if (i > 0 )
479
+ os << ' x' ;
480
+ os << storage.dim (i);
453
481
}
454
482
}
455
- else {
456
- for (dim_t i = 0 ; i < PRINT_MAX_VALUES / 2 ; ++i) {
457
- os << ' ' ;
458
- print_value (os, values[i]);
483
+ os << ' ]' ;
484
+ return os;
485
+ }
486
+
487
+ template <typename T>
488
+ void StorageView::print_tensor (std::ostream& os, const T* data, const std::vector<dim_t >& shape, size_t dim, size_t offset, int indent) const {
489
+ std::string indentation (indent, ' ' );
490
+
491
+ os << indentation << " [" ;
492
+ bool is_last_dim = (dim == shape.size () - 1 );
493
+ for (dim_t i = 0 ; i < shape[dim]; ++i) {
494
+ if (i > 0 ) {
495
+ os << " , " ;
496
+ if (!is_last_dim) {
497
+ os << " \n " << std::string (indent, ' ' );
498
+ }
459
499
}
460
- os << " ..." ;
461
- for (dim_t i = printable.size () - (PRINT_MAX_VALUES / 2 ); i < printable.size (); ++i) {
462
- os << ' ' ;
463
- print_value (os, values[i]);
500
+
501
+ if (i == PRINT_MAX_VALUES / 2 ) {
502
+ os << " ..." ;
503
+ i = shape[dim] - PRINT_MAX_VALUES / 2 - 1 ; // Skip to the last part
504
+ } else {
505
+ if (is_last_dim) {
506
+ os << data[offset + i];
507
+ } else {
508
+ print_tensor (os, data, shape, dim + 1 , offset + i * shape[dim + 1 ], indent);
509
+ }
464
510
}
465
511
}
466
- os << std::endl);
467
- os << " [" << device_to_str (storage.device (), storage.device_index ())
468
- << " " << dtype_name (storage.dtype ()) << " storage viewed as " ;
469
- if (storage.is_scalar ())
470
- os << " scalar" ;
471
- else {
472
- for (dim_t i = 0 ; i < storage.rank (); ++i) {
473
- if (i > 0 )
474
- os << ' x' ;
475
- os << storage.dim (i);
476
- }
512
+ os << " ]" ;
477
513
}
478
- os << ' ]' ;
479
- return os;
480
- }
481
514
482
515
#define DECLARE_IMPL (T ) \
483
516
template \
0 commit comments