Skip to content

Commit 1c42af1

Browse files
pb8ozulinx86
authored andcommitted
test: revamp snapshot support to Microvm
Add some abstractions and functions around the snapshotting process. This is something that otherwise gets repeated over and over in integration tests, and makes debugging tests hard. Cleaning up existing tests is left over to the next changes in the series. Signed-off-by: Pablo Barbáchano <pablob@amazon.com>
1 parent 9515f0e commit 1c42af1

File tree

8 files changed

+280
-355
lines changed

8 files changed

+280
-355
lines changed

tests/framework/microvm.py

Lines changed: 174 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
34
"""Classes for working with microVMs.
45
56
This module defines `Microvm`, which can be used to create, test drive, and
@@ -8,6 +9,8 @@
89
- Use the Firecracker Open API spec to populate Microvm API resource URLs.
910
"""
1011

12+
# pylint:disable=too-many-lines
13+
1114
import json
1215
import logging
1316
import os
@@ -18,6 +21,8 @@
1821
import uuid
1922
import weakref
2023
from collections import namedtuple
24+
from dataclasses import dataclass
25+
from enum import Enum
2126
from functools import lru_cache
2227
from pathlib import Path
2328
from threading import Lock
@@ -59,6 +64,96 @@
5964
data_lock = Lock()
6065

6166

