Skip to content

Commit 8b8de3c

Browse files
pandafynemesifier
authored andcommitted
[fix] Fixed enforce_required_templates bug in change of device org #445
Closes #445
1 parent f1a12fd commit 8b8de3c

File tree

5 files changed

+109
-22
lines changed

5 files changed

+109
-22
lines changed

openwisp_controller/config/admin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,14 @@ def clean_templates(self):
347347
reverse=False,
348348
model=config.templates.model,
349349
pk_set=templates,
350+
# The template validation retrieves the device object
351+
# from the database. Even if the organization of the device
352+
# is changed by the user, the validation uses the old
353+
# organization of the device because the device is not
354+
# saved yet. The raw POST data is passed here so that
355+
# validation can be performed using up to date data of
356+
# the device object.
357+
raw_data=self.data,
350358
)
351359
return templates
352360

openwisp_controller/config/api/serializers.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -201,22 +201,23 @@ class Meta(BaseMeta):
201201
]
202202

203203
def update(self, instance, validated_data):
204-
config_data = None
205-
204+
config_data = validated_data.pop('config', {})
205+
config_templates = [
206+
template.pk for template in config_data.get('templates', [])
207+
]
208+
raw_data_for_signal_handlers = {
209+
'organization': validated_data.get('organization', instance.organization)
210+
}
206211
if self.initial_data.get('config.backend') and instance._has_config() is False:
207-
config_data = dict(validated_data.pop('config'))
208-
config_templates = [
209-
template.pk for template in config_data.pop('templates')
210-
]
212+
config_data = dict(config_data)
211213
with transaction.atomic():
212-
config = Config.objects.create(device=instance, **config_data)
214+
config = Config(device=instance, **config_data)
213215
config.templates.add(*config_templates)
214216
config.full_clean()
215217
config.save()
216218
return super().update(instance, validated_data)
217219

218-
if validated_data.get('config'):
219-
config_data = validated_data.pop('config')
220+
if config_data:
220221
instance.config.backend = config_data.get(
221222
'backend', instance.config.backend
222223
)
@@ -227,9 +228,7 @@ def update(self, instance, validated_data):
227228

