@@ -37,7 +37,6 @@ impl<B: Backend> Normalizer<B> {
3737#[ derive( Clone ) ]
3838pub struct ClassificationBatcher < B : Backend > {
3939 normalizer : Normalizer < B > ,
40- device : B :: Device ,
4140}
4241
4342#[ derive( Clone , Debug ) ]
@@ -50,13 +49,12 @@ impl<B: Backend> ClassificationBatcher<B> {
5049 pub fn new ( device : B :: Device ) -> Self {
5150 Self {
5251 normalizer : Normalizer :: < B > :: new ( & device) ,
53- device,
5452 }
5553 }
5654}
5755
58- impl < B : Backend > Batcher < ImageDatasetItem , ClassificationBatch < B > > for ClassificationBatcher < B > {
59- fn batch ( & self , items : Vec < ImageDatasetItem > ) -> ClassificationBatch < B > {
56+ impl < B : Backend > Batcher < B , ImageDatasetItem , ClassificationBatch < B > > for ClassificationBatcher < B > {
57+ fn batch ( & self , items : Vec < ImageDatasetItem > , device : & B :: Device ) -> ClassificationBatch < B > {
6058 fn image_as_vec_u8 ( item : ImageDatasetItem ) -> Vec < u8 > {
6159 item. image
6260 . into_iter ( )
@@ -70,7 +68,7 @@ impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for Classific
7068 if let Annotation :: Label ( y) = item. annotation {
7169 Tensor :: < B , 1 , Int > :: from_data (
7270 TensorData :: from ( [ ( y as i64 ) . elem :: < B :: IntElem > ( ) ] ) ,
73- & self . device ,
71+ device,
7472 )
7573 } else {
7674 panic ! ( "Invalid target type" )
@@ -82,7 +80,7 @@ impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for Classific
8280 . into_iter ( )
8381 . map ( |item| TensorData :: new ( image_as_vec_u8 ( item) , Shape :: new ( [ 32 , 32 , 3 ] ) ) )
8482 . map ( |data| {
85- Tensor :: < B , 3 > :: from_data ( data. convert :: < B :: FloatElem > ( ) , & self . device )
83+ Tensor :: < B , 3 > :: from_data ( data. convert :: < B :: FloatElem > ( ) , device)
8684 // permute(2, 0, 1)
8785 . swap_dims ( 2 , 1 ) // [H, C, W]
8886 . swap_dims ( 1 , 0 ) // [C, H, W]
0 commit comments