67+
class SnapshotType(Enum):
68+
"""Supported snapshot types."""
69+
70+
FULL = "FULL"
71+
DIFF = "DIFF"
72+
73+
def __repr__(self):
74+
cls_name = self.__class__.__name__
75+
return f"{cls_name}.{self.name}"
76+
77+
78+
def hardlink_or_copy(src, dst):
79+
"""If src and dst are in the same device, hardlink. Otherwise, copy."""
80+
dst.touch(exist_ok=False)
81+
if dst.stat().st_dev == src.stat().st_dev:
82+
dst.unlink()
83+
dst.hardlink_to(src)
84+
else:
85+
shutil.copyfile(src, dst)
86+
87+
88+
@dataclass(frozen=True, repr=True)
89+
class Snapshot:
90+
"""A Firecracker snapshot"""
91+
92+
vmstate: Path
93+
mem: Path
94+
net_ifaces: list
95+
disks: dict
96+
ssh_key: Path
97+
snapshot_type: str
98+
99+
@property
100+
def is_diff(self) -> bool:
101+
"""Is this a DIFF snapshot?"""
102+
return self.snapshot_type == SnapshotType.DIFF
103+
104+
def rebase_snapshot(self, base):
105+
"""Rebases current incremental snapshot onto a specified base layer."""
106+
if not self.is_diff:
107+
raise ValueError("Can only rebase DIFF snapshots")
108+
build_tools.run_rebase_snap_bin(base.mem, self.mem)
109+
new_args = self.__dict__ | {"mem": base.mem}
110+
return Snapshot(**new_args)
111+
112+
@classmethod
113+
# TBD when Python 3.11: -> Self
114+
def load_from(cls, src: Path) -> "Snapshot":
115+
"""Load a snapshot saved with `save_to`"""
116+
snap_json = src / "snapshot.json"
117+
obj = json.loads(snap_json.read_text())
118+
return cls(
119+
vmstate=src / obj["vmstate"],
120+
mem=src / obj["mem"],
121+
net_ifaces=[NetIfaceConfig(**d) for d in obj["net_ifaces"]],
122+
disks={dsk: src / p for dsk, p in obj["disks"].items()},
123+
ssh_key=src / obj["ssh_key"],
124+
snapshot_type=obj["snapshot_type"],
125+
)
126+
127+
def save_to(self, dst: Path):
128+
"""Serialize snapshot details to `dst`
129+
130+
Deserialize the snapshot with `load_from`
131+
"""
132+
for path in [self.vmstate, self.mem, self.ssh_key]:
133+
new_path = dst / path.name
134+
hardlink_or_copy(path, new_path)
135+
new_disks = {}
136+
for disk_id, path in self.disks.items():
137+
new_path = dst / path.name
138+
hardlink_or_copy(path, new_path)
139+
new_disks[disk_id] = new_path.name
140+
obj = {
141+
"vmstate": self.vmstate.name,
142+
"mem": self.mem.name,
143+
"net_ifaces": [x.__dict__ for x in self.net_ifaces],
144+
"disks": new_disks,
145+
"ssh_key": self.ssh_key.name,
146+
"snapshot_type": self.snapshot_type,
147+
}
148+
snap_json = dst / "snapshot.json"
149+
snap_json.write_text(json.dumps(obj))
150+
151+
def delete(self):
152+
"""Delete the backing files from disk."""
153+
self.mem.unlink()
154+
self.vmstate.unlink()
155+
156+
62157
# pylint: disable=R0904
63158
class Microvm:
64159
"""Class to represent a Firecracker microvm.
@@ -82,7 +177,6 @@ def __init__(
82177
monitor_memory=True,
83178
bin_cloner_path=None,
84179
):
85-
# pylint: disable=too-many-statements
86180
"""Set up microVM attributes, paths, and data structures."""
87181
# pylint: disable=too-many-statements
88182
# Unique identifier for this machine.
@@ -750,60 +844,114 @@ def start(self, check=True):
750844
except KeyError:
751845
assert self.started is True
752846

847+
def pause(self):
848+
"""Pauses the microVM"""
849+
response = self.vm.patch(state="Paused")
850+
assert self.api_session.is_status_no_content(response.status_code)
851+
852+
def resume(self):
853+
"""Resume the microVM"""
854+
response = self.vm.patch(state="Resumed")
855+
assert self.api_session.is_status_no_content(response.status_code)
856+
753857
def pause_to_snapshot(
754-
self, mem_file_path=None, snapshot_path=None, diff=False, version=None
858+
self,
859+
mem_file_path,
860+
snapshot_path,
861+
diff=False,
862+
version=None,
755863
):
756864
"""Pauses the microVM, and creates snapshot.
757865
758866
This function validates that the microVM pauses successfully and
759867
creates a snapshot.
760868
"""
761-
assert mem_file_path is not None, "Please specify mem_file_path."
762-
assert snapshot_path is not None, "Please specify snapshot_path."
763-
764-
response = self.vm.patch(state="Paused")
765-
assert self.api_session.is_status_no_content(response.status_code)
869+
self.pause()
766870

767871
response = self.snapshot.create(
768-
mem_file_path=mem_file_path,
769-
snapshot_path=snapshot_path,
872+
mem_file_path=str(mem_file_path),
873+
snapshot_path=str(snapshot_path),
770874
diff=diff,
771875
version=version,
772876
)
773877
assert self.api_session.is_status_no_content(
774878
response.status_code
775879
), response.text
776880

881+
def make_snapshot(self, snapshot_type: str, target_version: str = None):
882+
"""Create a Snapshot object from a microvm."""
883+
vmstate_path = "vmstate"
884+
mem_path = "mem"
885+
self.pause_to_snapshot(
886+
mem_file_path=mem_path,
887+
snapshot_path=vmstate_path,
888+
diff=snapshot_type == "DIFF",
889+
version=target_version,
890+
)
891+
root = Path(self.chroot())
892+
return Snapshot(
893+
vmstate=root / vmstate_path,
894+
mem=root / mem_path,
895+
disks=self.disks,
896+
net_ifaces=[x["iface"] for ifname, x in self.iface.items()],
897+
ssh_key=self.ssh_key,
898+
snapshot_type=snapshot_type,
899+
)
900+
901+
def snapshot_diff(self, target_version: str = None):
902+
"""Make a DIFF snapshot"""
903+
return self.make_snapshot("DIFF", target_version)
904+
905+
def snapshot_full(self, target_version: str = None):
906+
"""Make a FULL snapshot"""
907+
return self.make_snapshot("FULL", target_version)
908+
777909
def restore_from_snapshot(
778910
self,
779-
*,
780-
snapshot_mem: Path,
781-
snapshot_vmstate: Path,
782-
snapshot_disks: list[Path],
783-
snapshot_is_diff: bool = False,
911+
snapshot: Snapshot,
912+
resume: bool = False,
913+
uffd_path: Path = None,
784914
):
785-
"""
786-
Restores a snapshot, and resumes the microvm
787-
"""
788-
789-
# Hardlink all the snapshot files into the microvm jail.
790-
jailed_mem = self.create_jailed_resource(snapshot_mem)
791-
jailed_vmstate = self.create_jailed_resource(snapshot_vmstate)
792-
915+
"""Restore a snapshot"""
916+
# Move all the snapshot files into the microvm jail.
917+
# Use different names so a snapshot doesn't overwrite our original snapshot.
918+
chroot = Path(self.chroot())
919+
mem_src = chroot / snapshot.mem.with_suffix(".src").name
920+
hardlink_or_copy(snapshot.mem, mem_src)
921+
vmstate_src = chroot / snapshot.vmstate.with_suffix(".src").name
922+
hardlink_or_copy(snapshot.vmstate, vmstate_src)
923+
jailed_mem = Path("/") / mem_src.name
924+
jailed_vmstate = Path("/") / vmstate_src.name
925+
926+
snapshot_disks = [v for k, v in snapshot.disks.items()]
793927
assert len(snapshot_disks) > 0, "Snapshot requires at least one disk."
794928
jailed_disks = []
795929
for disk in snapshot_disks:
796930
jailed_disks.append(self.create_jailed_resource(disk))
931+
self.disks = snapshot.disks
932+
self.ssh_key = snapshot.ssh_key
933+
934+
# Create network interfaces.
935+
for iface in snapshot.net_ifaces:
936+
self.add_net_iface(iface, api=False)
937+
938+
mem_backend = {"type": "File", "path": str(jailed_mem)}
939+
if uffd_path is not None:
940+
mem_backend = {"type": "Uffd", "path": str(uffd_path)}
797941

798942
response = self.snapshot.load(
799-
mem_file_path=jailed_mem,
800-
snapshot_path=jailed_vmstate,
801-
diff=snapshot_is_diff,
802-
resume=True,
943+
mem_backend=mem_backend,
944+
snapshot_path=str(jailed_vmstate),
945+
diff=snapshot.is_diff,
946+
resume=resume,
803947
)
804948
assert response.ok, response.content
805949
return True
806950

951+
def restore_from_path(self, snap_dir: Path, **kwargs):
952+
"""Restore snapshot from a path"""
953+
return self.restore_from_snapshot(Snapshot.load_from(snap_dir), **kwargs)
954+
807955
@lru_cache
808956
def ssh_iface(self, iface_idx=0):
809957
"""Return a cached SSH connection on a given interface id."""

tests/framework/resources.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,12 @@ def create_json(
432432
"snapshot_path": snapshot_path,
433433
}
434434
else:
435+
backend_type = mem_backend["type"]
436+
if not isinstance(backend_type, str):
437+
backend_type = backend_type.value
435438
datax = {
436439
"mem_backend": {
437-
"backend_type": str(mem_backend["type"].value),
440+
"backend_type": backend_type,
438441
"backend_path": mem_backend["path"],
439442
},
440443
"snapshot_path": snapshot_path,

tests/integration_tests/functional/test_balloon.py

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pytest
99
from retry import retry
1010

11-
from framework.builder import MicrovmBuilder, SnapshotBuilder, SnapshotType
1211
from framework.utils import get_free_mem_ssh, run_cmd
1312

1413
STATS_POLLING_INTERVAL_S = 1
@@ -431,22 +430,17 @@ def test_stats_update(test_microvm_with_api):
431430
assert next_stats["available_memory"] != final_stats["available_memory"]
432431

433432

434-
def test_balloon_snapshot(bin_cloner_path, microvm_factory, guest_kernel, rootfs):
433+
def test_balloon_snapshot(microvm_factory, guest_kernel, rootfs):
435434
"""
436435
Test that the balloon works after pause/resume.
437436
"""
438-
logger = logging.getLogger("snapshot_sequence")
439-
snapshot_type = SnapshotType.FULL
440-
diff_snapshots = snapshot_type == SnapshotType.DIFF
441-
442437
vm = microvm_factory.build(guest_kernel, rootfs)
443438
vm.spawn()
444439
vm.basic_config(
445440
vcpu_count=2,
446441
mem_size_mib=256,
447-
track_dirty_pages=diff_snapshots,
448442
)
449-
iface = vm.add_net_iface()
443+
vm.add_net_iface()
450444

451445
# Add a memory balloon with stats enabled.
452446
response = vm.balloon.put(
@@ -479,21 +473,12 @@ def test_balloon_snapshot(bin_cloner_path, microvm_factory, guest_kernel, rootfs
479473
# We only test that the reduction happens.
480474
assert first_reading > second_reading
481475

482-
logger.info("Create %s #0.", snapshot_type)
483-
# Create a snapshot builder from a microvm.
484-
snapshot_builder = SnapshotBuilder(vm)
485-
disks = [vm.rootfs_file]
486-
# Create base snapshot.
487-
snapshot = snapshot_builder.create(
488-
disks, rootfs.ssh_key(), snapshot_type, net_ifaces=[iface]
489-
)
490-
vm.kill()
476+
snapshot = vm.snapshot_full()
477+
microvm = microvm_factory.build()
478+
microvm.spawn()
479+
microvm.restore_from_snapshot(snapshot)
480+
microvm.resume()
491481

492-
logger.info("Load snapshot #%d, mem %s", 1, snapshot.mem)
493-
vm_builder = MicrovmBuilder(bin_cloner_path)
494-
microvm, _ = vm_builder.build_from_snapshot(
495-
snapshot, resume=True, diff_snapshots=diff_snapshots
496-
)
497482
# Attempt to connect to resumed microvm.
498483
microvm.ssh.run("true")
499484

@@ -538,16 +523,11 @@ def test_snapshot_compatibility(microvm_factory, guest_kernel, rootfs):
538523
"""
539524
Test that the balloon serializes correctly.
540525
"""
541-
logger = logging.getLogger("snapshot_compatibility")
542-
snapshot_type = SnapshotType.FULL
543-
diff_snapshots = snapshot_type == SnapshotType.DIFF
544-
545526
vm = microvm_factory.build(guest_kernel, rootfs)
546527
vm.spawn()
547528
vm.basic_config(
548529
vcpu_count=2,
549530
mem_size_mib=256,
550-
track_dirty_pages=diff_snapshots,
551531
)
552532

553533
# Add a memory balloon with stats enabled.
@@ -557,18 +537,7 @@ def test_snapshot_compatibility(microvm_factory, guest_kernel, rootfs):
557537
assert vm.api_session.is_status_no_content(response.status_code)
558538

559539
vm.start()
560-
561-
logger.info("Create %s #0.", snapshot_type)
562-
563-
# Pause the microVM in order to allow snapshots
564-
response = vm.vm.patch(state="Paused")
565-
assert vm.api_session.is_status_no_content(response.status_code)
566-
567-
# Create a snapshot builder from a microvm.
568-
snapshot_builder = SnapshotBuilder(vm)
569-
570-
# Check we can create a snapshot with a balloon on current version.
571-
snapshot_builder.create([rootfs.local_path()], rootfs.ssh_key(), snapshot_type)
540+
vm.snapshot_full()
572541

573542

574543
def test_memory_scrub(microvm_factory, guest_kernel, rootfs):

0 commit comments

Comments
 (0)