Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cos_registration_server/api/schema_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,29 @@
code_200_device = {200: DeviceSerializer}
code_201_device = {201: DeviceSerializer}

code_200_device_certificate = {
200: {
"type": "object",
"properties": {
"certificate": {
"type": "string",
"description": "PEM-encoded certificate",
},
"private_key": {
"type": "string",
"description": "PEM-encoded private key",
},
},
}
}

code_404_uid_not_found = {404: OpenApiResponse(description="UID not found")}
code_404_device_address_not_found = {
404: OpenApiResponse(description="device address not found")
}
code_404_device_certificate_not_found = {
404: OpenApiResponse(description="Device certficate not found")
}

code_200_grafana_dashboard = {200: GrafanaDashboardSerializer}
code_201_grafana_dashboard = {201: GrafanaDashboardSerializer}
Expand Down
52 changes: 52 additions & 0 deletions cos_registration_server/api/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from datetime import datetime, timedelta
from typing import Any, Dict, Set, Union
from unittest.mock import Mock, patch

import yaml
from applications.models import (
Expand Down Expand Up @@ -523,6 +524,57 @@ def test_delete_device(self) -> None:
self.assertEqual(response.status_code, 404)


class DeviceCertificateViewTests(APITestCase):
def setUp(self) -> None:
self.device_uid = "robot-123"
self.device_address = "192.168.0.10"
self.url = reverse(
"api:device-certificate", kwargs={"uid": self.device_uid}
)

def create_device(self, **fields: Union[str, Set[str]]) -> HttpResponse:
data = {}
for field, value in fields.items():
data[field] = value
url = reverse("api:devices")
return self.client.post(url, data, format="json")

def test_get_certificate_success(self) -> None:
self.create_device(uid=self.device_uid, address=self.device_address)

response = self.client.get(self.url)

self.assertEqual(response.status_code, 200)
data = json.loads(response.content)
self.assertIn("certificate", data)
self.assertIn("private_key", data)

def test_get_certificate_device_not_found(self) -> None:
response = self.client.get(self.url)

self.assertEqual(response.status_code, 404)

def test_get_certificate_no_address(self) -> None:
self.create_device(uid=self.device_uid, address="")
response = self.client.get(self.url)
self.assertEqual(response.status_code, 404)

@patch("api.views.generate_tls_certificate")
def test_get_certificate_missing_cert_data(
self, mock_generate: Mock
) -> None:
self.create_device(uid=self.device_uid, address=self.device_address)

mock_generate.return_value = {"certificate": "", "private_key": ""}

response = self.client.get(self.url)

self.assertEqual(response.status_code, 404)
self.assertIn(
"Certificate data for device not found", str(response.data)
)


class GrafanaDashboardsViewTests(APITestCase):
def setUp(self) -> None:
self.url = reverse("api:grafana_dashboards")
Expand Down
5 changes: 5 additions & 0 deletions cos_registration_server/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
path("v1/health/", views.HealthView.as_view(), name="health"),
path("v1/devices/", views.DevicesView.as_view(), name="devices"),
path("v1/devices/<str:uid>/", views.DeviceView.as_view(), name="device"),
path(
"v1/devices/<str:uid>/certificate",
views.DeviceCertificateView.as_view(),
name="device-certificate",
),
path(
"v1/applications/grafana/dashboards/",
views.GrafanaDashboardsView.as_view(),
Expand Down
52 changes: 52 additions & 0 deletions cos_registration_server/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""API utils."""

import ipaddress
from datetime import datetime, timedelta

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID


def generate_tls_certificate(
device_uid: str, device_ip: str
) -> dict[str, str]:
"""Generate a self-signed TLS certificate with device IP in SAN.

device_uid: the uid of the device.
device_ip: the ip of the device.
"""
private_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)

subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, device_uid),
]
)

san_list = [x509.IPAddress(ipaddress.ip_address(device_ip))]

cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.utcnow())
.not_valid_after(datetime.utcnow() + timedelta(days=365))
.add_extension(x509.SubjectAlternativeName(san_list), critical=False)
.sign(private_key, hashes.SHA256())
)

private_key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
).decode()

cert_pem = cert.public_bytes(encoding=serialization.Encoding.PEM).decode()

return {"certificate": cert_pem, "private_key": private_key_pem}
47 changes: 47 additions & 0 deletions cos_registration_server/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from rest_framework.response import Response
from rest_framework.views import APIView

from .utils import generate_tls_certificate


class HealthView(APIView):
"""Health API view."""
Expand Down Expand Up @@ -161,6 +163,51 @@ def delete(
return super().delete(request, *args, **kwargs)


class DeviceCertificateView(APIView):
"""Device Certificate API view."""

@extend_schema(
summary="Get a device TLS certificate",
description="Retrieve TLS certificate and "
"private key for a device by UID",
responses={
**status.code_200_device_certificate,
**status.code_404_uid_not_found,
**status.code_404_device_address_not_found,
**status.code_404_device_certificate_not_found,
},
)
def get(
self,
request: Request,
uid: str,
*args: Tuple[Any],
**kwargs: Dict[str, Any],
) -> Response:
"""GET a device TLS certificate and private key."""
try:
device = Device.objects.get(uid=uid)
except Device.DoesNotExist:
raise NotFound("Device does not exist")

if not device.address:
raise NotFound("Device does not have an ip address")

cert_data = generate_tls_certificate(device.uid, device.address)

if not cert_data.get("certificate") or not cert_data.get(
"private_key"
):
raise NotFound("Certificate data for device not found")

return Response(
{
"certificate": cert_data["certificate"],
"private_key": cert_data["private_key"],
}
)


class GrafanaDashboardsView(ListCreateAPIView): # type: ignore[type-arg]
"""GrafanaDashboards API view."""

Expand Down
31 changes: 31 additions & 0 deletions cos_registration_server/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,37 @@ paths:
description: ''
'404':
description: UID not found
/api/v1/devices/{uid}/certificate:
get:
operationId: devices_certificate_retrieve
description: Retrieve TLS certificate and private key for a device by UID
summary: Get a device TLS certificate
parameters:
- in: path
name: uid
schema:
type: string
required: true
tags:
- devices
security:
- {}
responses:
'200':
content:
application/json:
schema:
type: object
properties:
certificate:
type: string
description: PEM-encoded certificate
private_key:
type: string
description: PEM-encoded private key
description: ''
'404':
description: Device certficate not found
/api/v1/health/:
get:
operationId: health_retrieve
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ whitenoise
tzdata
gunicorn
pyyaml
cryptography
Loading