Skip to content

Commit b05902b

Browse files
authored
Merge pull request #2247 from Trusted-AI/dev_1.15.1
Update to ART 1.15.1
2 parents 75248e1 + 0a01458 commit b05902b

File tree

10 files changed

+100
-25
lines changed

10 files changed

+100
-25
lines changed

art/attacks/evasion/auto_conjugate_gradient.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
463463

464464
# self.eta = np.full((self.batch_size, 1, 1, 1), 2 * self.eps_step).astype(ART_NUMPY_DTYPE)
465465
_batch_size = x_k.shape[0]
466-
eta = np.full((_batch_size, 1, 1, 1), self.eps_step).astype(ART_NUMPY_DTYPE)
466+
eta = np.full((_batch_size,) + (1,) * len(self.estimator.input_shape), self.eps_step).astype(
467+
ART_NUMPY_DTYPE
468+
)
467469
self.count_condition_1 = np.zeros(shape=(_batch_size,))
468470
gradk_1 = np.zeros_like(x_k)
469471
cgradk_1 = np.zeros_like(x_k)
@@ -650,4 +652,4 @@ def get_beta(gradk, gradk_1, cgradk_1):
650652
betak = -(_gradk * delta_gradk).sum(axis=1) / (
651653
(_cgradk_1 * delta_gradk).sum(axis=1) + np.finfo(ART_NUMPY_DTYPE).eps
652654
)
653-
return betak.reshape((_batch_size, 1, 1, 1))
655+
return betak.reshape((_batch_size,) + (1,) * (len(gradk.shape) - 1))

art/attacks/evasion/auto_projected_gradient_descent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
458458

459459
# modification for image-wise stepsize update
460460
_batch_size = x_k.shape[0]
461-
eta = np.full((_batch_size, 1, 1, 1), self.eps_step).astype(ART_NUMPY_DTYPE)
461+
eta = np.full((_batch_size,) + (1,) * len(self.estimator.input_shape), self.eps_step).astype(
462+
ART_NUMPY_DTYPE
463+
)
462464
self.count_condition_1 = np.zeros(shape=(_batch_size,))
463465

464466
for k_iter in trange(self.max_iter, desc="AutoPGD - iteration", leave=False, disable=not self.verbose):

art/attacks/poisoning/perturbations/image_perturbations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import Optional, Tuple
2222

2323
import numpy as np
24-
from PIL import Image
2524

2625

