Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: actions/checkout@v4
-
name: Run Labeler
uses: crazy-max/ghaction-github-labeler@31674a3852a9074f2086abcf1c53839d466a47e7
uses: crazy-max/ghaction-github-labeler@24d110aa46a59976b8a7f35518cb7f14f434c916
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
yaml-file: .github/labels.yml
Expand Down
4 changes: 2 additions & 2 deletions linode_api4/objects/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def entity(self):
)
return self.cls(self._client, self.id)

def _serialize(self):
def _serialize(self, *args, **kwargs):
"""
Returns this grant in as JSON the api will accept. This is only relevant
in the context of UserGrants.save
Expand Down Expand Up @@ -668,7 +668,7 @@ def _grants_dict(self):

return grants

def _serialize(self):
def _serialize(self, *args, **kwargs):
"""
Returns the user grants in as JSON the api will accept.
This is only relevant in the context of UserGrants.save
Expand Down
26 changes: 16 additions & 10 deletions linode_api4/objects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def _flatten_base_subclass(obj: "Base") -> Optional[Dict[str, Any]]:

@property
def dict(self):
return self._serialize()

def _serialize(self, is_put: bool = False) -> Dict[str, Any]:
result = vars(self).copy()
cls = type(self)

Expand All @@ -123,7 +126,7 @@ def dict(self):
elif isinstance(v, list):
result[k] = [
(
item.dict
item._serialize(is_put=is_put)
if isinstance(item, (cls, JSONObject))
else (
self._flatten_base_subclass(item)
Expand All @@ -136,7 +139,7 @@ def dict(self):
elif isinstance(v, Base):
result[k] = self._flatten_base_subclass(v)
elif isinstance(v, JSONObject):
result[k] = v.dict
result[k] = v._serialize(is_put=is_put)

return result

Expand Down Expand Up @@ -278,9 +281,9 @@ def save(self, force=True) -> bool:
data[key] = None

# Ensure we serialize any values that may not be already serialized
data = _flatten_request_body_recursive(data)
data = _flatten_request_body_recursive(data, is_put=True)
else:
data = self._serialize()
data = self._serialize(is_put=True)

resp = self._client.put(type(self).api_endpoint, model=self, data=data)

Expand Down Expand Up @@ -316,7 +319,7 @@ def invalidate(self):

self._set("_populated", False)

def _serialize(self):
def _serialize(self, is_put: bool = False):
"""
A helper method to build a dict of all mutable Properties of
this object
Expand Down Expand Up @@ -345,7 +348,7 @@ def _serialize(self):

# Resolve the underlying IDs of results
for k, v in result.items():
result[k] = _flatten_request_body_recursive(v)
result[k] = _flatten_request_body_recursive(v, is_put=is_put)

return result

Expand Down Expand Up @@ -503,7 +506,7 @@ def make_instance(cls, id, client, parent_id=None, json=None):
return Base.make(id, client, cls, parent_id=parent_id, json=json)


def _flatten_request_body_recursive(data: Any) -> Any:
def _flatten_request_body_recursive(data: Any, is_put: bool = False) -> Any:
"""
This is a helper recursively flatten the given data for use in an API request body.

Expand All @@ -515,15 +518,18 @@ def _flatten_request_body_recursive(data: Any) -> Any:
"""

if isinstance(data, dict):
return {k: _flatten_request_body_recursive(v) for k, v in data.items()}
return {
k: _flatten_request_body_recursive(v, is_put=is_put)
for k, v in data.items()
}

if isinstance(data, list):
return [_flatten_request_body_recursive(v) for v in data]
return [_flatten_request_body_recursive(v, is_put=is_put) for v in data]

if isinstance(data, Base):
return data.id

if isinstance(data, MappedObject) or issubclass(type(data), JSONObject):
return data.dict
return data._serialize(is_put=is_put)

return data
12 changes: 6 additions & 6 deletions linode_api4/objects/linode.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ class ConfigInterface(JSONObject):
def __repr__(self):
return f"Interface: {self.purpose}"

def _serialize(self):
def _serialize(self, *args, **kwargs):
purpose_formats = {
"public": {"purpose": "public", "primary": self.primary},
"vlan": {
Expand Down Expand Up @@ -510,16 +510,16 @@ def _populate(self, json):

self._set("devices", MappedObject(**devices))

def _serialize(self):
def _serialize(self, is_put: bool = False):
"""
Overrides _serialize to transform interfaces into json
"""
partial = DerivedBase._serialize(self)
partial = DerivedBase._serialize(self, is_put=is_put)
interfaces = []

for c in self.interfaces:
if isinstance(c, ConfigInterface):
interfaces.append(c._serialize())
interfaces.append(c._serialize(is_put=is_put))
else:
interfaces.append(c)

Expand Down Expand Up @@ -1927,8 +1927,8 @@ def _populate(self, json):
ndist = [Image(self._client, d) for d in self.images]
self._set("images", ndist)

def _serialize(self):
dct = Base._serialize(self)
def _serialize(self, is_put: bool = False):
dct = Base._serialize(self, is_put=is_put)
dct["images"] = [d.id for d in self.images]
return dct

Expand Down
24 changes: 21 additions & 3 deletions linode_api4/objects/serializable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from dataclasses import dataclass
from dataclasses import dataclass, fields
from enum import Enum
from types import SimpleNamespace
from typing import (
Expand All @@ -9,6 +9,7 @@
List,
Optional,
Set,
Type,
Union,
get_args,
get_origin,
Expand Down Expand Up @@ -71,6 +72,13 @@ class JSONObject(metaclass=JSONFilterableMetaclass):
are None.
"""

put_class: ClassVar[Optional[Type["JSONObject"]]] = None
"""
An alternative JSONObject class to use as the schema for PUT requests.
This prevents read-only fields from being included in PUT request bodies,
which in theory will result in validation errors from the API.
"""

def __init__(self):
raise NotImplementedError(
"JSONObject is not intended to be constructed directly"
Expand Down Expand Up @@ -154,19 +162,25 @@ def from_json(cls, json: Dict[str, Any]) -> Optional["JSONObject"]:

return obj

def _serialize(self) -> Dict[str, Any]:
def _serialize(self, is_put: bool = False) -> Dict[str, Any]:
"""
Serializes this object into a JSON dict.
"""
cls = type(self)

if is_put and cls.put_class is not None:
cls = cls.put_class

cls_field_keys = {field.name for field in fields(cls)}

type_hints = get_type_hints(cls)

def attempt_serialize(value: Any) -> Any:
"""
Attempts to serialize the given value, else returns the value unchanged.
"""
if issubclass(type(value), JSONObject):
return value._serialize()
return value._serialize(is_put=is_put)

return value

Expand All @@ -175,6 +189,10 @@ def should_include(key: str, value: Any) -> bool:
Returns whether the given key/value pair should be included in the resulting dict.
"""

# During PUT operations, keys not present in the put_class should be excluded
if key not in cls_field_keys:
return False

if cls.include_none_values or key in cls.always_include:
return True

Expand Down
1 change: 1 addition & 0 deletions test/integration/login_client/test_login_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_linode_login_client_generate_login_url_with_scope(linode_login_client):
assert "scopes=linodes%3Aread_write" in url


@pytest.mark.skip("Endpoint may be deprecated")
def test_linode_login_client_expire_token(
linode_login_client, test_oauth_client
):
Expand Down
2 changes: 0 additions & 2 deletions test/integration/models/domain/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ def test_save_null_values_excluded(test_linode_client, test_domain):
domain.master_ips = ["127.0.0.1"]
res = domain.save()

assert res


def test_zone_file_view(test_linode_client, test_domain):
domain = test_linode_client.load(Domain, test_domain.id)
Expand Down
12 changes: 6 additions & 6 deletions test/integration/models/linode/test_linode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import pytest

from linode_api4 import VPCIPAddress
from linode_api4.errors import ApiError
from linode_api4.objects import (
Config,
Expand Down Expand Up @@ -181,7 +180,7 @@ def create_linode_for_long_running_tests(test_linode_client, e2e_test_firewall):
def linode_with_disk_encryption(test_linode_client, request):
client = test_linode_client

target_region = get_region(client, {"Disk Encryption"})
target_region = get_region(client, {"LA Disk Encryption"})
label = get_test_label(length=8)

disk_encryption = request.param
Expand Down Expand Up @@ -236,7 +235,7 @@ def test_linode_transfer(test_linode_client, linode_with_volume_firewall):
def test_linode_rebuild(test_linode_client):
client = test_linode_client

region = get_region(client, {"Disk Encryption"})
region = get_region(client, {"LA Disk Encryption"})

label = get_test_label() + "_rebuild"

Expand Down Expand Up @@ -535,6 +534,7 @@ def test_linode_create_disk(test_linode_client, linode_for_disk_tests):
assert disk.linode_id == linode.id


@pytest.mark.flaky(reruns=3, reruns_delay=2)
def test_linode_instance_password(create_linode_for_pass_reset):
linode = create_linode_for_pass_reset[0]
password = create_linode_for_pass_reset[1]
Expand Down Expand Up @@ -775,10 +775,10 @@ def test_create_vpc(
assert vpc_range_ip.address_range == "10.0.0.5/32"
assert not vpc_range_ip.active

# TODO:: Add `VPCIPAddress.filters.linode_id == linode.id` filter back

# Attempt to resolve the IP from /vpcs/ips
all_vpc_ips = test_linode_client.vpcs.ips(
VPCIPAddress.filters.linode_id == linode.id
)
all_vpc_ips = test_linode_client.vpcs.ips()
assert all_vpc_ips[0].dict == vpc_ip.dict

# Test getting the ips under this specific VPC
Expand Down
12 changes: 8 additions & 4 deletions test/integration/models/lke/test_lke.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def lke_cluster(test_linode_client):
node_type = test_linode_client.linode.types()[1] # g6-standard-1
version = test_linode_client.lke.versions()[0]

region = get_region(test_linode_client, {"Kubernetes", "Disk Encryption"})
region = get_region(
test_linode_client, {"Kubernetes", "LA Disk Encryption"}
)

node_pools = test_linode_client.lke.node_pool(node_type, 3)
label = get_test_label() + "_cluster"
Expand Down Expand Up @@ -115,7 +117,9 @@ def lke_cluster_with_labels_and_taints(test_linode_client):
def lke_cluster_with_apl(test_linode_client):
version = test_linode_client.lke.versions()[0]

region = get_region(test_linode_client, {"Kubernetes", "Disk Encryption"})
region = get_region(
test_linode_client, {"Kubernetes", "LA Disk Encryption"}
)

# NOTE: g6-dedicated-4 is the minimum APL-compatible Linode type
node_pools = test_linode_client.lke.node_pool("g6-dedicated-4", 3)
Expand Down Expand Up @@ -145,7 +149,7 @@ def lke_cluster_enterprise(test_linode_client):
)[0]

region = get_region(
test_linode_client, {"Kubernetes Enterprise", "Disk Encryption"}
test_linode_client, {"Kubernetes Enterprise", "LA Disk Encryption"}
)

node_pools = test_linode_client.lke.node_pool(
Expand Down Expand Up @@ -204,7 +208,7 @@ def _to_comparable(p: LKENodePool) -> Dict[str, Any]:

assert _to_comparable(cluster.pools[0]) == _to_comparable(pool)

assert pool.disk_encryption == InstanceDiskEncryptionType.enabled
assert pool.disk_encryption == InstanceDiskEncryptionType.disabled


def test_cluster_dashboard_url_view(lke_cluster):
Expand Down
55 changes: 54 additions & 1 deletion test/unit/objects/serializable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from test.unit.base import ClientBaseCase
from typing import Optional

from linode_api4 import JSONObject
from linode_api4 import Base, JSONObject, Property


class JSONObjectTest(ClientBaseCase):
Expand Down Expand Up @@ -47,3 +47,56 @@ class Foo(JSONObject):
assert foo["foo"] == "test"
assert foo["bar"] == "test2"
assert foo["baz"] == "test3"

def test_serialize_put_class(self):
"""
Ensures that the JSONObject put_class ClassVar functions as expected.
"""

@dataclass
class SubStructOptions(JSONObject):
test1: Optional[str] = None

@dataclass
class SubStruct(JSONObject):
put_class = SubStructOptions

test1: str = ""
test2: int = 0

class Model(Base):
api_endpoint = "/foo/bar"

properties = {
"id": Property(identifier=True),
"substruct": Property(mutable=True, json_object=SubStruct),
}

mock_response = {
"id": 123,
"substruct": {
"test1": "abc",
"test2": 321,
},
}

with self.mock_get(mock_response) as mock:
obj = self.client.load(Model, 123)

assert mock.called

assert obj.id == 123
assert obj.substruct.test1 == "abc"
assert obj.substruct.test2 == 321

obj.substruct.test1 = "cba"

with self.mock_put(mock_response) as mock:
obj.save()

assert mock.called
assert mock.call_data == {
"substruct": {
"test1": "cba",
}
}
Loading