228229
if 'templates' in config_data:
229230
if config_data.get('templates'):
230-
new_config_templates = [
231-
template.pk for template in config_data.get('templates')
232-
]
231+
new_config_templates = config_templates
233232
old_config_templates = [
234233
template
235234
for template in instance.config.templates.values_list(
@@ -256,8 +255,23 @@ def update(self, instance, validated_data):
256255
instance.config.templates.clear()
257256
instance.config.templates.add(*[])
258257

259-
instance.config.full_clean()
260-
instance.config.save()
258+
elif hasattr(instance, 'config') and validated_data.get('organization'):
259+
if instance.organization != validated_data.get('organization'):
260+
# config.device.organization is used for validating
261+
# the organization of templates. It is also used for adding
262+
# default and required templates configured for an organization.
263+
# The value of the organization field is set here to
264+
# prevent access of the old value stored in the database
265+
# while performing above operations.
266+
instance.config.device.organization = validated_data.get('organization')
267+
instance.config.templates.clear()
268+
Config.enforce_required_templates(
269+
action='post_clear',
270+
instance=instance.config,
271+
sender=instance.config.templates,
272+
pk_set=None,
273+
raw_data=raw_data_for_signal_handlers,
274+
)
261275
return super().update(instance, validated_data)
262276

263277

openwisp_controller/config/base/config.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,21 @@ def _get_templates_from_pk_set(cls, pk_set):
163163
return templates
164164

165165
@classmethod
166-
def clean_templates(cls, action, instance, pk_set, **kwargs):
166+
def clean_templates(cls, action, instance, pk_set, raw_data=None, **kwargs):
167167
"""
168168
validates resulting configuration of config + templates
169169
raises a ValidationError if invalid
170170
must be called from forms or APIs
171171
this method is called from a django signal (m2m_changed)
172172
see config.apps.ConfigConfig.connect_signals
173+
174+
raw_data contains the non-validated data that is submitted to
175+
a form or API.
173176
"""
174-
templates = cls.clean_templates_org(action, instance, pk_set, **kwargs)
177+
raw_data = raw_data or {}
178+
templates = cls.clean_templates_org(
179+
action, instance, pk_set, raw_data=raw_data, **kwargs
180+
)
175181
if not templates:
176182
return
177183
backend = instance.get_backend_instance(template_instances=templates)
@@ -254,19 +260,25 @@ def manage_vpn_clients(cls, action, instance, pk_set, **kwargs):
254260
client.delete()
255261

256262
@classmethod
257-
def clean_templates_org(cls, action, instance, pk_set, **kwargs):
263+
def clean_templates_org(cls, action, instance, pk_set, raw_data=None, **kwargs):
264+
"""
265+
raw_data contains the non-validated data that is submitted to
266+
a form or API.
267+
"""
258268
if action != 'pre_add':
259269
return False
270+
raw_data = raw_data or {}
260271
templates = cls._get_templates_from_pk_set(pk_set)
261272
# when using the admin, templates will be a list
262273
# we need to get the queryset from this list in order to proceed
263274
if not isinstance(templates, models.QuerySet):
264275
template_model = cls.templates.rel.model
265276
pk_list = [template.pk for template in templates]
266277
templates = template_model.objects.filter(pk__in=pk_list)
267-
# lookg for invalid templates
278+
# looking for invalid templates
279+
organization = raw_data.get('organization', instance.device.organization)
268280
invalids = (
269-
templates.exclude(organization=instance.device.organization)
281+
templates.exclude(organization=organization)
270282
.exclude(organization=None)
271283
.values('name')
272284
)
@@ -287,17 +299,23 @@ def clean_templates_org(cls, action, instance, pk_set, **kwargs):
287299
return templates
288300

289301
@classmethod
290-
def enforce_required_templates(cls, action, instance, pk_set, **kwargs):
302+
def enforce_required_templates(
303+
cls, action, instance, pk_set, raw_data=None, **kwargs
304+
):
291305
"""
292306
This method is called from a django signal (m2m_changed),
293307
see config.apps.ConfigConfig.connect_signals.
294308
It raises a PermissionDenied if a required template
295309
is unassigned from a config.
296310
It adds back required templates on post_clear events
297311
(post-clear is used by sortedm2m to assign templates).
312+
313+
raw_data contains the non-validated data that is submitted to
314+
a form or API.
298315
"""
299316
if action not in ['pre_remove', 'post_clear']:
300317
return False
318+
raw_data = raw_data or {}
301319
template_query = models.Q(required=True, backend=instance.backend)
302320
# trying to remove a required template will raise PermissionDenied
303321
if action == 'pre_remove':
@@ -309,12 +327,12 @@ def enforce_required_templates(cls, action, instance, pk_set, **kwargs):
309327
if action == 'post_clear':
310328
# retrieve required templates related to this
311329
# device and ensure they're always present
330+
organization = raw_data.get('organization', instance.device.organization)
312331
required_templates = (
313332
cls.get_template_model()
314333
.objects.filter(template_query)
315334
.filter(
316-
models.Q(organization=instance.device.organization)
317-
| models.Q(organization=None)
335+
models.Q(organization=organization) | models.Q(organization=None)
318336
)
319337
)
320338
if required_templates.exists():

openwisp_controller/config/tests/test_admin.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,29 @@ def test_change_device_required_template(self):
551551
self.assertEqual(c.name, 'test-device-templates-cleared')
552552
self.assertTrue(c.templates.filter(pk=t.pk).exists())
553553

554+
def test_change_device_org_required_templates(self):
555+
org1 = self._create_org(name='org1')
556+
org2 = self._create_org(name='org2')
557+
template = self._create_template(organization=org1, config={'interfaces': []})
558+
device = self._create_device(organization=org1)
559+
config = self._create_config(device=device)
560+
path = reverse(f'admin:{self.app_label}_device_change', args=[device.pk])
561+
params = self._get_device_params(org=org1)
562+
params.update(
563+
{
564+
'name': 'test-device-changed',
565+
'config-0-id': str(config.pk),
566+
'config-0-device': str(device.pk),
567+
'config-0-templates': str(template.pk),
568+
'config-INITIAL_FORMS': 1,
569+
'organization': str(org2.pk),
570+
}
571+
)
572+
response = self.client.post(path, params, follow=True)
573+
self.assertEqual(response.status_code, 200)
574+
config.refresh_from_db()
575+
self.assertEqual(config.templates.count(), 0)
576+
554577
def test_download_device_config(self):
555578
d = self._create_device(name='download')
556579
self._create_config(device=d)

openwisp_controller/config/tests/test_api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,30 @@ def test_device_patch_with_templates_of_different_org(self):
265265
'''
266266
self.assertTrue(' '.join(validation_msg.split()) in error.exception.message)
267267

268+
def test_device_change_organization_required_templates(self):
269+
org1 = self._create_org(name='org1')
270+
org2 = self._create_org(name='org2')
271+
org1_template = self._create_template(
272+
name='org1-template', organization=org1, required=True
273+
)
274+
org2_template = self._create_template(
275+
name='org2-template', organization=org2, required=True
276+
)
277+
device = self._create_device(organization=org1)
278+
config = self._create_config(device=device)
279+
self.assertEqual(config.templates.count(), 1)
280+
self.assertEqual(config.templates.first(), org1_template)
281+
282+
path = reverse('config_api:device_detail', args=[device.pk])
283+
data = {'organization': org2.pk}
284+
response = self.client.patch(path, data=data, content_type='application/json')
285+
self.assertEqual(response.status_code, 200)
286+
device.refresh_from_db()
287+
config.refresh_from_db()
288+
self.assertEqual(device.organization, org2)
289+
self.assertEqual(config.templates.count(), 1)
290+
self.assertEqual(config.templates.first(), org2_template)
291+
268292
def test_device_patch_api(self):
269293
d1 = self._create_device(name='test-device')
270294
path = reverse('config_api:device_detail', args=[d1.pk])

0 commit comments

Comments
 (0)