Skip to content

Commit e6f9c06

Browse files
committed
refactor: update batcher implementations to include device parameter for better device management
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
1 parent ae177fb commit e6f9c06

File tree

10 files changed

+35
-28
lines changed

10 files changed

+35
-28
lines changed

burn-algorithms/agnews/src/data.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,10 @@ pub struct ClassificationInferenceBatch<B: Backend> {
196196
pub mask_pad: Tensor<B, 2, Bool>,
197197
}
198198

199-
impl<B: Backend> Batcher<ClassificationItem, ClassificationTrainingBatch<B>>
199+
impl<B: Backend> Batcher<B, ClassificationItem, ClassificationTrainingBatch<B>>
200200
for ClassificationBatcher<B>
201201
{
202-
fn batch(&self, items: Vec<ClassificationItem>) -> ClassificationTrainingBatch<B> {
202+
fn batch(&self, items: Vec<ClassificationItem>, _device: &B::Device) -> ClassificationTrainingBatch<B> {
203203
let mut tokens_list = Vec::with_capacity(items.len());
204204
let mut labels_list = Vec::with_capacity(items.len());
205205

@@ -226,8 +226,8 @@ impl<B: Backend> Batcher<ClassificationItem, ClassificationTrainingBatch<B>>
226226
}
227227
}
228228

229-
impl<B: Backend> Batcher<String, ClassificationInferenceBatch<B>> for ClassificationBatcher<B> {
230-
fn batch(&self, items: Vec<String>) -> ClassificationInferenceBatch<B> {
229+
impl<B: Backend> Batcher<B, String, ClassificationInferenceBatch<B>> for ClassificationBatcher<B> {
230+
fn batch(&self, items: Vec<String>, _device: &B::Device) -> ClassificationInferenceBatch<B> {
231231
let mut tokens_list = Vec::with_capacity(items.len());
232232

233233
for item in items {
@@ -238,12 +238,12 @@ impl<B: Backend> Batcher<String, ClassificationInferenceBatch<B>> for Classifica
238238
self.tokenizer.pad_token(),
239239
tokens_list,
240240
Some(self.max_seq_length),
241-
&B::Device::default(),
241+
&self.device,
242242
);
243243

244244
ClassificationInferenceBatch {
245-
tokens: mask.tensor.to_device(&self.device),
246-
mask_pad: mask.mask.to_device(&self.device),
245+
tokens: mask.tensor,
246+
mask_pad: mask.mask,
247247
}
248248
}
249249
}

burn-algorithms/cifar10/src/data.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ impl<B: Backend> Normalizer<B> {
3737
#[derive(Clone)]
3838
pub struct ClassificationBatcher<B: Backend> {
3939
normalizer: Normalizer<B>,
40-
device: B::Device,
4140
}
4241

4342
#[derive(Clone, Debug)]
@@ -50,13 +49,12 @@ impl<B: Backend> ClassificationBatcher<B> {
5049
pub fn new(device: B::Device) -> Self {
5150
Self {
5251
normalizer: Normalizer::<B>::new(&device),
53-
device,
5452
}
5553
}
5654
}
5755

58-
impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for ClassificationBatcher<B> {
59-
fn batch(&self, items: Vec<ImageDatasetItem>) -> ClassificationBatch<B> {
56+
impl<B: Backend> Batcher<B, ImageDatasetItem, ClassificationBatch<B>> for ClassificationBatcher<B> {
57+
fn batch(&self, items: Vec<ImageDatasetItem>, device: &B::Device) -> ClassificationBatch<B> {
6058
fn image_as_vec_u8(item: ImageDatasetItem) -> Vec<u8> {
6159
item.image
6260
.into_iter()
@@ -70,7 +68,7 @@ impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for Classific
7068
if let Annotation::Label(y) = item.annotation {
7169
Tensor::<B, 1, Int>::from_data(
7270
TensorData::from([(y as i64).elem::<B::IntElem>()]),
73-
&self.device,
71+
device,
7472
)
7573
} else {
7674
panic!("Invalid target type")
@@ -82,7 +80,7 @@ impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for Classific
8280
.into_iter()
8381
.map(|item| TensorData::new(image_as_vec_u8(item), Shape::new([32, 32, 3])))
8482
.map(|data| {
85-
Tensor::<B, 3>::from_data(data.convert::<B::FloatElem>(), &self.device)
83+
Tensor::<B, 3>::from_data(data.convert::<B::FloatElem>(), device)
8684
// permute(2, 0, 1)
8785
.swap_dims(2, 1) // [H, C, W]
8886
.swap_dims(1, 0) // [C, H, W]

burn-algorithms/cifar10/src/training.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
8181
.metric_valid_numeric(LossMetric::new())
8282
.with_file_checkpointer(CompactRecorder::new())
8383
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
84+
&LossMetric::new(),
8485
Aggregate::Mean,
8586
Direction::Lowest,
8687
Split::Valid,
@@ -104,6 +105,7 @@ pub fn train<B: AutodiffBackend>(config: TrainingConfig, device: B::Device) {
104105
.metric_valid_numeric(LossMetric::new())
105106
.with_file_checkpointer(CompactRecorder::new())
106107
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
108+
&LossMetric::new(),
107109
Aggregate::Mean,
108110
Direction::Lowest,
109111
Split::Valid,

burn-algorithms/imdb/src/data.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ pub struct ClassificationInferenceBatch<B: Backend> {
179179
pub mask_pad: Tensor<B, 2, Bool>,
180180
}
181181

182-
impl<B: Backend> Batcher<ClassificationItem, ClassificationTrainingBatch<B>>
182+
impl<B: Backend> Batcher<B, ClassificationItem, ClassificationTrainingBatch<B>>
183183
for ClassificationBatcher<B>
184184
{
185-
fn batch(&self, items: Vec<ClassificationItem>) -> ClassificationTrainingBatch<B> {
185+
fn batch(&self, items: Vec<ClassificationItem>, _device: &B::Device) -> ClassificationTrainingBatch<B> {
186186
let mut tokens_list = Vec::with_capacity(items.len());
187187
let mut labels_list = Vec::with_capacity(items.len());
188188

@@ -209,8 +209,8 @@ impl<B: Backend> Batcher<ClassificationItem, ClassificationTrainingBatch<B>>
209209
}
210210
}
211211

212-
impl<B: Backend> Batcher<String, ClassificationInferenceBatch<B>> for ClassificationBatcher<B> {
213-
fn batch(&self, items: Vec<String>) -> ClassificationInferenceBatch<B> {
212+
impl<B: Backend> Batcher<B, String, ClassificationInferenceBatch<B>> for ClassificationBatcher<B> {
213+
fn batch(&self, items: Vec<String>, _device: &B::Device) -> ClassificationInferenceBatch<B> {
214214
let mut tokens_list = Vec::with_capacity(items.len());
215215

216216
for item in items {
@@ -221,12 +221,12 @@ impl<B: Backend> Batcher<String, ClassificationInferenceBatch<B>> for Classifica
221221
self.tokenizer.pad_token(),
222222
tokens_list,
223223
Some(self.max_seq_length),
224-
&B::Device::default(),
224+
&self.device,
225225
);
226226

227227
ClassificationInferenceBatch {
228-
tokens: mask.tensor.to_device(&self.device),
229-
mask_pad: mask.mask.to_device(&self.device),
228+
tokens: mask.tensor,
229+
mask_pad: mask.mask,
230230
}
231231
}
232232
}

burn-algorithms/iris/src/data.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ impl<B: Backend> IrisBatcher<B> {
122122
}
123123
}
124124

125-
impl<B: Backend> Batcher<IrisItem, IrisBatch<B>> for IrisBatcher<B> {
126-
fn batch(&self, items: Vec<IrisItem>) -> IrisBatch<B> {
125+
impl<B: Backend> Batcher<B, IrisItem, IrisBatch<B>> for IrisBatcher<B> {
126+
fn batch(&self, items: Vec<IrisItem>, _device: &B::Device) -> IrisBatch<B> {
127127
let mut inputs: Vec<Tensor<B, 2>> = Vec::new();
128128

129129
for item in items.iter() {

burn-algorithms/iris/src/training.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
8181
.metric_valid_numeric(LossMetric::new())
8282
.with_file_checkpointer(CompactRecorder::new())
8383
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
84+
&LossMetric::new(),
8485
Aggregate::Mean,
8586
Direction::Lowest,
8687
Split::Valid,
@@ -100,6 +101,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
100101
.metric_valid_numeric(LossMetric::new())
101102
.with_file_checkpointer(CompactRecorder::new())
102103
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
104+
&LossMetric::new(),
103105
Aggregate::Mean,
104106
Direction::Lowest,
105107
Split::Valid,

burn-algorithms/mnist/src/data.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ impl<B: Backend> MnistBatcher<B> {
2020
}
2121
}
2222

23-
impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
24-
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
23+
impl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher<B> {
24+
fn batch(&self, items: Vec<MnistItem>, _device: &B::Device) -> MnistBatch<B> {
25+
let device = &self.device;
2526
let images = items
2627
.iter()
2728
.map(|item| TensorData::from(item.image))
28-
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), &self.device))
29+
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), device))
2930
.map(|tensor| tensor.reshape([1, 28, 28]))
3031
// normalize: make between [0,1] and make the mean = 0 and std = 1
3132
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
@@ -38,7 +39,7 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
3839
.map(|item| {
3940
Tensor::<B, 1, Int>::from_data(
4041
TensorData::from([(item.label as i64).elem::<B::FloatElem>()]),
41-
&self.device,
42+
device,
4243
)
4344
})
4445
.collect();

burn-algorithms/mnist/src/training.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
7272
.metric_valid_numeric(LossMetric::new())
7373
.with_file_checkpointer(CompactRecorder::new())
7474
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
75+
&LossMetric::new(),
7576
Aggregate::Mean,
7677
Direction::Lowest,
7778
Split::Valid,
@@ -95,6 +96,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
9596
.metric_valid_numeric(LossMetric::new())
9697
.with_file_checkpointer(CompactRecorder::new())
9798
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
99+
&LossMetric::new(),
98100
Aggregate::Mean,
99101
Direction::Lowest,
100102
Split::Valid,

burn-algorithms/winequality/src/data.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ impl<B: Backend> WineQualityBatcher<B> {
152152
}
153153
}
154154

155-
impl<B: Backend> Batcher<WineQualityItem, WineQualityBatch<B>> for WineQualityBatcher<B> {
156-
fn batch(&self, items: Vec<WineQualityItem>) -> WineQualityBatch<B> {
155+
impl<B: Backend> Batcher<B, WineQualityItem, WineQualityBatch<B>> for WineQualityBatcher<B> {
156+
fn batch(&self, items: Vec<WineQualityItem>, _device: &B::Device) -> WineQualityBatch<B> {
157157
let mut inputs: Vec<Tensor<B, 2>> = Vec::new();
158158

159159
// The constants are the min and max values of the dataset

burn-algorithms/winequality/src/training.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
7979
.metric_valid_numeric(LossMetric::new())
8080
.with_file_checkpointer(CompactRecorder::new())
8181
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
82+
&LossMetric::new(),
8283
Aggregate::Mean,
8384
Direction::Lowest,
8485
Split::Valid,
@@ -96,6 +97,7 @@ pub fn run<B: AutodiffBackend>(device: B::Device) {
9697
.metric_valid_numeric(LossMetric::new())
9798
.with_file_checkpointer(CompactRecorder::new())
9899
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
100+
&LossMetric::new(),
99101
Aggregate::Mean,
100102
Direction::Lowest,
101103
Split::Valid,

0 commit comments

Comments
 (0)