diff --git a/cos_registration_server/api/schema_status.py b/cos_registration_server/api/schema_status.py index 9a58fd2..bd3572c 100644 --- a/cos_registration_server/api/schema_status.py +++ b/cos_registration_server/api/schema_status.py @@ -13,8 +13,28 @@ code_200_device = {200: DeviceSerializer} code_201_device = {201: DeviceSerializer} +code_200_device_certificate = { + 200: { + "type": "object", + "properties": { + "certificate": { + "type": "string", + "description": "PEM-encoded certificate", + }, + "private_key": { + "type": "string", + "description": "PEM-encoded private key", + }, + }, + } +} + code_404_uid_not_found = {404: OpenApiResponse(description="UID not found")} +code_404_device_certificate_not_found = { + 404: OpenApiResponse(description="Device certificate not found") +} + code_200_grafana_dashboard = {200: GrafanaDashboardSerializer} code_201_grafana_dashboard = {201: GrafanaDashboardSerializer} diff --git a/cos_registration_server/api/tests.py b/cos_registration_server/api/tests.py index ebabae3..ed324fd 100644 --- a/cos_registration_server/api/tests.py +++ b/cos_registration_server/api/tests.py @@ -1,6 +1,7 @@ import json from datetime import datetime, timedelta from typing import Any, Dict, Set, Union +from unittest.mock import Mock, patch import yaml from applications.models import ( @@ -523,6 +524,60 @@ def test_delete_device(self) -> None: self.assertEqual(response.status_code, 404) +class DeviceCertificateViewTests(APITestCase): + def setUp(self) -> None: + self.device_uid = "robot-123" + self.device_address = "192.168.0.10" + self.url = reverse( + "api:device-certificate", kwargs={"uid": self.device_uid} + ) + + def create_device(self, **fields: Union[str, Set[str]]) -> HttpResponse: + data = {} + for field, value in fields.items(): + data[field] = value + url = reverse("api:devices") + return self.client.post(url, data, format="json") + + def test_get_certificate_success(self) -> None: + self.create_device(uid=self.device_uid, address=self.device_address) + + response = self.client.get(self.url) + + self.assertEqual(response.status_code, 200) + data = json.loads(response.content) + self.assertIn("certificate", data) + self.assertIn("private_key", data) + + def test_get_certificate_device_not_found(self) -> None: + response = self.client.get(self.url) + + self.assertEqual(response.status_code, 404) + self.assertIn("Device does not exist", str(response.data)) + + def test_get_certificate_no_address(self) -> None: + self.create_device(uid=self.device_uid, address="") + + response = self.client.get(self.url) + self.assertEqual(response.status_code, 404) + self.assertIn("Device does not exist", str(response.data)) + + @patch("api.views.generate_tls_certificate") + def test_get_certificate_missing_cert_data( + self, mock_generate: Mock + ) -> None: + self.create_device(uid=self.device_uid, address=self.device_address) + + mock_generate.return_value = {"certificate": "", "private_key": ""} + + response = self.client.get(self.url) + + self.assertEqual(response.status_code, 404) + self.assertIn( + "Certificate data for device not found", str(response.data) + ) + + class GrafanaDashboardsViewTests(APITestCase): def setUp(self) -> None: self.url = reverse("api:grafana_dashboards") diff --git a/cos_registration_server/api/urls.py b/cos_registration_server/api/urls.py index eaa06f7..22b0c1b 100644 --- a/cos_registration_server/api/urls.py +++ b/cos_registration_server/api/urls.py @@ -18,6 +18,11 @@ path("v1/health/", views.HealthView.as_view(), name="health"), path("v1/devices/", views.DevicesView.as_view(), name="devices"), path("v1/devices//", views.DeviceView.as_view(), name="device"), + path( + "v1/devices//certificate", + views.DeviceCertificateView.as_view(), + name="device-certificate", + ), path( "v1/applications/grafana/dashboards/", views.GrafanaDashboardsView.as_view(), diff --git a/cos_registration_server/api/utils.py b/cos_registration_server/api/utils.py new file mode 100644 index 0000000..9f06bb9 --- /dev/null +++ b/cos_registration_server/api/utils.py @@ -0,0 +1,52 @@ +"""API utils.""" + +import ipaddress +from datetime import datetime, timedelta + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + + +def generate_tls_certificate( + device_uid: str, device_ip: str +) -> dict[str, str]: + """Generate a self-signed TLS certificate with device IP in SAN. + + device_uid: the uid of the device. + device_ip: the ip of the device. + """ + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048 + ) + + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, device_uid), + ] + ) + + san_list = [x509.IPAddress(ipaddress.ip_address(device_ip))] + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.utcnow()) + .not_valid_after(datetime.utcnow() + timedelta(days=365)) + .add_extension(x509.SubjectAlternativeName(san_list), critical=False) + .sign(private_key, hashes.SHA256()) + ) + + private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM).decode() + + return {"certificate": cert_pem, "private_key": private_key_pem} diff --git a/cos_registration_server/api/views.py b/cos_registration_server/api/views.py index 9ded956..9c8618a 100644 --- a/cos_registration_server/api/views.py +++ b/cos_registration_server/api/views.py @@ -38,6 +38,8 @@ from rest_framework.response import Response from rest_framework.views import APIView +from .utils import generate_tls_certificate + class HealthView(APIView): """Health API view.""" @@ -161,6 +163,50 @@ def delete( return super().delete(request, *args, **kwargs) +class DeviceCertificateView(APIView): + """Device Certificate API view.""" + + @extend_schema( + summary="Generate and get a device TLS certificate", + description="Generate and retrieve TLS certificate and " + "private key for a device by UID", + responses={ + **status.code_200_device_certificate, + **status.code_404_uid_not_found, + **status.code_404_device_certificate_not_found, + }, + ) + def get( + self, + request: Request, + uid: str, + *args: Tuple[Any], + **kwargs: Dict[str, Any], + ) -> Response: + """GET a device TLS certificate and private key.""" + try: + device = Device.objects.get(uid=uid) + except Device.DoesNotExist: + raise NotFound("Device does not exist") + + # If the device exists it will have an address, + # the serializer does not allow the creation of a device + # without an ip address. + cert_data = generate_tls_certificate(device.uid, device.address) + + if not cert_data.get("certificate") or not cert_data.get( + "private_key" + ): + raise NotFound("Certificate data for device not found") + + return Response( + { + "certificate": cert_data["certificate"], + "private_key": cert_data["private_key"], + } + ) + + class GrafanaDashboardsView(ListCreateAPIView): # type: ignore[type-arg] """GrafanaDashboards API view.""" diff --git a/cos_registration_server/openapi.yaml b/cos_registration_server/openapi.yaml index 39109e4..425364f 100644 --- a/cos_registration_server/openapi.yaml +++ b/cos_registration_server/openapi.yaml @@ -938,6 +938,38 @@ paths: description: '' '404': description: UID not found + /api/v1/devices/{uid}/certificate: + get: + operationId: devices_certificate_retrieve + description: Generate and retrieve TLS certificate and private key for a device + by UID + summary: Generate and get a device TLS certificate + parameters: + - in: path + name: uid + schema: + type: string + required: true + tags: + - devices + security: + - {} + responses: + '200': + content: + application/json: + schema: + type: object + properties: + certificate: + type: string + description: PEM-encoded certificate + private_key: + type: string + description: PEM-encoded private key + description: '' + '404': + description: Device certificate not found /api/v1/health/: get: operationId: health_retrieve diff --git a/requirements.txt b/requirements.txt index 47fd032..48d8202 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ whitenoise tzdata gunicorn pyyaml +cryptography