2726
def add_single_bd(x: np.ndarray, distance: int = 2, pixel_value: int = 1) -> np.ndarray:
@@ -112,6 +111,8 @@ def insert_image(
112111
:param blend: The blending factor
113112
:return: Backdoored image.
114113
"""
114+
from PIL import Image
115+
115116
n_dim = len(x.shape)
116117
if n_dim == 4:
117118
return np.array(

art/defences/preprocessor/spatial_smoothing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from typing import Optional, Tuple
3131

3232
import numpy as np
33-
from scipy.ndimage.filters import median_filter
33+
from scipy.ndimage import median_filter
3434

3535
from art.utils import CLIP_VALUES_TYPE
3636
from art.defences.preprocessor.preprocessor import Preprocessor

art/defences/trainer/adversarial_trainer_trades_pytorch.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from art.estimators.classification.pytorch import PyTorchClassifier
3434
from art.data_generators import DataGenerator
3535
from art.attacks.attack import EvasionAttack
36+
from art.utils import check_and_transform_label_format
3637

3738
if TYPE_CHECKING:
3839
import torch
@@ -97,6 +98,15 @@ def fit(
9798
ind = np.arange(len(x))
9899

99100
logger.info("Adversarial Training TRADES")
101+
y = check_and_transform_label_format(y, nb_classes=self.classifier.nb_classes)
102+
103+
if validation_data is not None:
104+
(x_test, y_test) = validation_data
105+
y_test = check_and_transform_label_format(y_test, nb_classes=self.classifier.nb_classes)
106+
107+
x_preprocessed_test, y_preprocessed_test = self._classifier._apply_preprocessing( # pylint: disable=W0212
108+
x_test, y_test, fit=True
109+
)
100110

101111
for i_epoch in trange(nb_epochs, desc="Adversarial Training TRADES - Epochs"):
102112
# Shuffle the examples
@@ -107,7 +117,6 @@ def fit(
107117
train_n = 0.0
108118

109119
for batch_id in range(nb_batches):
110-
111120
# Create batch data
112121
x_batch = x[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]].copy()
113122
y_batch = y[ind[batch_id * batch_size : min((batch_id + 1) * batch_size, x.shape[0])]]
@@ -125,9 +134,9 @@ def fit(
125134

126135
# compute accuracy
127136
if validation_data is not None:
128-
(x_test, y_test) = validation_data
129-
output = np.argmax(self.predict(x_test), axis=1)
130-
nb_correct_pred = np.sum(output == np.argmax(y_test, axis=1))
137+
output = np.argmax(self.predict(x_preprocessed_test), axis=1)
138+
nb_correct_pred = np.sum(output == np.argmax(y_preprocessed_test, axis=1))
139+
131140
logger.info(
132141
"epoch: %s time(s): %.1f loss: %.4f acc(tr): %.4f acc(val): %.4f",
133142
i_epoch,
@@ -188,7 +197,6 @@ def fit_generator(
188197
train_n = 0.0
189198

190199
for batch_id in range(nb_batches): # pylint: disable=W0612
191-
192200
# Create batch data
193201
x_batch, y_batch = generator.get_batch()
194202
x_batch = x_batch.copy()
@@ -232,6 +240,8 @@ def _batch_process(self, x_batch: np.ndarray, y_batch: np.ndarray) -> Tuple[floa
232240
x_batch_pert = self._attack.generate(x_batch, y=y_batch)
233241

234242
# Apply preprocessing
243+
y_batch = check_and_transform_label_format(y_batch, nb_classes=self.classifier.nb_classes)
244+
235245
x_preprocessed, y_preprocessed = self._classifier._apply_preprocessing( # pylint: disable=W0212
236246
x_batch, y_batch, fit=True
237247
)

art/estimators/certification/__init__.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,26 @@
22
This module contains certified classifiers.
33
"""
44
import importlib
5-
from art.estimators.certification import randomized_smoothing
6-
from art.estimators.certification import derandomized_smoothing
5+
from art.estimators.certification.randomized_smoothing.randomized_smoothing import RandomizedSmoothingMixin
6+
from art.estimators.certification.randomized_smoothing.numpy import NumpyRandomizedSmoothing
7+
from art.estimators.certification.randomized_smoothing.tensorflow import TensorFlowV2RandomizedSmoothing
8+
from art.estimators.certification.randomized_smoothing.pytorch import PyTorchRandomizedSmoothing
9+
from art.estimators.certification.derandomized_smoothing.derandomized_smoothing import DeRandomizedSmoothingMixin
10+
from art.estimators.certification.derandomized_smoothing.pytorch import PyTorchDeRandomizedSmoothing
11+
from art.estimators.certification.derandomized_smoothing.tensorflow import TensorFlowV2DeRandomizedSmoothing
712

813
if importlib.util.find_spec("torch") is not None:
9-
from art.estimators.certification import deep_z
10-
from art.estimators.certification import interval
14+
from art.estimators.certification.deep_z.deep_z import ZonoDenseLayer
15+
from art.estimators.certification.deep_z.deep_z import ZonoBounds
16+
from art.estimators.certification.deep_z.deep_z import ZonoConv
17+
from art.estimators.certification.deep_z.deep_z import ZonoReLU
18+
from art.estimators.certification.deep_z.pytorch import PytorchDeepZ
19+
from art.estimators.certification.interval.interval import PyTorchIntervalDense
20+
from art.estimators.certification.interval.interval import PyTorchIntervalConv2D
21+
from art.estimators.certification.interval.interval import PyTorchIntervalReLU
22+
from art.estimators.certification.interval.interval import PyTorchIntervalFlatten
23+
from art.estimators.certification.interval.interval import PyTorchIntervalBounds
24+
from art.estimators.certification.interval.pytorch import PyTorchIBPClassifier
1125
else:
1226
import warnings
1327

art/estimators/object_detection/pytorch_object_detector.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ def _preprocess_and_convert_inputs(
219219

220220
# Set gradients
221221
if not no_grad:
222-
x_tensor.requires_grad = True
222+
if x_tensor.is_leaf:
223+
x_tensor.requires_grad = True
224+
else:
225+
x_tensor.retain_grad()
223226

224227
# Apply framework-specific preprocessing
225228
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x_tensor, y=y_tensor, fit=fit, no_grad=no_grad)
@@ -267,6 +270,12 @@ def _get_losses(
267270
x_preprocessed = x_preprocessed.to(self.device)
268271
y_preprocessed = [{k: v.to(self.device) for k, v in y_i.items()} for y_i in y_preprocessed]
269272

273+
# Set gradients again after inputs are moved to another device
274+
if x_preprocessed.is_leaf:
275+
x_preprocessed.requires_grad = True
276+
else:
277+
x_preprocessed.retain_grad()
278+
270279
loss_components = self._model(x_preprocessed, y_preprocessed)
271280

272281
return loss_components, x_preprocessed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,12 @@ def _get_losses(
358358
x_preprocessed = x_preprocessed.to(self.device)
359359
y_preprocessed_yolo = y_preprocessed_yolo.to(self.device)
360360

361+
# Set gradients again after inputs are moved to another device
362+
if x_preprocessed.is_leaf:
363+
x_preprocessed.requires_grad = True
364+
else:
365+
x_preprocessed.retain_grad()
366+
361367
# Calculate loss components
362368
loss_components = self._model(x_preprocessed, y_preprocessed_yolo)
363369

art/visualization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from typing import List, Optional, TYPE_CHECKING
2626

2727
import numpy as np
28-
from PIL import Image
2928

3029
from art import config
3130

@@ -97,6 +96,8 @@ def save_image(image_array: np.ndarray, f_name: str) -> None:
9796
:param image_array: Image to be saved.
9897
:param f_name: File name containing extension e.g., my_img.jpg, my_img.png, my_images/my_img.png.
9998
"""
99+
from PIL import Image
100+
100101
file_name = os.path.join(config.ART_DATA_PATH, f_name)
101102
folder = os.path.split(file_name)[0]
102103
if not os.path.exists(folder):

tests/defences/trainer/test_adversarial_trainer_trades_pytorch.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _get_adv_trainer():
3434
if framework in ["tensorflow", "tensorflow2v1"]:
3535
trainer = None
3636
if framework == "pytorch":
37-
classifier, _ = image_dl_estimator()
37+
classifier, _ = image_dl_estimator(from_logits=True)
3838
attack = ProjectedGradientDescent(
3939
classifier,
4040
norm=np.inf,
@@ -63,22 +63,38 @@ def fix_get_mnist_subset(get_mnist_dataset):
6363
yield x_train_mnist[:n_train], y_train_mnist[:n_train], x_test_mnist[:n_test], y_test_mnist[:n_test]
6464

6565

66-
@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
67-
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset):
66+
@pytest.mark.only_with_platform("pytorch")
67+
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
68+
def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix_get_mnist_subset, label_format):
6869
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
6970
x_test_mnist_original = x_test_mnist.copy()
7071

72+
if label_format == "one_hot":
73+
assert y_train_mnist.shape[-1] == 10
74+
assert y_test_mnist.shape[-1] == 10
75+
if label_format == "numerical":
76+
y_test_mnist = np.argmax(y_test_mnist, axis=1)
77+
y_train_mnist = np.argmax(y_train_mnist, axis=1)
78+
7179
trainer = get_adv_trainer()
7280
if trainer is None:
7381
logging.warning("Couldn't perform this test because no trainer is defined for this framework configuration")
7482
return
7583

7684
predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
77-
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
85+
86+
if label_format == "one_hot":
87+
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
88+
else:
89+
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
7890

7991
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20)
8092
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
81-
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
93+
94+
if label_format == "one_hot":
95+
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
96+
else:
97+
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]
8298

8399
np.testing.assert_array_almost_equal(
84100
float(np.mean(x_test_mnist_original - x_test_mnist)),
@@ -92,13 +108,20 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
92108
trainer.fit(x_train_mnist, y_train_mnist, nb_epochs=20, validation_data=(x_train_mnist, y_train_mnist))
93109

94110

95-
@pytest.mark.skip_framework("tensorflow", "keras", "scikitlearn", "mxnet", "kerastf")
111+
@pytest.mark.only_with_platform("pytorch")
112+
@pytest.mark.parametrize("label_format", ["one_hot", "numerical"])
96113
def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
97-
get_adv_trainer, fix_get_mnist_subset, image_data_generator
114+
get_adv_trainer, fix_get_mnist_subset, image_data_generator, label_format
98115
):
99116
(x_train_mnist, y_train_mnist, x_test_mnist, y_test_mnist) = fix_get_mnist_subset
100117
x_test_mnist_original = x_test_mnist.copy()
101118

119+
if label_format == "one_hot":
120+
assert y_train_mnist.shape[-1] == 10
121+
assert y_test_mnist.shape[-1] == 10
122+
if label_format == "numerical":
123+
y_test_mnist = np.argmax(y_test_mnist, axis=1)
124+
102125
generator = image_data_generator()
103126

104127
trainer = get_adv_trainer()
@@ -107,11 +130,18 @@ def test_adversarial_trainer_trades_pytorch_fit_generator_and_predict(
107130
return
108131

109132
predictions = np.argmax(trainer.predict(x_test_mnist), axis=1)
110-
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
133+
if label_format == "one_hot":
134+
accuracy = np.sum(predictions == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
135+
else:
136+
accuracy = np.sum(predictions == y_test_mnist) / x_test_mnist.shape[0]
111137

112138
trainer.fit_generator(generator=generator, nb_epochs=20)
113139
predictions_new = np.argmax(trainer.predict(x_test_mnist), axis=1)
114-
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
140+
141+
if label_format == "one_hot":
142+
accuracy_new = np.sum(predictions_new == np.argmax(y_test_mnist, axis=1)) / x_test_mnist.shape[0]
143+
else:
144+
accuracy_new = np.sum(predictions_new == y_test_mnist) / x_test_mnist.shape[0]
115145

116146
np.testing.assert_array_almost_equal(
117147
float(np.mean(x_test_mnist_original - x_test_mnist)),

0 commit comments

Comments
 (0)