Skip to content

Added a format_keycloak flag to build_rs256_token helper #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

## v2.1.1 - 2024-10-23

- Added `format_keycloak` to `build_rs256_token()` helper

## v2.1.0 - 2024-10-13

- Refactored pemission claim mapping
Expand Down
45 changes: 42 additions & 3 deletions armasec/pytest_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

from collections import namedtuple
from contextlib import _GeneratorContextManager, contextmanager
from datetime import datetime
from datetime import datetime, timezone
from typing import Callable, Optional
from uuid import uuid4

import httpx
import pytest
Expand Down Expand Up @@ -165,24 +166,62 @@ def build_rs256_token(rs256_private_key, rs256_iss, rs256_sub, rs256_kid):
base_claims = dict(
iss=rs256_iss,
sub=rs256_sub,
permissions=[],
)
base_headers = dict(kid=rs256_kid)

def _helper(
claim_overrides: Optional[dict] = None,
headers_overrides: Optional[dict] = None,
format_keycloak: bool = False,
):
"""
Encode a jwt token with the default claims and headers overridden with user supplied values.

Args:
claim_overrides: A dictionary of claims to add to the token.
Will override any existing values upon collision
header_overrides: A dictionary of headers to add to the token.
Will override any existing values upon collision
format_keycloak: If set, will remap "permissions" provided as a part of the
claim_overrides to the expected position for keycloak. If an "azp"
claim is not provided in the claim_overrides, it will generate
a random test client_id in the "azp" claim that matches the keycloak
structure. Example keycloak structure:

```
{
"exp": 1728627701,
"iat": 1728626801,
"jti": "24fdb7ef-d773-4e6b-982a-b8126dd58af7",
"sub": "dfa64115-40b5-46ab-924c-c376e73f631d",
"azp": "my-client",
"resource_access": {
"my-client": {
"roles": [
"read:stuff"
]
},
},
}
```

"""
if claim_overrides is None:
claim_overrides = dict()

if headers_overrides is None:
headers_overrides = dict()

now = int(datetime.utcnow().timestamp())
now = int(datetime.now(timezone.utc).timestamp())

if format_keycloak and "permissions" in claim_overrides:
test_client = claim_overrides.get("azp", f"test-client-{uuid4()}")
claim_overrides["azp"] = test_client
claim_overrides["resource_access"] = {
test_client: {
"roles": claim_overrides.pop("permissions"),
}
}

return jwt.encode(
{
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "armasec"
version = "2.1.0"
version = "2.1.1"
description = "Injectable FastAPI auth via OIDC"
authors = ["Omnivector Engineering Team <info@omnivector.solutions>"]
license = "MIT"
Expand Down
32 changes: 32 additions & 0 deletions tests/test_token_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,38 @@ def test_decode__success(rs256_jwk, build_rs256_token):
assert token_payload.original_token == token


def test_decode__with_token_built_with_format_keycloak(rs256_jwk, build_rs256_token):
"""
Verify that an RS256Decoder can successfully decode a valid jwt that was encoded with the
format_keycloak flag.
"""
decoder = TokenDecoder(
JWKs(keys=[rs256_jwk]),
permission_extractor=extract_keycloak_permissions,
)
token = build_rs256_token(
claim_overrides=dict(
permissions=["read:stuff", "write:stuff"],
),
format_keycloak=True,
)
token_payload = decoder.decode(token)
assert token_payload.permissions == ["read:stuff", "write:stuff"]
assert token_payload.client_id is not None
assert token_payload.client_id.startswith("test-client")

token = build_rs256_token(
claim_overrides=dict(
azp="my-client-id",
permissions=["read:stuff", "write:stuff"],
),
format_keycloak=True,
)
token_payload = decoder.decode(token)
assert token_payload.permissions == ["read:stuff", "write:stuff"]
assert token_payload.client_id == "my-client-id"


def test_decode__fails_when_jwt_decode_throws_an_error(rs256_jwk):
"""
This test verifies that the ``decode()`` raises an exception with a helpful message when it
Expand Down
Loading