diff --git a/.gitignore b/.gitignore index 5547864f2b..84b41e0feb 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,6 @@ src/client/gui/flutter*.log packaging/windows/wix/obj/* packaging/windows/custom-actions/packages/* packaging/windows/custom-actions/x64/* + +# clangd cache path +.cache/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 77ac411879..fd86c2282f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,7 +275,7 @@ if(MSVC) add_definitions(-DLIBSSH_STATIC) # otherwise adds declspec specifiers to libssh apis add_definitions(-D_SILENCE_ALL_CXX17_DEPRECATION_WARNINGS) add_definitions(-DWIN32_LEAN_AND_MEAN) - set(MULTIPASS_BACKENDS hyperv virtualbox) + set(MULTIPASS_BACKENDS hyperv hyperv_api virtualbox) set(MULTIPASS_PLATFORM windows) else() add_compile_options(-Werror -Wall -pedantic -fPIC -Wno-error=deprecated-declarations) diff --git a/src/platform/backends/hyperv_api/CMakeLists.txt b/src/platform/backends/hyperv_api/CMakeLists.txt new file mode 100644 index 0000000000..ccfdf570df --- /dev/null +++ b/src/platform/backends/hyperv_api/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (C) Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 3 as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + +if(WIN32) + + include(CheckCXXSourceRuns) + + macro(check_pragma_lib LIB_NAME HEADER_NAME OUT_VAR) + check_cxx_source_runs(" + #pragma comment(lib, \"${LIB_NAME}\") + #define WIN32_LEAN_AND_MEAN + #include + #include <${HEADER_NAME}> + int main(void){ return 0; } + " ${OUT_VAR}) + endmacro() + + check_pragma_lib("computecore.lib" "computecore.h" HAS_COMPUTECORE) + check_pragma_lib("computenetwork.lib" "computenetwork.h" HAS_COMPUTENETWORK) + check_pragma_lib("virtdisk.lib" "virtdisk.h" HAS_VIRTDISK) + + if(NOT (HAS_COMPUTECORE AND HAS_COMPUTENETWORK AND HAS_VIRTDISK)) + message(FATAL_ERROR + "[hyperv_api] One or more required libraries are missing:\n" + " HAS_COMPUTECORE_LIB=${HAS_COMPUTECORE_LIB}\n" + " HAS_COMPUTENETWORK_LIB=${HAS_COMPUTENETWORK_LIB}\n" + " HAS_VIRTDISK_LIB=${HAS_VIRTDISK_LIB}\n" + ) + endif() + + add_library(hyperv_api_backend STATIC + hyperv_api_common.cpp + hcn/hyperv_hcn_api_wrapper.cpp + hcs/hyperv_hcs_api_wrapper.cpp + virtdisk/virtdisk_api_wrapper.cpp + ) + + target_link_libraries(hyperv_api_backend PRIVATE + fmt::fmt-header-only + utils + computecore.lib + computenetwork.lib + virtdisk.lib + ) +endif() diff --git a/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_table.h b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_table.h new file mode 100644 index 0000000000..171d5ec134 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_table.h @@ -0,0 +1,88 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCN_API_TABLE +#define MULTIPASS_HYPERV_API_HCN_API_TABLE + +// clang-format off +// (xmkg): clang-format is messing with the include order. +#include +#include +#include // for CoTaskMemFree +// clang-format on + +#include +#include + +namespace multipass::hyperv::hcn +{ + +/** + * API function table for the Host Compute Network API + */ +struct HCNAPITable +{ + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcncreatenetwork + std::function CreateNetwork = &HcnCreateNetwork; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcnopennetwork + std::function OpenNetwork = &HcnOpenNetwork; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcndeletenetwork + std::function DeleteNetwork = &HcnDeleteNetwork; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcnclosenetwork + std::function CloseNetwork = &HcnCloseNetwork; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcncreateendpoint + std::function CreateEndpoint = &HcnCreateEndpoint; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcnopenendpoint + std::function OpenEndpoint = &HcnOpenEndpoint; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcndeleteendpoint + std::function DeleteEndpoint = &HcnDeleteEndpoint; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcn/reference/hcndeleteendpoint + std::function CloseEndpoint = &HcnCloseEndpoint; + // @ref https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cotaskmemfree + std::function CoTaskMemFree = &::CoTaskMemFree; +}; + +} // namespace multipass::hyperv::hcn + +/** + * Formatter type specialization for HCNAPITable + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::hcn::HCNAPITable& api, FormatContext& ctx) const + { + return format_to(ctx.out(), + "CreateNetwork: ({}) | OpenNetwork: ({}) | DeleteNetwork: ({}) | CreateEndpoint: ({}) | " + "OpenEndpoint: ({}) | DeleteEndpoint: ({}) | CoTaskMemFree: ({})", + static_cast(api.CreateNetwork), + static_cast(api.OpenNetwork), + static_cast(api.DeleteNetwork), + static_cast(api.CreateEndpoint), + static_cast(api.OpenEndpoint), + static_cast(api.DeleteEndpoint), + static_cast(api.CoTaskMemFree)); + } +}; + +#endif // MULTIPASS_HYPERV_API_HCN_API_TABLE diff --git a/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_wrapper.cpp b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_wrapper.cpp new file mode 100644 index 0000000000..7a5394bff2 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_wrapper.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +// clang-format off +#include +#include +#include +#include +#include +#include // HCN API uses CoTaskMem* functions to allocate memory. +// clang-format on + +#include + +#include +#include +#include + +namespace multipass::hyperv::hcn +{ + +namespace +{ + +using UniqueHcnNetwork = std::unique_ptr, decltype(HCNAPITable::CloseNetwork)>; +using UniqueHcnEndpoint = std::unique_ptr, decltype(HCNAPITable::CloseEndpoint)>; +using UniqueCotaskmemString = std::unique_ptr; + +namespace mpl = logging; +using lvl = mpl::Level; + +// --------------------------------------------------------- + +/** + * Category for the log messages. + */ +constexpr auto kLogCategory = "HyperV-HCN-Wrapper"; + +// --------------------------------------------------------- + +/** + * Perform a Host Compute Network API operation + * + * @param fn The API function pointer + * @param args The arguments to the function + * + * @return HCNOperationResult Result of the performed operation + */ +template +OperationResult perform_hcn_operation(const HCNAPITable& api, const FnType& fn, Args&&... args) +{ + + // Ensure that function to call is set. + if (nullptr == fn) + { + assert(0); + // E_POINTER means "invalid pointer", seems to be appropriate. + return {E_POINTER, L"Operation function is unbound!"}; + } + + // HCN functions will use CoTaskMemAlloc to allocate the error message buffer + // so use UniqueCotaskmemString to auto-release it with appropriate free + // function. + + wchar_t* result_msg_out{nullptr}; + + // Perform the operation. The last argument of the all HCN operations (except + // HcnClose*) is ErrorRecord, which is a JSON-formatted document emitted by + // the API describing the error happened. Therefore, we can streamline all API + // calls through perform_operation to perform co + const auto result = ResultCode{fn(std::forward(args)..., &result_msg_out)}; + + UniqueCotaskmemString result_msgbuf{result_msg_out, api.CoTaskMemFree}; + + mpl::debug(kLogCategory, + "perform_operation(...) > fn: {}, result: {}", + fmt::ptr(fn.template target()), + static_cast(result)); + + // Error message is only valid when the operation resulted in an error. + // Passing a nullptr is well-defined in "< C++23", but it's going to be + // forbidden afterwards. Going an extra mile just to be future-proof. + return {result, {result_msgbuf ? result_msgbuf.get() : L""}}; +} + +// --------------------------------------------------------- + +/** + * Open an existing Host Compute Network and return a handle to it. + * + * This function is used for altering network resources, e.g. adding a new + * endpoint. + * + * @param api The HCN API table + * @param network_guid GUID of the network to open + * + * @return UniqueHcnNetwork Unique handle to the network. Non-nullptr when successful. + */ +UniqueHcnNetwork open_network(const HCNAPITable& api, const std::string& network_guid) +{ + mpl::debug(kLogCategory, "open_network(...) > network_guid: {} ", network_guid); + HCN_NETWORK network{nullptr}; + + const auto result = perform_hcn_operation(api, api.OpenNetwork, guid_from_string(network_guid), &network); + if (!result) + { + mpl::error(kLogCategory, "open_network() > HcnOpenNetwork failed with {}!", result.code); + } + return UniqueHcnNetwork{network, api.CloseNetwork}; +} + +} // namespace + +// --------------------------------------------------------- + +HCNWrapper::HCNWrapper(const HCNAPITable& api_table) : api{api_table} +{ + mpl::debug(kLogCategory, "HCNWrapper::HCNWrapper(...): api_table: {}", api); +} + +// --------------------------------------------------------- + +OperationResult HCNWrapper::create_network(const CreateNetworkParameters& params) const +{ + mpl::debug(kLogCategory, "HCNWrapper::create_network(...) > params: {} ", params); + + /** + * HcnCreateNetwork settings JSON template + */ + constexpr auto network_settings_template = LR"""( + {{ + "Name": "{0}", + "Type": "ICS", + "Subnets" : [ + {{ + "GatewayAddress": "{2}", + "AddressPrefix" : "{1}", + "IpSubnets" : [ + {{ + "IpAddressPrefix": "{1}" + }} + ] + }} + ], + "IsolateSwitch": true, + "Flags" : 265 + }} + )"""; + + // Render the template + const auto network_settings = fmt::format(network_settings_template, + string_to_wstring(params.name), + string_to_wstring(params.subnet), + string_to_wstring(params.gateway)); + + HCN_NETWORK network{nullptr}; + const auto result = perform_hcn_operation(api, + api.CreateNetwork, + guid_from_string(params.guid), + network_settings.c_str(), + &network); + + if (!result) + { + // FIXME: Also include the result error message, if any. + mpl::error(kLogCategory, "HCNWrapper::create_network(...) > HcnCreateNetwork failed with {}!", result.code); + } + + [[maybe_unused]] UniqueHcnNetwork _{network, api.CloseNetwork}; + return result; +} + +// --------------------------------------------------------- + +OperationResult HCNWrapper::delete_network(const std::string& network_guid) const +{ + mpl::debug(kLogCategory, "HCNWrapper::delete_network(...) > network_guid: {}", network_guid); + return perform_hcn_operation(api, api.DeleteNetwork, guid_from_string(network_guid)); +} + +// --------------------------------------------------------- + +OperationResult HCNWrapper::create_endpoint(const CreateEndpointParameters& params) const +{ + mpl::debug(kLogCategory, "HCNWrapper::create_endpoint(...) > params: {} ", params); + + const auto network = open_network(api, params.network_guid); + + if (nullptr == network) + { + return {E_POINTER, L"Could not open the network!"}; + } + + /** + * HcnCreateEndpoint settings JSON template + */ + constexpr auto endpoint_settings_template = LR"( + {{ + "SchemaVersion": {{ + "Major": 2, + "Minor": 16 + }}, + "HostComputeNetwork": "{0}", + "Policies": [ + ], + "IpConfigurations": [ + {{ + "IpAddress": "{1}" + }} + ] + }})"; + + // Render the template + const auto endpoint_settings = fmt::format(endpoint_settings_template, + string_to_wstring(params.network_guid), + string_to_wstring(params.endpoint_ipvx_addr)); + HCN_ENDPOINT endpoint{nullptr}; + const auto result = perform_hcn_operation(api, + api.CreateEndpoint, + network.get(), + guid_from_string(params.endpoint_guid), + endpoint_settings.c_str(), + &endpoint); + [[maybe_unused]] UniqueHcnEndpoint _{endpoint, api.CloseEndpoint}; + return result; +} + +// --------------------------------------------------------- + +OperationResult HCNWrapper::delete_endpoint(const std::string& endpoint_guid) const +{ + mpl::debug(kLogCategory, "HCNWrapper::delete_endpoint(...) > endpoint_guid: {} ", endpoint_guid); + return perform_hcn_operation(api, api.DeleteEndpoint, guid_from_string(endpoint_guid)); +} + +} // namespace multipass::hyperv::hcn diff --git a/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_wrapper.h b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_wrapper.h new file mode 100644 index 0000000000..8990f593e9 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_api_wrapper.h @@ -0,0 +1,94 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCN_WRAPPER +#define MULTIPASS_HYPERV_API_HCN_WRAPPER + +#include +#include + +namespace multipass::hyperv::hcn +{ + +/** + * A high-level wrapper class that defines + * the common operations that Host Compute Network + * API provide. + */ +struct HCNWrapper : public HCNWrapperInterface +{ + + /** + * Construct a new HCNWrapper + * + * @param api_table The HCN API table object (optional) + * + * The wrapper will use the real HCN API by default. + */ + HCNWrapper(const HCNAPITable& api_table = {}); + HCNWrapper(const HCNWrapper&) = default; + HCNWrapper(HCNWrapper&&) = default; + HCNWrapper& operator=(const HCNWrapper&) = delete; + HCNWrapper& operator=(HCNWrapper&&) = delete; + + /** + * Create a new Host Compute Network + * + * @param [in] params Parameters for the new network + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult create_network(const CreateNetworkParameters& params) const override; + + /** + * Delete an existing Host Compute Network + * + * @param [in] network_guid Target network's GUID + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult delete_network(const std::string& network_guid) const override; + + /** + * Create a new Host Compute Network Endpoint + * + * @param [in] params Parameters for the new endpoint + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult create_endpoint(const CreateEndpointParameters& params) const override; + + /** + * Delete an existing Host Compute Network Endpoint + * + * @param [in] params Target endpoint's GUID + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult delete_endpoint(const std::string& endpoint_guid) const override; + +private: + const HCNAPITable api{}; +}; + +} // namespace multipass::hyperv::hcn + +#endif // MULTIPASS_HYPERV_API_HCN_WRAPPER diff --git a/src/platform/backends/hyperv_api/hcn/hyperv_hcn_create_endpoint_params.h b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_create_endpoint_params.h new file mode 100644 index 0000000000..478f401da0 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_create_endpoint_params.h @@ -0,0 +1,76 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCN_CREATE_ENDPOINT_PARAMETERS_H +#define MULTIPASS_HYPERV_API_HCN_CREATE_ENDPOINT_PARAMETERS_H + +#include +#include + +namespace multipass::hyperv::hcn +{ + +/** + * Parameters for creating a network endpoint. + */ +struct CreateEndpointParameters +{ + /** + * The GUID of the network that will own the endpoint. + * + * The network must already exist. + */ + std::string network_guid{}; + + /** + * GUID for the new endpoint. + * + * Must be unique. + */ + std::string endpoint_guid{}; + + /** + * The IPv[4-6] address to assign to the endpoint. + */ + std::string endpoint_ipvx_addr{}; +}; + +} // namespace multipass::hyperv::hcn + +/** + * Formatter type specialization for CreateEndpointParameters + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::hcn::CreateEndpointParameters& params, FormatContext& ctx) const + { + return format_to(ctx.out(), + "Endpoint GUID: ({}) | Network GUID: ({}) | Endpoint IPvX Addr.: ({})", + params.endpoint_guid, + params.network_guid, + params.endpoint_ipvx_addr); + } +}; + +#endif // MULTIPASS_HYPERV_API_HCN_CREATE_ENDPOINT_PARAMETERS_H diff --git a/src/platform/backends/hyperv_api/hcn/hyperv_hcn_create_network_params.h b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_create_network_params.h new file mode 100644 index 0000000000..feb0a32a86 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_create_network_params.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCN_CREATE_NETWORK_PARAMETERS_H +#define MULTIPASS_HYPERV_API_HCN_CREATE_NETWORK_PARAMETERS_H + +#include +#include + +namespace multipass::hyperv::hcn +{ + +/** + * Parameters for creating a new Host Compute Network + */ +struct CreateNetworkParameters +{ + /** + * Name for the network + */ + std::string name{}; + + /** + * RFC4122 unique identifier for the network. + */ + std::string guid{}; + + /** + * Subnet CIDR that defines the address space of + * the network. + */ + std::string subnet{}; + + /** + * The default gateway address for the network. + */ + std::string gateway{}; +}; + +} // namespace multipass::hyperv::hcn + +/** + * Formatter type specialization for CreateNetworkParameters + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::hcn::CreateNetworkParameters& params, FormatContext& ctx) const + { + return format_to(ctx.out(), + "Network Name: ({}) | Network GUID: ({}) | Subnet CIDR: ({}) | Gateway Addr.: ({}) ", + params.name, + params.guid, + params.subnet, + params.gateway); + } +}; + +#endif // MULTIPASS_HYPERV_API_HCN_CREATE_NETWORK_PARAMETERS_H diff --git a/src/platform/backends/hyperv_api/hcn/hyperv_hcn_wrapper_interface.h b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_wrapper_interface.h new file mode 100644 index 0000000000..14048cdf22 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcn/hyperv_hcn_wrapper_interface.h @@ -0,0 +1,41 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCN_WRAPPER_INTERFACE_H +#define MULTIPASS_HYPERV_API_HCN_WRAPPER_INTERFACE_H + +#include + +#include + +namespace multipass::hyperv::hcn +{ + +/** + * Abstract interface for Host Compute Network API wrapper. + */ +struct HCNWrapperInterface +{ + [[nodiscard]] virtual OperationResult create_network(const struct CreateNetworkParameters& params) const = 0; + [[nodiscard]] virtual OperationResult delete_network(const std::string& network_guid) const = 0; + [[nodiscard]] virtual OperationResult create_endpoint(const struct CreateEndpointParameters& params) const = 0; + [[nodiscard]] virtual OperationResult delete_endpoint(const std::string& endpoint_guid) const = 0; + virtual ~HCNWrapperInterface() = default; +}; +} // namespace multipass::hyperv::hcn + +#endif // MULTIPASS_HYPERV_API_HCN_WRAPPER_INTERFACE_H diff --git a/src/platform/backends/hyperv_api/hcs/hyperv_hcs_add_endpoint_params.h b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_add_endpoint_params.h new file mode 100644 index 0000000000..198eb39f8e --- /dev/null +++ b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_add_endpoint_params.h @@ -0,0 +1,73 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCS_ADD_ENDPOINT_PARAMETERS_H +#define MULTIPASS_HYPERV_API_HCS_ADD_ENDPOINT_PARAMETERS_H + +#include +#include + +namespace multipass::hyperv::hcs +{ + +/** + * Parameters for adding a network endpoint to + * a Host Compute System.. + */ +struct AddEndpointParameters +{ + /** + * Name of the target host compute system + */ + std::string target_compute_system_name{}; + + /** + * GUID of the endpoint to add. + */ + std::string endpoint_guid{}; + + /** + * MAC address to assign to the NIC + */ + std::string nic_mac_address{}; +}; + +} // namespace multipass::hyperv::hcs + +/** + * Formatter type specialization for CreateComputeSystemParameters + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::hcs::AddEndpointParameters& params, FormatContext& ctx) const + { + return format_to(ctx.out(), + "Host Compute System Name: ({}) | Endpoint GUID: ({}) | NIC MAC Address: ({})", + params.target_compute_system_name, + params.endpoint_guid, + params.nic_mac_address); + } +}; + +#endif // MULTIPASS_HYPERV_API_HCS_ADD_ENDPOINT_PARAMETERS_H diff --git a/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_table.h b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_table.h new file mode 100644 index 0000000000..34bb23735e --- /dev/null +++ b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_table.h @@ -0,0 +1,122 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCS_API_TABLE +#define MULTIPASS_HYPERV_API_HCS_API_TABLE + +// clang-format off +// (xmkg): clang-format is messing with the include order. +#include +#include +// clang-format on + +#include + +#include + +namespace multipass::hyperv::hcs +{ + +/** + * API function table for host compute system + * @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/apioverview + */ +struct HCSAPITable +{ + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcscreateoperation + std::function CreateOperation = &HcsCreateOperation; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcswaitforoperationresult + std::function WaitForOperationResult = &HcsWaitForOperationResult; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcscloseoperation + std::function CloseOperation = &HcsCloseOperation; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcscreatecomputesystem + std::function CreateComputeSystem = &HcsCreateComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsopencomputesystem + std::function OpenComputeSystem = &HcsOpenComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsstartcomputesystem + std::function StartComputeSystem = &HcsStartComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsshutdowncomputesystem + std::function ShutDownComputeSystem = &HcsShutDownComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsterminatecomputesystem + std::function TerminateComputeSystem = &HcsTerminateComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsclosecomputesystem + std::function CloseComputeSystem = &HcsCloseComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcspausecomputesystem + std::function PauseComputeSystem = &HcsPauseComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsresumecomputesystem + std::function ResumeComputeSystem = &HcsResumeComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsmodifycomputesystem + std::function ModifyComputeSystem = &HcsModifyComputeSystem; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsgetcomputesystemproperties + std::function GetComputeSystemProperties = &HcsGetComputeSystemProperties; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsgrantvmaccess + std::function GrantVmAccess = &HcsGrantVmAccess; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsrevokevmaccess + std::function RevokeVmAccess = &HcsRevokeVmAccess; + // @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsenumeratecomputesystems + std::function EnumerateComputeSystems = &HcsEnumerateComputeSystems; + + /** + * @brief LocalAlloc/LocalFree is used by the HCS API to manage memory for the status/error + * messages. It's caller's responsibility to free the messages allocated by the API, that's + * why the LocalFree is part of the API table. + * + * @ref https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-localfree + */ + std::function LocalFree = &::LocalFree; +}; + +} // namespace multipass::hyperv::hcs + +/** + * Formatter type specialization for HCNAPITable + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::hcs::HCSAPITable& api, FormatContext& ctx) const + { + return format_to(ctx.out(), + "CreateOperation: ({}) | WaitForOperationResult: ({}) | CreateComputeSystem: ({}) | " + "OpenComputeSystem: ({}) | StartComputeSystem: ({}) | ShutDownComputeSystem: ({}) | " + "PauseComputeSystem: ({}) | ResumeComputeSystem: ({}) | ModifyComputeSystem: ({}) | " + "GetComputeSystemProperties: ({}) | GrantVmAccess: ({}) | RevokeVmAccess: ({}) | " + "EnumerateComputeSystems: ({}) | LocalFree: ({})", + static_cast(api.CreateOperation), + static_cast(api.WaitForOperationResult), + static_cast(api.CreateComputeSystem), + static_cast(api.OpenComputeSystem), + static_cast(api.StartComputeSystem), + static_cast(api.ShutDownComputeSystem), + static_cast(api.PauseComputeSystem), + static_cast(api.ResumeComputeSystem), + static_cast(api.ModifyComputeSystem), + static_cast(api.GetComputeSystemProperties), + static_cast(api.GrantVmAccess), + static_cast(api.RevokeVmAccess), + static_cast(api.EnumerateComputeSystems), + static_cast(api.LocalFree)); + } +}; + +#endif // MULTIPASS_HYPERV_API_HCS_API_TABLE diff --git a/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_wrapper.cpp b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_wrapper.cpp new file mode 100644 index 0000000000..34c6482e87 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_wrapper.cpp @@ -0,0 +1,526 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +#include +#include + +#include + +namespace multipass::hyperv::hcs +{ + +namespace +{ + +using UniqueHcsSystem = std::unique_ptr, decltype(HCSAPITable::CloseComputeSystem)>; +using UniqueHcsOperation = std::unique_ptr, decltype(HCSAPITable::CloseOperation)>; +using UniqueHlocalString = std::unique_ptr; + +namespace mpl = logging; +using lvl = mpl::Level; + +/** + * Category for the log messages. + */ +constexpr auto kLogCategory = "HyperV-HCS-Wrapper"; + +/** + * Default timeout value for HCS API operations + */ +constexpr auto kDefaultOperationTimeout = std::chrono::seconds{240}; + +// --------------------------------------------------------- + +/** + * Create a new HCS operation. + * + * @param api The HCS API table + * + * @return UniqueHcsOperation The new operation + */ +UniqueHcsOperation create_operation(const HCSAPITable& api) +{ + mpl::debug(kLogCategory, "create_operation(...)"); + return UniqueHcsOperation{api.CreateOperation(nullptr, nullptr), api.CloseOperation}; +} + +// --------------------------------------------------------- + +/** + * Wait until given operation completes, or the timeout has reached. + * + * @param api The HCS API table + * @param op Operation to wait for + * @param timeout Maximum amount of time to wait + * @return Operation result + */ +OperationResult wait_for_operation_result(const HCSAPITable& api, + UniqueHcsOperation op, + std::chrono::milliseconds timeout = kDefaultOperationTimeout) +{ + mpl::debug(kLogCategory, + "wait_for_operation_result(...) > ({}), timeout: {} ms", + fmt::ptr(op.get()), + timeout.count()); + + wchar_t* result_msg_out{nullptr}; + const auto result = api.WaitForOperationResult(op.get(), timeout.count(), &result_msg_out); + UniqueHlocalString result_msg{result_msg_out, api.LocalFree}; + + if (result_msg) + { + // TODO: Convert from wstring to ascii and log this + // mpl::debug(kLogCategory, + // "wait_for_operation_result(...): ({}), result: {}, result_msg: {}", fmt::ptr(op.get()), + // result, result_msg); + return OperationResult{result, result_msg.get()}; + } + return OperationResult{result, L""}; +} + +// --------------------------------------------------------- + +/** + * Open an existing Host Compute System + * + * @param api The HCS API table + * @param name Target Host Compute System's name + * + * @return auto UniqueHcsSystem non-nullptr on success. + */ +UniqueHcsSystem open_host_compute_system(const HCSAPITable& api, const std::string& name) +{ + mpl::debug(kLogCategory, "open_host_compute_system(...) > name: ({})", name); + + // Windows API uses wide strings. + const auto name_w = string_to_wstring(name); + constexpr auto kRequestedAccessLevel = GENERIC_ALL; + + HCS_SYSTEM system{nullptr}; + const auto result = ResultCode{api.OpenComputeSystem(name_w.c_str(), kRequestedAccessLevel, &system)}; + + if (!result) + { + mpl::error(kLogCategory, + "open_host_compute_system(...) > failed to open ({}), result code: ({})", + name, + result); + } + return UniqueHcsSystem{system, api.CloseComputeSystem}; +} + +// --------------------------------------------------------- + +/** + * Perform a Host Compute System API operation. + * + * Host Compute System operation functions have a common + * signature, where `system` and `operation` are always + * the first two arguments. This functions is a common + * shorthand for invoking any of those. + * + * @param [in] api The API function table + * @param [in] fn The API function pointer + * @param [in] target_hcs_system_name HCS system to operate on + * @param [in] args The arguments to the function + * + * @return HCNOperationResult Result of the performed operation + */ +template +OperationResult perform_hcs_operation(const HCSAPITable& api, + const FnType& fn, + const std::string& target_hcs_system_name, + Args&&... args) +{ + // Ensure that function to call is set. + if (nullptr == fn) + { + assert(0); + // E_POINTER means "invalid pointer", seems to be appropriate. + return {E_POINTER, L"Operation function is unbound!"}; + } + + const auto system = open_host_compute_system(api, target_hcs_system_name); + + if (nullptr == system) + { + mpl::error(kLogCategory, + "perform_hcs_operation(...) > HcsOpenComputeSystem failed! {}", + target_hcs_system_name); + return OperationResult{E_POINTER, L"HcsOpenComputeSystem failed!"}; + } + + auto operation = create_operation(api); + + if (nullptr == operation) + { + mpl::error(kLogCategory, "perform_hcs_operation(...) > HcsCreateOperation failed! {}", target_hcs_system_name); + return OperationResult{E_POINTER, L"HcsCreateOperation failed!"}; + } + + const auto result = ResultCode{fn(system.get(), operation.get(), std::forward(args)...)}; + + if (!result) + { + mpl::error(kLogCategory, + "perform_hcs_operation(...) > Operation failed! {} Result code {}", + target_hcs_system_name, + result); + return OperationResult{result, L"HCS operation failed!"}; + } + + mpl::debug(kLogCategory, + "perform_hcs_operation(...) > fn: {}, result: {}", + fmt::ptr(fn.template target()), + static_cast(result)); + + return wait_for_operation_result(api, std::move(operation)); +} + +} // namespace + +// --------------------------------------------------------- + +HCSWrapper::HCSWrapper(const HCSAPITable& api_table) : api{api_table} +{ + mpl::debug(kLogCategory, "HCSWrapper::HCSWrapper(...) > api_table: {}", api); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::create_compute_system(const CreateComputeSystemParameters& params) const +{ + mpl::debug(kLogCategory, "HCSWrapper::create_compute_system(...) > params: {} ", params); + + // Fill the SCSI devices template depending on + // available drives. + const auto scsi_devices = [¶ms]() { + constexpr auto scsi_device_template = LR"( + "{0}": {{ + "Attachments": {{ + "0": {{ + "Type": "{1}", + "Path": "{2}", + "ReadOnly": {3} + }} + }} + }}, + )"; + std::wstring result = {}; + if (!params.cloudinit_iso_path.empty()) + { + result += fmt::format(scsi_device_template, + L"cloud-init iso file", + L"Iso", + string_to_wstring(params.cloudinit_iso_path), + true); + } + + if (!params.vhdx_path.empty()) + { + result += fmt::format(scsi_device_template, + L"Primary disk", + L"VirtualDisk", + string_to_wstring(params.vhdx_path), + false); + } + return result; + }(); + + // Ideally, we should codegen from the schema + // and use that. + // https://raw.githubusercontent.com/MicrosoftDocs/Virtualization-Documentation/refs/heads/main/hyperv-samples/hcs-samples/JSON_files/HCS_Schema%5BWindows_10_SDK_version_1809%5D.json + constexpr auto vm_settings_template = LR"( + {{ + "SchemaVersion": {{ + "Major": 2, + "Minor": 1 + }}, + "Owner": "Multipass", + "ShouldTerminateOnLastHandleClosed": false, + "VirtualMachine": {{ + "Chipset": {{ + "Uefi": {{ + "BootThis": {{ + "DevicePath": "Primary disk", + "DiskNumber": 0, + "DeviceType": "ScsiDrive" + }}, + "Console": "ComPort1" + }} + }}, + "ComputeTopology": {{ + "Memory": {{ + "Backing": "Virtual", + "SizeInMB": {1} + }}, + "Processor": {{ + "Count": {0} + }} + }}, + "Devices": {{ + "ComPorts": {{ + "0": {{ + "NamedPipe": "\\\\.\\pipe\\{2}" + }} + }}, + "Scsi": {{ + {3} + }} + }} + }} + }})"; + + // Render the template + const auto vm_settings = fmt::format(vm_settings_template, + params.processor_count, + params.memory_size_mb, + string_to_wstring(params.name), + scsi_devices); + HCS_SYSTEM system{nullptr}; + + auto operation = create_operation(api); + + if (nullptr == operation) + { + return OperationResult{E_POINTER, L"HcsCreateOperation failed."}; + } + + const auto name_w = string_to_wstring(params.name); + const auto result = + ResultCode{api.CreateComputeSystem(name_w.c_str(), vm_settings.c_str(), operation.get(), nullptr, &system)}; + + // Auto-release the system handle + [[maybe_unused]] UniqueHcsSystem _{system, api.CloseComputeSystem}; + + if (!result) + { + return OperationResult{result, L"HcsCreateComputeSystem failed."}; + } + + return wait_for_operation_result(api, std::move(operation), std::chrono::seconds{240}); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::start_compute_system(const std::string& compute_system_name) const +{ + mpl::debug(kLogCategory, "start_compute_system(...) > name: ({})", compute_system_name); + return perform_hcs_operation(api, api.StartComputeSystem, compute_system_name, nullptr); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::shutdown_compute_system(const std::string& compute_system_name) const +{ + mpl::debug(kLogCategory, "shutdown_compute_system(...) > name: ({})", compute_system_name); + return perform_hcs_operation(api, api.ShutDownComputeSystem, compute_system_name, nullptr); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::terminate_compute_system(const std::string& compute_system_name) const +{ + mpl::debug(kLogCategory, "terminate_compute_system(...) > name: ({})", compute_system_name); + return perform_hcs_operation(api, api.TerminateComputeSystem, compute_system_name, nullptr); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::pause_compute_system(const std::string& compute_system_name) const +{ + mpl::debug(kLogCategory, "pause_compute_system(...) > name: ({})", compute_system_name); + static constexpr wchar_t c_pauseOption[] = LR"( + { + "SuspensionLevel": "Suspend", + "HostedNotification": { + "Reason": "Save" + } + })"; + return perform_hcs_operation(api, api.PauseComputeSystem, compute_system_name, c_pauseOption); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::resume_compute_system(const std::string& compute_system_name) const +{ + mpl::debug(kLogCategory, "resume_compute_system(...) > name: ({})", compute_system_name); + return perform_hcs_operation(api, api.ResumeComputeSystem, compute_system_name, nullptr); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::add_endpoint(const AddEndpointParameters& params) const +{ + mpl::debug(kLogCategory, "add_endpoint(...) > params: {}", params); + constexpr auto add_endpoint_settings_template = LR"( + {{ + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{{{0}}}", + "RequestType": "Add", + "Settings": {{ + "EndpointId": "{0}", + "MacAddress": "{1}", + "InstanceId": "{0}" + }} + }})"; + + const auto settings = fmt::format(add_endpoint_settings_template, + string_to_wstring(params.endpoint_guid), + string_to_wstring(params.nic_mac_address)); + + return perform_hcs_operation(api, + api.ModifyComputeSystem, + params.target_compute_system_name, + settings.c_str(), + nullptr); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::remove_endpoint(const std::string& compute_system_name, + const std::string& endpoint_guid) const +{ + mpl::debug(kLogCategory, + "remove_endpoint(...) > name: ({}), endpoint_guid: ({})", + compute_system_name, + endpoint_guid); + + constexpr auto remove_endpoint_settings_template = LR"( + {{ + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{{{0}}}", + "RequestType": "Remove" + }})"; + + const auto settings = fmt::format(remove_endpoint_settings_template, string_to_wstring(endpoint_guid)); + + return perform_hcs_operation(api, api.ModifyComputeSystem, compute_system_name, settings.c_str(), nullptr); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::resize_memory(const std::string& compute_system_name, std::uint32_t new_size_mib) const +{ + // Machine must be booted up. + mpl::debug(kLogCategory, "resize_memory(...) > name: ({}), new_size_mb: ({})", compute_system_name, new_size_mib); + // https://learn.microsoft.com/en-us/virtualization/api/hcs/reference/hcsmodifycomputesystem#remarks + constexpr auto resize_memory_settings_template = LR"( + {{ + "ResourcePath": "VirtualMachine/ComputeTopology/Memory/SizeInMB", + "RequestType": "Update", + "Settings": {0} + }})"; + + const auto settings = fmt::format(resize_memory_settings_template, new_size_mib); + + return perform_hcs_operation(api, api.ModifyComputeSystem, compute_system_name, settings.c_str(), nullptr); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::update_cpu_count(const std::string& compute_system_name, std::uint32_t new_vcpu_count) const +{ + return OperationResult{E_NOTIMPL, L"Not implemented yet!"}; +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::get_compute_system_properties(const std::string& compute_system_name) const +{ + + mpl::debug(kLogCategory, "get_compute_system_properties(...) > name: ({})", compute_system_name); + + // https://learn.microsoft.com/en-us/virtualization/api/hcs/schemareference#System_PropertyType + static constexpr wchar_t c_VmQuery[] = LR"( + { + "PropertyTypes":[] + })"; + + return perform_hcs_operation(api, api.GetComputeSystemProperties, compute_system_name, c_VmQuery); +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::grant_vm_access(const std::string& compute_system_name, + const std::filesystem::path& file_path) const +{ + mpl::debug(kLogCategory, + "grant_vm_access(...) > name: ({}), file_path: ({})", + compute_system_name, + file_path.string()); + + const auto path_as_wstring = file_path.wstring(); + const auto csname_as_wstring = string_to_wstring(compute_system_name); + const auto result = api.GrantVmAccess(csname_as_wstring.c_str(), path_as_wstring.c_str()); + return {result, FAILED(result) ? L"GrantVmAccess failed!" : L""}; +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::revoke_vm_access(const std::string& compute_system_name, + const std::filesystem::path& file_path) const +{ + mpl::debug(kLogCategory, + "revoke_vm_access(...) > name: ({}), file_path: ({}) ", + compute_system_name, + file_path.string()); + + const auto path_as_wstring = file_path.wstring(); + const auto csname_as_wstring = string_to_wstring(compute_system_name); + const auto result = api.RevokeVmAccess(csname_as_wstring.c_str(), path_as_wstring.c_str()); + return {result, FAILED(result) ? L"RevokeVmAccess failed!" : L""}; +} + +// --------------------------------------------------------- + +OperationResult HCSWrapper::get_compute_system_state(const std::string& compute_system_name) const +{ + mpl::debug(kLogCategory, "get_compute_system_state(...) > name: ({})", compute_system_name); + + const auto result = perform_hcs_operation(api, api.GetComputeSystemProperties, compute_system_name, nullptr); + if (!result) + { + return {result.code, L"Unknown"}; + } + + const QString qstr{QString::fromStdWString(result.status_msg)}; + const auto doc = QJsonDocument::fromJson(qstr.toUtf8()); + const auto obj = doc.object(); + if (obj.contains("State")) + { + const auto state = obj["State"]; + const auto state_str = state.toString(); + return {result.code, state_str.toStdWString()}; + } + + return {result.code, L"Unknown"}; +} + +} // namespace multipass::hyperv::hcs diff --git a/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_wrapper.h b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_wrapper.h new file mode 100644 index 0000000000..d08f7a0753 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_api_wrapper.h @@ -0,0 +1,237 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCS_WRAPPER +#define MULTIPASS_HYPERV_API_HCS_WRAPPER + +#include +#include +#include + +namespace multipass::hyperv::hcs +{ + +/** + * A high-level wrapper class that defines + * the common operations that Host Compute System + * API provide. + */ +struct HCSWrapper : public HCSWrapperInterface +{ + + /** + * Construct a new HCNWrapper + * + * @param api_table The HCN API table object (optional) + * + * The wrapper will use the real HCN API by default. + */ + HCSWrapper(const HCSAPITable& api_table = {}); + HCSWrapper(const HCSWrapper&) = default; + HCSWrapper(HCSWrapper&&) = default; + HCSWrapper& operator=(const HCSWrapper&) = delete; + HCSWrapper& operator=(HCSWrapper&&) = delete; + + // --------------------------------------------------------- + + /** + * Create a new Host Compute System + * + * @param [in] params Parameters for the new compute system + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult create_compute_system(const CreateComputeSystemParameters& params) const override; + + // --------------------------------------------------------- + + /** + * Start a compute system. + * + * @param [in] compute_system_name Target compute system's name + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult start_compute_system(const std::string& compute_system_name) const override; + + // --------------------------------------------------------- + + /** + * Gracefully shutdown the compute system + * + * @param [in] compute_system_name Target compute system's name + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult shutdown_compute_system(const std::string& compute_system_name) const override; + + // --------------------------------------------------------- + + /** + * Forcefully shutdown the compute system + * + * @param [in] compute_system_name Target compute system's name + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult terminate_compute_system(const std::string& compute_system_name) const override; + + // --------------------------------------------------------- + + /** + * Pause the execution of a running compute system + * + * @param [in] compute_system_name Target compute system's name + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult pause_compute_system(const std::string& compute_system_name) const override; + + // --------------------------------------------------------- + + /** + * Resume the execution of a previously paused compute system + * + * @param [in] compute_system_name Target compute system's name + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult resume_compute_system(const std::string& compute_system_name) const override; + + // --------------------------------------------------------- + + /** + * Retrieve a Host Compute System's properties + * + * @param [in] compute_system_name Target compute system's name + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult get_compute_system_properties(const std::string& compute_system_name) const override; + + // --------------------------------------------------------- + + /** + * Grant a compute system access to a file path. + * + * @param [in] compute_system_name Target compute system's name + * @param [in] file_path File path to grant access to + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult grant_vm_access(const std::string& compute_system_name, + const std::filesystem::path& file_path) const override; + + // --------------------------------------------------------- + + /** + * Revoke a compute system's access to a file path. + * + * @param [in] compute_system_name Target compute system's name + * @param [in] file_path File path to revoke access to + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult revoke_vm_access(const std::string& compute_system_name, + const std::filesystem::path& file_path) const override; + + // --------------------------------------------------------- + + /** + * Add a network endpoint to the host compute system. + * + * A new network interface card will be automatically created for + * the endpoint. The network interface card's name will be the + * endpoint's GUID for convenience. + * + * @param [in] params Endpoint parameters + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult add_endpoint(const AddEndpointParameters& params) const override; + + // --------------------------------------------------------- + + /** + * Remove a network endpoint from the host compute system. + * + * @param [in] name Target compute system's name + * @param [in] endpoint_guid GUID of the endpoint to remove + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult remove_endpoint(const std::string& compute_system_name, + const std::string& endpoint_guid) const override; + + // --------------------------------------------------------- + + /** + * Resize the amount of memory the compute system has. + * + * @param compute_system_name Target compute system name + * @param new_size_mib New memory size, in megabytes + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult resize_memory(const std::string& compute_system_name, + std::uint32_t new_size_mib) const override; + + // --------------------------------------------------------- + + /** + * Change the amount of available vCPUs in the compute system + * + * @param compute_system_name Target compute system name + * @param new_size_mib New memory size, in megabytes + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult update_cpu_count(const std::string& compute_system_name, + std::uint32_t new_vcpu_count) const override; + + // --------------------------------------------------------- + + /** + * Retrieve the current state of the compute system. + * + * @param [in] compute_system_name Target compute system's name + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult get_compute_system_state(const std::string& compute_system_name) const override; + +private: + const HCSAPITable api{}; +}; + +} // namespace multipass::hyperv::hcs + +#endif // MULTIPASS_HYPERV_API_HCS_WRAPPER diff --git a/src/platform/backends/hyperv_api/hcs/hyperv_hcs_compute_system_state.h b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_compute_system_state.h new file mode 100644 index 0000000000..a1dc83bb37 --- /dev/null +++ b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_compute_system_state.h @@ -0,0 +1,72 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCS_CREATE_COMPUTE_SYSTEM_STATE_H +#define MULTIPASS_HYPERV_API_HCS_CREATE_COMPUTE_SYSTEM_STATE_H + +#include +#include +#include +#include +#include + +namespace multipass::hyperv::hcs +{ + +/** + * Enum values representing a compute system's possible state + * + * @ref https://learn.microsoft.com/en-us/virtualization/api/hcs/schemareference#State + */ +enum class ComputeSystemState : std::uint8_t +{ + created, + running, + paused, + stopped, + saved_as_template, + unknown, +}; + +/** + * Translate host compute system state string to enum + * + * @param str + * @return ComputeSystemState + */ +inline std::optional compute_system_state_from_string(std::string str) +{ + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); }); + // std::unordered_map + static const std::unordered_map translation_map{ + {"created", ComputeSystemState::created}, + {"running", ComputeSystemState::running}, + {"paused", ComputeSystemState::paused}, + {"stopped", ComputeSystemState::stopped}, + {"savedastemplate", ComputeSystemState::saved_as_template}, + {"unknown", ComputeSystemState::unknown}, + }; + + if (const auto itr = translation_map.find(str); translation_map.end() != itr) + return itr->second; + + return std::nullopt; +} + +} // namespace multipass::hyperv::hcs + +#endif diff --git a/src/platform/backends/hyperv_api/hcs/hyperv_hcs_create_compute_system_params.h b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_create_compute_system_params.h new file mode 100644 index 0000000000..365069f56d --- /dev/null +++ b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_create_compute_system_params.h @@ -0,0 +1,85 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCS_CREATE_COMPUTE_SYSTEM_PARAMETERS_H +#define MULTIPASS_HYPERV_API_HCS_CREATE_COMPUTE_SYSTEM_PARAMETERS_H + +#include +#include + +namespace multipass::hyperv::hcs +{ + +/** + * Parameters for creating a network endpoint. + */ +struct CreateComputeSystemParameters +{ + /** + * Unique name for the compute system + */ + std::string name{}; + + /** + * Memory size, in megabytes + */ + std::uint32_t memory_size_mb{}; + + /** + * vCPU count + */ + std::uint32_t processor_count{}; + + /** + * Path to the cloud-init ISO file + */ + std::string cloudinit_iso_path{}; + + /** + * Path to the Primary (boot) VHDX file + */ + std::string vhdx_path{}; +}; + +} // namespace multipass::hyperv::hcs + +/** + * Formatter type specialization for CreateComputeSystemParameters + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::hcs::CreateComputeSystemParameters& params, FormatContext& ctx) const + { + return format_to(ctx.out(), + "Compute System name: ({}) | vCPU count: ({}) | Memory size: ({} MiB) | cloud-init ISO path: " + "({}) | VHDX path: ({})", + params.name, + params.processor_count, + params.memory_size_mb, + params.cloudinit_iso_path, + params.vhdx_path); + } +}; + +#endif // MULTIPASS_HYPERV_API_HCS_CREATE_COMPUTE_SYSTEM_PARAMETERS_H diff --git a/src/platform/backends/hyperv_api/hcs/hyperv_hcs_wrapper_interface.h b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_wrapper_interface.h new file mode 100644 index 0000000000..ea1f7348ef --- /dev/null +++ b/src/platform/backends/hyperv_api/hcs/hyperv_hcs_wrapper_interface.h @@ -0,0 +1,59 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_HCS_WRAPPER_INTERFACE_H +#define MULTIPASS_HYPERV_API_HCS_WRAPPER_INTERFACE_H + +#include +#include +#include + +#include +#include + +namespace multipass::hyperv::hcs +{ + +/** + * Abstract interface for the Host Compute System API wrapper. + */ +struct HCSWrapperInterface +{ + virtual OperationResult create_compute_system(const CreateComputeSystemParameters& params) const = 0; + virtual OperationResult start_compute_system(const std::string& compute_system_name) const = 0; + virtual OperationResult shutdown_compute_system(const std::string& compute_system_name) const = 0; + virtual OperationResult pause_compute_system(const std::string& compute_system_name) const = 0; + virtual OperationResult resume_compute_system(const std::string& compute_system_name) const = 0; + virtual OperationResult terminate_compute_system(const std::string& compute_system_name) const = 0; + virtual OperationResult get_compute_system_properties(const std::string& compute_system_name) const = 0; + virtual OperationResult grant_vm_access(const std::string& compute_system_name, + const std::filesystem::path& file_path) const = 0; + virtual OperationResult revoke_vm_access(const std::string& compute_system_name, + const std::filesystem::path& file_path) const = 0; + virtual OperationResult add_endpoint(const AddEndpointParameters& params) const = 0; + virtual OperationResult remove_endpoint(const std::string& compute_system_name, + const std::string& endpoint_guid) const = 0; + virtual OperationResult resize_memory(const std::string& compute_system_name, + const std::uint32_t new_size_mib) const = 0; + virtual OperationResult update_cpu_count(const std::string& compute_system_name, + const std::uint32_t new_core_count) const = 0; + virtual OperationResult get_compute_system_state(const std::string& compute_system_name) const = 0; + virtual ~HCSWrapperInterface() = default; +}; +} // namespace multipass::hyperv::hcs + +#endif // MULTIPASS_HYPERV_API_HCS_WRAPPER_INTERFACE_H diff --git a/src/platform/backends/hyperv_api/hyperv_api_common.cpp b/src/platform/backends/hyperv_api/hyperv_api_common.cpp new file mode 100644 index 0000000000..3c1c37b213 --- /dev/null +++ b/src/platform/backends/hyperv_api/hyperv_api_common.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include +#include + +#include + +#include // for CLSIDFromString + +#include +#include + +/** + * Formatter for GUID type + */ +template +struct fmt::formatter<::GUID, Char> +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const ::GUID& guid, FormatContext& ctx) const + { + // The format string is laid out char by char to allow it + // to be used for initializing variables with different character + // sizes. + static constexpr Char guid_f[] = {'{', ':', '0', '8', 'x', '}', '-', '{', ':', '0', '4', 'x', '}', '-', '{', + ':', '0', '4', 'x', '}', '-', '{', ':', '0', '2', 'x', '}', '{', ':', '0', + '2', 'x', '}', '-', '{', ':', '0', '2', 'x', '}', '{', ':', '0', '2', 'x', + '}', '{', ':', '0', '2', 'x', '}', '{', ':', '0', '2', 'x', '}', '{', ':', + '0', '2', 'x', '}', '{', ':', '0', '2', 'x', '}', 0}; + return format_to(ctx.out(), + guid_f, + guid.Data1, + guid.Data2, + guid.Data3, + guid.Data4[0], + guid.Data4[1], + guid.Data4[2], + guid.Data4[3], + guid.Data4[4], + guid.Data4[5], + guid.Data4[6], + guid.Data4[7]); + } +}; + +namespace multipass::hyperv +{ + +struct GuidParseError : FormattedExceptionBase<> +{ + using FormattedExceptionBase<>::FormattedExceptionBase; +}; + +auto guid_from_wstring(const std::wstring& guid_wstr) -> ::GUID +{ + constexpr static auto kGUIDLength = 36; + constexpr static auto kGUIDLengthWithBraces = kGUIDLength + 2; + + const auto input = [&guid_wstr]() { + switch (guid_wstr.length()) + { + case kGUIDLength: + // CLSIDFromString requires GUIDs to be wrapped with braces. + return fmt::format(L"{{{}}}", guid_wstr); + case kGUIDLengthWithBraces: + { + if (*guid_wstr.begin() != L'{' || *std::prev(guid_wstr.end()) != L'}') + { + throw GuidParseError{"GUID string either does not start or end with a brace."}; + } + return guid_wstr; + } + } + throw GuidParseError{"Invalid length for a GUID string ({}).", guid_wstr.length()}; + }(); + + ::GUID guid = {}; + + const auto result = CLSIDFromString(input.c_str(), &guid); + + if (FAILED(result)) + { + throw GuidParseError{"Failed to parse the GUID string ({}).", result}; + } + + return guid; +} + +// --------------------------------------------------------- + +auto string_to_wstring(const std::string& str) -> std::wstring +{ + return std::wstring_convert>().from_bytes(str); +} + +// --------------------------------------------------------- + +auto guid_from_string(const std::string& guid_str) -> GUID +{ + // Just use the wide string overload. + return guid_from_wstring(string_to_wstring(guid_str)); +} + +// --------------------------------------------------------- + +auto guid_to_string(const ::GUID& guid) -> std::string +{ + + return fmt::format("{}", guid); +} + +// --------------------------------------------------------- + +auto guid_to_wstring(const ::GUID& guid) -> std::wstring +{ + return fmt::format(L"{}", guid); +} + +} // namespace multipass::hyperv diff --git a/src/platform/backends/hyperv_api/hyperv_api_common.h b/src/platform/backends/hyperv_api/hyperv_api_common.h new file mode 100644 index 0000000000..8bc0b021f8 --- /dev/null +++ b/src/platform/backends/hyperv_api/hyperv_api_common.h @@ -0,0 +1,82 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_COMMON_H +#define MULTIPASS_HYPERV_API_COMMON_H + +#include + +#include + +namespace multipass::hyperv +{ + +// --------------------------------------------------------- + +/** + * Parse given GUID string into a GUID struct. + * + * @param guid_str GUID in string form, either 36 characters + * (without braces) or 38 characters (with braces.) + * + * @return GUID The parsed GUID + */ +[[nodiscard]] auto guid_from_string(const std::string& guid_str) -> GUID; + +/** + * Parse given GUID string into a GUID struct. + * + * @param guid_wstr GUID in string form, either 36 characters + * (without braces) or 38 characters (with braces.) + * + * @return GUID The parsed GUID + */ +[[nodiscard]] auto guid_from_wstring(const std::wstring& guid_wstr) -> GUID; + +// --------------------------------------------------------- + +/** + * @brief Convert a GUID to its string representation + * + * @param [in] guid GUID to convert + * @return std::string GUID in string form + */ +[[nodiscard]] auto guid_to_string(const ::GUID& guid) -> std::string; + +// --------------------------------------------------------- + +/** + * @brief Convert a guid to its wide string representation + * + * @param [in] guid GUID to convert + * @return std::wstring GUID in wstring form + */ +[[nodiscard]] auto guid_to_wstring(const ::GUID& guid) -> std::wstring; + +// --------------------------------------------------------- + +/** + * Convert a multi-byte string to a wide-character string. + * + * @param str Multi-byte string + * @return Wide-character equivalent of the given multi-byte string. + */ +[[nodiscard]] auto string_to_wstring(const std::string& str) -> std::wstring; + +} // namespace multipass::hyperv + +#endif // MULTIPASS_HYPERV_API_COMMON_H diff --git a/src/platform/backends/hyperv_api/hyperv_api_operation_result.h b/src/platform/backends/hyperv_api/hyperv_api_operation_result.h new file mode 100644 index 0000000000..0a906df477 --- /dev/null +++ b/src/platform/backends/hyperv_api/hyperv_api_operation_result.h @@ -0,0 +1,130 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_OPERATION_RESULT_H +#define MULTIPASS_HYPERV_API_OPERATION_RESULT_H + +#include + +#include + +#include + +namespace multipass::hyperv +{ + +/** + * A simple HRESULT wrapper which is boolable for + * convenience. + */ +struct ResultCode +{ + using unsigned_hresult_t = std::make_unsigned_t; + + ResultCode(HRESULT r) noexcept : result(r) + { + } + ResultCode& operator=(HRESULT r) noexcept + { + result = r; + return *this; + } + + [[nodiscard]] explicit operator bool() const noexcept + { + return !FAILED(result); + } + + [[nodiscard]] explicit operator HRESULT() const noexcept + { + return result; + } + + [[nodiscard]] explicit operator unsigned_hresult_t() const noexcept + { + return static_cast(result); + } + +private: + HRESULT result{}; +}; + +/** + * An object that describes the result of an HCN operation + * performed through HCNWrapper. + */ +struct OperationResult +{ + + /** + * Status code of the operation. Evaluates to + * true and greater or equal to 0 on success. + */ + const ResultCode code; + + /** + * A message that describes the result of the operation. + * It might contain an error message describing the error + * when the operation fails, or details regarding the status + * of a successful operation. + */ + const std::wstring status_msg; + + [[nodiscard]] explicit operator bool() const noexcept + { + return static_cast(code); + } +}; +} // namespace multipass::hyperv + +/** + * Formatter type specialization for ResultCode + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::ResultCode& rc, FormatContext& ctx) const + { + return format_to(ctx.out(), "{:#x}", static_cast>(rc)); + } +}; + +/** + * Formatter type specialization for ResultCode + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::OperationResult& opr, FormatContext& ctx) const + { + return format_to(ctx.out(), "{:#x}", opr.code); + } +}; + +#endif // MULTIPASS_HYPERV_API_OPERATION_RESULT_H diff --git a/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_table.h b/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_table.h new file mode 100644 index 0000000000..e77fe0d8cf --- /dev/null +++ b/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_table.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_VIRTDISK_API_TABLE +#define MULTIPASS_HYPERV_API_VIRTDISK_API_TABLE + +// clang-format off +#include +#include +#include +// clang-format on + +#include + +#include + +namespace multipass::hyperv::virtdisk +{ + +/** + * API function table for the virtdisk API + * @ref https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/ + */ +struct VirtDiskAPITable +{ + // @ref https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/nf-virtdisk-createvirtualdisk + std::function CreateVirtualDisk = &::CreateVirtualDisk; + // @ref https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/nf-virtdisk-openvirtualdisk + std::function OpenVirtualDisk = &::OpenVirtualDisk; + // @ref https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/nf-virtdisk-resizevirtualdisk + std::function ResizeVirtualDisk = &::ResizeVirtualDisk; + // @ref https://learn.microsoft.com/en-us/windows/win32/api/virtdisk/nf-virtdisk-getvirtualdiskinformation + std::function GetVirtualDiskInformation = &::GetVirtualDiskInformation; + // @ref https://learn.microsoft.com/en-us/windows/win32/api/handleapi/nf-handleapi-closehandle + std::function CloseHandle = &::CloseHandle; +}; + +} // namespace multipass::hyperv::virtdisk + +/** + * Formatter type specialization for VirtDiskAPITable + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::virtdisk::VirtDiskAPITable& api, FormatContext& ctx) const + { + return format_to(ctx.out(), + "CreateVirtualDisk: ({}) | OpenVirtualDisk ({}) | ResizeVirtualDisk: ({}) | " + "GetVirtualDiskInformation: ({}) | CloseHandle: ({})", + static_cast(api.CreateVirtualDisk), + static_cast(api.OpenVirtualDisk), + static_cast(api.ResizeVirtualDisk), + static_cast(api.GetVirtualDiskInformation), + static_cast(api.CloseHandle)); + } +}; + +#endif // MULTIPASS_HYPERV_API_VIRTDISK_API_TABLE diff --git a/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_wrapper.cpp b/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_wrapper.cpp new file mode 100644 index 0000000000..b419919fd8 --- /dev/null +++ b/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_wrapper.cpp @@ -0,0 +1,307 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include + +// clang-format off +#include +#include +#include +// clang-format on + +#include +#include + +namespace multipass::hyperv::virtdisk +{ + +namespace +{ + +using UniqueHandle = std::unique_ptr, decltype(VirtDiskAPITable::CloseHandle)>; + +namespace mpl = logging; +using lvl = mpl::Level; + +/** + * Category for the log messages. + */ +constexpr auto kLogCategory = "HyperV-VirtDisk-Wrapper"; + +UniqueHandle open_virtual_disk(const VirtDiskAPITable& api, const std::filesystem::path& vhdx_path) +{ + mpl::debug(kLogCategory, "open_virtual_disk(...) > vhdx_path: {}", vhdx_path.string()); + // + // Specify UNKNOWN for both device and vendor so the system will use the + // file extension to determine the correct VHD format. + // + VIRTUAL_STORAGE_TYPE type{}; + type.DeviceId = VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN; + type.VendorId = VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN; + + HANDLE handle{nullptr}; + const auto path_w = vhdx_path.wstring(); + + const auto result = api.OpenVirtualDisk( + // [in] PVIRTUAL_STORAGE_TYPE VirtualStorageType + &type, + // [in] PCWSTR Path + path_w.c_str(), + // [in] VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask + VIRTUAL_DISK_ACCESS_ALL, + // [in] OPEN_VIRTUAL_DISK_FLAG Flags + OPEN_VIRTUAL_DISK_FLAG_NONE, + // [in, optional] POPEN_VIRTUAL_DISK_PARAMETERS Parameters + nullptr, + // [out] PHANDLE Handle + &handle); + + if (!(result == ERROR_SUCCESS)) + { + mpl::error(kLogCategory, "open_virtual_disk(...) > OpenVirtualDisk failed with: {}", result); + return UniqueHandle{nullptr, api.CloseHandle}; + } + + return {handle, api.CloseHandle}; +} + +} // namespace + +// --------------------------------------------------------- + +VirtDiskWrapper::VirtDiskWrapper(const VirtDiskAPITable& api_table) : api{api_table} +{ + mpl::debug(kLogCategory, "VirtDiskWrapper::VirtDiskWrapper(...) > api_table: {}", api); +} + +// --------------------------------------------------------- + +OperationResult VirtDiskWrapper::create_virtual_disk(const CreateVirtualDiskParameters& params) const +{ + mpl::debug(kLogCategory, "create_virtual_disk(...) > params: {}", params); + // + // https://github.com/microsoft/Windows-classic-samples/blob/main/Samples/Hyper-V/Storage/cpp/CreateVirtualDisk.cpp + // + VIRTUAL_STORAGE_TYPE type{}; + + // + // Specify UNKNOWN for both device and vendor so the system will use the + // file extension to determine the correct VHD format. + // + type.DeviceId = VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN; + type.VendorId = VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN; + + CREATE_VIRTUAL_DISK_PARAMETERS parameters{}; + parameters.Version = CREATE_VIRTUAL_DISK_VERSION_2; + parameters.Version2 = {}; + parameters.Version2.MaximumSize = params.size_in_bytes; + + // + // Internal size of the virtual disk object blocks, in bytes. + // For VHDX this must be a multiple of 1 MB between 1 and 256 MB. + // For VHD 1 this must be set to one of the following values. + // parameters.Version2.BlockSizeInBytes + // + parameters.Version2.BlockSizeInBytes = 1048576; // 1024 KiB + + if (params.path.extension() == ".vhd") + { + parameters.Version2.BlockSizeInBytes = 524288; // 512 KiB + } + + const auto path_w = params.path.wstring(); + + HANDLE result_handle{nullptr}; + + const auto result = api.CreateVirtualDisk(&type, + // [in] PCWSTR Path + path_w.c_str(), + // [in] VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + VIRTUAL_DISK_ACCESS_NONE, + // [in, optional] PSECURITY_DESCRIPTOR SecurityDescriptor, + nullptr, + // [in] CREATE_VIRTUAL_DISK_FLAG Flags, + CREATE_VIRTUAL_DISK_FLAG_NONE, + // [in] ULONG ProviderSpecificFlags, + 0, + // [in] PCREATE_VIRTUAL_DISK_PARAMETERS Parameters, + ¶meters, + // [in, optional] LPOVERLAPPED Overlapped + nullptr, + // [out] PHANDLE Handle + &result_handle); + + if (result == ERROR_SUCCESS) + { + [[maybe_unused]] UniqueHandle _{result_handle, api.CloseHandle}; + return OperationResult{NOERROR, L""}; + } + + mpl::error(kLogCategory, "create_virtual_disk(...) > CreateVirtualDisk failed with {}!", result); + return OperationResult{E_FAIL, fmt::format(L"CreateVirtualDisk failed with {}!", result)}; +} + +// --------------------------------------------------------- + +OperationResult VirtDiskWrapper::resize_virtual_disk(const std::filesystem::path& vhdx_path, + std::uint64_t new_size_bytes) const +{ + mpl::debug(kLogCategory, + "resize_virtual_disk(...) > vhdx_path: {}, new_size_bytes: {}", + vhdx_path.string(), + new_size_bytes); + const auto disk_handle = open_virtual_disk(api, vhdx_path); + + if (nullptr == disk_handle) + { + return OperationResult{E_FAIL, L"open_virtual_disk failed!"}; + } + + RESIZE_VIRTUAL_DISK_PARAMETERS params{}; + params.Version = RESIZE_VIRTUAL_DISK_VERSION_1; + params.Version1 = {}; + params.Version1.NewSize = new_size_bytes; + + const auto resize_result = api.ResizeVirtualDisk( + // [in] HANDLE VirtualDiskHandle + disk_handle.get(), + // [in] RESIZE_VIRTUAL_DISK_FLAG Flags + RESIZE_VIRTUAL_DISK_FLAG_NONE, + // [in] PRESIZE_VIRTUAL_DISK_PARAMETERS Parameters + ¶ms, + // [in, optional] LPOVERLAPPED Overlapped + nullptr); + + if (ERROR_SUCCESS == resize_result) + { + return OperationResult{NOERROR, L""}; + } + + mpl::error(kLogCategory, "resize_virtual_disk(...) > ResizeVirtualDisk failed with {}!", resize_result); + + return OperationResult{E_FAIL, fmt::format(L"ResizeVirtualDisk failed with {}!", resize_result)}; +} + +// --------------------------------------------------------- + +OperationResult VirtDiskWrapper::get_virtual_disk_info(const std::filesystem::path& vhdx_path, + VirtualDiskInfo& vdinfo) const +{ + mpl::debug(kLogCategory, "get_virtual_disk_info(...) > vhdx_path: {}", vhdx_path.string()); + // + // https://github.com/microsoft/Windows-classic-samples/blob/main/Samples/Hyper-V/Storage/cpp/GetVirtualDiskInformation.cpp + // + + const auto disk_handle = open_virtual_disk(api, vhdx_path); + + if (nullptr == disk_handle) + { + return OperationResult{E_FAIL, L"open_virtual_disk failed!"}; + } + + constexpr GET_VIRTUAL_DISK_INFO_VERSION what_to_get[] = {GET_VIRTUAL_DISK_INFO_SIZE, + GET_VIRTUAL_DISK_INFO_VIRTUAL_STORAGE_TYPE, + GET_VIRTUAL_DISK_INFO_SMALLEST_SAFE_VIRTUAL_SIZE, + GET_VIRTUAL_DISK_INFO_PROVIDER_SUBTYPE}; + + for (const auto version : what_to_get) + { + GET_VIRTUAL_DISK_INFO disk_info{}; + disk_info.Version = version; + + ULONG sz = sizeof(disk_info); + + const auto result = api.GetVirtualDiskInformation(disk_handle.get(), &sz, &disk_info, nullptr); + + if (ERROR_SUCCESS == result) + { + switch (disk_info.Version) + { + case GET_VIRTUAL_DISK_INFO_SIZE: + vdinfo.size = std::make_optional(); + vdinfo.size->virtual_ = disk_info.Size.VirtualSize; + vdinfo.size->block = disk_info.Size.BlockSize; + vdinfo.size->physical = disk_info.Size.PhysicalSize; + vdinfo.size->sector = disk_info.Size.SectorSize; + break; + case GET_VIRTUAL_DISK_INFO_VIRTUAL_STORAGE_TYPE: + { + switch (disk_info.VirtualStorageType.DeviceId) + { + case VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN: + vdinfo.virtual_storage_type = "unknown"; + break; + case VIRTUAL_STORAGE_TYPE_DEVICE_ISO: + vdinfo.virtual_storage_type = "iso"; + break; + case VIRTUAL_STORAGE_TYPE_DEVICE_VHD: + vdinfo.virtual_storage_type = "vhd"; + break; + case VIRTUAL_STORAGE_TYPE_DEVICE_VHDX: + vdinfo.virtual_storage_type = "vhdx"; + break; + case VIRTUAL_STORAGE_TYPE_DEVICE_VHDSET: + vdinfo.virtual_storage_type = "vhdset"; + break; + } + } + break; + case GET_VIRTUAL_DISK_INFO_SMALLEST_SAFE_VIRTUAL_SIZE: + vdinfo.smallest_safe_virtual_size = disk_info.SmallestSafeVirtualSize; + break; + case GET_VIRTUAL_DISK_INFO_PROVIDER_SUBTYPE: + { + enum class ProviderSubtype : std::uint8_t + { + fixed = 2, + dynamic = 3, + differencing = 4 + }; + + switch (static_cast(disk_info.ProviderSubtype)) + { + case ProviderSubtype::fixed: + vdinfo.provider_subtype = "fixed"; + break; + case ProviderSubtype::dynamic: + vdinfo.provider_subtype = "dynamic"; + + break; + case ProviderSubtype::differencing: + vdinfo.provider_subtype = "differencing"; + break; + default: + vdinfo.provider_subtype = "unknown"; + break; + } + } + break; + default: + assert(0); + break; + } + } + else + { + mpl::warn(kLogCategory, "get_virtual_disk_info(...) > failed to get {}", fmt::underlying(version)); + } + } + + return {NOERROR, L""}; +} + +} // namespace multipass::hyperv::virtdisk diff --git a/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_wrapper.h b/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_wrapper.h new file mode 100644 index 0000000000..aedfee658d --- /dev/null +++ b/src/platform/backends/hyperv_api/virtdisk/virtdisk_api_wrapper.h @@ -0,0 +1,95 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_VIRTDISK_WRAPPER_H +#define MULTIPASS_HYPERV_API_VIRTDISK_WRAPPER_H + +#include +#include +#include + +namespace multipass::hyperv::virtdisk +{ + +/** + * A high-level wrapper class that defines + * the common operations that Host Compute System + * API provide. + */ +struct VirtDiskWrapper : public VirtDiskWrapperInterface +{ + + /** + * Construct a new HCNWrapper + * + * @param api_table The HCN API table object (optional) + * + * The wrapper will use the real HCN API by default. + */ + VirtDiskWrapper(const VirtDiskAPITable& api_table = {}); + VirtDiskWrapper(const VirtDiskWrapper&) = default; + VirtDiskWrapper(VirtDiskWrapper&&) = default; + VirtDiskWrapper& operator=(const VirtDiskWrapper&) = delete; + VirtDiskWrapper& operator=(VirtDiskWrapper&&) = delete; + + // --------------------------------------------------------- + + /** + * Create a new Virtual Disk + * + * @param [in] params Parameters for the new virtual disk + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + [[nodiscard]] OperationResult create_virtual_disk(const CreateVirtualDiskParameters& params) const override; + + // --------------------------------------------------------- + + /** + * Resize an existing Virtual Disk + * + * @param [in] vhdx_path Path to the virtual disk + * @param [in] new_size New disk size, in bytes + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + virtual OperationResult resize_virtual_disk(const std::filesystem::path& vhdx_path, + std::uint64_t new_size_bytes) const override; + + // --------------------------------------------------------- + + /** + * Get information about an existing Virtual Disk + * + * @param [in] vhdx_path Path to the virtual disk + * @param [out] vdinfo Virtual disk info output object + * + * @return An object that evaluates to true on success, false otherwise. + * message() may contain details of failure when result is false. + */ + virtual OperationResult get_virtual_disk_info(const std::filesystem::path& vhdx_path, + VirtualDiskInfo& vdinfo) const override; + +private: + const VirtDiskAPITable api{}; +}; + +} // namespace multipass::hyperv::virtdisk + +#endif // MULTIPASS_HYPERV_API_VIRTDISK_WRAPPER_H diff --git a/src/platform/backends/hyperv_api/virtdisk/virtdisk_create_virtual_disk_params.h b/src/platform/backends/hyperv_api/virtdisk/virtdisk_create_virtual_disk_params.h new file mode 100644 index 0000000000..953157184c --- /dev/null +++ b/src/platform/backends/hyperv_api/virtdisk/virtdisk_create_virtual_disk_params.h @@ -0,0 +1,57 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_VIRTDISK_CREATE_VIRTUAL_DISK_PARAMETERS_H +#define MULTIPASS_HYPERV_API_VIRTDISK_CREATE_VIRTUAL_DISK_PARAMETERS_H + +#include + +#include + +namespace multipass::hyperv::virtdisk +{ + +/** + * Parameters for creating a new virtual disk drive. + */ +struct CreateVirtualDiskParameters +{ + std::uint64_t size_in_bytes{}; + std::filesystem::path path{}; +}; + +} // namespace multipass::hyperv::virtdisk + +/** + * Formatter type specialization for CreateComputeSystemParameters + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::virtdisk::CreateVirtualDiskParameters& params, FormatContext& ctx) const + { + return format_to(ctx.out(), "Size (in bytes): ({}) | Path: ({}) ", params.size_in_bytes, params.path.string()); + } +}; + +#endif // MULTIPASS_HYPERV_API_VIRTDISK_CREATE_VIRTUAL_DISK_PARAMETERS_H diff --git a/src/platform/backends/hyperv_api/virtdisk/virtdisk_disk_info.h b/src/platform/backends/hyperv_api/virtdisk/virtdisk_disk_info.h new file mode 100644 index 0000000000..b967966a36 --- /dev/null +++ b/src/platform/backends/hyperv_api/virtdisk/virtdisk_disk_info.h @@ -0,0 +1,94 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_VIRTDISK_DISK_INFO_H +#define MULTIPASS_HYPERV_API_VIRTDISK_DISK_INFO_H + +#include +#include + +#include +#include + +namespace multipass::hyperv::virtdisk +{ + +struct VirtualDiskInfo +{ + + struct size_info + { + std::uint64_t virtual_{}; + std::uint64_t physical{}; + std::uint64_t block{}; + std::uint64_t sector{}; + }; + std::optional size{}; + std::optional smallest_safe_virtual_size{}; + std::optional virtual_storage_type{}; + std::optional provider_subtype{}; +}; + +} // namespace multipass::hyperv::virtdisk + +/** + * Formatter type specialization for CreateComputeSystemParameters + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::virtdisk::VirtualDiskInfo::size_info& params, FormatContext& ctx) const + { + return format_to(ctx.out(), + "Virtual: ({}) | Physical: ({}) | Block: ({}) | Sector: ({})", + params.virtual_, + params.physical, + params.block, + params.sector); + } +}; + +/** + * Formatter type specialization for CreateComputeSystemParameters + */ +template +struct fmt::formatter +{ + constexpr auto parse(basic_format_parse_context& ctx) + { + return ctx.begin(); + } + + template + auto format(const multipass::hyperv::virtdisk::VirtualDiskInfo& params, FormatContext& ctx) const + { + return format_to(ctx.out(), + "Storage type: {} | Size: {} | Smallest safe size: {} | Provider subtype: {}", + params.virtual_storage_type, + params.size, + params.smallest_safe_virtual_size, + params.provider_subtype); + } +}; + +#endif // MULTIPASS_HYPERV_API_VIRTDISK_DISK_INFO_H diff --git a/src/platform/backends/hyperv_api/virtdisk/virtdisk_wrapper_interface.h b/src/platform/backends/hyperv_api/virtdisk/virtdisk_wrapper_interface.h new file mode 100644 index 0000000000..566f98a12b --- /dev/null +++ b/src/platform/backends/hyperv_api/virtdisk/virtdisk_wrapper_interface.h @@ -0,0 +1,44 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_HYPERV_API_VIRTDISK_WRAPPER_INTERFACE_H +#define MULTIPASS_HYPERV_API_VIRTDISK_WRAPPER_INTERFACE_H + +#include +#include +#include + +#include + +namespace multipass::hyperv::virtdisk +{ + +/** + * Abstract interface for the virtdisk API wrapper. + */ +struct VirtDiskWrapperInterface +{ + virtual OperationResult create_virtual_disk(const CreateVirtualDiskParameters& params) const = 0; + virtual OperationResult resize_virtual_disk(const std::filesystem::path& vhdx_path, + std::uint64_t new_size_bytes) const = 0; + virtual OperationResult get_virtual_disk_info(const std::filesystem::path& vhdx_path, + VirtualDiskInfo& vdinfo) const = 0; + virtual ~VirtDiskWrapperInterface() = default; +}; +} // namespace multipass::hyperv::virtdisk + +#endif diff --git a/tests/hyperv_api/CMakeLists.txt b/tests/hyperv_api/CMakeLists.txt new file mode 100644 index 0000000000..c55bc03c34 --- /dev/null +++ b/tests/hyperv_api/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (C) Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 3 as +# published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# + + +if(WIN32) + target_sources(multipass_tests + PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/test_it_hyperv_hcn_api.cpp + ${CMAKE_CURRENT_LIST_DIR}/test_ut_hyperv_hcn_api.cpp + ${CMAKE_CURRENT_LIST_DIR}/test_it_hyperv_hcs_api.cpp + ${CMAKE_CURRENT_LIST_DIR}/test_ut_hyperv_hcs_api.cpp + ${CMAKE_CURRENT_LIST_DIR}/test_it_hyperv_virtdisk.cpp + ${CMAKE_CURRENT_LIST_DIR}/test_ut_hyperv_virtdisk.cpp + ${CMAKE_CURRENT_LIST_DIR}/test_bb_cit_hyperv.cpp + ) +endif() diff --git a/tests/hyperv_api/hyperv_test_utils.h b/tests/hyperv_api/hyperv_test_utils.h new file mode 100644 index 0000000000..f301aba5e1 --- /dev/null +++ b/tests/hyperv_api/hyperv_test_utils.h @@ -0,0 +1,74 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#ifndef MULTIPASS_TESTS_HYPERV_API_HYPERV_TEST_UTILS_H +#define MULTIPASS_TESTS_HYPERV_API_HYPERV_TEST_UTILS_H + +#include +#include + +#define EXPECT_NO_CALL(mock) EXPECT_CALL(mock, Call).Times(0) + +namespace multipass::test +{ + +inline auto trim_whitespace(const wchar_t* input) +{ + std::wstring str{input}; + str.erase(std::remove_if(str.begin(), str.end(), ::iswspace), str.end()); + return str; +} + +inline auto make_tempfile_path(std::string extension) +{ + + struct auto_remove_path + { + + auto_remove_path(std::filesystem::path p) : path(p) + { + } + + ~auto_remove_path() noexcept + { + std::error_code ec{}; + // Use the noexcept overload + std::filesystem::remove(path, ec); + } + + operator const std::filesystem::path&() const& noexcept + { + return path; + } + + operator const std::filesystem::path&() const&& noexcept = delete; + + private: + const std::filesystem::path path; + }; + char pattern[] = "temp-XXXXXX"; + if (_mktemp_s(pattern) != 0) + { + throw std::runtime_error{"Incorrect format for _mktemp_s."}; + } + const auto filename = pattern + extension; + return auto_remove_path{std::filesystem::temp_directory_path() / filename}; +} + +} // namespace multipass::test + +#endif diff --git a/tests/hyperv_api/test_bb_cit_hyperv.cpp b/tests/hyperv_api/test_bb_cit_hyperv.cpp new file mode 100644 index 0000000000..1fa8740676 --- /dev/null +++ b/tests/hyperv_api/test_bb_cit_hyperv.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include "hyperv_test_utils.h" +#include "tests/common.h" + +#include + +#include +#include +#include +#include +#include + +namespace multipass::test +{ + +using hcn_wrapper_t = hyperv::hcn::HCNWrapper; +using hcs_wrapper_t = hyperv::hcs::HCSWrapper; +using virtdisk_wrapper_t = multipass::hyperv::virtdisk::VirtDiskWrapper; + +// Component level big bang integration tests for Hyper-V HCN/HCS + virtdisk API's. +// These tests ensure that the API's working together as expected. +struct HyperV_ComponentIntegrationTests : public ::testing::Test +{ +}; + +TEST_F(HyperV_ComponentIntegrationTests, spawn_empty_test_vm) +{ + hcn_wrapper_t hcn{}; + hcs_wrapper_t hcs{}; + virtdisk_wrapper_t virtdisk{}; + // 10.0. 0.0 to 10.255. 255.255. + const auto network_parameters = []() { + hyperv::hcn::CreateNetworkParameters network_parameters{}; + network_parameters.name = "multipass-hyperv-cit"; + network_parameters.guid = "b4d77a0e-2507-45f0-99aa-c638f3e47486"; + network_parameters.subnet = "10.99.99.0/24"; + network_parameters.gateway = "10.99.99.1"; + return network_parameters; + }(); + + const auto endpoint_parameters = [&network_parameters]() { + hyperv::hcn::CreateEndpointParameters endpoint_parameters{}; + endpoint_parameters.network_guid = network_parameters.guid; + endpoint_parameters.endpoint_guid = "aee79cf9-54d1-4653-81fb-8110db97029f"; + endpoint_parameters.endpoint_ipvx_addr = "10.99.99.10"; + return endpoint_parameters; + }(); + + const auto temp_path = make_tempfile_path(".vhdx"); + + const auto create_disk_parameters = [&temp_path]() { + hyperv::virtdisk::CreateVirtualDiskParameters create_disk_parameters{}; + create_disk_parameters.path = temp_path; + create_disk_parameters.size_in_bytes = (1024 * 1024) * 512; // 512 MiB + return create_disk_parameters; + }(); + + const auto create_vm_parameters = []() { + hyperv::hcs::CreateComputeSystemParameters vm_parameters{}; + vm_parameters.name = "multipass-hyperv-cit-vm"; + vm_parameters.processor_count = 1; + vm_parameters.memory_size_mb = 512; + return vm_parameters; + }(); + + // Remove remnants from previous tests, if any. + { + if (hcn.delete_endpoint(endpoint_parameters.endpoint_guid)) + { + GTEST_LOG_(WARNING) << "The test endpoint was already present, deleted it."; + } + if (hcn.delete_network(network_parameters.guid)) + { + GTEST_LOG_(WARNING) << "The test network was already present, deleted it."; + } + + if (hcs.terminate_compute_system(create_vm_parameters.name)) + { + GTEST_LOG_(WARNING) << "The test system was already present, terminated it."; + } + } + + const auto add_endpoint_parameters = [&create_vm_parameters, &endpoint_parameters]() { + hyperv::hcs::AddEndpointParameters add_endpoint_parameters{}; + add_endpoint_parameters.endpoint_guid = endpoint_parameters.endpoint_guid; + add_endpoint_parameters.target_compute_system_name = create_vm_parameters.name; + add_endpoint_parameters.nic_mac_address = "00-15-5D-9D-CF-69"; + return add_endpoint_parameters; + }(); + + // Create the test network + { + const auto& [status, status_msg] = hcn.create_network(network_parameters); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } + + // Create the test endpoint + { + const auto& [status, status_msg] = hcn.create_endpoint(endpoint_parameters); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } + + // Create the test VHDX (empty) + { + const auto& [status, status_msg] = virtdisk.create_virtual_disk(create_disk_parameters); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } + + // Create test VM + { + const auto& [status, status_msg] = hcs.create_compute_system(create_vm_parameters); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } + + // Start test VM + { + const auto& [status, status_msg] = hcs.start_compute_system(create_vm_parameters.name); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } + + // Add endpoint + { + const auto& [status, status_msg] = hcs.add_endpoint(add_endpoint_parameters); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } + + EXPECT_TRUE(hcs.terminate_compute_system(create_vm_parameters.name)) << "Terminate system failed!"; + EXPECT_TRUE(hcn.delete_endpoint(endpoint_parameters.endpoint_guid)) << "Delete endpoint failed!"; + EXPECT_TRUE(hcn.delete_network(network_parameters.guid)) << "Delete network failed!"; +} + +} // namespace multipass::test diff --git a/tests/hyperv_api/test_it_hyperv_hcn_api.cpp b/tests/hyperv_api/test_it_hyperv_hcn_api.cpp new file mode 100644 index 0000000000..947d2469aa --- /dev/null +++ b/tests/hyperv_api/test_it_hyperv_hcn_api.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include "tests/common.h" + +#include +#include +#include + +namespace multipass::test +{ + +using uut_t = hyperv::hcn::HCNWrapper; + +struct HyperVHCNAPI_IntegrationTests : public ::testing::Test +{ +}; + +TEST_F(HyperVHCNAPI_IntegrationTests, create_delete_network) +{ + uut_t uut; + hyperv::hcn::CreateNetworkParameters params{}; + params.name = "multipass-hyperv-api-hcn-create-delete-test"; + params.guid = "{b70c479d-f808-4053-aafa-705bc15b6d68}"; + params.subnet = "172.50.224.0/20"; + params.gateway = "172.50.224.1"; + + (void)uut.delete_network(params.guid); + + { + const auto& [success, error_msg] = uut.create_network(params); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } + + { + const auto& [success, error_msg] = uut.delete_network(params.guid); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } +} + +TEST_F(HyperVHCNAPI_IntegrationTests, create_delete_endpoint) +{ + uut_t uut; + hyperv::hcn::CreateNetworkParameters network_params{}; + network_params.name = "multipass-hyperv-api-hcn-create-delete-test"; + network_params.guid = "b70c479d-f808-4053-aafa-705bc15b6d68"; + network_params.subnet = "172.50.224.0/20"; + network_params.gateway = "172.50.224.1"; + + hyperv::hcn::CreateEndpointParameters endpoint_params{}; + + endpoint_params.network_guid = network_params.guid; + endpoint_params.endpoint_guid = "b70c479d-f808-4053-aafa-705bc15b6d70"; + endpoint_params.endpoint_ipvx_addr = "172.50.224.2"; + + (void)uut.delete_network(network_params.guid); + + { + const auto& [success, error_msg] = uut.create_network(network_params); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } + + { + const auto& [success, error_msg] = uut.create_endpoint(endpoint_params); + std::wprintf(L"%s\n", error_msg.c_str()); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } + + { + const auto& [success, error_msg] = uut.delete_endpoint(endpoint_params.endpoint_guid); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } + + { + const auto& [success, error_msg] = uut.delete_network(network_params.guid); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } +} + +} // namespace multipass::test diff --git a/tests/hyperv_api/test_it_hyperv_hcs_api.cpp b/tests/hyperv_api/test_it_hyperv_hcs_api.cpp new file mode 100644 index 0000000000..2a3753ab98 --- /dev/null +++ b/tests/hyperv_api/test_it_hyperv_hcs_api.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include "tests/common.h" + +#include + +#include + +namespace multipass::test +{ + +using uut_t = hyperv::hcs::HCSWrapper; + +struct HyperVHCSAPI_IntegrationTests : public ::testing::Test +{ +}; + +TEST_F(HyperVHCSAPI_IntegrationTests, create_delete_compute_system) +{ + + uut_t uut{}; + + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test"; + params.memory_size_mb = 1024; + params.processor_count = 1; + params.cloudinit_iso_path = ""; + params.vhdx_path = ""; + + const auto c_result = uut.create_compute_system(params); + + ASSERT_TRUE(c_result); + ASSERT_TRUE(c_result.status_msg.empty()); + + const auto d_result = uut.terminate_compute_system(params.name); + ASSERT_TRUE(d_result); + std::wprintf(L"%s\n", d_result.status_msg.c_str()); + ASSERT_FALSE(d_result.status_msg.empty()); +} + +TEST_F(HyperVHCSAPI_IntegrationTests, enumerate_properties) +{ + + uut_t uut{}; + + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test"; + params.memory_size_mb = 1024; + params.processor_count = 1; + params.cloudinit_iso_path = ""; + params.vhdx_path = ""; + + const auto c_result = uut.create_compute_system(params); + + ASSERT_TRUE(c_result); + ASSERT_TRUE(c_result.status_msg.empty()); + + const auto s_result = uut.start_compute_system(params.name); + ASSERT_TRUE(s_result); + ASSERT_TRUE(s_result.status_msg.empty()); + + const auto p_result = uut.get_compute_system_properties(params.name); + EXPECT_TRUE(p_result); + std::wprintf(L"%s\n", p_result.status_msg.c_str()); + + // const auto e_result = uut.enumerate_all_compute_systems(); + // EXPECT_TRUE(e_result); + // std::wprintf(L"%s\n", e_result.status_msg.c_str()); + + const auto d_result = uut.terminate_compute_system(params.name); + ASSERT_TRUE(d_result); + std::wprintf(L"%s\n", d_result.status_msg.c_str()); + ASSERT_FALSE(d_result.status_msg.empty()); +} + +TEST_F(HyperVHCSAPI_IntegrationTests, DISABLED_update_cpu_count) +{ + + uut_t uut{}; + + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test"; + params.memory_size_mb = 1024; + params.processor_count = 1; + params.cloudinit_iso_path = ""; + params.vhdx_path = ""; + + const auto c_result = uut.create_compute_system(params); + + ASSERT_TRUE(c_result); + ASSERT_TRUE(c_result.status_msg.empty()); + + const auto s_result = uut.start_compute_system(params.name); + ASSERT_TRUE(s_result); + ASSERT_TRUE(s_result.status_msg.empty()); + + const auto p_result = uut.get_compute_system_properties(params.name); + EXPECT_TRUE(p_result); + std::wprintf(L"%s\n", p_result.status_msg.c_str()); + + const auto u_result = uut.update_cpu_count(params.name, 8); + EXPECT_TRUE(u_result); + auto v = fmt::format("{}", u_result.code); + std::wprintf(L"%s\n", u_result.status_msg.c_str()); + std::printf("%s \n", v.c_str()); + + // const auto e_result = uut.enumerate_all_compute_systems(); + // EXPECT_TRUE(e_result); + // std::wprintf(L"%s\n", e_result.status_msg.c_str()); + + const auto d_result = uut.terminate_compute_system(params.name); + ASSERT_TRUE(d_result); + std::wprintf(L"%s\n", d_result.status_msg.c_str()); + ASSERT_FALSE(d_result.status_msg.empty()); +} + +} // namespace multipass::test diff --git a/tests/hyperv_api/test_it_hyperv_virtdisk.cpp b/tests/hyperv_api/test_it_hyperv_virtdisk.cpp new file mode 100644 index 0000000000..a55e7256a4 --- /dev/null +++ b/tests/hyperv_api/test_it_hyperv_virtdisk.cpp @@ -0,0 +1,191 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include "hyperv_test_utils.h" +#include "tests/common.h" + +#include + +#include + +#include +#include + +namespace multipass::test +{ + +using uut_t = hyperv::virtdisk::VirtDiskWrapper; + +struct HyperVVirtDisk_IntegrationTests : public ::testing::Test +{ +}; + +TEST_F(HyperVVirtDisk_IntegrationTests, create_virtual_disk_vhdx) +{ + auto temp_path = make_tempfile_path(".vhdx"); + std::wprintf(L"Path: %s\n", static_cast(temp_path).c_str()); + + uut_t uut{}; + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = temp_path; + params.size_in_bytes = 1024 * 1024 * 1024; // 1 GiB + + const auto result = uut.create_virtual_disk(params); + ASSERT_TRUE(result); + ASSERT_TRUE(result.status_msg.empty()); +} + +TEST_F(HyperVVirtDisk_IntegrationTests, create_virtual_disk_vhd) +{ + auto temp_path = make_tempfile_path(".vhd"); + std::wprintf(L"Path: %s\n", static_cast(temp_path).c_str()); + + uut_t uut{}; + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = temp_path; + params.size_in_bytes = 1024 * 1024 * 1024; // 1 GiB + + const auto result = uut.create_virtual_disk(params); + ASSERT_TRUE(result); + ASSERT_TRUE(result.status_msg.empty()); +} + +TEST_F(HyperVVirtDisk_IntegrationTests, get_virtual_disk_properties) +{ + auto temp_path = make_tempfile_path(".vhdx"); + std::wprintf(L"Path: %s\n", static_cast(temp_path).c_str()); + + uut_t uut{}; + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = temp_path; + params.size_in_bytes = 1024 * 1024 * 1024; // 1 GiB + + const auto c_result = uut.create_virtual_disk(params); + ASSERT_TRUE(c_result); + ASSERT_TRUE(c_result.status_msg.empty()); + + hyperv::virtdisk::VirtualDiskInfo info{}; + const auto g_result = uut.get_virtual_disk_info(temp_path, info); + + ASSERT_TRUE(info.virtual_storage_type.has_value()); + ASSERT_TRUE(info.size.has_value()); + + ASSERT_STREQ(info.virtual_storage_type.value().c_str(), "vhdx"); + ASSERT_EQ(info.size->virtual_, 1024 * 1024 * 1024); + ASSERT_EQ(info.size->block, 1024 * 1024); + ASSERT_EQ(info.size->sector, 512); + + fmt::print("{}", info); +} + +TEST_F(HyperVVirtDisk_IntegrationTests, resize_grow) +{ + auto temp_path = make_tempfile_path(".vhdx"); + std::wprintf(L"Path: %s\n", static_cast(temp_path).c_str()); + + uut_t uut{}; + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = temp_path; + params.size_in_bytes = 1024 * 1024 * 1024; // 1 GiB + + const auto c_result = uut.create_virtual_disk(params); + ASSERT_TRUE(c_result); + ASSERT_TRUE(c_result.status_msg.empty()); + + hyperv::virtdisk::VirtualDiskInfo info{}; + const auto g_result = uut.get_virtual_disk_info(temp_path, info); + + ASSERT_TRUE(g_result); + ASSERT_TRUE(info.virtual_storage_type.has_value()); + ASSERT_TRUE(info.size.has_value()); + + ASSERT_STREQ(info.virtual_storage_type.value().c_str(), "vhdx"); + ASSERT_EQ(info.size->virtual_, params.size_in_bytes); + ASSERT_EQ(info.size->block, 1024 * 1024); + ASSERT_EQ(info.size->sector, 512); + + fmt::print("{}", info); + + const auto r_result = uut.resize_virtual_disk(temp_path, params.size_in_bytes * 2); + ASSERT_TRUE(r_result); + + info = {}; + + const auto g2_result = uut.get_virtual_disk_info(temp_path, info); + + ASSERT_TRUE(g2_result); + ASSERT_TRUE(info.virtual_storage_type.has_value()); + ASSERT_TRUE(info.size.has_value()); + + ASSERT_STREQ(info.virtual_storage_type.value().c_str(), "vhdx"); + ASSERT_EQ(info.size->virtual_, params.size_in_bytes * 2); + ASSERT_EQ(info.size->block, 1024 * 1024); + ASSERT_EQ(info.size->sector, 512); + + fmt::print("{}", info); +} + +TEST_F(HyperVVirtDisk_IntegrationTests, DISABLED_resize_shrink) +{ + auto temp_path = make_tempfile_path(".vhdx"); + std::wprintf(L"Path: %s\n", static_cast(temp_path).c_str()); + + uut_t uut{}; + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = temp_path; + params.size_in_bytes = 1024 * 1024 * 1024; // 1 GiB + + const auto c_result = uut.create_virtual_disk(params); + ASSERT_TRUE(c_result); + ASSERT_TRUE(c_result.status_msg.empty()); + + hyperv::virtdisk::VirtualDiskInfo info{}; + const auto g_result = uut.get_virtual_disk_info(temp_path, info); + + ASSERT_TRUE(g_result); + ASSERT_TRUE(info.virtual_storage_type.has_value()); + ASSERT_TRUE(info.size.has_value()); + + ASSERT_STREQ(info.virtual_storage_type.value().c_str(), "vhdx"); + ASSERT_EQ(info.size->virtual_, params.size_in_bytes); + ASSERT_EQ(info.size->block, 1024 * 1024); + ASSERT_EQ(info.size->sector, 512); + + fmt::print("{}", info); + + const auto r_result = uut.resize_virtual_disk(temp_path, params.size_in_bytes / 2); + ASSERT_TRUE(r_result); + + info = {}; + + // SmallestSafeVirtualSize + + const auto g2_result = uut.get_virtual_disk_info(temp_path, info); + + ASSERT_TRUE(g2_result); + ASSERT_TRUE(info.virtual_storage_type.has_value()); + ASSERT_TRUE(info.size.has_value()); + + ASSERT_STREQ(info.virtual_storage_type.value().c_str(), "vhdx"); + ASSERT_EQ(info.size->virtual_, params.size_in_bytes / 2); + ASSERT_EQ(info.size->block, 1024 * 1024); + ASSERT_EQ(info.size->sector, 512); + + fmt::print("{}", info); +} + +} // namespace multipass::test diff --git a/tests/hyperv_api/test_ut_hyperv_hcn_api.cpp b/tests/hyperv_api/test_ut_hyperv_hcn_api.cpp new file mode 100644 index 0000000000..fe44f67fbb --- /dev/null +++ b/tests/hyperv_api/test_ut_hyperv_hcn_api.cpp @@ -0,0 +1,755 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include "hyperv_api/hcn/hyperv_hcn_api_table.h" +#include "hyperv_test_utils.h" +#include "tests/mock_logger.h" + +#include "gmock/gmock.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mpt = multipass::test; +namespace mpl = multipass::logging; + +using testing::DoAll; +using testing::Return; + +namespace multipass::test +{ + +using uut_t = hyperv::hcn::HCNWrapper; + +struct HyperVHCNAPI_UnitTests : public ::testing::Test +{ + + mpt::MockLogger::Scope logger_scope = mpt::MockLogger::inject(); + + void SetUp() override + { + + // Each of the unit tests are expected to have their own mock functions + // and override the mock_api_table with them. Hence, the stub mocks should + // not be called at all. + // If any of them do get called, then: + // + // a-) You have forgotten to mock something + // b-) The implementation is using a function that you didn't expect + // + // Either way, you should have a look. + + EXPECT_NO_CALL(stub_mock_create_network); + EXPECT_NO_CALL(stub_mock_open_network); + EXPECT_NO_CALL(stub_mock_delete_network); + EXPECT_NO_CALL(stub_mock_close_network); + EXPECT_NO_CALL(stub_mock_create_endpoint); + EXPECT_NO_CALL(stub_mock_open_endpoint); + EXPECT_NO_CALL(stub_mock_delete_endpoint); + EXPECT_NO_CALL(stub_mock_close_endpoint); + EXPECT_NO_CALL(stub_mock_cotaskmemfree); + } + + void TearDown() override + { + } + + // Set of placeholder mocks in order to catch *unexpected* calls. + ::testing::MockFunction stub_mock_create_network; + ::testing::MockFunction stub_mock_open_network; + ::testing::MockFunction stub_mock_delete_network; + ::testing::MockFunction stub_mock_close_network; + ::testing::MockFunction stub_mock_create_endpoint; + ::testing::MockFunction stub_mock_open_endpoint; + ::testing::MockFunction stub_mock_delete_endpoint; + ::testing::MockFunction stub_mock_close_endpoint; + ::testing::MockFunction stub_mock_cotaskmemfree; + + // Initialize the API table with stub functions, so if any of these fire without + // our will, we'll know. + hyperv::hcn::HCNAPITable mock_api_table{stub_mock_create_network.AsStdFunction(), + stub_mock_open_network.AsStdFunction(), + stub_mock_delete_network.AsStdFunction(), + stub_mock_close_network.AsStdFunction(), + stub_mock_create_endpoint.AsStdFunction(), + stub_mock_open_endpoint.AsStdFunction(), + stub_mock_delete_endpoint.AsStdFunction(), + stub_mock_close_endpoint.AsStdFunction(), + stub_mock_cotaskmemfree.AsStdFunction()}; + + // Sentinel values as mock API parameters. These handles are opaque handles and + // they're not being dereferenced in any way -- only address values are compared. + inline static auto mock_network_object = reinterpret_cast(0xbadf00d); + inline static auto mock_endpoint_object = reinterpret_cast(0xbadcafe); + + // Generic error message for all tests, intended to be used for API calls returning + // an "error_record". + inline static wchar_t mock_error_msg[16] = L"It's a failure."; +}; + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCNAPI_UnitTests, create_network_success) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_network; + ::testing::MockFunction mock_close_network; + + mock_api_table.CreateNetwork = mock_create_network.AsStdFunction(); + mock_api_table.CloseNetwork = mock_close_network.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_network, Call) + .WillOnce(DoAll( + [&](REFGUID id, PCWSTR settings, PHCN_NETWORK network, PWSTR* error_record) { + constexpr auto expected_network_settings = LR"""( + { + "Name": "multipass-hyperv-api-hcn-create-test", + "Type": "ICS", + "Subnets" : [ + { + "GatewayAddress": "172.50.224.1", + "AddressPrefix" : "172.50.224.0/20", + "IpSubnets" : [ + { + "IpAddressPrefix": "172.50.224.0/20" + } + ] + } + ], + "IsolateSwitch": true, + "Flags" : 265 + } + )"""; + ASSERT_NE(nullptr, network); + ASSERT_EQ(nullptr, *network); + ASSERT_NE(nullptr, error_record); + ASSERT_EQ(nullptr, *error_record); + const auto config_no_whitespace = trim_whitespace(settings); + const auto expected_no_whitespace = trim_whitespace(expected_network_settings); + ASSERT_STREQ(config_no_whitespace.c_str(), expected_no_whitespace.c_str()); + const auto guid_str = hyperv::guid_to_string(id); + ASSERT_EQ("b70c479d-f808-4053-aafa-705bc15b6d68", guid_str); + *network = mock_network_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_network, Call) + .WillOnce(DoAll([&](HCN_NETWORK n) { ASSERT_EQ(n, mock_network_object); }, Return(NOERROR))); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcn::CreateNetworkParameters params{}; + params.name = "multipass-hyperv-api-hcn-create-test"; + params.guid = "{b70c479d-f808-4053-aafa-705bc15b6d68}"; + params.subnet = "172.50.224.0/20"; + params.gateway = "172.50.224.1"; + + const auto& [status, status_msg] = uut.create_network(params); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario 2: HcnCloseNetwork returns an error. + */ +TEST_F(HyperVHCNAPI_UnitTests, create_network_close_network_failed) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_network; + ::testing::MockFunction mock_close_network; + + mock_api_table.CreateNetwork = mock_create_network.AsStdFunction(); + mock_api_table.CloseNetwork = mock_close_network.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_network, Call) + .WillOnce(DoAll( + [&](REFGUID id, PCWSTR settings, PHCN_NETWORK network, PWSTR* error_record) { + *network = mock_network_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_network, Call) + .WillOnce(DoAll([&](HCN_NETWORK n) { ASSERT_EQ(n, mock_network_object); }, Return(E_POINTER))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::create_network(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + hyperv::hcn::CreateNetworkParameters params{}; + params.name = "multipass-hyperv-api-hcn-create-test"; + params.guid = "{b70c479d-f808-4053-aafa-705bc15b6d68}"; + params.subnet = "172.50.224.0/20"; + params.gateway = "172.50.224.1"; + + uut_t uut{mock_api_table}; + const auto& [success, error_msg] = uut.create_network(params); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } +} + +// --------------------------------------------------------- + +/** + * Failure scenario 1: HcnCreateNetwork returns an error. + */ +TEST_F(HyperVHCNAPI_UnitTests, create_network_failed) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_network; + ::testing::MockFunction mock_close_network; + ::testing::MockFunction mock_cotaskmemfree; + + mock_api_table.CreateNetwork = mock_create_network.AsStdFunction(); + mock_api_table.CloseNetwork = mock_close_network.AsStdFunction(); + mock_api_table.CoTaskMemFree = mock_cotaskmemfree.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_network, Call) + .WillOnce(DoAll( + [&](REFGUID id, PCWSTR settings, PHCN_NETWORK network, PWSTR* error_record) { + *network = mock_network_object; + *error_record = mock_error_msg; + }, + Return(E_POINTER))); + + EXPECT_CALL(mock_close_network, Call) + .WillOnce(DoAll([&](HCN_NETWORK n) { ASSERT_EQ(n, mock_network_object); }, Return(NOERROR))); + + EXPECT_CALL(mock_cotaskmemfree, Call).WillOnce([&](void* ptr) { EXPECT_EQ(ptr, mock_error_msg); }); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::create_network(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::error, + "HCNWrapper::create_network(...) > HcnCreateNetwork failed with 0x80004003!"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + hyperv::hcn::CreateNetworkParameters params{}; + params.name = "multipass-hyperv-api-hcn-create-test"; + params.guid = "{b70c479d-f808-4053-aafa-705bc15b6d68}"; + params.subnet = "172.50.224.0/20"; + params.gateway = "172.50.224.1"; + + uut_t uut{mock_api_table}; + const auto& [success, error_msg] = uut.create_network(params); + ASSERT_FALSE(success); + ASSERT_EQ(static_cast(success), E_POINTER); + ASSERT_FALSE(error_msg.empty()); + ASSERT_STREQ(error_msg.c_str(), mock_error_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCNAPI_UnitTests, delete_network_success) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_delete_network; + + mock_api_table.DeleteNetwork = mock_delete_network.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_delete_network, Call) + .WillOnce(DoAll( + [&](REFGUID guid, PWSTR* error_record) { + const auto guid_str = hyperv::guid_to_string(guid); + ASSERT_EQ("af3fb745-2f23-463c-8ded-443f876d9e81", guid_str); + ASSERT_EQ(nullptr, *error_record); + ASSERT_NE(nullptr, error_record); + }, + Return(NOERROR))); + + // Expected logs + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCNWrapper::delete_network(...) > network_guid: af3fb745-2f23-463c-8ded-443f876d9e81"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...) > fn: 0x0, result: true"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + const auto& [status, error_msg] = uut.delete_network("af3fb745-2f23-463c-8ded-443f876d9e81"); + ASSERT_TRUE(status); + ASSERT_TRUE(error_msg.empty()); + } +} + +// --------------------------------------------------------- + +/** + * Failure scenario: API call returns non-success + */ +TEST_F(HyperVHCNAPI_UnitTests, delete_network_failed) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_delete_network; + ::testing::MockFunction mock_cotaskmemfree; + + mock_api_table.DeleteNetwork = mock_delete_network.AsStdFunction(); + mock_api_table.CoTaskMemFree = mock_cotaskmemfree.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_delete_network, Call) + .WillOnce(DoAll( + [&](REFGUID, PWSTR* error_record) { + ASSERT_EQ(nullptr, *error_record); + ASSERT_NE(nullptr, error_record); + *error_record = mock_error_msg; + }, + Return(E_POINTER))); + + EXPECT_CALL(mock_cotaskmemfree, Call).WillOnce([&](void* ptr) { EXPECT_EQ(ptr, mock_error_msg); }); + // Expected logs + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCNWrapper::delete_network(...) > network_guid: af3fb745-2f23-463c-8ded-443f876d9e81"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...) > fn: 0x0, result: false"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + const auto& [status, error_msg] = uut.delete_network("af3fb745-2f23-463c-8ded-443f876d9e81"); + ASSERT_FALSE(status); + ASSERT_FALSE(error_msg.empty()); + ASSERT_STREQ(error_msg.c_str(), mock_error_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCNAPI_UnitTests, create_endpoint_success) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_endpoint; + ::testing::MockFunction mock_close_endpoint; + ::testing::MockFunction mock_open_network; + ::testing::MockFunction mock_close_network; + + mock_api_table.CreateEndpoint = mock_create_endpoint.AsStdFunction(); + mock_api_table.CloseEndpoint = mock_close_endpoint.AsStdFunction(); + mock_api_table.OpenNetwork = mock_open_network.AsStdFunction(); + mock_api_table.CloseNetwork = mock_close_network.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_endpoint, Call) + .WillOnce(DoAll( + [&](HCN_NETWORK network, REFGUID id, PCWSTR settings, PHCN_ENDPOINT endpoint, PWSTR* error_record) { + constexpr auto expected_endpoint_settings = LR"""( + { + "SchemaVersion": { + "Major": 2, + "Minor": 16 + }, + "HostComputeNetwork": "b70c479d-f808-4053-aafa-705bc15b6d68", + "Policies": [ + ], + "IpConfigurations": [ + { + "IpAddress": "172.50.224.27" + } + ] + })"""; + + ASSERT_NE(nullptr, network); + ASSERT_EQ(mock_network_object, network); + ASSERT_NE(nullptr, error_record); + ASSERT_EQ(nullptr, *error_record); + ASSERT_NE(nullptr, endpoint); + ASSERT_EQ(nullptr, *endpoint); + const auto config_no_whitespace = trim_whitespace(settings); + const auto expected_no_whitespace = trim_whitespace(expected_endpoint_settings); + ASSERT_STREQ(config_no_whitespace.c_str(), expected_no_whitespace.c_str()); + const auto endpoint_guid_str = hyperv::guid_to_string(id); + ASSERT_EQ("77c27c1e-8204-437d-a7cc-fb4ce1614819", endpoint_guid_str); + *endpoint = mock_endpoint_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_endpoint, Call) + .WillOnce(DoAll([&](HCN_ENDPOINT n) { ASSERT_EQ(n, mock_endpoint_object); }, Return(NOERROR))); + + EXPECT_CALL(mock_open_network, Call) + .WillOnce(DoAll( + [&](REFGUID id, PHCN_NETWORK network, PWSTR* error_record) { + const auto expected_network_guid_str = hyperv::guid_to_string(id); + ASSERT_EQ("b70c479d-f808-4053-aafa-705bc15b6d68", expected_network_guid_str); + ASSERT_NE(nullptr, network); + ASSERT_EQ(nullptr, *network); + ASSERT_NE(nullptr, error_record); + ASSERT_EQ(nullptr, *error_record); + *network = mock_network_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_network, Call) + .WillOnce(DoAll([&](HCN_NETWORK n) { ASSERT_EQ(n, mock_network_object); }, Return(NOERROR))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCNWrapper::create_endpoint(...) > params: Endpoint GUID: (77c27c1e-8204-437d-a7cc-fb4ce1614819) | " + "Network GUID: (b70c479d-f808-4053-aafa-705bc15b6d68) | Endpoint IPvX Addr.: (172.50.224.27)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "open_network(...) > network_guid: b70c479d-f808-4053-aafa-705bc15b6d68"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "perform_operation(...) > fn: 0x0, result: true", + testing::Exactly(2)); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcn::CreateEndpointParameters params{}; + params.endpoint_guid = "77c27c1e-8204-437d-a7cc-fb4ce1614819"; + params.network_guid = "b70c479d-f808-4053-aafa-705bc15b6d68"; + params.endpoint_ipvx_addr = "172.50.224.27"; + + const auto& [success, error_msg] = uut.create_endpoint(params); + ASSERT_TRUE(success); + ASSERT_TRUE(error_msg.empty()); + } +} + +// --------------------------------------------------------- + +/** + * Failure scenario: internal open_network call fails. + */ +TEST_F(HyperVHCNAPI_UnitTests, create_endpoint_open_network_failed) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_open_network; + + mock_api_table.OpenNetwork = mock_open_network.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_open_network, Call).WillOnce(Return(E_POINTER)); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCNWrapper::create_endpoint(...) > params: Endpoint GUID: (77c27c1e-8204-437d-a7cc-fb4ce1614819) | " + "Network GUID: (b70c479d-f808-4053-aafa-705bc15b6d68) | Endpoint IPvX Addr.: (172.50.224.27)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "open_network(...) > network_guid: b70c479d-f808-4053-aafa-705bc15b6d68"); + logger_scope.mock_logger->expect_log(mpl::Level::error, + "open_network() > HcnOpenNetwork failed with 0x80004003!"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...) > fn: 0x0, result: false"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcn::CreateEndpointParameters params{}; + params.endpoint_guid = "77c27c1e-8204-437d-a7cc-fb4ce1614819"; + params.network_guid = "b70c479d-f808-4053-aafa-705bc15b6d68"; + params.endpoint_ipvx_addr = "172.50.224.27"; + + const auto& [status, error_msg] = uut.create_endpoint(params); + ASSERT_FALSE(status); + ASSERT_EQ(E_POINTER, static_cast(status)); + ASSERT_FALSE(error_msg.empty()); + ASSERT_STREQ(error_msg.c_str(), L"Could not open the network!"); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCNAPI_UnitTests, create_endpoint_failure) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + + ::testing::MockFunction mock_create_endpoint; + ::testing::MockFunction mock_close_endpoint; + ::testing::MockFunction mock_open_network; + ::testing::MockFunction mock_close_network; + ::testing::MockFunction mock_cotaskmemfree; + + mock_api_table.CreateEndpoint = mock_create_endpoint.AsStdFunction(); + mock_api_table.CloseEndpoint = mock_close_endpoint.AsStdFunction(); + mock_api_table.OpenNetwork = mock_open_network.AsStdFunction(); + mock_api_table.CloseNetwork = mock_close_network.AsStdFunction(); + mock_api_table.CoTaskMemFree = mock_cotaskmemfree.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_endpoint, Call) + .WillOnce(DoAll( + [&](HCN_NETWORK network, REFGUID id, PCWSTR settings, PHCN_ENDPOINT endpoint, PWSTR* error_record) { + constexpr auto expected_endpoint_settings = LR"""( + { + "SchemaVersion": { + "Major": 2, + "Minor": 16 + }, + "HostComputeNetwork": "b70c479d-f808-4053-aafa-705bc15b6d68", + "Policies": [ + ], + "IpConfigurations": [ + { + "IpAddress": "172.50.224.27" + } + ] + })"""; + + ASSERT_EQ(mock_network_object, network); + ASSERT_NE(nullptr, error_record); + const auto config_no_whitespace = trim_whitespace(settings); + const auto expected_no_whitespace = trim_whitespace(expected_endpoint_settings); + ASSERT_STREQ(config_no_whitespace.c_str(), expected_no_whitespace.c_str()); + const auto expected_endpoint_guid_str = hyperv::guid_to_string(id); + ASSERT_EQ("77c27c1e-8204-437d-a7cc-fb4ce1614819", expected_endpoint_guid_str); + *endpoint = mock_endpoint_object; + *error_record = mock_error_msg; + }, + Return(E_POINTER))); + + EXPECT_CALL(mock_close_endpoint, Call) + .WillOnce(DoAll([&](HCN_ENDPOINT n) { ASSERT_EQ(n, mock_endpoint_object); }, Return(NOERROR))); + + EXPECT_CALL(mock_open_network, Call) + .WillOnce(DoAll( + [&](REFGUID id, PHCN_NETWORK network, PWSTR* error_record) { + const auto expected_network_guid_str = hyperv::guid_to_string(id); + ASSERT_EQ("b70c479d-f808-4053-aafa-705bc15b6d68", expected_network_guid_str); + ASSERT_NE(nullptr, error_record); + ASSERT_EQ(nullptr, *error_record); + *network = mock_network_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_network, Call) + .WillOnce(DoAll([&](HCN_NETWORK n) { ASSERT_EQ(n, mock_network_object); }, Return(NOERROR))); + + EXPECT_CALL(mock_cotaskmemfree, Call).WillOnce([](const void* ptr) { ASSERT_EQ(ptr, mock_error_msg); }); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCNWrapper::create_endpoint(...) > params: Endpoint GUID: (77c27c1e-8204-437d-a7cc-fb4ce1614819) | " + "Network GUID: (b70c479d-f808-4053-aafa-705bc15b6d68) | Endpoint IPvX Addr.: (172.50.224.27)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "open_network(...) > network_guid: b70c479d-f808-4053-aafa-705bc15b6d68"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...) > fn: 0x0, result: true"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...) > fn: 0x0, result: false"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcn::CreateEndpointParameters params{}; + params.endpoint_guid = "77c27c1e-8204-437d-a7cc-fb4ce1614819"; + params.network_guid = "b70c479d-f808-4053-aafa-705bc15b6d68"; + params.endpoint_ipvx_addr = "172.50.224.27"; + + const auto& [success, error_msg] = uut.create_endpoint(params); + ASSERT_FALSE(success); + ASSERT_FALSE(error_msg.empty()); + ASSERT_STREQ(error_msg.c_str(), mock_error_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCNAPI_UnitTests, delete_endpoint_success) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_delete_endpoint; + + mock_api_table.DeleteEndpoint = mock_delete_endpoint.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_delete_endpoint, Call) + .WillOnce(DoAll( + [&](REFGUID guid, PWSTR* error_record) { + const auto guid_str = hyperv::guid_to_string(guid); + ASSERT_EQ("af3fb745-2f23-463c-8ded-443f876d9e81", guid_str); + ASSERT_EQ(nullptr, *error_record); + ASSERT_NE(nullptr, error_record); + }, + Return(NOERROR))); + + // Expected logs + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCNWrapper::delete_endpoint(...) > endpoint_guid: af3fb745-2f23-463c-8ded-443f876d9e81"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...) > fn: 0x0, result: true"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + const auto& [status, error_msg] = uut.delete_endpoint("af3fb745-2f23-463c-8ded-443f876d9e81"); + ASSERT_TRUE(status); + ASSERT_TRUE(error_msg.empty()); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCNAPI_UnitTests, delete_endpoint_failure) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_delete_endpoint; + ::testing::MockFunction mock_cotaskmemfree; + + mock_api_table.DeleteEndpoint = mock_delete_endpoint.AsStdFunction(); + mock_api_table.CoTaskMemFree = mock_cotaskmemfree.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_delete_endpoint, Call) + .WillOnce(DoAll([&](REFGUID, PWSTR* error_record) { *error_record = mock_error_msg; }, Return(E_POINTER))); + + EXPECT_CALL(mock_cotaskmemfree, Call).WillOnce([](const void* ptr) { ASSERT_EQ(ptr, mock_error_msg); }); + + // Expected logs + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCNWrapper::HCNWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCNWrapper::delete_endpoint(...) > endpoint_guid: af3fb745-2f23-463c-8ded-443f876d9e81"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_operation(...) > fn: 0x0, result: false"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + const auto& [status, error_msg] = uut.delete_endpoint("af3fb745-2f23-463c-8ded-443f876d9e81"); + ASSERT_FALSE(status); + ASSERT_FALSE(error_msg.empty()); + ASSERT_STREQ(error_msg.c_str(), mock_error_msg); + } +} + +} // namespace multipass::test diff --git a/tests/hyperv_api/test_ut_hyperv_hcs_api.cpp b/tests/hyperv_api/test_ut_hyperv_hcs_api.cpp new file mode 100644 index 0000000000..b09ad775fb --- /dev/null +++ b/tests/hyperv_api/test_ut_hyperv_hcs_api.cpp @@ -0,0 +1,2708 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include "hyperv_api/hcs/hyperv_hcs_add_endpoint_params.h" +#include "hyperv_api/hcs/hyperv_hcs_api_wrapper.h" +#include "hyperv_api/hcs/hyperv_hcs_create_compute_system_params.h" +#include "hyperv_test_utils.h" +#include "tests/mock_logger.h" +#include "gmock/gmock.h" + +#include +#include +#include +#include +#include + +namespace mpt = multipass::test; +namespace mpl = multipass::logging; + +using testing::DoAll; +using testing::Return; + +namespace multipass::test +{ +using uut_t = hyperv::hcs::HCSWrapper; + +struct HyperVHCSAPI_UnitTests : public ::testing::Test +{ + mpt::MockLogger::Scope logger_scope = mpt::MockLogger::inject(); + + void SetUp() override + { + + // Each of the unit tests are expected to have their own mock functions + // and override the mock_api_table with them. Hence, the stub mocks should + // not be called at all. + // If any of them do get called, then: + // + // a-) You have forgotten to mock something + // b-) The implementation is using a function that you didn't expect + // + // Either way, you should have a look. + EXPECT_NO_CALL(stub_mock_create_operation); + EXPECT_NO_CALL(stub_mock_wait_for_operation_result); + EXPECT_NO_CALL(stub_mock_close_operation); + EXPECT_NO_CALL(stub_mock_create_compute_system); + EXPECT_NO_CALL(stub_mock_open_compute_system); + EXPECT_NO_CALL(stub_mock_start_compute_system); + EXPECT_NO_CALL(stub_mock_shutdown_compute_system); + EXPECT_NO_CALL(stub_mock_terminate_compute_system); + EXPECT_NO_CALL(stub_mock_close_compute_system); + EXPECT_NO_CALL(stub_mock_pause_compute_system); + EXPECT_NO_CALL(stub_mock_resume_compute_system); + EXPECT_NO_CALL(stub_mock_modify_compute_system); + EXPECT_NO_CALL(stub_mock_get_compute_system_properties); + EXPECT_NO_CALL(stub_mock_grant_vm_access); + EXPECT_NO_CALL(stub_mock_revoke_vm_access); + EXPECT_NO_CALL(stub_mock_enumerate_compute_systems); + EXPECT_NO_CALL(stub_mock_local_free); + } + + void TearDown() override + { + } + + // Set of placeholder mocks in order to catch *unexpected* calls. + ::testing::MockFunction stub_mock_create_operation; + ::testing::MockFunction stub_mock_wait_for_operation_result; + ::testing::MockFunction stub_mock_close_operation; + ::testing::MockFunction stub_mock_create_compute_system; + ::testing::MockFunction stub_mock_open_compute_system; + ::testing::MockFunction stub_mock_start_compute_system; + ::testing::MockFunction stub_mock_shutdown_compute_system; + ::testing::MockFunction stub_mock_terminate_compute_system; + ::testing::MockFunction stub_mock_close_compute_system; + ::testing::MockFunction stub_mock_pause_compute_system; + ::testing::MockFunction stub_mock_resume_compute_system; + ::testing::MockFunction stub_mock_modify_compute_system; + ::testing::MockFunction stub_mock_get_compute_system_properties; + ::testing::MockFunction stub_mock_grant_vm_access; + ::testing::MockFunction stub_mock_revoke_vm_access; + ::testing::MockFunction stub_mock_enumerate_compute_systems; + ::testing::MockFunction stub_mock_local_free; + + // Initialize the API table with stub functions, so if any of these fire without + // our will, we'll know. + hyperv::hcs::HCSAPITable mock_api_table{stub_mock_create_operation.AsStdFunction(), + stub_mock_wait_for_operation_result.AsStdFunction(), + stub_mock_close_operation.AsStdFunction(), + stub_mock_create_compute_system.AsStdFunction(), + stub_mock_open_compute_system.AsStdFunction(), + stub_mock_start_compute_system.AsStdFunction(), + stub_mock_shutdown_compute_system.AsStdFunction(), + stub_mock_terminate_compute_system.AsStdFunction(), + stub_mock_close_compute_system.AsStdFunction(), + stub_mock_pause_compute_system.AsStdFunction(), + stub_mock_resume_compute_system.AsStdFunction(), + stub_mock_modify_compute_system.AsStdFunction(), + stub_mock_get_compute_system_properties.AsStdFunction(), + stub_mock_grant_vm_access.AsStdFunction(), + stub_mock_revoke_vm_access.AsStdFunction(), + stub_mock_enumerate_compute_systems.AsStdFunction(), + stub_mock_local_free.AsStdFunction()}; + + // Sentinel values as mock API parameters. These handles are opaque handles and + // they're not being dereferenced in any way -- only address values are compared. + inline static auto mock_operation_object = reinterpret_cast(0xbadf00d); + inline static auto mock_compute_system_object = reinterpret_cast(0xbadcafe); + + // Generic error message for all tests, intended to be used for API calls returning + // an "error_record". + inline static wchar_t mock_error_msg[16] = L"It's a failure."; + inline static wchar_t mock_success_msg[16] = L"Succeeded."; + inline static wchar_t operation_fail_msg[22] = L"HCS operation failed!"; + inline static wchar_t hcs_create_operation_fail_msg[27] = L"HcsCreateOperation failed!"; + inline static wchar_t hcs_open_compute_system_fail_msg[29] = L"HcsOpenComputeSystem failed!"; + + template + void generic_operation_happy_path(ApiFnT& target_api_function, + UutCallableT uut_callback, + MockCallableT mock_callback, + PWSTR operation_result_document = nullptr, + PWSTR expected_status_msg = nullptr); + + template + void generic_operation_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + MockCallableT mock_callback, + PWSTR expected_status_msg = operation_fail_msg); + + template + void generic_operation_wait_for_operation_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + MockCallableT mock_callback, + PWSTR operation_result_document = mock_error_msg, + PWSTR expected_status_msg = mock_error_msg); + + template + void generic_operation_hcs_open_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + PWSTR expected_status_msg = hcs_open_compute_system_fail_msg); + + template + void generic_operation_create_operation_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + PWSTR expected_status_msg = hcs_create_operation_fail_msg); +}; + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, create_compute_system_happy_path) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_wait_for_operation_result; + ::testing::MockFunction mock_create_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_local_free; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.WaitForOperationResult = mock_wait_for_operation_result.AsStdFunction(); + mock_api_table.CreateComputeSystem = mock_create_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + mock_api_table.LocalFree = mock_local_free.AsStdFunction(); + + constexpr static auto expected_vm_settings_json = LR"( + { + "SchemaVersion": { + "Major": 2, + "Minor": 1 + }, + "Owner": "Multipass", + "ShouldTerminateOnLastHandleClosed": false, + "VirtualMachine": { + "Chipset": { + "Uefi": { + "BootThis": { + "DevicePath": "Primary disk", + "DiskNumber": 0, + "DeviceType": "ScsiDrive" + }, + "Console": "ComPort1" + } + }, + "ComputeTopology": { + "Memory": { + "Backing": "Virtual", + "SizeInMB": 16384 + }, + "Processor": { + "Count": 8 + } + }, + "Devices": { + "ComPorts": { + "0": { + "NamedPipe": "\\\\.\\pipe\\test_vm" + } + }, + "Scsi": { + "cloud-init iso file": { + "Attachments": { + "0": { + "Type": "Iso", + "Path": "cloudinit iso path", + "ReadOnly": true + } + } + }, + "Primary disk": { + "Attachments": { + "0": { + "Type": "VirtualDisk", + "Path": "virtual disk path", + "ReadOnly": false + } + } + }, + } + } + } + })"; + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_wait_for_operation_result, Call) + .WillOnce(DoAll( + [](HCS_OPERATION operation, DWORD timeoutMs, PWSTR* resultDocument) { + ASSERT_EQ(operation, mock_operation_object); + ASSERT_EQ(timeoutMs, 240000); + ASSERT_NE(nullptr, resultDocument); + ASSERT_EQ(nullptr, *resultDocument); + *resultDocument = mock_success_msg; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_create_compute_system, Call) + .WillOnce(DoAll( + [](PCWSTR id, + PCWSTR configuration, + HCS_OPERATION operation, + const SECURITY_DESCRIPTOR* securityDescriptor, + HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(L"test_vm", id); + + const auto config_no_whitespace = trim_whitespace(configuration); + const auto expected_no_whitespace = trim_whitespace(expected_vm_settings_json); + + ASSERT_STREQ(expected_no_whitespace.c_str(), config_no_whitespace.c_str()); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, securityDescriptor); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + EXPECT_CALL(mock_local_free, Call) + .WillOnce(DoAll([](HLOCAL ptr) { ASSERT_EQ(ptr, mock_success_msg); }, Return(nullptr))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCSWrapper::create_compute_system(...) > params: Compute System name: (test_vm) | vCPU count: (8) | " + "Memory size: (16384 MiB) | cloud-init ISO path: (cloudinit iso path) | VHDX path: (virtual disk path)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "wait_for_operation_result(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test_vm"; + params.cloudinit_iso_path = "cloudinit iso path"; + params.vhdx_path = "virtual disk path"; + params.memory_size_mb = 16384; + params.processor_count = 8; + + const auto& [status, status_msg] = uut.create_compute_system(params); + ASSERT_TRUE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), mock_success_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, create_compute_system_wo_cloudinit) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_wait_for_operation_result; + ::testing::MockFunction mock_create_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_local_free; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.WaitForOperationResult = mock_wait_for_operation_result.AsStdFunction(); + mock_api_table.CreateComputeSystem = mock_create_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + mock_api_table.LocalFree = mock_local_free.AsStdFunction(); + + constexpr static auto expected_vm_settings_json = LR"( + { + "SchemaVersion": { + "Major": 2, + "Minor": 1 + }, + "Owner": "Multipass", + "ShouldTerminateOnLastHandleClosed": false, + "VirtualMachine": { + "Chipset": { + "Uefi": { + "BootThis": { + "DevicePath": "Primary disk", + "DiskNumber": 0, + "DeviceType": "ScsiDrive" + }, + "Console": "ComPort1" + } + }, + "ComputeTopology": { + "Memory": { + "Backing": "Virtual", + "SizeInMB": 16384 + }, + "Processor": { + "Count": 8 + } + }, + "Devices": { + "ComPorts": { + "0": { + "NamedPipe": "\\\\.\\pipe\\test_vm" + } + }, + "Scsi": { + "Primary disk": { + "Attachments": { + "0": { + "Type": "VirtualDisk", + "Path": "virtual disk path", + "ReadOnly": false + } + } + }, + } + } + } + })"; + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_wait_for_operation_result, Call) + .WillOnce(DoAll( + [](HCS_OPERATION operation, DWORD timeoutMs, PWSTR* resultDocument) { + ASSERT_EQ(operation, mock_operation_object); + ASSERT_EQ(timeoutMs, 240000); + ASSERT_NE(nullptr, resultDocument); + ASSERT_EQ(nullptr, *resultDocument); + *resultDocument = mock_success_msg; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_create_compute_system, Call) + .WillOnce(DoAll( + [](PCWSTR id, + PCWSTR configuration, + HCS_OPERATION operation, + const SECURITY_DESCRIPTOR* securityDescriptor, + HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(L"test_vm", id); + + const auto config_no_whitespace = trim_whitespace(configuration); + const auto expected_no_whitespace = trim_whitespace(expected_vm_settings_json); + + ASSERT_STREQ(expected_no_whitespace.c_str(), config_no_whitespace.c_str()); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, securityDescriptor); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + EXPECT_CALL(mock_local_free, Call) + .WillOnce(DoAll([](HLOCAL ptr) { ASSERT_EQ(ptr, mock_success_msg); }, Return(nullptr))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCSWrapper::create_compute_system(...) > params: Compute System name: (test_vm) | vCPU count: (8) | " + "Memory size: (16384 MiB) | cloud-init ISO path: () | VHDX path: (virtual disk path)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "wait_for_operation_result(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test_vm"; + params.cloudinit_iso_path = ""; + params.vhdx_path = "virtual disk path"; + params.memory_size_mb = 16384; + params.processor_count = 8; + + const auto& [status, status_msg] = uut.create_compute_system(params); + ASSERT_TRUE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), mock_success_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, create_compute_system_wo_vhdx) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_wait_for_operation_result; + ::testing::MockFunction mock_create_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_local_free; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.WaitForOperationResult = mock_wait_for_operation_result.AsStdFunction(); + mock_api_table.CreateComputeSystem = mock_create_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + mock_api_table.LocalFree = mock_local_free.AsStdFunction(); + + constexpr static auto expected_vm_settings_json = LR"( + { + "SchemaVersion": { + "Major": 2, + "Minor": 1 + }, + "Owner": "Multipass", + "ShouldTerminateOnLastHandleClosed": false, + "VirtualMachine": { + "Chipset": { + "Uefi": { + "BootThis": { + "DevicePath": "Primary disk", + "DiskNumber": 0, + "DeviceType": "ScsiDrive" + }, + "Console": "ComPort1" + } + }, + "ComputeTopology": { + "Memory": { + "Backing": "Virtual", + "SizeInMB": 16384 + }, + "Processor": { + "Count": 8 + } + }, + "Devices": { + "ComPorts": { + "0": { + "NamedPipe": "\\\\.\\pipe\\test_vm" + } + }, + "Scsi": { + "cloud-init iso file": { + "Attachments": { + "0": { + "Type": "Iso", + "Path": "cloudinit iso path", + "ReadOnly": true + } + } + }, + } + } + } + })"; + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_wait_for_operation_result, Call) + .WillOnce(DoAll( + [](HCS_OPERATION operation, DWORD timeoutMs, PWSTR* resultDocument) { + ASSERT_EQ(operation, mock_operation_object); + ASSERT_EQ(timeoutMs, 240000); + ASSERT_NE(nullptr, resultDocument); + ASSERT_EQ(nullptr, *resultDocument); + *resultDocument = mock_success_msg; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_create_compute_system, Call) + .WillOnce(DoAll( + [](PCWSTR id, + PCWSTR configuration, + HCS_OPERATION operation, + const SECURITY_DESCRIPTOR* securityDescriptor, + HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(L"test_vm", id); + + const auto config_no_whitespace = trim_whitespace(configuration); + const auto expected_no_whitespace = trim_whitespace(expected_vm_settings_json); + + ASSERT_STREQ(expected_no_whitespace.c_str(), config_no_whitespace.c_str()); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, securityDescriptor); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + EXPECT_CALL(mock_local_free, Call) + .WillOnce(DoAll([](HLOCAL ptr) { ASSERT_EQ(ptr, mock_success_msg); }, Return(nullptr))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCSWrapper::create_compute_system(...) > params: Compute System name: (test_vm) | vCPU count: (8) | " + "Memory size: (16384 MiB) | cloud-init ISO path: (cloudinit iso path) | VHDX path: ()"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "wait_for_operation_result(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test_vm"; + params.cloudinit_iso_path = "cloudinit iso path"; + params.vhdx_path = ""; + params.memory_size_mb = 16384; + params.processor_count = 8; + + const auto& [status, status_msg] = uut.create_compute_system(params); + ASSERT_TRUE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), mock_success_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, create_compute_system_wo_cloudinit_and_vhdx) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_wait_for_operation_result; + ::testing::MockFunction mock_create_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_local_free; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.WaitForOperationResult = mock_wait_for_operation_result.AsStdFunction(); + mock_api_table.CreateComputeSystem = mock_create_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + mock_api_table.LocalFree = mock_local_free.AsStdFunction(); + + constexpr static auto expected_vm_settings_json = LR"( + { + "SchemaVersion": { + "Major": 2, + "Minor": 1 + }, + "Owner": "Multipass", + "ShouldTerminateOnLastHandleClosed": false, + "VirtualMachine": { + "Chipset": { + "Uefi": { + "BootThis": { + "DevicePath": "Primary disk", + "DiskNumber": 0, + "DeviceType": "ScsiDrive" + }, + "Console": "ComPort1" + } + }, + "ComputeTopology": { + "Memory": { + "Backing": "Virtual", + "SizeInMB": 16384 + }, + "Processor": { + "Count": 8 + } + }, + "Devices": { + "ComPorts": { + "0": { + "NamedPipe": "\\\\.\\pipe\\test_vm" + } + }, + "Scsi": { + } + } + } + })"; + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_wait_for_operation_result, Call) + .WillOnce(DoAll( + [](HCS_OPERATION operation, DWORD timeoutMs, PWSTR* resultDocument) { + ASSERT_EQ(operation, mock_operation_object); + ASSERT_EQ(timeoutMs, 240000); + ASSERT_NE(nullptr, resultDocument); + ASSERT_EQ(nullptr, *resultDocument); + *resultDocument = mock_success_msg; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_create_compute_system, Call) + .WillOnce(DoAll( + [](PCWSTR id, + PCWSTR configuration, + HCS_OPERATION operation, + const SECURITY_DESCRIPTOR* securityDescriptor, + HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(L"test_vm", id); + + const auto config_no_whitespace = trim_whitespace(configuration); + const auto expected_no_whitespace = trim_whitespace(expected_vm_settings_json); + + ASSERT_STREQ(expected_no_whitespace.c_str(), config_no_whitespace.c_str()); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, securityDescriptor); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + EXPECT_CALL(mock_local_free, Call) + .WillOnce(DoAll([](HLOCAL ptr) { ASSERT_EQ(ptr, mock_success_msg); }, Return(nullptr))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "HCSWrapper::create_compute_system(...) > params: Compute System name: (test_vm) | vCPU count: (8) | " + "Memory size: (16384 MiB) | cloud-init ISO path: () | VHDX path: ()"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "wait_for_operation_result(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test_vm"; + params.cloudinit_iso_path = ""; + params.vhdx_path = ""; + params.memory_size_mb = 16384; + params.processor_count = 8; + + const auto& [status, status_msg] = uut.create_compute_system(params); + ASSERT_TRUE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), mock_success_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, create_compute_system_create_operation_fail) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(nullptr))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::create_compute_system(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test_vm"; + params.cloudinit_iso_path = "cloudinit iso path"; + params.vhdx_path = "virtual disk path"; + params.memory_size_mb = 16384; + params.processor_count = 8; + + const auto& [status, status_msg] = uut.create_compute_system(params); + ASSERT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), L"HcsCreateOperation failed."); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, create_compute_system_fail) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_create_compute_system; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.CreateComputeSystem = mock_create_compute_system.AsStdFunction(); + + constexpr static auto expected_vm_settings_json = LR"( + { + "SchemaVersion": { + "Major": 2, + "Minor": 1 + }, + "Owner": "Multipass", + "ShouldTerminateOnLastHandleClosed": false, + "VirtualMachine": { + "Chipset": { + "Uefi": { + "BootThis": { + "DevicePath": "Primary disk", + "DiskNumber": 0, + "DeviceType": "ScsiDrive" + }, + "Console": "ComPort1" + } + }, + "ComputeTopology": { + "Memory": { + "Backing": "Virtual", + "SizeInMB": 16384 + }, + "Processor": { + "Count": 8 + } + }, + "Devices": { + "ComPorts": { + "0": { + "NamedPipe": "\\\\.\\pipe\\test_vm" + } + }, + "Scsi": { + "cloud-init iso file": { + "Attachments": { + "0": { + "Type": "Iso", + "Path": "cloudinit iso path", + "ReadOnly": true + } + } + }, + "Primary disk": { + "Attachments": { + "0": { + "Type": "VirtualDisk", + "Path": "virtual disk path", + "ReadOnly": false + } + } + }, + } + } + } + })"; + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_create_compute_system, Call) + .WillOnce(DoAll( + [](PCWSTR id, + PCWSTR configuration, + HCS_OPERATION operation, + const SECURITY_DESCRIPTOR* securityDescriptor, + HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(L"test_vm", id); + + const auto config_no_whitespace = trim_whitespace(configuration); + const auto expected_no_whitespace = trim_whitespace(expected_vm_settings_json); + + ASSERT_STREQ(expected_no_whitespace.c_str(), config_no_whitespace.c_str()); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, securityDescriptor); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + }, + Return(E_POINTER))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::create_compute_system(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test_vm"; + params.cloudinit_iso_path = "cloudinit iso path"; + params.vhdx_path = "virtual disk path"; + params.memory_size_mb = 16384; + params.processor_count = 8; + + const auto& [status, status_msg] = uut.create_compute_system(params); + ASSERT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), L"HcsCreateComputeSystem failed."); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, create_compute_system_wait_for_operation_fail) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_wait_for_operation_result; + ::testing::MockFunction mock_create_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_local_free; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.WaitForOperationResult = mock_wait_for_operation_result.AsStdFunction(); + mock_api_table.CreateComputeSystem = mock_create_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + mock_api_table.LocalFree = mock_local_free.AsStdFunction(); + + constexpr static auto expected_vm_settings_json = LR"( + { + "SchemaVersion": { + "Major": 2, + "Minor": 1 + }, + "Owner": "Multipass", + "ShouldTerminateOnLastHandleClosed": false, + "VirtualMachine": { + "Chipset": { + "Uefi": { + "BootThis": { + "DevicePath": "Primary disk", + "DiskNumber": 0, + "DeviceType": "ScsiDrive" + }, + "Console": "ComPort1" + } + }, + "ComputeTopology": { + "Memory": { + "Backing": "Virtual", + "SizeInMB": 16384 + }, + "Processor": { + "Count": 8 + } + }, + "Devices": { + "ComPorts": { + "0": { + "NamedPipe": "\\\\.\\pipe\\test_vm" + } + }, + "Scsi": { + "cloud-init iso file": { + "Attachments": { + "0": { + "Type": "Iso", + "Path": "cloudinit iso path", + "ReadOnly": true + } + } + }, + "Primary disk": { + "Attachments": { + "0": { + "Type": "VirtualDisk", + "Path": "virtual disk path", + "ReadOnly": false + } + } + }, + } + } + } + })"; + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_wait_for_operation_result, Call) + .WillOnce(DoAll( + [](HCS_OPERATION operation, DWORD timeoutMs, PWSTR* resultDocument) { + ASSERT_EQ(operation, mock_operation_object); + ASSERT_EQ(timeoutMs, 240000); + ASSERT_NE(nullptr, resultDocument); + ASSERT_EQ(nullptr, *resultDocument); + *resultDocument = mock_error_msg; + }, + Return(E_POINTER))); + + EXPECT_CALL(mock_create_compute_system, Call) + .WillOnce(DoAll( + [](PCWSTR id, + PCWSTR configuration, + HCS_OPERATION operation, + const SECURITY_DESCRIPTOR* securityDescriptor, + HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(L"test_vm", id); + + const auto config_no_whitespace = trim_whitespace(configuration); + const auto expected_no_whitespace = trim_whitespace(expected_vm_settings_json); + + ASSERT_STREQ(expected_no_whitespace.c_str(), config_no_whitespace.c_str()); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, securityDescriptor); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + EXPECT_CALL(mock_local_free, Call) + .WillOnce(DoAll([](HLOCAL ptr) { ASSERT_EQ(ptr, mock_error_msg); }, Return(nullptr))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::create_compute_system(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "wait_for_operation_result(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + hyperv::hcs::CreateComputeSystemParameters params{}; + params.name = "test_vm"; + params.cloudinit_iso_path = "cloudinit iso path"; + params.vhdx_path = "virtual disk path"; + params.memory_size_mb = 16384; + params.processor_count = 8; + + const auto& [status, status_msg] = uut.create_compute_system(params); + ASSERT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), mock_error_msg); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, grant_vm_access_success) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_grant_vm_access; + + mock_api_table.GrantVmAccess = mock_grant_vm_access.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_grant_vm_access, Call) + .WillOnce(DoAll( + [](PCWSTR vmId, PCWSTR filePath) { + ASSERT_NE(nullptr, vmId); + ASSERT_NE(nullptr, filePath); + ASSERT_STREQ(vmId, L"test_vm"); + ASSERT_STREQ(filePath, L"this is a path"); + }, + Return(NOERROR))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "grant_vm_access(...) > name: (test_vm), file_path: (this is a path)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut.grant_vm_access("test_vm", "this is a path"); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, grant_vm_access_fail) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_grant_vm_access; + + mock_api_table.GrantVmAccess = mock_grant_vm_access.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_grant_vm_access, Call) + .WillOnce(DoAll( + [](PCWSTR vmId, PCWSTR filePath) { + ASSERT_NE(nullptr, vmId); + ASSERT_NE(nullptr, filePath); + ASSERT_STREQ(vmId, L"test_vm"); + ASSERT_STREQ(filePath, L"this is a path"); + }, + Return(E_POINTER))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "grant_vm_access(...) > name: (test_vm), file_path: (this is a path)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut.grant_vm_access("test_vm", "this is a path"); + ASSERT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), L"GrantVmAccess failed!"); + } +} + +// --------------------------------------------------------- + +/** + * Success scenario: Everything goes as expected. + */ +TEST_F(HyperVHCSAPI_UnitTests, revoke_vm_access_success) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_revoke_vm_access; + + mock_api_table.RevokeVmAccess = mock_revoke_vm_access.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_revoke_vm_access, Call) + .WillOnce(DoAll( + [](PCWSTR vmId, PCWSTR filePath) { + ASSERT_NE(nullptr, vmId); + ASSERT_NE(nullptr, filePath); + ASSERT_STREQ(vmId, L"test_vm"); + ASSERT_STREQ(filePath, L"this is a path"); + }, + Return(NOERROR))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "revoke_vm_access(...) > name: (test_vm), file_path: (this is a path)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut.revoke_vm_access("test_vm", "this is a path"); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, revoke_vm_access_fail) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_revoke_vm_access; + + mock_api_table.RevokeVmAccess = mock_revoke_vm_access.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_revoke_vm_access, Call) + .WillOnce(DoAll( + [](PCWSTR vmId, PCWSTR filePath) { + ASSERT_NE(nullptr, vmId); + ASSERT_NE(nullptr, filePath); + ASSERT_STREQ(vmId, L"test_vm"); + ASSERT_STREQ(filePath, L"this is a path"); + }, + Return(E_POINTER))); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "revoke_vm_access(...) > name: (test_vm), file_path: (this is a path)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut.revoke_vm_access("test_vm", "this is a path"); + ASSERT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), L"RevokeVmAccess failed!"); + } +} + +// --------------------------------------------------------- + +// +// Below are the skeleton test cases for the functions that are following +// the same pattern. +// + +template +void HyperVHCSAPI_UnitTests::generic_operation_happy_path(ApiFnT& target_api_function, + UutCallableT uut_callback, + MockCallableT mock_callback, + PWSTR operation_result_document, + PWSTR expected_status_msg) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_wait_for_operation_result; + ::testing::MockFunction mock_open_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_target_function; + ::testing::MockFunction mock_local_free; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.WaitForOperationResult = mock_wait_for_operation_result.AsStdFunction(); + mock_api_table.OpenComputeSystem = mock_open_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + target_api_function = mock_target_function.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_wait_for_operation_result, Call) + .WillOnce(DoAll( + [operation_result_document](HCS_OPERATION operation, DWORD timeoutMs, PWSTR* resultDocument) { + ASSERT_EQ(operation, mock_operation_object); + ASSERT_EQ(timeoutMs, 240000); + ASSERT_NE(nullptr, resultDocument); + ASSERT_EQ(nullptr, *resultDocument); + *resultDocument = operation_result_document; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_open_compute_system, Call) + .WillOnce(DoAll( + [&](PCWSTR id, DWORD requestedAccess, HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(id, L"test_vm"); + ASSERT_EQ(requestedAccess, GENERIC_ALL); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_target_function, Call).WillOnce(DoAll(mock_callback, Return(NOERROR))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + if (operation_result_document) + { + mock_api_table.LocalFree = mock_local_free.AsStdFunction(); + EXPECT_CALL(mock_local_free, Call) + .WillOnce(DoAll([operation_result_document](HLOCAL ptr) { ASSERT_EQ(operation_result_document, ptr); }, + Return(nullptr))); + } + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_host_compute_system(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_hcs_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "wait_for_operation_result(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut_callback(uut); + ASSERT_TRUE(status); + + if (nullptr == expected_status_msg) + { + ASSERT_TRUE(status_msg.empty()); + } + else + { + ASSERT_STREQ(status_msg.c_str(), expected_status_msg); + } + } +} + +template +void HyperVHCSAPI_UnitTests::generic_operation_hcs_open_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + PWSTR expected_status_msg) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + + ::testing::MockFunction mock_open_compute_system; + + mock_api_table.OpenComputeSystem = mock_open_compute_system.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + + EXPECT_CALL(mock_open_compute_system, Call) + .WillOnce(DoAll( + [&](PCWSTR id, DWORD requestedAccess, HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(id, L"test_vm"); + ASSERT_EQ(requestedAccess, GENERIC_ALL); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + }, + Return(E_POINTER))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_host_compute_system(...) > name: (test_vm)"); + logger_scope.mock_logger->expect_log( + mpl::Level::error, + "open_host_compute_system(...) > failed to open (test_vm), result code: (0x80004003)"); + logger_scope.mock_logger->expect_log(mpl::Level::error, + "perform_hcs_operation(...) > HcsOpenComputeSystem failed!"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut_callback(uut); + ASSERT_FALSE(status); + + if (nullptr == expected_status_msg) + { + ASSERT_TRUE(status_msg.empty()); + } + else + { + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), expected_status_msg); + } + } +} + +template +void HyperVHCSAPI_UnitTests::generic_operation_create_operation_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + PWSTR expected_status_msg) +{ + + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_open_compute_system; + ::testing::MockFunction mock_close_compute_system; + + mock_api_table.OpenComputeSystem = mock_open_compute_system.AsStdFunction(); + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + + EXPECT_CALL(mock_open_compute_system, Call) + .WillOnce(DoAll( + [&](PCWSTR id, DWORD requestedAccess, HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(id, L"test_vm"); + ASSERT_EQ(requestedAccess, GENERIC_ALL); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(nullptr))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_host_compute_system(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::error, + "perform_hcs_operation(...) > HcsCreateOperation failed!"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut_callback(uut); + ASSERT_FALSE(status); + + if (nullptr == expected_status_msg) + { + ASSERT_TRUE(status_msg.empty()); + } + else + { + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), expected_status_msg); + } + } +} + +template +void HyperVHCSAPI_UnitTests::generic_operation_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + MockCallableT mock_callback, + PWSTR expected_status_msg) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_open_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_target_function; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.OpenComputeSystem = mock_open_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + target_api_function = mock_target_function.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_open_compute_system, Call) + .WillOnce(DoAll( + [&](PCWSTR id, DWORD requestedAccess, HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(id, L"test_vm"); + ASSERT_EQ(requestedAccess, GENERIC_ALL); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_target_function, Call).WillOnce(DoAll(mock_callback, Return(E_POINTER))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_host_compute_system(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::error, "perform_hcs_operation(...) > Operation failed!"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut_callback(uut); + ASSERT_FALSE(status); + if (nullptr == expected_status_msg) + { + ASSERT_TRUE(status_msg.empty()); + } + else + { + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), expected_status_msg); + } + } +} + +template +void HyperVHCSAPI_UnitTests::generic_operation_wait_for_operation_fail(ApiFnT& target_api_function, + UutCallableT uut_callback, + MockCallableT mock_callback, + PWSTR operation_result_document, + PWSTR expected_status_msg) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_operation; + ::testing::MockFunction mock_close_operation; + ::testing::MockFunction mock_wait_for_operation_result; + ::testing::MockFunction mock_open_compute_system; + ::testing::MockFunction mock_close_compute_system; + ::testing::MockFunction mock_target_function; + ::testing::MockFunction mock_local_free; + + mock_api_table.CreateOperation = mock_create_operation.AsStdFunction(); + mock_api_table.CloseOperation = mock_close_operation.AsStdFunction(); + mock_api_table.WaitForOperationResult = mock_wait_for_operation_result.AsStdFunction(); + mock_api_table.OpenComputeSystem = mock_open_compute_system.AsStdFunction(); + mock_api_table.CloseComputeSystem = mock_close_compute_system.AsStdFunction(); + target_api_function = mock_target_function.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_operation, Call) + .WillOnce(DoAll( + [](const void* context, HCS_OPERATION_COMPLETION callback) { + ASSERT_EQ(nullptr, context); + ASSERT_EQ(nullptr, callback); + }, + Return(mock_operation_object))); + + EXPECT_CALL(mock_close_operation, Call).WillOnce([](HCS_OPERATION op) { + ASSERT_EQ(op, mock_operation_object); + }); + + EXPECT_CALL(mock_wait_for_operation_result, Call) + .WillOnce(DoAll( + [operation_result_document](HCS_OPERATION operation, DWORD timeoutMs, PWSTR* resultDocument) { + ASSERT_EQ(operation, mock_operation_object); + ASSERT_EQ(timeoutMs, 240000); + ASSERT_NE(nullptr, resultDocument); + ASSERT_EQ(nullptr, *resultDocument); + *resultDocument = operation_result_document; + }, + Return(E_POINTER))); + + EXPECT_CALL(mock_open_compute_system, Call) + .WillOnce(DoAll( + [&](PCWSTR id, DWORD requestedAccess, HCS_SYSTEM* computeSystem) { + ASSERT_STREQ(id, L"test_vm"); + ASSERT_EQ(requestedAccess, GENERIC_ALL); + ASSERT_NE(nullptr, computeSystem); + ASSERT_EQ(nullptr, *computeSystem); + *computeSystem = mock_compute_system_object; + }, + Return(NOERROR))); + + EXPECT_CALL(mock_target_function, Call).WillOnce(DoAll(mock_callback, Return(NOERROR))); + + EXPECT_CALL(mock_close_compute_system, Call).WillOnce([](HCS_SYSTEM computeSystem) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + }); + + if (operation_result_document) + { + mock_api_table.LocalFree = mock_local_free.AsStdFunction(); + EXPECT_CALL(mock_local_free, Call) + .WillOnce(DoAll([operation_result_document](HLOCAL ptr) { ASSERT_EQ(operation_result_document, ptr); }, + Return(nullptr))); + } + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "HCSWrapper::HCSWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_host_compute_system(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "create_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "perform_hcs_operation(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "wait_for_operation_result(...)"); + } + + /****************************************************** + * Verify the expected outcome. + ******************************************************/ + { + uut_t uut{mock_api_table}; + + const auto& [status, status_msg] = uut_callback(uut); + ASSERT_FALSE(status); + + if (nullptr == expected_status_msg) + { + ASSERT_TRUE(status_msg.empty()); + } + else + { + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), expected_status_msg); + } + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, start_compute_system_happy_path) +{ + generic_operation_happy_path( + mock_api_table.StartComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "start_compute_system(...) > name: (test_vm)"); + return wrapper.start_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, start_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.StartComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "start_compute_system(...)"); + return wrapper.start_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, start_compute_system_create_operation_fail) +{ + + generic_operation_create_operation_fail( + mock_api_table.StartComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "start_compute_system(...)"); + return wrapper.start_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, start_compute_system_fail) +{ + generic_operation_fail( + mock_api_table.StartComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "start_compute_system(...)"); + return wrapper.start_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, start_compute_system_wait_for_operation_result_fail) +{ + generic_operation_wait_for_operation_fail( + mock_api_table.StartComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "start_compute_system(...)"); + return wrapper.start_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, shutdown_compute_system_happy_path) +{ + generic_operation_happy_path( + mock_api_table.ShutDownComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "shutdown_compute_system(...) > name: (test_vm)"); + return wrapper.shutdown_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, shutdown_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.ShutDownComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "shutdown_compute_system(...)"); + return wrapper.shutdown_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, shutdown_compute_system_create_operation_fail) +{ + + generic_operation_create_operation_fail( + mock_api_table.ShutDownComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "shutdown_compute_system(...)"); + return wrapper.shutdown_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, shutdown_compute_system_fail) +{ + generic_operation_fail( + mock_api_table.ShutDownComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "shutdown_compute_system(...)"); + return wrapper.shutdown_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, shutdown_compute_system_wait_for_operation_result_fail) +{ + generic_operation_wait_for_operation_fail( + mock_api_table.ShutDownComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "shutdown_compute_system(...)"); + return wrapper.shutdown_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, terminate_compute_system_happy_path) +{ + generic_operation_happy_path( + mock_api_table.TerminateComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "terminate_compute_system(...) > name: (test_vm)"); + return wrapper.terminate_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, terminate_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.TerminateComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "terminate_compute_system(...)"); + return wrapper.terminate_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, terminate_compute_system_create_operation_fail) +{ + + generic_operation_create_operation_fail( + mock_api_table.TerminateComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "terminate_compute_system(...)"); + return wrapper.terminate_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, terminate_compute_system_fail) +{ + generic_operation_fail( + mock_api_table.TerminateComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "terminate_compute_system(...)"); + return wrapper.terminate_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, terminate_compute_system_wait_for_operation_result_fail) +{ + generic_operation_wait_for_operation_fail( + mock_api_table.TerminateComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "terminate_compute_system(...)"); + return wrapper.terminate_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, pause_compute_system_happy_path) +{ + static constexpr wchar_t expected_pause_option[] = LR"( + { + "SuspensionLevel": "Suspend", + "HostedNotification": { + "Reason": "Save" + } + })"; + + generic_operation_happy_path( + mock_api_table.PauseComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "pause_compute_system(...) > name: (test_vm)"); + return wrapper.pause_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(options); + const auto expected_options_no_whitespace = trim_whitespace(expected_pause_option); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, pause_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.PauseComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "pause_compute_system(...)"); + return wrapper.pause_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, pause_compute_system_create_operation_fail) +{ + + generic_operation_create_operation_fail( + mock_api_table.PauseComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "pause_compute_system(...)"); + return wrapper.pause_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, pause_compute_system_fail) +{ + static constexpr wchar_t expected_pause_option[] = LR"( + { + "SuspensionLevel": "Suspend", + "HostedNotification": { + "Reason": "Save" + } + })"; + generic_operation_fail( + mock_api_table.PauseComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "pause_compute_system(...)"); + return wrapper.pause_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(options); + const auto expected_options_no_whitespace = trim_whitespace(expected_pause_option); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, pause_compute_system_wait_for_operation_result_fail) +{ + static constexpr wchar_t expected_pause_option[] = LR"( + { + "SuspensionLevel": "Suspend", + "HostedNotification": { + "Reason": "Save" + } + })"; + generic_operation_wait_for_operation_fail( + mock_api_table.PauseComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "pause_compute_system(...)"); + return wrapper.pause_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(options); + const auto expected_options_no_whitespace = trim_whitespace(expected_pause_option); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resume_compute_system_happy_path) +{ + generic_operation_happy_path( + mock_api_table.ResumeComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resume_compute_system(...) > name: (test_vm)"); + return wrapper.resume_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resume_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.ResumeComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resume_compute_system(...)"); + return wrapper.resume_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resume_compute_system_create_operation_fail) +{ + + generic_operation_create_operation_fail( + mock_api_table.ResumeComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resume_compute_system(...)"); + return wrapper.resume_compute_system("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resume_compute_system_fail) +{ + generic_operation_fail( + mock_api_table.ResumeComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resume_compute_system(...)"); + return wrapper.resume_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resume_compute_system_wait_for_operation_result_fail) +{ + generic_operation_wait_for_operation_fail( + mock_api_table.ResumeComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resume_compute_system(...)"); + return wrapper.resume_compute_system("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR options) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(options, nullptr); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, add_endpoint_to_compute_system_happy_path) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{288cc1ac-8f31-4a09-9e90-30ad0bcfdbca}", + "RequestType": "Add", + "Settings": { + "EndpointId": "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca", + "MacAddress": "00:00:00:00:00:00", + "InstanceId": "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca" + } + })"; + + generic_operation_happy_path( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "add_endpoint(...) > params: Host Compute System Name: (test_vm) | Endpoint GUID: " + "(288cc1ac-8f31-4a09-9e90-30ad0bcfdbca) | NIC MAC Address: (00:00:00:00:00:00)"); + hyperv::hcs::AddEndpointParameters params{}; + params.endpoint_guid = "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"; + params.nic_mac_address = "00:00:00:00:00:00"; + params.target_compute_system_name = "test_vm"; + return wrapper.add_endpoint(params); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, add_endpoint_to_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "add_endpoint(...)"); + hyperv::hcs::AddEndpointParameters params{}; + params.target_compute_system_name = "test_vm"; + return wrapper.add_endpoint(params); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, add_endpoint_to_compute_system_create_operation_fail) +{ + generic_operation_create_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "add_endpoint(...)"); + hyperv::hcs::AddEndpointParameters params{}; + params.target_compute_system_name = "test_vm"; + return wrapper.add_endpoint(params); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, add_endpoint_to_compute_system_fail) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{288cc1ac-8f31-4a09-9e90-30ad0bcfdbca}", + "RequestType": "Add", + "Settings": { + "EndpointId": "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca", + "MacAddress": "00:00:00:00:00:00", + "InstanceId": "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca" + } + })"; + + generic_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "add_endpoint(...)"); + hyperv::hcs::AddEndpointParameters params{}; + params.endpoint_guid = "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"; + params.nic_mac_address = "00:00:00:00:00:00"; + params.target_compute_system_name = "test_vm"; + return wrapper.add_endpoint(params); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, add_endpoint_to_compute_system_wait_for_operation_result_fail) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{288cc1ac-8f31-4a09-9e90-30ad0bcfdbca}", + "RequestType": "Add", + "Settings": { + "EndpointId": "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca", + "MacAddress": "00:00:00:00:00:00", + "InstanceId": "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca" + } + })"; + + generic_operation_wait_for_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "add_endpoint(...)"); + hyperv::hcs::AddEndpointParameters params{}; + params.endpoint_guid = "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"; + params.nic_mac_address = "00:00:00:00:00:00"; + params.target_compute_system_name = "test_vm"; + return wrapper.add_endpoint(params); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, remove_endpoint_from_compute_system_happy_path) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{288cc1ac-8f31-4a09-9e90-30ad0bcfdbca}", + "RequestType": "Remove" + })"; + + generic_operation_happy_path( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "remove_endpoint(...) > name: (test_vm), endpoint_guid: (288cc1ac-8f31-4a09-9e90-30ad0bcfdbca)"); + return wrapper.remove_endpoint("test_vm", "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, remove_endpoint_from_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "remove_endpoint(...)"); + return wrapper.remove_endpoint("test_vm", "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, remove_endpoint_from_compute_system_create_operation_fail) +{ + generic_operation_create_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "remove_endpoint(...)"); + return wrapper.remove_endpoint("test_vm", "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, remove_endpoint_from_compute_system_fail) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{288cc1ac-8f31-4a09-9e90-30ad0bcfdbca}", + "RequestType": "Remove" + })"; + + generic_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "remove_endpoint(...)"); + return wrapper.remove_endpoint("test_vm", "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, remove_endpoint_from_compute_system_wait_for_operation_result_fail) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/Devices/NetworkAdapters/{288cc1ac-8f31-4a09-9e90-30ad0bcfdbca}", + "RequestType": "Remove" + })"; + + generic_operation_wait_for_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "remove_endpoint(...)"); + return wrapper.remove_endpoint("test_vm", "288cc1ac-8f31-4a09-9e90-30ad0bcfdbca"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resize_memory_of_compute_system_happy_path) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/ComputeTopology/Memory/SizeInMB", + "RequestType": "Update", + "Settings": 16384 + })"; + + generic_operation_happy_path( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "resize_memory(...) > name: (test_vm), new_size_mb: (16384)"); + return wrapper.resize_memory("test_vm", 16384); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resize_memory_of_compute_system_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resize_memory(...)"); + return wrapper.resize_memory("test_vm", 16384); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resize_memory_of_compute_system_create_operation_fail) +{ + generic_operation_create_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resize_memory(...)"); + return wrapper.resize_memory("test_vm", 16384); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resize_memory_of_compute_system_fail) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/ComputeTopology/Memory/SizeInMB", + "RequestType": "Update", + "Settings": 16384 + })"; + + generic_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resize_memory(...)"); + return wrapper.resize_memory("test_vm", 16384); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, resize_memory_of_compute_system_wait_for_operation_result_fail) +{ + constexpr static auto expected_modify_compute_system_configuration = LR"( + { + "ResourcePath": "VirtualMachine/ComputeTopology/Memory/SizeInMB", + "RequestType": "Update", + "Settings": 16384 + })"; + + generic_operation_wait_for_operation_fail( + mock_api_table.ModifyComputeSystem, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "resize_memory(...)"); + return wrapper.resize_memory("test_vm", 16384); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR configuration, HANDLE identity) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(configuration); + const auto expected_options_no_whitespace = trim_whitespace(expected_modify_compute_system_configuration); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_properties_happy_path) +{ + constexpr static auto expected_vm_query = LR"( + { + "PropertyTypes":[] + })"; + + generic_operation_happy_path( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, + "get_compute_system_properties(...) > name: (test_vm)"); + return wrapper.get_compute_system_properties("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR propertyQuery) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(propertyQuery); + const auto expected_options_no_whitespace = trim_whitespace(expected_vm_query); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_properties_hcs_open_fail) +{ + generic_operation_hcs_open_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_properties(...)"); + return wrapper.get_compute_system_properties("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_properties_create_operation_fail) +{ + generic_operation_create_operation_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_properties(...)"); + return wrapper.get_compute_system_properties("test_vm"); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_properties_fail) +{ + constexpr static auto expected_vm_query = LR"( + { + "PropertyTypes":[] + })"; + + generic_operation_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_properties(...)"); + return wrapper.get_compute_system_properties("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR propertyQuery) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(propertyQuery); + const auto expected_options_no_whitespace = trim_whitespace(expected_vm_query); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_properties_wait_for_operation_result_fail) +{ + constexpr static auto expected_vm_query = LR"( + { + "PropertyTypes":[] + })"; + + generic_operation_wait_for_operation_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_properties(...)"); + return wrapper.get_compute_system_properties("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR propertyQuery) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + const auto options_no_whitespace = trim_whitespace(propertyQuery); + const auto expected_options_no_whitespace = trim_whitespace(expected_vm_query); + ASSERT_STREQ(options_no_whitespace.c_str(), expected_options_no_whitespace.c_str()); + }); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_state_happy_path) +{ + static wchar_t result_doc[21] = L"{\"State\": \"Running\"}"; + static wchar_t expected_state[8] = L"Running"; + + generic_operation_happy_path( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_state(...) > name: (test_vm)"); + return wrapper.get_compute_system_state("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR propertyQuery) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(propertyQuery, nullptr); + }, + result_doc, + expected_state); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_state_no_state) +{ + static wchar_t result_doc[21] = L"{\"Frodo\": \"Baggins\"}"; + static wchar_t expected_state[8] = L"Unknown"; + + generic_operation_happy_path( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_state(...)"); + return wrapper.get_compute_system_state("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR propertyQuery) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(propertyQuery, nullptr); + }, + result_doc, + expected_state); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_state_hcs_open_fail) +{ + static wchar_t expected_status_msg[] = L"Unknown"; + generic_operation_hcs_open_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_state(...)"); + return wrapper.get_compute_system_state("test_vm"); + }, + expected_status_msg); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_state_create_operation_fail) +{ + static wchar_t expected_status_msg[] = L"Unknown"; + generic_operation_create_operation_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_state(...)"); + return wrapper.get_compute_system_state("test_vm"); + }, + expected_status_msg); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_state_fail) +{ + static wchar_t expected_status_msg[] = L"Unknown"; + + generic_operation_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_state(...)"); + return wrapper.get_compute_system_state("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR propertyQuery) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, propertyQuery); + }, + expected_status_msg); +} + +// --------------------------------------------------------- + +TEST_F(HyperVHCSAPI_UnitTests, get_compute_system_state_wait_for_operation_result_fail) +{ + static wchar_t expected_status_msg[] = L"Unknown"; + + generic_operation_wait_for_operation_fail( + mock_api_table.GetComputeSystemProperties, + [&](hyperv::hcs::HCSWrapper& wrapper) { + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_compute_system_state(...)"); + return wrapper.get_compute_system_state("test_vm"); + }, + [](HCS_SYSTEM computeSystem, HCS_OPERATION operation, PCWSTR propertyQuery) { + ASSERT_EQ(mock_compute_system_object, computeSystem); + ASSERT_EQ(mock_operation_object, operation); + ASSERT_EQ(nullptr, propertyQuery); + }, + nullptr, + expected_status_msg); +} + +} // namespace multipass::test diff --git a/tests/hyperv_api/test_ut_hyperv_virtdisk.cpp b/tests/hyperv_api/test_ut_hyperv_virtdisk.cpp new file mode 100644 index 0000000000..0365eb6637 --- /dev/null +++ b/tests/hyperv_api/test_ut_hyperv_virtdisk.cpp @@ -0,0 +1,697 @@ +/* + * Copyright (C) Canonical, Ltd. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; version 3. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include "hyperv_test_utils.h" +#include "tests/common.h" +#include "tests/mock_logger.h" + +#include + +#include +#include + +namespace mpt = multipass::test; +namespace mpl = multipass::logging; + +using namespace testing; + +namespace multipass::test +{ + +using uut_t = hyperv::virtdisk::VirtDiskWrapper; + +struct HyperVVirtDisk_UnitTests : public ::testing::Test +{ + mpt::MockLogger::Scope logger_scope = mpt::MockLogger::inject(); + + virtual void SetUp() override + { + + // Each of the unit tests are expected to have their own mock functions + // and override the mock_api_table with them. Hence, the stub mocks should + // not be called at all. + // If any of them do get called, then: + // + // a-) You have forgotten to mock something + // b-) The implementation is using a function that you didn't expect + // + // Either way, you should have a look. + EXPECT_NO_CALL(stub_mock_create_virtual_disk); + EXPECT_NO_CALL(stub_mock_open_virtual_disk); + EXPECT_NO_CALL(stub_mock_resize_virtual_disk); + EXPECT_NO_CALL(stub_mock_get_virtual_disk_information); + EXPECT_NO_CALL(stub_mock_close_handle); + } + + // Set of placeholder mocks in order to catch *unexpected* calls. + ::testing::MockFunction stub_mock_create_virtual_disk; + ::testing::MockFunction stub_mock_open_virtual_disk; + ::testing::MockFunction stub_mock_resize_virtual_disk; + ::testing::MockFunction stub_mock_get_virtual_disk_information; + ::testing::MockFunction stub_mock_close_handle; + + // Initialize the API table with stub functions, so if any of these fire without + // our will, we'll know. + hyperv::virtdisk::VirtDiskAPITable mock_api_table{ + stub_mock_create_virtual_disk.AsStdFunction(), + stub_mock_open_virtual_disk.AsStdFunction(), + stub_mock_resize_virtual_disk.AsStdFunction(), + stub_mock_get_virtual_disk_information.AsStdFunction(), + stub_mock_close_handle.AsStdFunction(), + }; + + // Sentinel values as mock API parameters. These handles are opaque handles and + // they're not being dereferenced in any way -- only address values are compared. + inline static auto mock_handle_object = reinterpret_cast(0xbadf00d); +}; + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, create_virtual_disk_vhdx_happy_path) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_virtual_disk; + ::testing::MockFunction mock_close_handle; + + mock_api_table.CreateVirtualDisk = mock_create_virtual_disk.AsStdFunction(); + mock_api_table.CloseHandle = mock_close_handle.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_virtual_disk, Call) + .WillOnce(DoAll( + [](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + PSECURITY_DESCRIPTOR SecurityDescriptor, + CREATE_VIRTUAL_DISK_FLAG Flags, + ULONG ProviderSpecificFlags, + PCREATE_VIRTUAL_DISK_PARAMETERS Parameters, + LPOVERLAPPED Overlapped, + PHANDLE Handle + + ) { + ASSERT_NE(nullptr, VirtualStorageType); + ASSERT_EQ(VirtualStorageType->DeviceId, VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN); + ASSERT_EQ(VirtualStorageType->VendorId, VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN); + ASSERT_NE(nullptr, Path); + ASSERT_STREQ(Path, L"test.vhdx"); + + ASSERT_EQ(VirtualDiskAccessMask, VIRTUAL_DISK_ACCESS_NONE); + ASSERT_EQ(nullptr, SecurityDescriptor); + ASSERT_EQ(CREATE_VIRTUAL_DISK_FLAG_NONE, Flags); + ASSERT_EQ(0, ProviderSpecificFlags); + ASSERT_NE(nullptr, Parameters); + ASSERT_EQ(Parameters->Version, CREATE_VIRTUAL_DISK_VERSION_2); + ASSERT_EQ(Parameters->Version2.MaximumSize, 2097152); + ASSERT_EQ(Parameters->Version2.BlockSizeInBytes, 1048576); + + ASSERT_EQ(nullptr, Overlapped); + ASSERT_NE(nullptr, Handle); + ASSERT_EQ(nullptr, *Handle); + + *Handle = mock_handle_object; + }, + Return(ERROR_SUCCESS))); + + EXPECT_CALL(mock_close_handle, Call(Eq(mock_handle_object))).WillOnce(Return(true)); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "create_virtual_disk(...) > params: Size (in bytes): (2097152) | Path: (test.vhdx)"); + } + + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = "test.vhdx"; + params.size_in_bytes = 2097152; + + { + uut_t uut{mock_api_table}; + const auto& [status, status_msg] = uut.create_virtual_disk(params); + EXPECT_TRUE(status); + EXPECT_TRUE(status_msg.empty()); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, create_virtual_disk_vhd_happy_path) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_virtual_disk; + ::testing::MockFunction mock_close_handle; + + mock_api_table.CreateVirtualDisk = mock_create_virtual_disk.AsStdFunction(); + mock_api_table.CloseHandle = mock_close_handle.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_virtual_disk, Call) + .WillOnce(DoAll( + [](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + PSECURITY_DESCRIPTOR SecurityDescriptor, + CREATE_VIRTUAL_DISK_FLAG Flags, + ULONG ProviderSpecificFlags, + PCREATE_VIRTUAL_DISK_PARAMETERS Parameters, + LPOVERLAPPED Overlapped, + PHANDLE Handle + + ) { + ASSERT_NE(nullptr, VirtualStorageType); + ASSERT_EQ(VirtualStorageType->DeviceId, VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN); + ASSERT_EQ(VirtualStorageType->VendorId, VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN); + ASSERT_NE(nullptr, Path); + ASSERT_STREQ(Path, L"test.vhd"); + ASSERT_EQ(VirtualDiskAccessMask, VIRTUAL_DISK_ACCESS_NONE); + ASSERT_EQ(nullptr, SecurityDescriptor); + ASSERT_EQ(CREATE_VIRTUAL_DISK_FLAG_NONE, Flags); + ASSERT_EQ(0, ProviderSpecificFlags); + ASSERT_NE(nullptr, Parameters); + ASSERT_EQ(Parameters->Version, CREATE_VIRTUAL_DISK_VERSION_2); + ASSERT_EQ(Parameters->Version2.MaximumSize, 2097152); + ASSERT_EQ(Parameters->Version2.BlockSizeInBytes, 524288); + ASSERT_EQ(nullptr, Overlapped); + ASSERT_NE(nullptr, Handle); + ASSERT_EQ(nullptr, *Handle); + + *Handle = mock_handle_object; + }, + Return(ERROR_SUCCESS))); + + EXPECT_CALL(mock_close_handle, Call(Eq(mock_handle_object))).WillOnce(Return(true)); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "create_virtual_disk(...) > params: Size (in bytes): (2097152) | Path: (test.vhd)"); + } + + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = "test.vhd"; + params.size_in_bytes = 2097152; + + { + uut_t uut{mock_api_table}; + const auto& [status, status_msg] = uut.create_virtual_disk(params); + EXPECT_TRUE(status); + EXPECT_TRUE(status_msg.empty()); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, create_virtual_disk_failed) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_create_virtual_disk; + + mock_api_table.CreateVirtualDisk = mock_create_virtual_disk.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_create_virtual_disk, Call) + .WillOnce(DoAll([](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + PSECURITY_DESCRIPTOR SecurityDescriptor, + CREATE_VIRTUAL_DISK_FLAG Flags, + ULONG ProviderSpecificFlags, + PCREATE_VIRTUAL_DISK_PARAMETERS Parameters, + LPOVERLAPPED Overlapped, + PHANDLE Handle + + ) {}, + Return(ERROR_PATH_NOT_FOUND))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "create_virtual_disk(...) > params: Size (in bytes): (2097152) | Path: (test.vhd)"); + logger_scope.mock_logger->expect_log(mpl::Level::error, + "create_virtual_disk(...) > CreateVirtualDisk failed with 3!"); + } + + hyperv::virtdisk::CreateVirtualDiskParameters params{}; + params.path = "test.vhd"; + params.size_in_bytes = 2097152; + + { + uut_t uut{mock_api_table}; + const auto& [status, status_msg] = uut.create_virtual_disk(params); + EXPECT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), L"CreateVirtualDisk failed with 3!"); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, resize_virtual_disk_happy_path) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_open_virtual_disk; + ::testing::MockFunction mock_resize_virtual_disk; + ::testing::MockFunction mock_close_handle; + + mock_api_table.OpenVirtualDisk = mock_open_virtual_disk.AsStdFunction(); + mock_api_table.ResizeVirtualDisk = mock_resize_virtual_disk.AsStdFunction(); + mock_api_table.CloseHandle = mock_close_handle.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_open_virtual_disk, Call) + .WillOnce(DoAll( + [](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + OPEN_VIRTUAL_DISK_FLAG Flags, + POPEN_VIRTUAL_DISK_PARAMETERS Parameters, + PHANDLE Handle) { + ASSERT_NE(nullptr, VirtualStorageType); + ASSERT_EQ(VirtualStorageType->DeviceId, VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN); + ASSERT_EQ(VirtualStorageType->VendorId, VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN); + ASSERT_NE(nullptr, Path); + ASSERT_STREQ(Path, L"test.vhdx"); + ASSERT_EQ(VIRTUAL_DISK_ACCESS_ALL, VirtualDiskAccessMask); + ASSERT_EQ(OPEN_VIRTUAL_DISK_FLAG_NONE, Flags); + ASSERT_EQ(nullptr, Parameters); + ASSERT_NE(nullptr, Handle); + ASSERT_EQ(nullptr, *Handle); + + *Handle = mock_handle_object; + }, + Return(ERROR_SUCCESS))); + + EXPECT_CALL(mock_resize_virtual_disk, Call) + .WillOnce(DoAll( + [](HANDLE VirtualDiskHandle, + RESIZE_VIRTUAL_DISK_FLAG Flags, + PRESIZE_VIRTUAL_DISK_PARAMETERS Parameters, + LPOVERLAPPED Overlapped) { + ASSERT_EQ(mock_handle_object, VirtualDiskHandle); + ASSERT_EQ(RESIZE_VIRTUAL_DISK_FLAG_NONE, Flags); + ASSERT_NE(nullptr, Parameters); + ASSERT_EQ(Parameters->Version, RESIZE_VIRTUAL_DISK_VERSION_1); + ASSERT_EQ(Parameters->Version1.NewSize, 1234567); + ASSERT_EQ(nullptr, Overlapped); + }, + Return(ERROR_SUCCESS))); + + EXPECT_CALL(mock_close_handle, Call(Eq(mock_handle_object))).WillOnce(Return(true)); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "resize_virtual_disk(...) > vhdx_path: test.vhdx, new_size_bytes: 1234567"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_virtual_disk(...) > vhdx_path: test.vhdx"); + } + + { + uut_t uut{mock_api_table}; + const auto& [status, status_msg] = uut.resize_virtual_disk("test.vhdx", 1234567); + EXPECT_TRUE(status); + EXPECT_TRUE(status_msg.empty()); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, resize_virtual_disk_open_failed) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_open_virtual_disk; + + mock_api_table.OpenVirtualDisk = mock_open_virtual_disk.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_open_virtual_disk, Call) + .WillOnce(DoAll( + [](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + OPEN_VIRTUAL_DISK_FLAG Flags, + POPEN_VIRTUAL_DISK_PARAMETERS Parameters, + PHANDLE Handle) { + + }, + Return(ERROR_PATH_NOT_FOUND))); + + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "resize_virtual_disk(...) > vhdx_path: test.vhdx, new_size_bytes: 1234567"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_virtual_disk(...) > vhdx_path: test.vhdx"); + logger_scope.mock_logger->expect_log(mpl::Level::error, + "open_virtual_disk(...) > OpenVirtualDisk failed with: 3"); + } + + { + uut_t uut{mock_api_table}; + const auto& [status, status_msg] = uut.resize_virtual_disk("test.vhdx", 1234567); + EXPECT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), L"open_virtual_disk failed!"); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, resize_virtual_disk_resize_failed) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_open_virtual_disk; + ::testing::MockFunction mock_resize_virtual_disk; + ::testing::MockFunction mock_close_handle; + + mock_api_table.OpenVirtualDisk = mock_open_virtual_disk.AsStdFunction(); + mock_api_table.ResizeVirtualDisk = mock_resize_virtual_disk.AsStdFunction(); + mock_api_table.CloseHandle = mock_close_handle.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_open_virtual_disk, Call) + .WillOnce(DoAll([](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + OPEN_VIRTUAL_DISK_FLAG Flags, + POPEN_VIRTUAL_DISK_PARAMETERS Parameters, + PHANDLE Handle) { *Handle = mock_handle_object; }, + Return(ERROR_SUCCESS))); + + EXPECT_CALL(mock_resize_virtual_disk, Call) + .WillOnce(DoAll([](HANDLE VirtualDiskHandle, + RESIZE_VIRTUAL_DISK_FLAG Flags, + PRESIZE_VIRTUAL_DISK_PARAMETERS Parameters, + LPOVERLAPPED Overlapped) {}, + Return(ERROR_INVALID_PARAMETER))); + + EXPECT_CALL(mock_close_handle, Call(Eq(mock_handle_object))).WillOnce(Return(true)); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log( + mpl::Level::debug, + "resize_virtual_disk(...) > vhdx_path: test.vhdx, new_size_bytes: 1234567"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_virtual_disk(...) > vhdx_path: test.vhdx"); + logger_scope.mock_logger->expect_log(mpl::Level::error, + "resize_virtual_disk(...) > ResizeVirtualDisk failed with 87!"); + } + + { + uut_t uut{mock_api_table}; + const auto& [status, status_msg] = uut.resize_virtual_disk("test.vhdx", 1234567); + EXPECT_FALSE(status); + ASSERT_FALSE(status_msg.empty()); + ASSERT_STREQ(status_msg.c_str(), L"ResizeVirtualDisk failed with 87!"); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, get_virtual_disk_info_happy_path) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_open_virtual_disk; + ::testing::MockFunction mock_get_virtual_disk_information; + ::testing::MockFunction mock_close_handle; + + mock_api_table.OpenVirtualDisk = mock_open_virtual_disk.AsStdFunction(); + mock_api_table.GetVirtualDiskInformation = mock_get_virtual_disk_information.AsStdFunction(); + mock_api_table.CloseHandle = mock_close_handle.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_open_virtual_disk, Call) + .WillOnce(DoAll( + [](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + OPEN_VIRTUAL_DISK_FLAG Flags, + POPEN_VIRTUAL_DISK_PARAMETERS Parameters, + PHANDLE Handle) { + ASSERT_NE(nullptr, VirtualStorageType); + ASSERT_EQ(VirtualStorageType->DeviceId, VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN); + ASSERT_EQ(VirtualStorageType->VendorId, VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN); + ASSERT_NE(nullptr, Path); + ASSERT_STREQ(Path, L"test.vhdx"); + ASSERT_EQ(VIRTUAL_DISK_ACCESS_ALL, VirtualDiskAccessMask); + ASSERT_EQ(OPEN_VIRTUAL_DISK_FLAG_NONE, Flags); + ASSERT_EQ(nullptr, Parameters); + ASSERT_NE(nullptr, Handle); + ASSERT_EQ(nullptr, *Handle); + + *Handle = mock_handle_object; + }, + Return(ERROR_SUCCESS))); + + // The API will be called for several times. + EXPECT_CALL(mock_get_virtual_disk_information, Call) + .WillOnce(DoAll( + [](HANDLE VirtualDiskHandle, + PULONG VirtualDiskInfoSize, + PGET_VIRTUAL_DISK_INFO VirtualDiskInfo, + PULONG SizeUsed) { + ASSERT_EQ(mock_handle_object, VirtualDiskHandle); + ASSERT_NE(nullptr, VirtualDiskInfoSize); + ASSERT_EQ(sizeof(GET_VIRTUAL_DISK_INFO), *VirtualDiskInfoSize); + ASSERT_NE(nullptr, VirtualDiskInfo); + ASSERT_EQ(nullptr, SizeUsed); + ASSERT_EQ(GET_VIRTUAL_DISK_INFO_SIZE, VirtualDiskInfo->Version); + VirtualDiskInfo->Size.VirtualSize = 1111111; + VirtualDiskInfo->Size.BlockSize = 2222222; + VirtualDiskInfo->Size.PhysicalSize = 3333333; + VirtualDiskInfo->Size.SectorSize = 4444444; + }, + Return(ERROR_SUCCESS))) + .WillOnce(DoAll( + [](HANDLE VirtualDiskHandle, + PULONG VirtualDiskInfoSize, + PGET_VIRTUAL_DISK_INFO VirtualDiskInfo, + PULONG SizeUsed) { + ASSERT_EQ(mock_handle_object, VirtualDiskHandle); + ASSERT_NE(nullptr, VirtualDiskInfoSize); + ASSERT_EQ(sizeof(GET_VIRTUAL_DISK_INFO), *VirtualDiskInfoSize); + ASSERT_NE(nullptr, VirtualDiskInfo); + ASSERT_EQ(nullptr, SizeUsed); + ASSERT_EQ(GET_VIRTUAL_DISK_INFO_VIRTUAL_STORAGE_TYPE, VirtualDiskInfo->Version); + VirtualDiskInfo->VirtualStorageType.DeviceId = VIRTUAL_STORAGE_TYPE_DEVICE_VHDX; + VirtualDiskInfo->VirtualStorageType.VendorId = VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN; + }, + Return(ERROR_SUCCESS))) + .WillOnce(DoAll( + [](HANDLE VirtualDiskHandle, + PULONG VirtualDiskInfoSize, + PGET_VIRTUAL_DISK_INFO VirtualDiskInfo, + PULONG SizeUsed) { + ASSERT_EQ(mock_handle_object, VirtualDiskHandle); + ASSERT_NE(nullptr, VirtualDiskInfoSize); + ASSERT_EQ(sizeof(GET_VIRTUAL_DISK_INFO), *VirtualDiskInfoSize); + ASSERT_NE(nullptr, VirtualDiskInfo); + ASSERT_EQ(nullptr, SizeUsed); + ASSERT_EQ(GET_VIRTUAL_DISK_INFO_SMALLEST_SAFE_VIRTUAL_SIZE, VirtualDiskInfo->Version); + VirtualDiskInfo->SmallestSafeVirtualSize = 123456; + }, + Return(ERROR_SUCCESS))) + .WillOnce(DoAll( + [](HANDLE VirtualDiskHandle, + PULONG VirtualDiskInfoSize, + PGET_VIRTUAL_DISK_INFO VirtualDiskInfo, + PULONG SizeUsed) { + ASSERT_EQ(mock_handle_object, VirtualDiskHandle); + ASSERT_NE(nullptr, VirtualDiskInfoSize); + ASSERT_EQ(sizeof(GET_VIRTUAL_DISK_INFO), *VirtualDiskInfoSize); + ASSERT_NE(nullptr, VirtualDiskInfo); + ASSERT_EQ(nullptr, SizeUsed); + ASSERT_EQ(GET_VIRTUAL_DISK_INFO_PROVIDER_SUBTYPE, VirtualDiskInfo->Version); + VirtualDiskInfo->ProviderSubtype = 3; // dynamic + }, + Return(ERROR_SUCCESS))); + + EXPECT_CALL(mock_close_handle, Call(Eq(mock_handle_object))).WillOnce(Return(true)); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_virtual_disk_info(...) > vhdx_path: test.vhdx"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_virtual_disk(...) > vhdx_path: test.vhdx"); + } + + { + uut_t uut{mock_api_table}; + hyperv::virtdisk::VirtualDiskInfo info{}; + const auto& [status, status_msg] = uut.get_virtual_disk_info("test.vhdx", info); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + + ASSERT_TRUE(info.size.has_value()); + ASSERT_TRUE(info.smallest_safe_virtual_size.has_value()); + ASSERT_TRUE(info.provider_subtype.has_value()); + ASSERT_TRUE(info.virtual_storage_type.has_value()); + + ASSERT_EQ(info.size->virtual_, 1111111); + ASSERT_EQ(info.size->block, 2222222); + ASSERT_EQ(info.size->physical, 3333333); + ASSERT_EQ(info.size->sector, 4444444); + + ASSERT_STREQ(info.virtual_storage_type.value().c_str(), "vhdx"); + ASSERT_EQ(info.smallest_safe_virtual_size.value(), 123456); + ASSERT_STREQ(info.provider_subtype.value().c_str(), "dynamic"); + } +} + +// --------------------------------------------------------- + +TEST_F(HyperVVirtDisk_UnitTests, get_virtual_disk_info_fail_some) +{ + /****************************************************** + * Override the default mock functions. + ******************************************************/ + ::testing::MockFunction mock_open_virtual_disk; + ::testing::MockFunction mock_get_virtual_disk_information; + ::testing::MockFunction mock_close_handle; + + mock_api_table.OpenVirtualDisk = mock_open_virtual_disk.AsStdFunction(); + mock_api_table.GetVirtualDiskInformation = mock_get_virtual_disk_information.AsStdFunction(); + mock_api_table.CloseHandle = mock_close_handle.AsStdFunction(); + + /****************************************************** + * Verify that the dependencies are called with right + * data. + ******************************************************/ + { + EXPECT_CALL(mock_open_virtual_disk, Call) + .WillOnce(DoAll( + [](PVIRTUAL_STORAGE_TYPE VirtualStorageType, + PCWSTR Path, + VIRTUAL_DISK_ACCESS_MASK VirtualDiskAccessMask, + OPEN_VIRTUAL_DISK_FLAG Flags, + POPEN_VIRTUAL_DISK_PARAMETERS Parameters, + PHANDLE Handle) { + ASSERT_NE(nullptr, VirtualStorageType); + ASSERT_EQ(VirtualStorageType->DeviceId, VIRTUAL_STORAGE_TYPE_DEVICE_UNKNOWN); + ASSERT_EQ(VirtualStorageType->VendorId, VIRTUAL_STORAGE_TYPE_VENDOR_UNKNOWN); + ASSERT_NE(nullptr, Path); + ASSERT_STREQ(Path, L"test.vhdx"); + ASSERT_EQ(VIRTUAL_DISK_ACCESS_ALL, VirtualDiskAccessMask); + ASSERT_EQ(OPEN_VIRTUAL_DISK_FLAG_NONE, Flags); + ASSERT_EQ(nullptr, Parameters); + ASSERT_NE(nullptr, Handle); + ASSERT_EQ(nullptr, *Handle); + + *Handle = mock_handle_object; + }, + Return(ERROR_SUCCESS))); + + // The API will be called for several times. + EXPECT_CALL(mock_get_virtual_disk_information, Call) + .WillOnce(DoAll( + [](HANDLE VirtualDiskHandle, + PULONG VirtualDiskInfoSize, + PGET_VIRTUAL_DISK_INFO VirtualDiskInfo, + PULONG SizeUsed) { + ASSERT_EQ(mock_handle_object, VirtualDiskHandle); + ASSERT_NE(nullptr, VirtualDiskInfoSize); + ASSERT_EQ(sizeof(GET_VIRTUAL_DISK_INFO), *VirtualDiskInfoSize); + ASSERT_NE(nullptr, VirtualDiskInfo); + ASSERT_EQ(nullptr, SizeUsed); + ASSERT_EQ(GET_VIRTUAL_DISK_INFO_SIZE, VirtualDiskInfo->Version); + VirtualDiskInfo->Size.VirtualSize = 1111111; + VirtualDiskInfo->Size.BlockSize = 2222222; + VirtualDiskInfo->Size.PhysicalSize = 3333333; + VirtualDiskInfo->Size.SectorSize = 4444444; + }, + Return(ERROR_SUCCESS))) + .WillOnce(Return(ERROR_INVALID_PARAMETER)) + .WillOnce(DoAll( + [](HANDLE VirtualDiskHandle, + PULONG VirtualDiskInfoSize, + PGET_VIRTUAL_DISK_INFO VirtualDiskInfo, + PULONG SizeUsed) { + ASSERT_EQ(mock_handle_object, VirtualDiskHandle); + ASSERT_NE(nullptr, VirtualDiskInfoSize); + ASSERT_EQ(sizeof(GET_VIRTUAL_DISK_INFO), *VirtualDiskInfoSize); + ASSERT_NE(nullptr, VirtualDiskInfo); + ASSERT_EQ(nullptr, SizeUsed); + ASSERT_EQ(GET_VIRTUAL_DISK_INFO_SMALLEST_SAFE_VIRTUAL_SIZE, VirtualDiskInfo->Version); + VirtualDiskInfo->SmallestSafeVirtualSize = 123456; + }, + Return(ERROR_SUCCESS))) + .WillOnce(Return(ERROR_INVALID_PARAMETER)); + + EXPECT_CALL(mock_close_handle, Call(Eq(mock_handle_object))).WillOnce(Return(true)); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "VirtDiskWrapper::VirtDiskWrapper(...)"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "get_virtual_disk_info(...) > vhdx_path: test.vhdx"); + logger_scope.mock_logger->expect_log(mpl::Level::debug, "open_virtual_disk(...) > vhdx_path: test.vhdx"); + logger_scope.mock_logger->expect_log(mpl::Level::warning, "get_virtual_disk_info(...) > failed to get 6"); + logger_scope.mock_logger->expect_log(mpl::Level::warning, "get_virtual_disk_info(...) > failed to get 7"); + } + + { + uut_t uut{mock_api_table}; + hyperv::virtdisk::VirtualDiskInfo info{}; + const auto& [status, status_msg] = uut.get_virtual_disk_info("test.vhdx", info); + ASSERT_TRUE(status); + ASSERT_TRUE(status_msg.empty()); + + ASSERT_TRUE(info.size.has_value()); + ASSERT_FALSE(info.virtual_storage_type.has_value()); + ASSERT_TRUE(info.smallest_safe_virtual_size.has_value()); + ASSERT_FALSE(info.provider_subtype.has_value()); + + ASSERT_EQ(info.size->virtual_, 1111111); + ASSERT_EQ(info.size->block, 2222222); + ASSERT_EQ(info.size->physical, 3333333); + ASSERT_EQ(info.size->sector, 4444444); + + ASSERT_EQ(info.smallest_safe_virtual_size.value(), 123456); + } +} + +} // namespace multipass::test