Skip to content

Commit 60d3241

Browse files
authored
Update to new albumentations (#1209)
* Fixed augmentation imports and test_train.py to use the new import paths. * Add reruns in pytest ini to fix flaky web socket tests.
1 parent 5483490 commit 60d3241

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

donkeycar/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pyfiglet import Figlet
44
import logging
55

6-
__version__ = '5.2.dev2'
6+
__version__ = '5.2.dev3'
77

88
logging.basicConfig(level=os.environ.get('LOGLEVEL', 'INFO').upper())
99

donkeycar/pipeline/augmentations.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import logging
33
import albumentations as A
44
from albumentations import GaussianBlur
5-
from albumentations.augmentations.transforms import RandomBrightnessContrast
5+
from albumentations.augmentations import RandomBrightnessContrast
6+
67

78
from donkeycar.config import Config
89

@@ -11,14 +12,14 @@
1112

1213

1314
class ImageAugmentation:
14-
def __init__(self, cfg, key, prob=0.5, always_apply=False):
15+
def __init__(self, cfg, key, prob=0.5):
1516
aug_list = getattr(cfg, key, [])
16-
augmentations = [ImageAugmentation.create(a, cfg, prob, always_apply)
17+
augmentations = [ImageAugmentation.create(a, cfg, prob)
1718
for a in aug_list]
1819
self.augmentations = A.Compose(augmentations)
1920

2021
@classmethod
21-
def create(cls, aug_type: str, config: Config, prob, always) -> \
22+
def create(cls, aug_type: str, config: Config, prob) -> \
2223
albumentations.core.transforms_interface.BasicTransform:
2324
""" Augmentation factory. Cropping and trapezoidal mask are
2425
transformations which should be applied in training, validation
@@ -30,13 +31,13 @@ def create(cls, aug_type: str, config: Config, prob, always) -> \
3031
logger.info(f'Creating augmentation {aug_type} {b_limit}')
3132
return RandomBrightnessContrast(brightness_limit=b_limit,
3233
contrast_limit=b_limit,
33-
p=prob, always_apply=always)
34+
p=prob)
3435

3536
elif aug_type == 'BLUR':
3637
b_range = getattr(config, 'AUG_BLUR_RANGE', 3)
3738
logger.info(f'Creating augmentation {aug_type} {b_range}')
3839
return GaussianBlur(sigma_limit=b_range, blur_limit=(13, 13),
39-
p=prob, always_apply=always)
40+
p=prob)
4041

4142
# Parts interface
4243
def run(self, img_arr):

donkeycar/tests/pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ filterwarnings =
55

66
log_cli = True
77
log_cli_level = INFO
8+
reruns = 3

donkeycar/tests/test_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def car_dir(tmpdir_factory, base_config, imu_fields) -> str:
106106
record['localizer/location'] = 3 * count // len(tub)
107107
tub_full.write_record(record)
108108
count += 1
109+
tub_full.close()
110+
tub.close()
109111
return car_dir
110112

111113

0 commit comments

Comments
 (0)