diff --git a/.clang-tidy b/.clang-tidy index b733b7a..4aca86f 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,4 +1,4 @@ -Checks: '-*,bugprone-*,performance-*,readability-*,google-global-names-in-headers,cert-dcl59-cpp,-bugprone-easily-swappable-parameters,-readability-identifier-length,-readability-magic-numbers,-readability-function-cognitive-complexity,-readability-function-size' +Checks: '-*,bugprone-*,performance-*,readability-*,google-global-names-in-headers,cert-dcl59-cpp,-bugprone-easily-swappable-parameters,-readability-identifier-length,-readability-magic-numbers,-readability-function-cognitive-complexity,-readability-function-size,-readability-convert-member-functions-to-static' CheckOptions: - key: performance-unnecessary-value-param.AllowedTypes @@ -22,6 +22,9 @@ CheckOptions: # readability-magic-numbers: # Too strict. # +# readability-convert-member-functions-to-static +# Too many false positives. +# # readability-function-cognitive-complexity # Catch2. # diff --git a/.github/workflows/Build.yml b/.github/workflows/Build.yml index a478e9d..383e892 100644 --- a/.github/workflows/Build.yml +++ b/.github/workflows/Build.yml @@ -7,9 +7,9 @@ on: - '!master' jobs: - build: - name: Release - runs-on: ubuntu-latest + release: + name: release + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 @@ -17,6 +17,9 @@ jobs: - name: Build shell: bash run: | - cmake . -B build + cmake . -B build -DSCL_BUILD_DOCUMENTATION=OFF cmake --build build + - name: Test + shell: bash + run: cmake --build build --target test diff --git a/.github/workflows/Checks.yml b/.github/workflows/Checks.yml index 93089d7..a9422d2 100644 --- a/.github/workflows/Checks.yml +++ b/.github/workflows/Checks.yml @@ -7,22 +7,9 @@ on: - '!master' jobs: - documentation: - name: Documentation - runs-on: ubuntu-20.04 - steps: - - uses: actions/checkout@v2 - - - name: Setup - run: sudo apt-get install -y doxygen - - - name: Documentation - shell: bash - run: ./scripts/build_documentation.sh - headers: - name: Header files - runs-on: ubuntu-latest + name: headers + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 @@ -33,17 +20,14 @@ jobs: run: ./scripts/check_header_guards.py style: - name: Code style - runs-on: ubuntu-latest + name: style + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - name: Setup - run: sudo apt-get install -y clang-format + run: sudo apt-get install -y clang-format-15 - name: Check shell: bash - run: | - find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n {} \; &> checks.txt - cat checks.txt - test ! -s checks.txt + run: ./scripts/check_formatting.sh diff --git a/.github/workflows/Coverage.yml b/.github/workflows/Coverage.yml new file mode 100644 index 0000000..6121963 --- /dev/null +++ b/.github/workflows/Coverage.yml @@ -0,0 +1,36 @@ +name: Coverage + +on: + push: + branches: + - '*' + - '!master' + +jobs: + coverage: + name: coverage + runs-on: ubuntu-24.04 + env: + COV_THRESHOLD_LINES: 95 + COV_THRESHOLD_FUNCS: 90 + + steps: + - uses: actions/checkout@v2 + + - name: Setup + shell: bash + run: sudo apt-get install -y lcov + + - name: Build + shell: bash + run: | + cmake . -B build -DSCL_BUILD_TEST_WITH_COVERAGE=ON -DSCL_BUILD_DOCUMENTATION=OFF + cmake --build build + + - name: Compute coverage + shell: bash + run: cmake --build build --target coverage | tee cov.txt + + - name: Check coverage + shell: bash + run: ./scripts/check_coverage.sh cov.txt diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml new file mode 100644 index 0000000..15a7627 --- /dev/null +++ b/.github/workflows/Documentation.yml @@ -0,0 +1,35 @@ +name: Documentation + +on: + push: + branches: + - '*' + - '!master' + +jobs: + documentation: + name: documentation + runs-on: ubuntu-22.04 + + steps: + - uses: actions/checkout@v2 + + - name: Cache doxygen + id: cache-doxygen + uses: actions/cache@v4 + with: + path: ~/doxygen + key: ${{ runner.os }}-doxygen + + - name: Install doxygen + if: ${{ steps.cache-doxygen.outputs.cache-hit != 'true' }} + run: | + curl https://www.doxygen.nl/files/doxygen-1.10.0.linux.bin.tar.gz -o doxygen.tar.gz + mkdir ~/doxygen + tar xf doxygen.tar.gz -C ~/doxygen --strip-components=1 + + - name: Generate documentation + shell: bash + run: | + cmake . -B build -DSCL_BUILD_TESTS=OFF -DSCL_DOXYGEN_BIN=$HOME/doxygen/bin/doxygen + cmake --build build --target documentation diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml deleted file mode 100644 index abb6c87..0000000 --- a/.github/workflows/Test.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: Test - -on: - push: - branches: - - '*' - - '!master' - -env: - BUILD_TYPE: Debug - -jobs: - build: - name: Coverage and Linting - runs-on: ubuntu-20.04 - - steps: - - uses: actions/checkout@v2 - - - name: Setup - run: | - sudo apt-get install -y lcov bear - curl -L https://github.com/catchorg/Catch2/archive/v2.13.0.tar.gz -o c.tar.gz - tar xvf c.tar.gz - cd Catch2-2.13.0/ - cmake -B catch -DBUILD_TESTING=OFF - cmake --build catch - sudo cmake --install catch - - - name: CMake - run: cmake -B ${{runner.workspace}}/build -DCMAKE_BUILD_TYPE=$BUILD_TYPE . - - - name: Build - working-directory: ${{runner.workspace}}/build - shell: bash - run: bear make -s -j4 - - - name: Coverage - shell: bash - run: | - cmake --build ${{runner.workspace}}/build --target coverage - lcov --summary ${{runner.workspace}}/build/coverage.info >> ${{runner.workspace}}/summary.txt - ./scripts/check_coverage.py ${{runner.workspace}}/summary.txt - - - name: Lint - shell: bash - run: | - find include/ src/ test/ -type f \( -iname "*.h" -o -iname "*.cc" \) \ - -exec clang-tidy -p ${{runner.workspace}}/build/compile_commands.json --quiet {} \; 1>> lint.txt 2>/dev/null - cat lint.txt - test ! -s lint.txt - diff --git a/.gitignore b/.gitignore index 075d1d2..2f2d264 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,11 @@ .cache/ .clangd build/ +buildDebug/ +buildRelease/ +buildOMP/ compile_commands.json # do not track compiled documentation files doc/html + diff --git a/CMakeLists.txt b/CMakeLists.txt index b8866c6..c9ce8d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ # SCL --- Secure Computation Library -# Copyright (C) 2022 Anders Dalskov +# Copyright (C) 2024 Anders Dalskov # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,72 +14,117 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -cmake_minimum_required( VERSION 3.14 ) +cmake_minimum_required(VERSION 3.5) -project( scl VERSION 0.7.0 DESCRIPTION "Secure Computation Library" ) +project(scl VERSION 0.1.0 DESCRIPTION "Secure Computation Library" LANGUAGES CXX) -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release) -endif() +option(SCL_BUILD_TESTS + "Build tests for SCL" + ON) + +option( + SCL_BUILD_TEST_WITH_COVERAGE + "Build test with coverage. Implies SCL_BUILD_TESTS=ON" + OFF) + +option( + SCL_BUILD_WITH_ADDRESS_SANITIZATION + "Buid SCL with -fsanitize=address. Implies SCL_BUILD_TESTS=ON" + OFF) -option(WITH_EC "Include support for elliptic curves (requires GMP)" ON) +option( + SCL_BUILD_DOCUMENTATION + "Build documentation for SCL" + ON) -message(STATUS "CMAKE_BUILD_TYPE=" ${CMAKE_BUILD_TYPE}) -message(STATUS "WITH_EC=" ${WITH_EC}) +# This option is a only really here because of a github action. It might be +# better to remove it and fix the action. +option( + SCL_DOXYGEN_BIN + "Optional location of doxygen binary" + "") -if(WITH_EC MATCHES ON) - find_library(GMP gmp libgmp REQUIRED) +option( + SCL_EXPORT_COMPILE_COMMANDS + "Generate compile_commands.json" + ON) + +if(SCL_BUILD_TEST_WITH_COVERAGE) + set(SCL_BUILD_TESTS ON) + set(CMAKE_BUILD_TYPE RelWithDebInfo) endif() -set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -Wall -Wextra -pedantic -Werror -std=gnu++17") +find_library(GMP gmp libgmp REQUIRED) +set(SCL_HEADERS "${CMAKE_SOURCE_DIR}/include") set(SCL_SOURCE_FILES src/scl/util/str.cc src/scl/util/prg.cc src/scl/util/sha3.cc src/scl/util/sha256.cc src/scl/util/cmdline.cc + src/scl/util/measurement.cc + + src/scl/math/fields/ff_ops_gmp.cc + src/scl/math/fields/mersenne61.cc + src/scl/math/fields/mersenne127.cc + src/scl/math/fields/secp256k1_field.cc + src/scl/math/fields/secp256k1_scalar.cc + src/scl/math/curves/secp256k1_curve.cc + src/scl/math/number.cc - src/scl/math/mersenne61.cc - src/scl/math/mersenne127.cc + src/scl/coro/runtime.cc src/scl/net/config.cc - src/scl/net/channel.cc - src/scl/net/mem_channel.cc - src/scl/net/threaded_sender.cc src/scl/net/network.cc - src/scl/simulation/simulator.cc - src/scl/simulation/simulate_recv_time.cc - src/scl/simulation/config.cc src/scl/simulation/event.cc - src/scl/simulation/measurement.cc - src/scl/simulation/result.cc + src/scl/simulation/context.cc + src/scl/simulation/config.cc + src/scl/simulation/transport.cc src/scl/simulation/channel.cc - src/scl/simulation/context.cc) - -if(WITH_EC MATCHES ON) - set(SCL_SOURCE_FILES ${SCL_SOURCE_FILES} - src/scl/math/ops_gmp_ff.cc - src/scl/math/secp256k1_field.cc - src/scl/math/secp256k1_curve.cc - src/scl/math/secp256k1_scalar.cc - src/scl/math/number.cc) -endif() + src/scl/simulation/simulator.cc + src/scl/simulation/runtime.cc +) -set(SCL_HEADERS "${CMAKE_SOURCE_DIR}/include") +set(CMAKE_EXPORT_COMPILE_COMMANDS ${SCL_EXPORT_COMPILE_COMMANDS}) + +set(CMAKE_CXX_STANDARD 20) +set(CXX_STANDARD 20) + +add_library(scl STATIC ${SCL_SOURCE_FILES}) +target_include_directories(scl PUBLIC "${SCL_HEADERS}") +target_compile_options(scl PUBLIC "-march=native") +target_compile_options(scl PUBLIC "-Wall") +target_compile_options(scl PUBLIC "-Wextra") +target_compile_options(scl PUBLIC "-pedantic") -include_directories(${SCL_HEADERS}) +## indicates that SCL is being built with some extra flags that will +## produce a non-optimal build. +set(SCL_SPECIAL_BUILD OFF) -if(CMAKE_BUILD_TYPE MATCHES "Release") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") - add_library(scl SHARED ${SCL_SOURCE_FILES}) +if(SCL_BUILD_WITH_ADDRESS_SANITIZATION) + target_compile_options(scl PUBLIC "-fsanitize=address") + target_link_libraries(scl PUBLIC "-fsanitize=address") + set(SCL_SPECIAL_BUILD ON) +endif() + +if(SCL_BUILD_TEST_WITH_COVERAGE) + target_compile_options(scl PUBLIC "-O0") + target_compile_options(scl PUBLIC "-g") + target_compile_options(scl PUBLIC "--coverage") + target_compile_options(scl PUBLIC "-fno-inline") + target_link_libraries(scl PUBLIC gcov) + set(SCL_SPECIAL_BUILD ON) +endif() - set_target_properties(scl PROPERTIES VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) +if (SCL_SPECIAL_BUILD MATCHES OFF) + target_compile_options(scl PRIVATE "-O3") - message(STATUS "CMAKE_INSTALL_PREFIX=" ${CMAKE_INSTALL_PREFIX}) - + set_target_properties(scl + PROPERTIES VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) + + ## only make installation possible if build is not special. install(TARGETS scl ARCHIVE DESTINATION lib LIBRARY DESTINATION lib) @@ -90,94 +135,11 @@ if(CMAKE_BUILD_TYPE MATCHES "Release") endif() -if(CMAKE_BUILD_TYPE MATCHES "Debug") - +if(SCL_BUILD_TESTS) enable_testing() - - set(SCL_TEST_SOURCE_FILES - test/scl/main.cc - - test/scl/util/test_prg.cc - test/scl/util/test_sha3.cc - test/scl/util/test_sha256.cc - test/scl/util/test_ecdsa.cc - test/scl/util/test_cmdline.cc - test/scl/util/test_merkle.cc - - test/scl/gf7.cc - test/scl/math/test_mersenne61.cc - test/scl/math/test_mersenne127.cc - test/scl/math/test_vec.cc - test/scl/math/test_mat.cc - test/scl/math/test_la.cc - test/scl/math/test_ff.cc - test/scl/math/test_z2k.cc - test/scl/math/test_poly.cc - - test/scl/ss/test_additive.cc - test/scl/ss/test_shamir.cc - test/scl/ss/test_feldman.cc - - test/scl/net/util.cc - test/scl/net/test_config.cc - test/scl/net/test_mem_channel.cc - test/scl/net/test_tcp.cc - test/scl/net/test_tcp_channel.cc - test/scl/net/test_threaded_sender.cc - test/scl/net/test_network.cc - test/scl/net/test_shared_deque.cc - test/scl/net/test_channel.cc - test/scl/net/test_packet.cc - - test/scl/protocol/test_protocol.cc - - test/scl/simulation/test_simulator.cc - test/scl/simulation/test_config.cc - test/scl/simulation/test_event.cc - test/scl/simulation/test_context.cc - test/scl/simulation/test_result.cc - test/scl/simulation/test_measurement.cc - test/scl/simulation/test_mem_channel_buffer.cc - test/scl/simulation/test_channel.cc - test/scl/simulation/test_env.cc - test/scl/simulation/test_manager.cc - - test/scl/serialization/test_serializer.cc) - - if(WITH_EC MATCHES ON) - set(SCL_TEST_SOURCE_FILES ${SCL_TEST_SOURCE_FILES} - test/scl/math/test_secp256k1.cc - test/scl/math/test_number.cc) - add_compile_definitions(SCL_ENABLE_EC_TESTS) - endif() - - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -fsanitize=address") - find_package(Catch2 REQUIRED) - include(CTest) - include(Catch) - include(${CMAKE_SOURCE_DIR}/cmake/CodeCoverage.cmake) - - add_compile_definitions(SCL_TEST_DATA_DIR="${CMAKE_SOURCE_DIR}/test/data/") - add_compile_definitions(SCL_UTIL_NO_EXIT_ON_ERROR) - - add_executable(scl_test ${SCL_SOURCE_FILES} ${SCL_TEST_SOURCE_FILES}) - target_link_libraries(scl_test Catch2::Catch2 pthread) - - if(WITH_EC MATCHES ON) - target_link_libraries(scl_test ${GMP}) - endif() - - catch_discover_tests(scl_test) - - append_coverage_compiler_flags() - - # Tell lcov to ignore system STL headers in order to make the coverage - # output more precise. - setup_target_for_coverage_lcov( - NAME coverage - EXECUTABLE scl_test - EXCLUDE "/usr/include/*" "test/*" "/usr/lib/*" "/usr/local/*") - + add_subdirectory(test) endif() -message(STATUS "CXX_FLAGS=" ${CMAKE_CXX_FLAGS}) +if(SCL_BUILD_DOCUMENTATION) + add_subdirectory(doc) +endif() diff --git a/README.md b/README.md index b110473..d6a034f 100644 --- a/README.md +++ b/README.md @@ -3,15 +3,18 @@ SCL is a utilities library for prototyping Secure Multiparty Computation (_MPC_ for short) protocols. The focus of SCL is usability, not necessarily speed. What this means is that SCL strives to provide an intuitive, easy to use and -understand and well documented interface that helps the programmer prototype an +understand, and well documented interface that helps the programmer prototype an MPC protocol faster (and nicer) than if they had to write everything themselves. SCL provides high level interfaces and functionality for working with * Secret sharing, additive and Shamir. * Finite fields. -* Networking. * Primitives, such as hash functions and PRGs. +SCL in addition provides methods for running protocols on both a real +network, where each party is connected via TCP, as well as a +*simulated* network. + ### Disclaimer SCL is distributed under the GNU Affero General Public License, for details, @@ -58,9 +61,9 @@ see examples of how the different functionality works. # Documentation -SCL uses Doxygen for documentation. Run `./scripts/build_documentation.sh` to -generate the documentation. This is placed in the `doc/` folder. Documentation -uses `doxygen`, so make sure that's installed. +SCL uses Doxygen for documentation, which can be generated by running +`make documentation` from within the build folder. The generated +documentation is placed in the `doc` folder. # Citing diff --git a/RELEASE.txt b/RELEASE.txt index c3928d3..7809c25 100644 --- a/RELEASE.txt +++ b/RELEASE.txt @@ -1,129 +1 @@ -0.7.0: -- Exponentiation for field elements -- Various bug fixes. Especially in the simulation code -- Change versioning. Make all releases start with 0 (to mark them as pre-release). -- Merkle tree hashing. -- Make it possible to hash anything which has a Serializer specialization. -- Vec::ScalarMultiply now allows multiplying a Vec of curve points with a - scalar. Same for Mat. -- Make it possible to prematurely terminate a party in a simulation. -- Introduce a "Manager" class that contains the parameters of a simulation. -- Rename EC::Order to EC::ScalarField. -- Introduce a function for acquiring the order of a field. -- Make utility functions in ECDSA public. -- Various optimizations for the elliptic curve code. -- Simplify the measurement class. - -0.6.2: More functionality for Number -- Add modulo operator to Number. -- Add some mathematical functions that operate on numbers. -- Make Number serializable; add Serializer specialization. -- Add a simple command-line argument parser. - -0.6.1: Extend serialization functionality -- Make Write methods return the number of bytes written. -- Make it possible to serialize vectors with arbitrary content. - -0.6.0: Improvements to serialization and Channels. -- Added a Serializer type that can be specialized in order to specify how - various objects are converted to bytes. -- Added a Packet type that allows reading and writing almost arbitrary objects, - but stores them internally in a serialized format. -- Modified the net::Channel interface to allow sending and receving - Packets. Remove old Send/Recv overloads. -- Remove proto::ProtocolEnvironment. - -0.5.3: ECDSA -- Added functionality for creating ECDSA signatures. - -0.5.2: Protocol environment extensions -- Make it possible to create "checkpoints" through the protocol environment - clock. -- fix a bug that prevented the documentation from being buildt -- Rename ProtocolEnvironment to Env, and introduce a typedef for backwards - compatability. - -0.5.1: Style changes -- Change naming style of private field members. -- Simplifed the NextToRun logic because a greedy strategy too often results in - rollbacks. -- Fixed a bug in the Rollback logic where WriteOps weren't rolled back - correctly. -- Add a Vec Mat to Vec multiplication function to Mat -- Minor refactoring of test_mat.cc - -0.5.0: Simulation -- Added a new module for simulating protocol executions under different network - conditions. -- Refactored layout with respect to namespaces. details no longer exists, and - the different modules have gotten their own namespace. -- Up test coverage to 100%. Minor refactoring to the actions. - -0.4.0: Shamir, Feldman, SHA-256 -- Refactor Shamir to allow caching of Lagrange coefficients -- Add support for Feldman Secret Sharing -- Add support for SHA-256 -- Add bibtex blob for citing SCL -- Refactor interface for hash functions -- Refactor interface for Shamir -- bugs: - - Fix negation of 0 in Secp256k1::Field and Secp256k1::Order - - Make serialization and deserialization of curve points behave more sanely - -0.3.0: More features, build changes -- Add method for returning a point as a pair of affine coordinates -- Add method to check if a channel has data available -- Allow sending and receiving STL vectors without specifying the size -- Extend Vec with a SubVector, operator== and operator!= methods -- Begin Shamir code refactor and move all of it into details namespace -- bugs: - - fix scalar multiplication for secp256k1_order - - fix compilation error on g++12 -- build: - - build tests with -fsanitize=address - - disable actions for master branch - - add clang-tidy action - -0.2.1: More Finite Fields -- Provide a FF implementation for computations modulo the order of Secp256k1 -- Extend EC with support for scalar multiplications with scalars from a finite - field of size the order of a subgroup. - -0.2.0: Elliptic curves and finite field refactoring -- Make it simpler to define new finite fields -- Include optional (but enabled by default) support for elliptic curves - - Implement secp256k1 -- Include optional (but enabled by default) support for multi-precision integers -- Significantly increase test coverage -- Make header guards standard compliant -- Rename FF to Fp. -- Move class FF into scl namespace. - -0.1.1: Refactoring of finite field internals -- Finite field operations are now defined by individual specializations of - templated functions -- Remove DEFINE_FINITE_FIELD macro -- Move Mersenne61 and Mersenne127 definitions into ff.h - -0.1.0: Initial public version of SCL. -- Features: - - Math: - - Finite Field class with two instantiations based on Mersenne primes - - Ring modulo 2^K, with support for any K less than or equal to 128 - - Vectors. This is a thinish wrapper around an STL vector - - Matrices - - All math types support serialization - - Primitives: - - PRG based on AES-CTR - - IUF Hash based on SHA3 - - Networking: - - Basic support for peer-to-peer communication via TCP. Channels are - designed in such a way that writing custom ones (or decorators for - existing channels) is easy - - Peer discovery functionality to make it easier to setup working networks - - Secret Sharing: - - Shamir secret sharing with support for both error detection and correction - - Additive secret sharing -- Development: - - Decent code coverage - - Documentation +0.1.0: Initial version of SCL for C++20 diff --git a/cmake/CodeCoverage.cmake b/cmake/CodeCoverage.cmake deleted file mode 100644 index 27e7d3d..0000000 --- a/cmake/CodeCoverage.cmake +++ /dev/null @@ -1,436 +0,0 @@ -# Copyright (c) 2012 - 2017, Lars Bilke -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without modification, -# are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its contributors -# may be used to endorse or promote products derived from this software without -# specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# CHANGES: -# -# 2012-01-31, Lars Bilke -# - Enable Code Coverage -# -# 2013-09-17, Joakim Söderberg -# - Added support for Clang. -# - Some additional usage instructions. -# -# 2016-02-03, Lars Bilke -# - Refactored functions to use named parameters -# -# 2017-06-02, Lars Bilke -# - Merged with modified version from github.com/ufz/ogs -# -# 2019-05-06, Anatolii Kurotych -# - Remove unnecessary --coverage flag -# -# 2019-12-13, FeRD (Frank Dana) -# - Deprecate COVERAGE_LCOVR_EXCLUDES and COVERAGE_GCOVR_EXCLUDES lists in favor -# of tool-agnostic COVERAGE_EXCLUDES variable, or EXCLUDE setup arguments. -# - CMake 3.4+: All excludes can be specified relative to BASE_DIRECTORY -# - All setup functions: accept BASE_DIRECTORY, EXCLUDE list -# - Set lcov basedir with -b argument -# - Add automatic --demangle-cpp in lcovr, if 'c++filt' is available (can be -# overridden with NO_DEMANGLE option in setup_target_for_coverage_lcovr().) -# - Delete output dir, .info file on 'make clean' -# - Remove Python detection, since version mismatches will break gcovr -# - Minor cleanup (lowercase function names, update examples...) -# -# 2019-12-19, FeRD (Frank Dana) -# - Rename Lcov outputs, make filtered file canonical, fix cleanup for targets -# -# 2020-01-19, Bob Apthorpe -# - Added gfortran support -# -# 2020-02-17, FeRD (Frank Dana) -# - Make all add_custom_target()s VERBATIM to auto-escape wildcard characters -# in EXCLUDEs, and remove manual escaping from gcovr targets -# -# USAGE: -# -# 1. Copy this file into your cmake modules path. -# -# 2. Add the following line to your CMakeLists.txt (best inside an if-condition -# using a CMake option() to enable it just optionally): -# include(CodeCoverage) -# -# 3. Append necessary compiler flags: -# append_coverage_compiler_flags() -# -# 3.a (OPTIONAL) Set appropriate optimization flags, e.g. -O0, -O1 or -Og -# -# 4. If you need to exclude additional directories from the report, specify them -# using full paths in the COVERAGE_EXCLUDES variable before calling -# setup_target_for_coverage_*(). -# Example: -# set(COVERAGE_EXCLUDES -# '${PROJECT_SOURCE_DIR}/src/dir1/*' -# '/path/to/my/src/dir2/*') -# Or, use the EXCLUDE argument to setup_target_for_coverage_*(). -# Example: -# setup_target_for_coverage_lcov( -# NAME coverage -# EXECUTABLE testrunner -# EXCLUDE "${PROJECT_SOURCE_DIR}/src/dir1/*" "/path/to/my/src/dir2/*") -# -# 4.a NOTE: With CMake 3.4+, COVERAGE_EXCLUDES or EXCLUDE can also be set -# relative to the BASE_DIRECTORY (default: PROJECT_SOURCE_DIR) -# Example: -# set(COVERAGE_EXCLUDES "dir1/*") -# setup_target_for_coverage_gcovr_html( -# NAME coverage -# EXECUTABLE testrunner -# BASE_DIRECTORY "${PROJECT_SOURCE_DIR}/src" -# EXCLUDE "dir2/*") -# -# 5. Use the functions described below to create a custom make target which -# runs your test executable and produces a code coverage report. -# -# 6. Build a Debug build: -# cmake -DCMAKE_BUILD_TYPE=Debug .. -# make -# make my_coverage_target -# - -include(CMakeParseArguments) - -# Check prereqs -find_program( GCOV_PATH gcov ) -find_program( LCOV_PATH NAMES lcov lcov.bat lcov.exe lcov.perl) -find_program( GENHTML_PATH NAMES genhtml genhtml.perl genhtml.bat ) -find_program( GCOVR_PATH gcovr PATHS ${CMAKE_SOURCE_DIR}/scripts/test) -find_program( CPPFILT_PATH NAMES c++filt ) - -if(NOT GCOV_PATH) - message(FATAL_ERROR "gcov not found! Aborting...") -endif() # NOT GCOV_PATH - -if("${CMAKE_CXX_COMPILER_ID}" MATCHES "(Apple)?[Cc]lang") - if("${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 3) - message(FATAL_ERROR "Clang version must be 3.0.0 or greater! Aborting...") - endif() -elseif(NOT CMAKE_COMPILER_IS_GNUCXX) - if("${CMAKE_Fortran_COMPILER_ID}" MATCHES "[Ff]lang") - # Do nothing; exit conditional without error if true - elseif("${CMAKE_Fortran_COMPILER_ID}" MATCHES "GNU") - # Do nothing; exit conditional without error if true - else() - message(FATAL_ERROR "Compiler is not GNU gcc! Aborting...") - endif() -endif() - -set(COVERAGE_COMPILER_FLAGS "-g -fprofile-arcs -ftest-coverage" - CACHE INTERNAL "") - -set(CMAKE_Fortran_FLAGS_COVERAGE - ${COVERAGE_COMPILER_FLAGS} - CACHE STRING "Flags used by the Fortran compiler during coverage builds." - FORCE ) -set(CMAKE_CXX_FLAGS_COVERAGE - ${COVERAGE_COMPILER_FLAGS} - CACHE STRING "Flags used by the C++ compiler during coverage builds." - FORCE ) -set(CMAKE_C_FLAGS_COVERAGE - ${COVERAGE_COMPILER_FLAGS} - CACHE STRING "Flags used by the C compiler during coverage builds." - FORCE ) -set(CMAKE_EXE_LINKER_FLAGS_COVERAGE - "" - CACHE STRING "Flags used for linking binaries during coverage builds." - FORCE ) -set(CMAKE_SHARED_LINKER_FLAGS_COVERAGE - "" - CACHE STRING "Flags used by the shared libraries linker during coverage builds." - FORCE ) -mark_as_advanced( - CMAKE_Fortran_FLAGS_COVERAGE - CMAKE_CXX_FLAGS_COVERAGE - CMAKE_C_FLAGS_COVERAGE - CMAKE_EXE_LINKER_FLAGS_COVERAGE - CMAKE_SHARED_LINKER_FLAGS_COVERAGE ) - -if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") - message(WARNING "Code coverage results with an optimised (non-Debug) build may be misleading") -endif() # NOT CMAKE_BUILD_TYPE STREQUAL "Debug" - -if(CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_Fortran_COMPILER_ID STREQUAL "GNU") - link_libraries(gcov) -endif() - -# Defines a target for running and collection code coverage information -# Builds dependencies, runs the given executable and outputs reports. -# NOTE! The executable should always have a ZERO as exit code otherwise -# the coverage generation will not complete. -# -# setup_target_for_coverage_lcov( -# NAME testrunner_coverage # New target name -# EXECUTABLE testrunner -j ${PROCESSOR_COUNT} # Executable in PROJECT_BINARY_DIR -# DEPENDENCIES testrunner # Dependencies to build first -# BASE_DIRECTORY "../" # Base directory for report -# # (defaults to PROJECT_SOURCE_DIR) -# EXCLUDE "src/dir1/*" "src/dir2/*" # Patterns to exclude (can be relative -# # to BASE_DIRECTORY, with CMake 3.4+) -# NO_DEMANGLE # Don't demangle C++ symbols -# # even if c++filt is found -# ) -function(setup_target_for_coverage_lcov) - - set(options NO_DEMANGLE) - set(oneValueArgs BASE_DIRECTORY NAME) - set(multiValueArgs EXCLUDE EXECUTABLE EXECUTABLE_ARGS DEPENDENCIES LCOV_ARGS GENHTML_ARGS) - cmake_parse_arguments(Coverage "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - if(NOT LCOV_PATH) - message(FATAL_ERROR "lcov not found! Aborting...") - endif() # NOT LCOV_PATH - - if(NOT GENHTML_PATH) - message(FATAL_ERROR "genhtml not found! Aborting...") - endif() # NOT GENHTML_PATH - - # Set base directory (as absolute path), or default to PROJECT_SOURCE_DIR - if(${Coverage_BASE_DIRECTORY}) - get_filename_component(BASEDIR ${Coverage_BASE_DIRECTORY} ABSOLUTE) - else() - set(BASEDIR ${PROJECT_SOURCE_DIR}) - endif() - - # Collect excludes (CMake 3.4+: Also compute absolute paths) - set(LCOV_EXCLUDES "") - foreach(EXCLUDE ${Coverage_EXCLUDE} ${COVERAGE_EXCLUDES} ${COVERAGE_LCOV_EXCLUDES}) - if(CMAKE_VERSION VERSION_GREATER 3.4) - get_filename_component(EXCLUDE ${EXCLUDE} ABSOLUTE BASE_DIR ${BASEDIR}) - endif() - list(APPEND LCOV_EXCLUDES "${EXCLUDE}") - endforeach() - list(REMOVE_DUPLICATES LCOV_EXCLUDES) - - # Conditional arguments - if(CPPFILT_PATH AND NOT ${Coverage_NO_DEMANGLE}) - set(GENHTML_EXTRA_ARGS "--demangle-cpp") - endif() - - # Setup target - add_custom_target(${Coverage_NAME} - - # Cleanup lcov - COMMAND ${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} -directory . -b ${BASEDIR} --zerocounters - # Create baseline to make sure untouched files show up in the report - COMMAND ${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} -c -i -d . -b ${BASEDIR} -o ${Coverage_NAME}.base - - # Run tests - COMMAND ${Coverage_EXECUTABLE} ${Coverage_EXECUTABLE_ARGS} - - # Capturing lcov counters and generating report - COMMAND ${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} --directory . -b ${BASEDIR} --capture --output-file ${Coverage_NAME}.capture - # add baseline counters - COMMAND ${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} -a ${Coverage_NAME}.base -a ${Coverage_NAME}.capture --output-file ${Coverage_NAME}.total - # filter collected data to final coverage report - COMMAND ${LCOV_PATH} ${Coverage_LCOV_ARGS} --gcov-tool ${GCOV_PATH} --remove ${Coverage_NAME}.total ${LCOV_EXCLUDES} --output-file ${Coverage_NAME}.info - - # Generate HTML output - COMMAND ${GENHTML_PATH} ${GENHTML_EXTRA_ARGS} ${Coverage_GENHTML_ARGS} -o ${Coverage_NAME} ${Coverage_NAME}.info - - # Set output files as GENERATED (will be removed on 'make clean') - BYPRODUCTS - ${Coverage_NAME}.base - ${Coverage_NAME}.capture - ${Coverage_NAME}.total - ${Coverage_NAME}.info - ${Coverage_NAME} # report directory - - WORKING_DIRECTORY ${PROJECT_BINARY_DIR} - DEPENDS ${Coverage_DEPENDENCIES} - VERBATIM # Protect arguments to commands - COMMENT "Resetting code coverage counters to zero.\nProcessing code coverage counters and generating report." - ) - - # Show where to find the lcov info report - add_custom_command(TARGET ${Coverage_NAME} POST_BUILD - COMMAND ; - COMMENT "Lcov code coverage info report saved in ${Coverage_NAME}.info." - ) - - # Show info where to find the report - add_custom_command(TARGET ${Coverage_NAME} POST_BUILD - COMMAND ; - COMMENT "Open ./${Coverage_NAME}/index.html in your browser to view the coverage report." - ) - -endfunction() # setup_target_for_coverage_lcov - -# Defines a target for running and collection code coverage information -# Builds dependencies, runs the given executable and outputs reports. -# NOTE! The executable should always have a ZERO as exit code otherwise -# the coverage generation will not complete. -# -# setup_target_for_coverage_gcovr_xml( -# NAME ctest_coverage # New target name -# EXECUTABLE ctest -j ${PROCESSOR_COUNT} # Executable in PROJECT_BINARY_DIR -# DEPENDENCIES executable_target # Dependencies to build first -# BASE_DIRECTORY "../" # Base directory for report -# # (defaults to PROJECT_SOURCE_DIR) -# EXCLUDE "src/dir1/*" "src/dir2/*" # Patterns to exclude (can be relative -# # to BASE_DIRECTORY, with CMake 3.4+) -# ) -function(setup_target_for_coverage_gcovr_xml) - - set(options NONE) - set(oneValueArgs BASE_DIRECTORY NAME) - set(multiValueArgs EXCLUDE EXECUTABLE EXECUTABLE_ARGS DEPENDENCIES) - cmake_parse_arguments(Coverage "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - if(NOT GCOVR_PATH) - message(FATAL_ERROR "gcovr not found! Aborting...") - endif() # NOT GCOVR_PATH - - # Set base directory (as absolute path), or default to PROJECT_SOURCE_DIR - if(${Coverage_BASE_DIRECTORY}) - get_filename_component(BASEDIR ${Coverage_BASE_DIRECTORY} ABSOLUTE) - else() - set(BASEDIR ${PROJECT_SOURCE_DIR}) - endif() - - # Collect excludes (CMake 3.4+: Also compute absolute paths) - set(GCOVR_EXCLUDES "") - foreach(EXCLUDE ${Coverage_EXCLUDE} ${COVERAGE_EXCLUDES} ${COVERAGE_GCOVR_EXCLUDES}) - if(CMAKE_VERSION VERSION_GREATER 3.4) - get_filename_component(EXCLUDE ${EXCLUDE} ABSOLUTE BASE_DIR ${BASEDIR}) - endif() - list(APPEND GCOVR_EXCLUDES "${EXCLUDE}") - endforeach() - list(REMOVE_DUPLICATES GCOVR_EXCLUDES) - - # Combine excludes to several -e arguments - set(GCOVR_EXCLUDE_ARGS "") - foreach(EXCLUDE ${GCOVR_EXCLUDES}) - list(APPEND GCOVR_EXCLUDE_ARGS "-e") - list(APPEND GCOVR_EXCLUDE_ARGS "${EXCLUDE}") - endforeach() - - add_custom_target(${Coverage_NAME} - # Run tests - ${Coverage_EXECUTABLE} ${Coverage_EXECUTABLE_ARGS} - - # Running gcovr - COMMAND ${GCOVR_PATH} --xml - -r ${BASEDIR} ${GCOVR_EXCLUDE_ARGS} - --object-directory=${PROJECT_BINARY_DIR} - -o ${Coverage_NAME}.xml - BYPRODUCTS ${Coverage_NAME}.xml - WORKING_DIRECTORY ${PROJECT_BINARY_DIR} - DEPENDS ${Coverage_DEPENDENCIES} - VERBATIM # Protect arguments to commands - COMMENT "Running gcovr to produce Cobertura code coverage report." - ) - - # Show info where to find the report - add_custom_command(TARGET ${Coverage_NAME} POST_BUILD - COMMAND ; - COMMENT "Cobertura code coverage report saved in ${Coverage_NAME}.xml." - ) -endfunction() # setup_target_for_coverage_gcovr_xml - -# Defines a target for running and collection code coverage information -# Builds dependencies, runs the given executable and outputs reports. -# NOTE! The executable should always have a ZERO as exit code otherwise -# the coverage generation will not complete. -# -# setup_target_for_coverage_gcovr_html( -# NAME ctest_coverage # New target name -# EXECUTABLE ctest -j ${PROCESSOR_COUNT} # Executable in PROJECT_BINARY_DIR -# DEPENDENCIES executable_target # Dependencies to build first -# BASE_DIRECTORY "../" # Base directory for report -# # (defaults to PROJECT_SOURCE_DIR) -# EXCLUDE "src/dir1/*" "src/dir2/*" # Patterns to exclude (can be relative -# # to BASE_DIRECTORY, with CMake 3.4+) -# ) -function(setup_target_for_coverage_gcovr_html) - - set(options NONE) - set(oneValueArgs BASE_DIRECTORY NAME) - set(multiValueArgs EXCLUDE EXECUTABLE EXECUTABLE_ARGS DEPENDENCIES) - cmake_parse_arguments(Coverage "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - if(NOT GCOVR_PATH) - message(FATAL_ERROR "gcovr not found! Aborting...") - endif() # NOT GCOVR_PATH - - # Set base directory (as absolute path), or default to PROJECT_SOURCE_DIR - if(${Coverage_BASE_DIRECTORY}) - get_filename_component(BASEDIR ${Coverage_BASE_DIRECTORY} ABSOLUTE) - else() - set(BASEDIR ${PROJECT_SOURCE_DIR}) - endif() - - # Collect excludes (CMake 3.4+: Also compute absolute paths) - set(GCOVR_EXCLUDES "") - foreach(EXCLUDE ${Coverage_EXCLUDE} ${COVERAGE_EXCLUDES} ${COVERAGE_GCOVR_EXCLUDES}) - if(CMAKE_VERSION VERSION_GREATER 3.4) - get_filename_component(EXCLUDE ${EXCLUDE} ABSOLUTE BASE_DIR ${BASEDIR}) - endif() - list(APPEND GCOVR_EXCLUDES "${EXCLUDE}") - endforeach() - list(REMOVE_DUPLICATES GCOVR_EXCLUDES) - - # Combine excludes to several -e arguments - set(GCOVR_EXCLUDE_ARGS "") - foreach(EXCLUDE ${GCOVR_EXCLUDES}) - list(APPEND GCOVR_EXCLUDE_ARGS "-e") - list(APPEND GCOVR_EXCLUDE_ARGS "${EXCLUDE}") - endforeach() - - add_custom_target(${Coverage_NAME} - # Run tests - ${Coverage_EXECUTABLE} ${Coverage_EXECUTABLE_ARGS} - - # Create folder - COMMAND ${CMAKE_COMMAND} -E make_directory ${PROJECT_BINARY_DIR}/${Coverage_NAME} - - # Running gcovr - COMMAND ${GCOVR_PATH} --html --html-details - -r ${BASEDIR} ${GCOVR_EXCLUDE_ARGS} - --object-directory=${PROJECT_BINARY_DIR} - -o ${Coverage_NAME}/index.html - - BYPRODUCTS ${PROJECT_BINARY_DIR}/${Coverage_NAME} # report directory - WORKING_DIRECTORY ${PROJECT_BINARY_DIR} - DEPENDS ${Coverage_DEPENDENCIES} - VERBATIM # Protect arguments to commands - COMMENT "Running gcovr to produce HTML code coverage report." - ) - - # Show info where to find the report - add_custom_command(TARGET ${Coverage_NAME} POST_BUILD - COMMAND ; - COMMENT "Open ./${Coverage_NAME}/index.html in your browser to view the coverage report." - ) - -endfunction() # setup_target_for_coverage_gcovr_html - -function(append_coverage_compiler_flags) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${COVERAGE_COMPILER_FLAGS}" PARENT_SCOPE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${COVERAGE_COMPILER_FLAGS}" PARENT_SCOPE) - set(CMAKE_Fortran_FLAGS "${CMAKE_Fortran_FLAGS} ${COVERAGE_COMPILER_FLAGS}" PARENT_SCOPE) - message(STATUS "Appending code coverage compiler flags: ${COVERAGE_COMPILER_FLAGS}") -endfunction() # append_coverage_compiler_flags diff --git a/doc/CMakeLists.txt b/doc/CMakeLists.txt new file mode 100644 index 0000000..a6beed4 --- /dev/null +++ b/doc/CMakeLists.txt @@ -0,0 +1,34 @@ +# SCL --- Secure Computation Library +# Copyright (C) 2024 Anders Dalskov +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +cmake_minimum_required(VERSION 3.5) + +if (NOT SCL_DOXYGEN_BIN) + find_program(doxygen doxygen REQUIRED) +else() + if (EXISTS ${SCL_DOXYGEN_BIN} AND (NOT IS_DIRECTORY ${SCL_DOXYGEN_BIN})) + set(doxygen ${SCL_DOXYGEN_BIN}) + else() + message(FATAL_ERROR "\"${SCL_DOXYGEN_BIN}\" not a valid argument for doxygen binary.") + endif() +endif() + +add_custom_target( + documentation + COMMAND ${doxygen} "${CMAKE_SOURCE_DIR}/doc/DoxyConf" + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + COMMENT "Generate documentation for SCL" + VERBATIM) diff --git a/doc/DoxyConf b/doc/DoxyConf index abd0f43..a6e9e5e 100644 --- a/doc/DoxyConf +++ b/doc/DoxyConf @@ -17,7 +17,9 @@ INPUT = doc/mainpage.md \ include/scl/net \ include/scl/simulation \ include/scl/protocol \ + include/scl/coro \ include/scl/serialization + FILE_PATTERNS = *.h EXCLUDE_SYMBOLS = SCL_* @@ -34,7 +36,6 @@ USE_MDFILE_AS_MAINPAGE = doc/mainpage.md HTML_EXTRA_STYLESHEET = doc/styling.css GENERATE_TREEVIEW = YES -CLASS_DIAGRAMS = NO USE_MATHJAX = YES diff --git a/doc/mainpage.md b/doc/mainpage.md index b39ccfc..5543be6 100644 --- a/doc/mainpage.md +++ b/doc/mainpage.md @@ -1,3 +1,19 @@ # Introduction - +**SCL** (short for *Secure Computation Library*) is a C++-20 library which aims +at removing a lot of the typical boilerplate code that one usually has to write +when developing a proof-of-concept for a new Secure Multiparty Computation (or +*MPC* for short) protocol. + +Everything SCL is placed in the \ref scl namespace, and different +functionalities are placed in different sub namespaces, a short description of +each as well as their purpose is given here: + +- \ref scl::coro coroutines. +- \ref scl::math math related stuff. +- \ref scl::net networking. +- \ref scl::proto protocol interfaces. +- \ref scl::sim protocol execution simulation. +- \ref scl::seri serialization. +- \ref scl::ss secret-sharing. +- \ref scl::util other utilities. diff --git a/include/scl/coro/batch.h b/include/scl/coro/batch.h new file mode 100644 index 0000000..261907a --- /dev/null +++ b/include/scl/coro/batch.h @@ -0,0 +1,246 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_CORO_BATCH_H +#define SCL_CORO_BATCH_H + +#include +#include +#include +#include +#include + +namespace scl::coro { + +template +class Task; + +class Runtime; + +namespace details { + +/** + * @brief A Batch of coroutine tasks. + * @tparam RESULT the result type of the tasks. + * + * The primary task of this class is to enable manual scheduling of a list of + * tasks. This allows the running the tasks that the batch is constructed with + * to run in parallel. + */ +template +class Batch final { + public: + /** + * @brief Construct a new batch. + * @param tasks the tasks in this batch. + */ + explicit Batch(std::vector>&& tasks) + : m_tasks(std::move(tasks)) {} + + /** + * @brief Check if this batch is ready. + * @return true if all tasks in the batch have finished, false otherwise. + */ + bool await_ready() const noexcept { + return !std::any_of(m_tasks.begin(), m_tasks.end(), [](const auto& task) { + return !task.ready(); + }); + } + + /** + * @brief Run when the coroutine waiting for this batch suspends. + * @param coroutine the coroutine co_await'ing this batch. + * @return coroutine to run instead. + */ + std::coroutine_handle<> await_suspend(std::coroutine_handle<> coroutine); + + /** + * @brief Run when the batch has finished. + */ + template + auto await_resume() -> std::enable_if_t, void> { + for (const auto& task : m_tasks) { + task.result(); + } + } + + /** + * @brief Run when the batch has finished. + */ + template + auto await_resume() -> std::enable_if_t, std::vector> { + std::vector results; + for (const auto& task : m_tasks) { + results.emplace_back(task.result()); + } + return results; + } + + /** + * @brief Assign the runtime to use for this batch. + * @param runtime the runtime. + * + * The runtime passed here is set as the runtime for all Tasks in the batch. + */ + void setRuntime(Runtime* runtime) noexcept { + m_runtime = runtime; + } + + private: + std::vector> m_tasks; + Runtime* m_runtime = nullptr; +}; + +/** + * @brief A batch of coroutine tasks. + * @tparam RESULT the result type of the tasks. + * + * Similar to Batch, except with a weaker requirement on the number of tasks + * that must complete. Tasks that do not complete are destroyed automatically + * when the PartialBatch is destroyed. + */ +template +class PartialBatch final { + public: + /** + * @brief Construct a new PartialBatch. + * @param tasks the tasks in the batch. + * @param min_complete the minimum number of tasks that should complete before + * the batch is done. + */ + explicit PartialBatch(std::vector>&& tasks, + std::size_t min_complete) + : m_tasks(std::move(tasks)), m_min_complete(min_complete) {} + + /** + * @brief Check if the batch is done. + * @return true if the number of ready tasks exceed min_complete. + */ + bool await_ready() const noexcept { + std::size_t count = 0; + for (const auto& task : m_tasks) { + if (task.ready()) { + count++; + } + } + return count >= m_min_complete; + } + + /** + * @brief Suspend this batch. + */ + std::coroutine_handle<> await_suspend(std::coroutine_handle<> coroutine); + + /** + * @brief Resume this batch. + */ + template + auto await_resume() -> std::enable_if_t, void> { + for (const auto& task : m_tasks) { + if (task.ready()) { + task.result(); + } + } + } + + /** + * @brief Resume this batch + */ + template + auto await_resume() + -> std::enable_if_t, std::vector>> { + std::vector> results; + for (const auto& task : m_tasks) { + if (task.ready()) { + results.emplace_back(task.result()); + } else { + results.emplace_back(std::optional{}); + } + } + return results; + } + + /** + * @brief Set the runtime to use for running this batch. + * @param runtime the runtime. + * + * \p runtime is assigned as the runtime for all Tasks in this batch. + */ + void setRuntime(Runtime* runtime) noexcept { + m_runtime = runtime; + } + + private: + std::vector> m_tasks; + std::size_t m_min_complete; + + Runtime* m_runtime = nullptr; +}; + +} // namespace details + +/** + * @brief Create a new batch task. + * @param tasks the tasks in the batch. + * + * Creates a batch of coroutine tasks that will be run concurrently when + * co_await'ed. This can therefore be used when, for example, we + * would like one task to resume when another is suspended (e.g., because the + * suspended task is waiting for data). + * + * @code + * std::vector> tasks; + * tasks.emplace_back(intTask()); + * tasks.emplace_back(anoterIntTask()); + * + * std::vector results = co_await batch(std::move(tasks)); + * @endcode + */ +template +auto batch(std::vector>&& tasks) { + return details::Batch{std::move(tasks)}; +} + +/** + * @brief Create a new batch task. + * @param tasks the tasks in the batch. + * @param min_complete the minimum number of tasks to complete for the batch to + * be considered complete. + * + * Creates a batch of coroutine tasks, where we are just interested in some of + * the finishing. + * + * @code + * std::vector> tasks; + * tasks.emplace_back(intTask()); + * tasks.emplace_back(intTaskThatRunsForever()); + * + * std::vector> results = co_await batch(std::move(tasks), + * 1); + * + * results[0].has_value(); // returns true + * results[1].has_value(); // returns false + * @endcode + */ +template +auto batch(std::vector&& tasks, std::size_t min_complete) { + return details::PartialBatch{std::move(tasks), min_complete}; +} + +} // namespace scl::coro + +#endif // SCL_CORO_BATCH_H diff --git a/test/scl/main.cc b/include/scl/coro/coroutine.h similarity index 72% rename from test/scl/main.cc rename to include/scl/coro/coroutine.h index 5b14fe7..a294f1d 100644 --- a/test/scl/main.cc +++ b/include/scl/coro/coroutine.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,5 +15,15 @@ * along with this program. If not, see . */ -#define CATCH_CONFIG_MAIN -#include +#ifndef SCL_CORO_COROUTINE_H +#define SCL_CORO_COROUTINE_H + +#include "scl/coro/runtime.h" +#include "scl/coro/task.h" + +/** + * @brief Coroutine utilities. + */ +namespace scl::coro {} // namespace scl::coro + +#endif // SCL_CORO_COROUTINE_H diff --git a/include/scl/coro/future.h b/include/scl/coro/future.h new file mode 100644 index 0000000..603a054 --- /dev/null +++ b/include/scl/coro/future.h @@ -0,0 +1,88 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_CORO_FUTURE_H +#define SCL_CORO_FUTURE_H + +#include +#include +#include +#include + +namespace scl::coro { + +class Runtime; + +namespace details { + +/** + * @brief Concept that a future type must satisfy. + */ +template +concept FutureAwaitableType = requires(FUTURE future) { + { future() } -> std::convertible_to; + }; + +/** + * @brief The awaiter for future events. + * @tparam FUTURE_EVENT the type of the future event. + * + * \p FUTURE_EVENT must be a subclass of FutureEvent. + */ +template +class FutureAwaiter final { + public: + /** + * @brief Construct a new awaiter from a future. + * @param future the future. + */ + FutureAwaiter(FUTURE&& future) : m_future(std::forward(future)){}; + + /** + * @brief Futures are by design not ready immediately. + */ + bool await_ready() const noexcept { + return false; + } + + /** + * @brief Schedule the coroutine for later execution, pending some condition. + * @return the next coroutine to execute. + */ + std::coroutine_handle<> await_suspend(std::coroutine_handle<> handle); + + /** + * @brief Does nothing. + */ + void await_resume() const noexcept {} + + /** + * @brief Sets the runtime for this future. + */ + void setRuntime(Runtime* runtime) { + m_runtime = runtime; + } + + private: + Runtime* m_runtime; + FUTURE m_future; +}; + +} // namespace details +} // namespace scl::coro + +#endif // SCL_CORO_FUTURE_H diff --git a/include/scl/coro/promise.h b/include/scl/coro/promise.h new file mode 100644 index 0000000..4b68c1e --- /dev/null +++ b/include/scl/coro/promise.h @@ -0,0 +1,231 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_CORO_PROMISE_H +#define SCL_CORO_PROMISE_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "scl/coro/batch.h" +#include "scl/coro/future.h" +#include "scl/coro/sleep_awaiter.h" +#include "scl/util/time.h" + +namespace scl::coro { + +template +class Task; + +class Runtime; + +namespace details { + +/** + * @brief Base type for the promise_type of Tasks. + */ +class TaskPromiseBase { + /** + * @brief Awaiter returned on final_suspend. + * + * The job of this awaiter is to suspend execution to prevent the coroutine + * result from simply dissapearing, and to resume any tasks that are waiting + * for this coroutine to finish. + */ + struct FinalAwaiter { + bool await_ready() noexcept { + return false; + } + + template + std::coroutine_handle<> await_suspend( + std::coroutine_handle handle) noexcept { + if (handle.promise().m_next) { + return handle.promise().m_next; + } + return std::noop_coroutine(); + } + + void await_resume() noexcept {} + }; + + public: + /** + * @brief Cold-starts the coroutine. + */ + auto initial_suspend() noexcept { + return std::suspend_always{}; + } + + /** + * @brief Called when the Task finishes running. + */ + auto final_suspend() noexcept { + return FinalAwaiter{}; + } + + /** + * @brief Transform an awaitable into an awaiter. + * @param awaitable the awaitable. + * + * The assumption made here is that the awaitable is also the awaiter. In + * particular, the type has the required functions for it to be an awaiter. In + * addition, the type should possess a setRuntime(coro::Runtime*) + * function which is used to specify which runtime should be used when + * suspending the awaitable. + */ + template + AWAITABLE await_transform(AWAITABLE&& awaitable); + + /** + * @brief Transform called on future types. + * + * A "future" is any callable which returns either true or false. The callable + * describes when the waiting coroutine can be resumed. + */ + template + auto await_transform(FUTURE&& future) { + return await_transform(FutureAwaiter(std::forward(future))); + } + + /** + * @brief Transform called on a std::chrono::duration. + */ + template + auto await_transform(std::chrono::duration duration) { + return await_transform(details::SleepAwaiter(duration)); + } + + /** + * @brief Set the coroutine to run when this task finishes. + */ + void setNext(std::coroutine_handle<> next) { + m_next = next; + } + + /** + * @brief Set the Runtime to use for executing this task. + */ + void setRuntime(Runtime* runtime) { + m_runtime = runtime; + } + + /** + * @brief Get the current Runtime. + */ + Runtime* getRuntime() const { + return m_runtime; + } + + private: + std::coroutine_handle<> m_next; + Runtime* m_runtime = nullptr; +}; + +/** + * @brief Task promise type for general non-void return types. + */ +template +class TaskPromise final : public TaskPromiseBase { + public: + /** + * @brief Create a Task object from this promise. + */ + Task get_return_object(); + + /** + * @brief Set the return value of this Task. + * @param result the value to return when this Task completes. + */ + void return_value(RESULT result) { + m_result.template emplace<1>(std::move(result)); + } + + /** + * @brief Called if the Task throws an exception. + */ + void unhandled_exception() noexcept { + m_result.template emplace<2>(std::current_exception()); + } + + /** + * @brief Get the return value of this Task. + */ + RESULT result() { + if (m_result.index() == 0) { + throw std::logic_error("result() called on unfinished coroutine"); + } + if (m_result.index() == 2) { + std::rethrow_exception(std::get<2>(m_result)); + } + return std::get<1>(std::move(m_result)); + } + + private: + std::variant m_result; +}; + +/** + * @brief Task promise specialization for Tasks returning void. + */ +template <> +class TaskPromise final : public TaskPromiseBase { + public: + /** + * @brief Create a Task object from this promise. + */ + Task get_return_object(); + + /** + * @brief Indicates that the Task has finished executing. + */ + void return_void() noexcept {} + + /** + * @brief Called if the Task throws an exception. + */ + void unhandled_exception() noexcept { + m_exception = std::current_exception(); + } + + /** + * @brief Get the result of this Task. + * + * Calling this function will rethrow any exception that the Task throw while + * executing. If the Task executed without error, then this function does + * nothing. + */ + void result() { + if (m_exception) { + std::rethrow_exception(m_exception); + } + } + + private: + std::exception_ptr m_exception = nullptr; +}; + +} // namespace details +} // namespace scl::coro + +#endif // SCL_CORO_PROMISE_H diff --git a/include/scl/coro/runtime.h b/include/scl/coro/runtime.h new file mode 100644 index 0000000..ec3d906 --- /dev/null +++ b/include/scl/coro/runtime.h @@ -0,0 +1,215 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_CORO_RUNTIME_H +#define SCL_CORO_RUNTIME_H + +#include +#include +#include +#include +#include + +#include "scl/coro/batch.h" +#include "scl/coro/future.h" +#include "scl/coro/promise.h" +#include "scl/coro/sleep_awaiter.h" +#include "scl/coro/task.h" +#include "scl/util/time.h" + +namespace scl::coro { + +/** + * @brief Interface for a coroutine runtime. + * + *

A coroutine runtime should be able to handle scheduling and descheduling + * of coroutines, as well as determination of which coroutines gets to run next. + * + *

Scheduled coroutines can roughly be divided into three categories: + * Coroutines that can be executed as soon as possible; coroutines which can be + * executed when some predicate becomes true, and coroutines that can be + * executed after some time has passed. This interface contains functions for + * each case, and instantiations handle each slightly differently. + */ +class Runtime { + public: + virtual ~Runtime() {} + + /** + * @brief Schedule a coroutine for execution when some predicate is true. + * @param coroutine the coroutine. + * @param predicate a predicate indicating when the coroutine can be resumed. + */ + virtual void schedule(std::coroutine_handle<> coroutine, + std::function&& predicate) = 0; + + /** + * @brief Schedule a coroutine for execution after some delay. + * @param coroutine the coroutine. + * @param delay a delay. + */ + virtual void schedule(std::coroutine_handle<> coroutine, + util::Time::Duration delay) = 0; + + /** + * @brief Schedule a coroutine for execution immediately. + * @param coroutine the coroutine. + */ + void schedule(std::coroutine_handle<> coroutine) { + return this->schedule(coroutine, []() { return true; }); + } + + /** + * @brief Deschedule a coroutine. + * @param coroutine the coroutine. + */ + virtual void deschedule(std::coroutine_handle<> coroutine) = 0; + + /** + * @brief Check if there are coroutines that still need to be executed. + */ + virtual bool taskQueueEmpty() const = 0; + + /** + * @brief Get the next coroutine to execute. + */ + virtual std::coroutine_handle<> next() = 0; + + /** + * @brief Assigns this runtime to an awaitable. + * @param awaitable the awaitable, a type with a "setRuntime" function. + */ + template + void assignTo(AWAITABLE& awaitable) { + awaitable.setRuntime(this); + } + + /** + * @brief Run a task to completion. + * @param task the task. + * @return the result of running \p task. + */ + template + RESULT run(Task&& task) { + task.setRuntime(this); + schedule(task.m_handle); + + while (!taskQueueEmpty()) { + next().resume(); + } + + if constexpr (std::is_void_v) { + task.result(); + } else { + return task.result(); + } + } +}; + +/** + * @brief A Default implementation for a coroutine runtime. + */ +class DefaultRuntime final : public Runtime { + using Pair = std::pair, std::function>; + + public: + /** + * @brief Create a default runtime. + */ + static std::unique_ptr create() { + return std::make_unique(); + } + + ~DefaultRuntime() {} + + void schedule(std::coroutine_handle<> coroutine, + std::function&& predicate) override { + m_tq.emplace_back(coroutine, std::move(predicate)); + } + + void schedule(std::coroutine_handle<> coroutine, + util::Time::Duration delay) override { + const auto start = util::Time::now(); + schedule(coroutine, [start, delay]() { + const auto now = util::Time::now(); + return now - start >= delay; + }); + } + + void deschedule(std::coroutine_handle<> coroutine) override; + + bool taskQueueEmpty() const override { + return m_tq.empty(); + } + + std::coroutine_handle<> next() override; + + private: + std::list m_tq; +}; + +namespace details { + +template +AWAITABLE TaskPromiseBase::await_transform(AWAITABLE&& awaitable) { + m_runtime->assignTo(awaitable); + return std::forward(awaitable); +} + +template +std::coroutine_handle<> FutureAwaiter::await_suspend( + std::coroutine_handle<> handle) { + m_runtime->schedule(handle, std::move(m_future)); + return m_runtime->next(); +} + +template +std::coroutine_handle<> Batch::await_suspend( + std::coroutine_handle<> coroutine) { + for (auto& task : m_tasks) { + task.setRuntime(m_runtime); + m_runtime->schedule(task.m_handle); + } + + m_runtime->schedule(coroutine, [this]() { return await_ready(); }); + + return m_runtime->next(); +} + +template +std::coroutine_handle<> PartialBatch::await_suspend( + std::coroutine_handle<> coroutine) { + for (auto& task : m_tasks) { + task.setRuntime(m_runtime); + m_runtime->schedule(task.m_handle); + } + + m_runtime->schedule(coroutine, [this]() { return await_ready(); }); + + return m_runtime->next(); +} + +inline std::coroutine_handle<> SleepAwaiter::await_suspend( + std::coroutine_handle<> handle) { + m_runtime->schedule(handle, m_duration); + return m_runtime->next(); +} + +} // namespace details +} // namespace scl::coro + +#endif // SCL_CORO_RUNTIME_H diff --git a/include/scl/coro/sleep_awaiter.h b/include/scl/coro/sleep_awaiter.h new file mode 100644 index 0000000..ab9f1c7 --- /dev/null +++ b/include/scl/coro/sleep_awaiter.h @@ -0,0 +1,78 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_CORO_SLEEP_AWAITER_H +#define SCL_CORO_SLEEP_AWAITER_H + +#include +#include + +#include "scl/util/time.h" + +namespace scl::coro { + +class Runtime; + +namespace details { + +/** + * @brief Awaiter interface for suspending coroutines for some amount of time. + */ +class SleepAwaiter { + public: + /** + * @brief Create a new sleep awaiter. + */ + SleepAwaiter(util::Time::Duration duration) : m_duration(duration) {} + + /** + * @brief Check if the sleep awaiter is ready. + * + * The assumption made is that duration > 0, and so this function + * always returns false. + */ + bool await_ready() const noexcept { + return false; + } + + /** + * @brief Suspend the coroutine that is being put to sleep. + */ + std::coroutine_handle<> await_suspend(std::coroutine_handle<> handle); + + /** + * @brief Resume the coroutine. Does nothing. + */ + void await_resume() const noexcept {}; + + /** + * @brief Set the runtime for this coroutine. + */ + void setRuntime(Runtime* runtime) { + m_runtime = runtime; + } + + private: + util::Time::Duration m_duration; + + Runtime* m_runtime = nullptr; +}; + +} // namespace details +} // namespace scl::coro + +#endif // SCL_CORO_SLEEP_AWAITER_H diff --git a/include/scl/coro/task.h b/include/scl/coro/task.h new file mode 100644 index 0000000..eeedbac --- /dev/null +++ b/include/scl/coro/task.h @@ -0,0 +1,193 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_CORO_TASK_H +#define SCL_CORO_TASK_H + +#include +#include +#include +#include +#include + +#include "scl/coro/batch.h" +#include "scl/coro/promise.h" + +namespace scl::coro { + +class Runtime; + +namespace details { + +/** + * @brief Remove a handle from a runtime. + * @param runtime the runtime. + * @param handle handle for the coroutine. + * + * This function is used when Task destroys the coroutine state. Part of this + * teardown involves telling the current runtime that the handle should not be + * considered for resumption anymore. + */ +void removeHandle(Runtime* runtime, std::coroutine_handle<> handle); + +} // namespace details + +/** + * @brief A coroutine task. + * @tparam RESULT the type of the return value of the coroutine. + * + * coro::Task specifies a coroutine which returns a value of type \p RESULT. + * Tasks are cold start, i.e., they wont start executing until they are awaited. + * + * Tasks are move only types and considered the unique owner of the + * std::coroutine_handle which is associated with the coroutine. + */ +template +class Task { + public: + /** + * @brief Promise type of Task. + */ + using promise_type = details::TaskPromise; + + /** + * @brief Destructor. + * + * The destructor of Task will first tell the current runtime to stop tracking + * the coroutine handle that this task manages. Afterwards the handle is + * destroyed. + */ + ~Task() { + destroy(); + } + + /** + * @brief Move constructor. + */ + Task(Task&& other) noexcept + : m_handle(std::exchange(other.m_handle, nullptr)) {} + + /** + * @brief Construction from copy not allowed. + */ + Task(const Task&) = delete; + + /** + * @brief Move assignment. + */ + Task& operator=(Task&& other) noexcept { + destroy(); + m_handle = std::exchange(other.m_handle, nullptr); + return *this; + } + + /** + * @brief Assignment from copy not allowed. + */ + Task& operator=(const Task&) = delete; + + /** + * @brief Allows a task to be co_await'ed. + */ + auto operator co_await() { + struct Awaiter { + std::coroutine_handle coroutine; + + // the awaiting corouting can resume immediately if the task it is waiting + // for already finished. + bool await_ready() noexcept { + return coroutine.done(); + } + + // the awaiting coroutine is suspended. So we will resume the task it is + // waiting for, and designate the awaiter as the coroutine that should be + // run when the task finishes. + std::coroutine_handle await_suspend( + std::coroutine_handle<> awaiter) { + coroutine.promise().setNext(awaiter); + return coroutine; + } + + // get the result of the task being co_await'ed. + auto await_resume() { + return coroutine.promise().result(); + } + }; + + return Awaiter{m_handle}; + } + + /** + * @brief Set the Runtime for executing this Task. + */ + void setRuntime(Runtime* runtime) { + m_handle.promise().setRuntime(runtime); + } + + /** + * @brief Destroy this task. + */ + void destroy() { + if (m_handle) { + details::removeHandle(m_handle.promise().getRuntime(), m_handle); + m_handle.destroy(); + } + } + + /** + * @brief Check if a result is ready. + */ + bool ready() const { + return m_handle.done(); + } + + /** + * @brief Get the return value of this task. + */ + RESULT result() const { + return m_handle.promise().result(); + } + + /** + * @brief The coroutine handle associated with this task. + */ + std::coroutine_handle m_handle; + + private: + friend promise_type; + + explicit Task(std::coroutine_handle handle) + : m_handle(handle) {} +}; + +namespace details { + +template +Task TaskPromise::get_return_object() { + return Task( + std::coroutine_handle>::from_promise(*this)); +} + +inline Task TaskPromise::get_return_object() { + return Task( + std::coroutine_handle>::from_promise(*this)); +} + +} // namespace details +} // namespace scl::coro + +#endif // SCL_CORO_TASK_H diff --git a/include/scl/math/array.h b/include/scl/math/array.h new file mode 100644 index 0000000..00dbbbe --- /dev/null +++ b/include/scl/math/array.h @@ -0,0 +1,461 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_MATH_ARRAY_H +#define SCL_MATH_ARRAY_H + +#include +#include +#include +#include + +#include "scl/serialization/serializer.h" +#include "scl/util/prg.h" + +namespace scl { +namespace math { + +/// @cond + +template +concept PrefixIncrementable = requires(T a) { ++a; }; + +template +concept PrefixDecrementable = requires(T a) { --a; }; + +template +concept Multipliable = requires(T a, V b) { (a) * (b); }; + +template +concept Divisble = requires(T a, T b) { (a) / (b); }; + +template +struct MultiplyResultTypeTrait { + using Type = decltype(std::declval() * std::declval()); +}; + +template +using MultiplyResultType = typename MultiplyResultTypeTrait::Type; + +template +concept Invertible = requires(T& a) { a.invert(); }; + +/// @endcond + +/** + * @brief Array of values, e.g., group elements. + * @tparam T the array element type. + * @tparam N the number of elements. + * + * Array is effectively a wrapper around std::array with + * added functionality that allows operating on Array objects as if they where + * group or ring elements. As such, Array behaves sort of like a direct product + * of N copies of the same group. + */ +template +class Array final { + public: + /** + * @brief Binary size of this Array. + */ + constexpr static std::size_t byteSize() { + return T::byteSize() * N; + } + + /** + * @brief Read an array from a buffer. + */ + static Array read(const unsigned char* src) { + Array p; + for (std::size_t i = 0; i < N; ++i) { + p.m_values[i] = T::read(src + i * T::byteSize()); + } + return p; + } // LCOV_EXCL_LINE + + /** + * @brief Create an array filled with random elements. + */ + static Array random(util::PRG& prg) + requires requires() { T::random(prg); } + { + Array p; + for (std::size_t i = 0; i < N; ++i) { + p.m_values[i] = T::random(prg); + } + return p; + } // LCOV_EXCL_LINE + + /** + * @brief Get an Array filled with the multiplicative identity + */ + static Array one() + requires requires() { T::one(); } + { + return Array(T::one()); + } + + /** + * @brief Get an Array filled with the additive identity. + */ + static Array zero() { + return Array(T::zero()); + } + + /** + * @brief Default constructor. + */ + Array() {} + + /** + * @brief Construct an Array filled with copies of the same element. + * @param element the element. + */ + Array(const T& element) { + m_values.fill(element); + } + + /** + * @brief Construct an Array filled with copies of the same element. + * + * This function will attempt to construct a \p T element using \p value, and + * then fill all slots with this value. + */ + explicit Array(int value) : Array(T{value}){}; + + /** + * @brief Copy construct an Array from another array. + * @param arr the array. + */ + Array(const std::array& arr) : m_values{arr} {} + + /** + * @brief Move construct an Array from another array. + * @param arr the array. + */ + Array(std::array&& arr) : m_values{std::move(arr)} {} + + /** + * @brief Add another Array to this Array. + */ + Array& operator+=(const Array& other) { + for (std::size_t i = 0; i < N; ++i) { + m_values[i] += other.m_values[i]; + } + return *this; + } + + /** + * @brief Add two arrays. + */ + friend Array operator+(const Array& lhs, const Array& rhs) { + Array tmp(lhs); + return tmp += rhs; + } + + /** + * @brief Prefix increment operator. + */ + Array& operator++() + requires PrefixIncrementable + { + for (auto& v : m_values) { + ++v; + } + return *this; + } + + /** + * @brief Postfix increment operator. + */ + friend Array operator++(Array& arr, int) + requires PrefixIncrementable + { + Array tmp(arr); + for (auto& v : arr.m_values) { + ++v; + } + return tmp; + } + + /** + * @brief Subtract an Array from this. + */ + Array& operator-=(const Array& other) { + for (std::size_t i = 0; i < N; ++i) { + m_values[i] -= other.m_values[i]; + } + return *this; + } + + /** + * @brief Subtract two Arrays. + */ + friend Array operator-(const Array& lhs, const Array& rhs) { + Array tmp(lhs); + return tmp -= rhs; + } + + /** + * @brief Prefix decrement operator. + */ + Array& operator--() + requires PrefixDecrementable + { + for (auto& v : m_values) { + --v; + } + return *this; + } + + /** + * @brief Postfix decrement operator. + */ + friend Array operator--(Array& arr, int) + requires PrefixDecrementable + { + Array tmp(arr); + for (auto& v : arr.m_values) { + --v; + } + return tmp; + } + + /** + * @brief Negate this Array. + */ + Array& negate() { + for (std::size_t i = 0; i < N; ++i) { + m_values[i].Negate(); + } + return *this; + } + + /** + * @brief Multiply this Array with a scalar. + */ + template + Array& operator*=(const S& scalar) + requires Multipliable + { + for (std::size_t i = 0; i < N; i++) { + m_values[i] *= scalar; + } + return *this; + } + + /** + * @brief Multiply two Arrays entry-wise. + */ + template + Array& operator*=(const Array& other) + requires Multipliable + { + for (std::size_t i = 0; i < N; ++i) { + m_values[i] *= other[i]; + } + return *this; + } + + /** + * @brief Multiply a scalar with this Array. + */ + template + Array operator*(const S& scalar) const + requires Multipliable + { + auto copy = *this; + return copy *= scalar; + } + + /** + * @brief Multiply two Arrays entry-wise. + */ + template + friend Array, N> operator*(const Array& lhs, + const Array& rhs) + requires Multipliable + { + Array, N> tmp; + for (std::size_t i = 0; i < N; i++) { + tmp[i] = lhs[i] * rhs[i]; + } + return tmp; + } + + /** + * @brief Invert all entries in this Array. + */ + Array& invert() + requires Invertible + { + for (std::size_t i = 0; i < N; ++i) { + m_values[i].invert(); + } + return *this; + } + + /** + * @brief Compute the inverse of each entry in this Array. + */ + Array Inverse() const + requires Invertible + { + Array p = *this; + return p.Invert(); + } + + /** + * @brief Divide this Array by another Array entry-wise. + */ + Array operator/=(const Array& other) + requires Divisble + { + for (std::size_t i = 0; i < N; ++i) { + m_values[i] /= other[i]; + } + return *this; + } + + /** + * @brief Divide this Array by another Array entry-wise. + */ + Array operator/(const Array& other) const + requires Divisble + { + Array r = *this; + r /= other; + return r; + } // LCOV_EXCL_LINE + + /** + * @brief Get the value at a particular entry in the product element. + */ + T& operator[](std::size_t index) { + return m_values[index]; + } + + /** + * @brief Get the value at a particular entry in the product element. + */ + T operator[](std::size_t index) const { + return m_values[index]; + } + + /** + * @brief Compare two product elements. + */ + bool equal(const Array& other) const { + bool eq = true; + for (std::size_t i = 0; i < N; ++i) { + eq = eq && (m_values[i] == other.m_values[i]); + } + return eq; + } + + /** + * @brief Equality operator for Array. + */ + friend bool operator==(const Array& lhs, const Array& rhs) { + return lhs.equal(rhs); + } + + /** + * @brief In-equality operator for Array. + */ + friend bool operator!=(const Array& lhs, const Array& rhs) { + return !(lhs == rhs); + } + + /** + * @brief Get a string representation of this product element. + */ + std::string toString() const { + std::stringstream ss; + ss << "P{"; + for (std::size_t i = 0; i < N - 1; ++i) { + ss << m_values[i] << ", "; + } + ss << m_values[N - 1] << "}"; + return ss.str(); + } + + /** + * @brief Stream printing operator for Array. + */ + friend std::ostream& operator<<(std::ostream& os, const Array& array) { + return os << array.toString(); + } + + /** + * @brief Write this product element to a buffer. + */ + void write(unsigned char* dest) const { + for (std::size_t i = 0; i < N; ++i) { + m_values[i].write(dest + i * T::byteSize()); + } + } + + private: + std::array m_values; +}; + +} // namespace math + +namespace seri { + +/** + * @brief Serializer specialization for product types. + */ +template +struct Serializer> { + /** + * @brief Get the serialized size of a product type. + */ + static constexpr std::size_t sizeOf( + const math::Array& /* ignored */) { + return math::Array::byteSize(); + } + + /** + * @brief Write a product element to a buffer. + * @param prod the element. + * @param buf the buffer. + */ + static std::size_t write(const math::Array& prod, + unsigned char* buf) { + prod.write(buf); + return sizeOf(prod); + } + + /** + * @brief Read a product element from a buffer. + * @param prod the output. + * @param buf the buffer. + */ + static std::size_t read(math::Array& prod, + const unsigned char* buf) { + prod = math::Array::read(buf); + return sizeOf(prod); + } +}; + +} // namespace seri + +} // namespace scl + +#endif // SCL_MATH_ARRAY_H diff --git a/include/scl/math/ec_ops.h b/include/scl/math/curves/ec_ops.h similarity index 60% rename from include/scl/math/ec_ops.h rename to include/scl/math/curves/ec_ops.h index 6adabb0..d96bfaf 100644 --- a/include/scl/math/ec_ops.h +++ b/include/scl/math/curves/ec_ops.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,35 +15,38 @@ * along with this program. If not, see . */ -#ifndef SCL_MATH_EC_OPS_H -#define SCL_MATH_EC_OPS_H +#ifndef SCL_MATH_CURVES_EC_OPS_H +#define SCL_MATH_CURVES_EC_OPS_H + +#include +#include #include "scl/math/ff.h" #include "scl/math/number.h" -namespace scl::math { +namespace scl::math::ec { /** * @brief Set a point equal to the point-at-infinity. * @param out the point to set to the point-at-infinity. */ -template -void CurveSetPointAtInfinity(typename C::ValueType& out); +template +void setPointAtInfinity(typename CURVE::ValueType& out); /** * @brief Check if a point is equal to the point-at-infinity. * @param point the point * @return true if \p point is equal to the point-at-infinity, otherwise false. */ -template -bool CurveIsPointAtInfinity(const typename C::ValueType& point); +template +bool isPointAtInfinity(const typename CURVE::ValueType& point); /** * @brief Set a point equal to the generator of this curve. * @param out the point to set equal to the generator of this curve. */ -template -void CurveSetGenerator(typename C::ValueType& out); +template +void setGenerator(typename CURVE::ValueType& out); /** * @brief Set a point equal to an affine point. @@ -51,83 +54,84 @@ void CurveSetGenerator(typename C::ValueType& out); * @param x the x coordinate * @param y the y coordinate */ -template -void CurveSetAffine(typename C::ValueType& out, - const FF& x, - const FF& y); +template +void setAffine(typename CURVE::ValueType& out, + const FF& x, + const FF& y); /** * @brief Convert a point to a pair of affine coordinates. * @param point the point to convert. * @return a set of affine coordinates. */ -template -std::array, 2> CurveToAffine( - const typename C::ValueType& point); +template +std::array, 2> toAffine( + const typename CURVE::ValueType& point); /** * @brief Add two elliptic curve points in-place. * @param out the first point and output * @param in the second point */ -template -void CurveAdd(typename C::ValueType& out, const typename C::ValueType& in); +template +void add(typename CURVE::ValueType& out, const typename CURVE::ValueType& in); /** * @brief Double an elliptic curve point in-place. * @param out the point to double */ -template -void CurveDouble(typename C::ValueType& out); +template +void dbl(typename CURVE::ValueType& out); /** * @brief Subtract two elliptic curve points in-place. * @param out the first point and output * @param in the second point */ -template -void CurveSubtract(typename C::ValueType& out, const typename C::ValueType& in); +template +void subtract(typename CURVE::ValueType& out, + const typename CURVE::ValueType& in); /** * @brief Negate an elliptic curve point. * @param out the point to negate */ -template -void CurveNegate(typename C::ValueType& out); +template +void negate(typename CURVE::ValueType& out); /** * @brief Scalar multiply an elliptic curve point in-place. * @param out the point * @param scalar the scalar */ -template -void CurveScalarMultiply(typename C::ValueType& out, const Number& scalar); +template +void scalarMultiply(typename CURVE::ValueType& out, const Number& scalar); /** * @brief Scalar multiply an elliptic curve point in-place. * @param out the point * @param scalar the scalar */ -template -void CurveScalarMultiply(typename C::ValueType& out, - const FF& scalar); +template +void scalarMultiply(typename CURVE::ValueType& out, + const FF& scalar); /** * @brief Check if two elliptic curve points are equal. * @param in1 the first point * @param in2 the second point */ -template -bool CurveEqual(const typename C::ValueType& in1, - const typename C::ValueType& in2); +template +bool equal(const typename CURVE::ValueType& in1, + const typename CURVE::ValueType& in2); /** * @brief Read an elliptic curve from a byte buffer. * @param out where to store the point * @param src the buffer */ -template -void CurveFromBytes(typename C::ValueType& out, const unsigned char* src); +template +void fromBytes(typename CURVE::ValueType& out, const unsigned char* src); /** * @brief Write an elliptic curve point to a buffer. @@ -135,19 +139,19 @@ void CurveFromBytes(typename C::ValueType& out, const unsigned char* src); * @param in the elliptic curve point to write * @param compress whether to compress the point */ -template -void CurveToBytes(unsigned char* dest, - const typename C::ValueType& in, - bool compress); +template +void toBytes(unsigned char* dest, + const typename CURVE::ValueType& in, + bool compress); /** * @brief Convert an elliptic curve point to a string * @param point the point * @return an STL string representation of \p in. */ -template -std::string CurveToString(const typename C::ValueType& point); +template +std::string toString(const typename CURVE::ValueType& point); -} // namespace scl::math +} // namespace scl::math::ec -#endif // SCL_MATH_EC_OPS_H +#endif // SCL_MATH_CURVES_EC_OPS_H diff --git a/include/scl/math/curves/secp256k1.h b/include/scl/math/curves/secp256k1.h index 085c944..ed97335 100644 --- a/include/scl/math/curves/secp256k1.h +++ b/include/scl/math/curves/secp256k1.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,64 +20,28 @@ #include -#include "scl/math/ec.h" +#include "scl/math/curves/ec_ops.h" #include "scl/math/ff.h" +#include "scl/math/fields/secp256k1_field.h" +#include "scl/math/fields/secp256k1_scalar.h" -namespace scl::math { +namespace scl::math::ec { /** * @brief Elliptic curve definition for secp256k1. + * @see http://www.secg.org/sec2-v2.pdf */ struct Secp256k1 { /** - * @brief The Field over which secp256k1 is defined. + * @brief The finite field defined by + * \f$p=2^{256}-2^{32}-2^{9}-2^{8}-2^{7}-2^{6}-2^{4}-1\f$ */ - struct Field { - /** - * @brief Field elements are stored as 4 limb numbers internally. - */ - using ValueType = std::array; - - /** - * @brief Name of the secp256k1 field. - */ - constexpr static const char* NAME = "secp256k1_field"; - - /** - * @brief Byte size of a secp256k1 field element. - */ - constexpr static const std::size_t BYTE_SIZE = 4 * sizeof(mp_limb_t); - - /** - * @brief Bit size of a secp256k1 field element. - */ - constexpr static const std::size_t BIT_SIZE = 8 * BYTE_SIZE; - }; + using Field = ff::Secp256k1Field; /** - * @brief Finite field modulo a Secp256k1 prime order sub-group. + * @brief The finite field defined by a large prime order subgroup. */ - struct Scalar { - /** - * @brief Internal type of elements. - */ - using ValueType = std::array; - - /** - * @brief Name of the field. - */ - constexpr static const char* NAME = "secp256k1_order"; - - /** - * @brief Size of an element in bytes. - */ - constexpr static const std::size_t BYTE_SIZE = 4 * sizeof(mp_limb_t); - - /** - * @brief Size of an element in bits. - */ - constexpr static const std::size_t BIT_SIZE = 8 * BYTE_SIZE; - }; + using Scalar = ff::Secp256k1Scalar; /** * @brief Secp256k1 curve elements are stored in projective coordinates. @@ -90,6 +54,6 @@ struct Secp256k1 { constexpr static const char* NAME = "secp256k1"; }; -} // namespace scl::math +} // namespace scl::math::ec #endif // SCL_MATH_CURVES_SECP256K1_H diff --git a/include/scl/math/ec.h b/include/scl/math/ec.h index 9306872..ade2134 100644 --- a/include/scl/math/ec.h +++ b/include/scl/math/ec.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,93 +18,100 @@ #ifndef SCL_MATH_EC_H #define SCL_MATH_EC_H +#include +#include #include +#include -#include "scl/math/ec_ops.h" +#include "scl/math/array.h" +#include "scl/math/curves/ec_ops.h" #include "scl/math/ff.h" #include "scl/math/number.h" -#include "scl/math/ops.h" -namespace scl::math { +namespace scl { +namespace math { /** * @brief Elliptic Curve interface. - * @tparam Curve elliptic curve definition + * @tparam CURVE elliptic curve definition * - * TODO. - * - * @see FF - * @see Secp256k1 + * EC defines a point \f$P\f$ on some Elliptic Curve \f$E(K)\f$. The + * curve parameters is defined through the \p CURVE template parameter and + * appropriate overloads of the functions in the \ref ec namespace. */ -template -class EC final : Add>, Eq>, Print> { +template +class EC final { public: /** - * @brief The field that this curve is defined over. + * @brief Field that this curve is defined over. */ - using Field = FF; + using Field = FF; /** - * @brief A large sub-group of this curve. + * @brief Large subgroup of this curve. */ - using ScalarField = FF; + using ScalarField = FF; /** - * @brief The size of a curve point in bytes. - * @param compressed + * @brief Size of a curve point in bytes. */ - constexpr static std::size_t ByteSize(bool compressed = true) { - return 1 + (compressed ? 0 : Field::ByteSize()) + Field::ByteSize(); + constexpr static std::size_t byteSize(bool compressed) { + return 1 + (compressed ? 0 : Field::byteSize()) + Field::byteSize(); } /** - * @brief The size of a curve point in bits. + * @brief Size of a curve point in bits. */ - constexpr static std::size_t BitSize(bool compressed = true) { - return ByteSize(compressed) * 8; + constexpr static std::size_t bitSize(bool compressed) { + return byteSize(compressed) * 8; } /** - * @brief A string indicating which curve this is. + * @brief String indicating which curve this is. */ - constexpr static const char* Name() { - return Curve::NAME; + constexpr static const char* name() { + return CURVE::NAME; } /** - * @brief Get the generator of this curve. + * @brief A generator of this curve. */ - constexpr static EC Generator() { + constexpr static EC generator() { EC g; - CurveSetGenerator(g.m_value); + ec::setGenerator(g.m_value); return g; - } // LCOV_EXCL_LINE + } /** - * @brief Read an elliptic curve point from bytes. - * @param src the bytes - * @return an elliptic curve point. + * @brief Reads an elliptic curve point from bytes. */ - static EC Read(const unsigned char* src) { + static EC read(const unsigned char* src) { EC e; - CurveFromBytes(e.m_value, src); + ec::fromBytes(e.m_value, src); return e; - } // LCOV_EXCL_LINE + } /** - * @brief Create a point from an pair of affine coordinates. + * @brief Creates a point from a pair of affine coordinates. */ - static EC FromAffine(const Field& x, const Field& y) { + static EC fromAffine(const Field& x, const Field& y) { EC e; - CurveSetAffine(e.m_value, x, y); + ec::setAffine(e.m_value, x, y); return e; } + /** + * @brief Get the additive identity of this curve. + */ + static EC zero() { + return EC{}; + } + /** * @brief Create a new point equal to the point at infinity. */ - explicit constexpr EC() { - CurveSetPointAtInfinity(m_value); + EC() { + ec::setPointAtInfinity(m_value); } /** @@ -114,67 +121,70 @@ class EC final : Add>, Eq>, Print> { /** * @brief Add another EC point to this. - * @param other the other point - * @return this */ EC& operator+=(const EC& other) { - CurveAdd(m_value, other.m_value); + ec::add(m_value, other.m_value); return *this; } + /** + * @brief Add two curve points. + */ + friend EC operator+(const EC& lhs, const EC& rhs) { + EC tmp(lhs); + return tmp += rhs; + } + /** * @brief Double this point. - * @return this after doubling. */ - EC& DoubleInPlace() { - CurveDouble(m_value); + EC& doublePointInPlace() { + ec::dbl(m_value); return *this; } /** * @brief Double this point. - * @return \p this + \p this. */ - EC Double() const { + EC doublePoint() const { EC copy(*this); - return copy.DoubleInPlace(); + return copy.doublePointInPlace(); } /** * @brief Subtract another point from this. - * @param other the other point - * @return this. */ EC& operator-=(const EC& other) { - CurveSubtract(m_value, other.m_value); + ec::subtract(m_value, other.m_value); return *this; } + /** + * @brief Subtract two curve points. + */ + friend EC operator-(const EC& lhs, const EC& rhs) { + EC tmp(lhs); + return tmp -= rhs; + } + /** * @brief Perform a scalar multiplication. - * @param scalar the scalar - * @return this. */ EC& operator*=(const Number& scalar) { - CurveScalarMultiply(m_value, scalar); + ec::scalarMultiply(m_value, scalar); return *this; } /** * @brief Perform a scalar multiplication. - * @param scalar the scalar - * @return this. */ EC& operator*=(const ScalarField& scalar) { - CurveScalarMultiply(m_value, scalar); + ec::scalarMultiply(m_value, scalar); return *this; } /** * @brief Multiply a point with a scalar from the right. - * @param point the point - * @param scalar the scalar - * @return the point multiplied with the scalar. */ friend EC operator*(const EC& point, const Number& scalar) { EC copy(point); @@ -183,9 +193,6 @@ class EC final : Add>, Eq>, Print> { /** * @brief Multiply a point with a scalar from the right. - * @param point the point - * @param scalar the scalar - * @return the point multiplied with the scalar. */ friend EC operator*(const EC& point, const ScalarField& scalar) { EC copy(point); @@ -194,9 +201,6 @@ class EC final : Add>, Eq>, Print> { /** * @brief Multiply a point with a scalar from the left. - * @param point the point - * @param scalar the scalar - * @return the point multiplied with the scalar. */ friend EC operator*(const Number& scalar, const EC& point) { return point * scalar; @@ -204,9 +208,6 @@ class EC final : Add>, Eq>, Print> { /** * @brief Multiply a point with a scalar from the left. - * @param point the point - * @param scalar the scalar - * @return the point multiplied with the scalar. */ friend EC operator*(const ScalarField& scalar, const EC& point) { return point * scalar; @@ -214,68 +215,131 @@ class EC final : Add>, Eq>, Print> { /** * @brief Negate this point. - * @return this. */ - EC& Negate() { - CurveNegate(m_value); + EC& negate() { + ec::negate(m_value); return *this; } + /** + * @brief Negate a curve point. + */ + friend EC operator-(const EC& point) { + EC tmp(point); + return tmp.negate(); + } + /** * @brief Check if this EC point is equal to another EC point. - * @param other the other EC point - * @return true if the two points are equal and false otherwise. */ - bool Equal(const EC& other) const { - return CurveEqual(m_value, other.m_value); + bool equal(const EC& other) const { + return ec::equal(m_value, other.m_value); } // LCOV_EXCL_LINE + /** + * @brief Operator == for curve points. + */ + friend bool operator==(const EC& lhs, const EC& rhs) { + return lhs.equal(rhs); + } + + /** + * @brief Operator != for curve points. + */ + friend bool operator!=(const EC& lhs, const EC& rhs) { + return !(lhs == rhs); + } + /** * @brief Check if this point is equal to the point at inifity. - * @return true if this point is equal to the point at inifity. */ - bool PointAtInfinity() const { - return CurveIsPointAtInfinity(m_value); + bool isPointAtInfinity() const { + return ec::isPointAtInfinity(m_value); } // LCOV_EXCL_LINE /** * @brief Return this point as a pair of affine coordinates. - * @return this point as a pair of affine coordinates. + * + * Only well-defined if the point is not the point at infinity. */ - std::array ToAffine() const { - return CurveToAffine(m_value); + std::array toAffine() const { + return ec::toAffine(m_value); } // LCOV_EXCL_LINE + /** + * @brief Normalize this point. + */ + void normalize() { + if (isPointAtInfinity()) { + ec::setPointAtInfinity(m_value); + } else { + const auto afp = toAffine(); + ec::setAffine(m_value, afp[0], afp[1]); + } + } + /** * @brief Output this point as a string. */ - std::string ToString() const { - return CurveToString(m_value); + std::string toString() const { + return ec::toString(m_value); } // LCOV_EXCL_LINE + /** + * @brief Operator << for printing a curve point. + */ + friend std::ostream& operator<<(std::ostream& os, const EC& e) { + return os << e.toString(); + } + /** * @brief Write this point to a buffer. - * @param dest the destination - * @param compress whether to compress the point */ - void Write(unsigned char* dest, bool compress = true) const { - CurveToBytes(dest, m_value, compress); + void write(unsigned char* dest, bool compress) const { + ec::toBytes(dest, m_value, compress); } // LCOV_EXCL_LINE private: - typename Curve::ValueType m_value; + typename CURVE::ValueType m_value; }; +} // namespace math + +namespace seri { + /** - * @brief Helper class for working with finite field elements. + * @brief Serializer for EC types. * - * This class is a private friend of FF. Specialization of this class therefore - * allows direct access to the internal representation of a finite field - * element. + * Elliptic curve points are serialized uncompressed and in affine form. */ -template -class FFAccess {}; +template +struct Serializer> { + /** + * @brief Get the size of a serialized EC point. + */ + static constexpr std::size_t sizeOf(const math::EC& /* ignored */) { + return math::EC::byteSize(false); + } + + /** + * @brief Write an EC point to a buffer. + */ + static std::size_t write(const math::EC& point, unsigned char* buf) { + point.write(buf, false); + return sizeOf(point); + } + + /** + * @brief Read an EC point from a buffer. + */ + static std::size_t read(math::EC& point, const unsigned char* buf) { + point = math::EC::read(buf); + return sizeOf(point); + } +}; + +} // namespace seri -} // namespace scl::math +} // namespace scl #endif // SCL_MATH_EC_H diff --git a/include/scl/math/ff.h b/include/scl/math/ff.h index c7ed9f6..617249e 100644 --- a/include/scl/math/ff.h +++ b/include/scl/math/ff.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -22,84 +22,72 @@ #include #include -#include "scl/math/ff_ops.h" -#include "scl/math/ops.h" +#include "scl/math/fields/ff_ops.h" +#include "scl/serialization/serializer.h" #include "scl/util/prg.h" -namespace scl::math { +namespace scl { +namespace math { /** * @brief Finite Field interface. - * @tparam Field finite field definition - * @see Mersenne61 - * @see Mersenne127 + * @tparam FIELD finite field definition. */ -template -class FF final : Add>, - Mul>, - Eq>, - Print> { +template +class FF final { public: /** * @brief Size in bytes of a field element. */ - constexpr static std::size_t ByteSize() { - return Field::BYTE_SIZE; + constexpr static std::size_t byteSize() { + return FIELD::BYTE_SIZE; } /** * @brief Actual bit size of an element. */ - constexpr static std::size_t BitSize() { - return Field::BIT_SIZE; + constexpr static std::size_t bitSize() { + return FIELD::BIT_SIZE; } /** * @brief A short string representation of this field. */ - constexpr static const char* Name() { - return Field::NAME; + constexpr static const char* name() { + return FIELD::NAME; } /** * @brief Read a field element from a buffer. - * @param src the buffer - * @return a field element. - * @see scl::FieldFromBytes */ - static FF Read(const unsigned char* src) { + static FF read(const unsigned char* src) { FF e; - FieldFromBytes(e.m_value, src); + ff::fromBytes(e.m_value, src); return e; } // LCOV_EXCL_LINE /** * @brief Create a random element, using a supplied PRG. - * @param prg the PRG - * @return a random element. */ - static FF Random(util::PRG& prg) { - unsigned char buffer[FF::ByteSize()]; - prg.Next(buffer, FF::ByteSize()); - return FF::Read(buffer); + static FF random(util::PRG& prg) { + unsigned char buffer[FF::byteSize()]; + prg.next(buffer, FF::byteSize()); + return FF::read(buffer); } /** - * @brief Create a field element from a string. - * @param str the string - * @return a finite field element. - * @see scl::FieldFromString + * @brief Create a field element from a hex-string. */ - static FF FromString(const std::string& str) { + static FF fromString(const std::string& hexstr) { FF e; - FieldFromString(e.m_value, str); + ff::convertTo(e.m_value, hexstr); return e; } /** * @brief Get the additive identity of this field. */ - static FF Zero() { + static FF zero() { static FF zero; return zero; } @@ -107,24 +95,22 @@ class FF final : Add>, /** * @brief Get the multiplicative identity of this field. */ - static FF One() { + static FF one() { static FF one(1); return one; } /** * @brief Create a new element from an int. - * @param value the value to interpret as a field element - * @see scl::FieldConvertIn */ explicit constexpr FF(int value) { - FieldConvertIn(m_value, value); + ff::convertTo(m_value, value); } /** * @brief Create a new element equal to 0 in the field. */ - explicit constexpr FF() : FF(0) {} + constexpr FF() : FF(0) {} /** * @brief Destrutor. Does nothing. @@ -133,133 +119,206 @@ class FF final : Add>, /** * @brief Add another field element to this. - * @param other the other element - * @return this set to this + \p other. - * @see scl::FieldAdd */ FF& operator+=(const FF& other) { - FieldAdd(m_value, other.m_value); + ff::add(m_value, other.m_value); return *this; } + /** + * @brief Add two finite field elements. + */ + friend FF operator+(const FF& lhs, const FF& rhs) { + FF tmp(lhs); + return tmp += rhs; + } + + /** + * @brief Pre-increment this finite-field element. + */ + FF& operator++() { + return *this += one(); + } + + /** + * @brief Post-increment this finite-field element. + */ + friend FF operator++(FF& e, int) { + FF tmp(e); + ++e; + return tmp; + } + /** * @brief Subtract another field element to this. - * @param other the other element - * @return this set to this - \p other. - * @see scl::FieldSubtract */ FF& operator-=(const FF& other) { - FieldSubtract(m_value, other.m_value); + ff::subtract(m_value, other.m_value); return *this; } + /** + * @brief Subtract two finite field elements. + */ + friend FF operator-(const FF& lhs, const FF& rhs) { + FF tmp(lhs); + return tmp -= rhs; + } + + /** + * @brief Pre-decrement this finite field element. + */ + FF& operator--() { + return *this -= one(); + } + + /** + * @brief Post-decrement this finite field element. + */ + friend FF operator--(FF& e, int) { + FF tmp(e); + --e; + return tmp; + } + /** * @brief Multiply another field element to this. - * @param other the other element - * @return this set to this * \p other. - * @see scl::FieldMultiply */ FF& operator*=(const FF& other) { - FieldMultiply(m_value, other.m_value); + ff::multiply(m_value, other.m_value); return *this; } /** - * @brief Multiplies this with the inverse of another elemenet. - * - * This is not an integer division. Rather, it computes \f$x\cdot y^{-1}\f$ - * where \f$x\f$ is this element and \f$y\f$ is \p other. - * - * @param other the other element - * @return this set to this * other.Inverse(). + * @brief Multiply two finite field elements. + */ + friend FF operator*(const FF& lhs, const FF& rhs) { + FF tmp(lhs); + return tmp *= rhs; + } + + /** + * @brief Divide and assign this finite field element with another element. */ FF& operator/=(const FF& other) { - return operator*=(other.Inverse()); + return operator*=(other.inverse()); + } + + /** + * @brief Divide two finite field elements. + */ + friend FF operator/(const FF& lhs, const FF& rhs) { + FF tmp(lhs); + return tmp /= rhs; } /** * @brief Negates this element. - * @return this set to -this. - * @see scl::FieldNegate */ - FF& Negate() { - FieldNegate(m_value); + FF& negate() { + ff::negate(m_value); return *this; } /** * @brief Computes the additive inverse of this element. - * @return the additive inverse of this. - * @see FF::Negate */ - FF Negated() const { + FF negated() const { auto copy = m_value; FF r; - FieldNegate(copy); + ff::negate(copy); r.m_value = copy; return r; } // LCOV_EXCL_LINE + /** + * @brief Negate a finite field element. + */ + friend FF operator-(const FF& e) { + return e.negated(); + } + /** * @brief Inverts this element. - * @return this set to its inverse. - * @see scl::FieldInvert */ - FF& Invert() { - FieldInvert(m_value); + FF& invert() { + ff::invert(m_value); return *this; } /** * @brief Computes the inverse of this element. - * @return the inverse of this element. - * @see FF::Invert */ - FF Inverse() const { + FF inverse() const { FF copy = *this; - return copy.Invert(); + return copy.invert(); } /** * @brief Checks if this element is equal to another. - * @param other the other element - * @return true if this is equal to \p other. - * @see scl::FieldEqual */ - bool Equal(const FF& other) const { - return FieldEqual(m_value, other.m_value); + bool equal(const FF& other) const { + return ff::equal(m_value, other.m_value); + } + + /** + * @brief Equality operator for finite field elements. + */ + friend bool operator==(const FF& lhs, const FF& rhs) { + return lhs.equal(rhs); + } + + /** + * @brief In-equality operator for finite field elements. + */ + friend bool operator!=(const FF& lhs, const FF& rhs) { + return !(lhs == rhs); } /** * @brief Returns a string representation of this element. - * @return a string representation of this field element. - * @see scl::FieldToString */ - std::string ToString() const { - return FieldToString(m_value); + std::string toString() const { + return ff::toString(m_value); + } + + /** + * @brief Write a string representation of a finite field element to a stream. + */ + friend std::ostream& operator<<(std::ostream& os, const FF& e) { + return os << e.toString(); } /** * @brief Write this element to a byte buffer. - * @param dest the buffer. Must have space for \ref ByteSize() bytes. - * @see scl::FieldToBytes */ - void Write(unsigned char* dest) const { - FieldToBytes(dest, m_value); + void write(unsigned char* dest) const { + ff::toBytes(dest, m_value); } - private: - typename Field::ValueType m_value; + /** + * @brief Get the internal value of this field element. + */ + typename FIELD::ValueType value() const { + return m_value; + } + + /** + * @brief Get the internal value of this field element. + */ + typename FIELD::ValueType& value() { + return m_value; + } - template - friend class FFAccess; + private: + typename FIELD::ValueType m_value; }; /** * @brief Returns the order of a finite field. */ template -Number Order(); +Number order(); /** * @brief Raise an element to a power. @@ -268,13 +327,14 @@ Number Order(); * @return \p base raised to the \p exp th power. */ template -FF Exp(const FF& base, std::size_t exp) { +FF exp(const FF& base, std::size_t exp) { + FF r = FF::one(); + if (exp == 0) { - return FF::One(); + return r; } const auto n = sizeof(std::size_t) * 8 - __builtin_clzll(exp); - FF r = FF::One(); for (std::size_t i = n; i-- > 0;) { r *= r; if (((exp >> i) & 1) == 1) { @@ -285,6 +345,52 @@ FF Exp(const FF& base, std::size_t exp) { return r; } -} // namespace scl::math +} // namespace math + +namespace seri { + +/** + * @brief Serializer specialization for math::FF types. + */ +template +struct Serializer> { + /** + * @brief Determine the size of an math::FF value. + * + * The size of an math::FF element can be determined from its type alone, so + * the argument is ignored. + */ + static constexpr std::size_t sizeOf(const math::FF& /* ignored */) { + return math::FF::byteSize(); + } + + /** + * @brief Write an math::FF element to a buffer. + * @param elem the element. + * @param buf the buffer. + * + * Calls math::FF::write(). + */ + static std::size_t write(const math::FF& elem, unsigned char* buf) { + elem.write(buf); + return sizeOf(elem); + } + + /** + * @brief Read an math::FF element from a buffer. + * @param elem output variable holding the read element after reading. + * @param buf the buffer. + * @return the number of bytes read. + * + * Calls math::FF::read() and returns math::FF::byteSize(); + */ + static std::size_t read(math::FF& elem, const unsigned char* buf) { + elem = math::FF::read(buf); + return sizeOf(elem); + } +}; + +} // namespace seri +} // namespace scl #endif // SCL_MATH_FF_H diff --git a/include/scl/math/ff_ops.h b/include/scl/math/fields/ff_ops.h similarity index 64% rename from include/scl/math/ff_ops.h rename to include/scl/math/fields/ff_ops.h index 8b81ee0..4c70c2a 100644 --- a/include/scl/math/ff_ops.h +++ b/include/scl/math/fields/ff_ops.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,8 +15,8 @@ * along with this program. If not, see . */ -#ifndef SCL_MATH_FF_OPS_H -#define SCL_MATH_FF_OPS_H +#ifndef SCL_MATH_FIELDS_FF_OPS_H +#define SCL_MATH_FIELDS_FF_OPS_H #include #include @@ -25,53 +25,63 @@ #include "scl/math/number.h" -namespace scl::math { +namespace scl::math::ff { /** * @brief Convert an int into a field element. * @param out where to store the converted element * @param value the integer to convert */ -template -void FieldConvertIn(typename F::ValueType& out, int value); +template +void convertTo(typename FIELD::ValueType& out, int value); + +/** + * @brief Read a field element from a string. + * @param out where to store the resulting element + * @param src the string + */ +template +void convertTo(typename FIELD::ValueType& out, const std::string& src); /** * @brief Add two field elements in-place. * @param out the first operand and output * @param op the second operand */ -template -void FieldAdd(typename F::ValueType& out, const typename F::ValueType& op); +template +void add(typename FIELD::ValueType& out, const typename FIELD::ValueType& op); /** * @brief Subtract two field elements in-place. * @param out the first operand and output * @param op the second operand */ -template -void FieldSubtract(typename F::ValueType& out, const typename F::ValueType& op); +template +void subtract(typename FIELD::ValueType& out, + const typename FIELD::ValueType& op); /** * @brief Multiply two field elements in-place. * @param out the first operand and output * @param op the second operand */ -template -void FieldMultiply(typename F::ValueType& out, const typename F::ValueType& op); +template +void multiply(typename FIELD::ValueType& out, + const typename FIELD::ValueType& op); /** * @brief Negate a field element in-place. * @param out the element to negate */ -template -void FieldNegate(typename F::ValueType& out); +template +void negate(typename FIELD::ValueType& out); /** * @brief Invert a field element in-place. * @param out the element to invert */ -template -void FieldInvert(typename F::ValueType& out); +template +void invert(typename FIELD::ValueType& out); /** * @brief Check if two field elements are the same. @@ -79,42 +89,34 @@ void FieldInvert(typename F::ValueType& out); * @param in2 the second element * @return true if \p in1 and \p in2 are the same and false otherwise */ -template -bool FieldEqual(const typename F::ValueType& in1, - const typename F::ValueType& in2); +template +bool equal(const typename FIELD::ValueType& in1, + const typename FIELD::ValueType& in2); /** * @brief Convert a field element to bytes. * @param dest the field element to convert * @param src where to store the converted element */ -template -void FieldToBytes(unsigned char* dest, const typename F::ValueType& src); +template +void toBytes(unsigned char* dest, const typename FIELD::ValueType& src); /** * @brief Convert the content of a buffer to a field element. * @param dest where to store the converted element * @param src the buffer */ -template -void FieldFromBytes(typename F::ValueType& dest, const unsigned char* src); +template +void fromBytes(typename FIELD::ValueType& dest, const unsigned char* src); /** * @brief Convert a field element to a string. * @param in the field element to convert * @return an STL string representation of \p in. */ -template -std::string FieldToString(const typename F::ValueType& in); - -/** - * @brief Read a field element from a string. - * @param out where to store the resulting element - * @param src the string - */ -template -void FieldFromString(typename F::ValueType& out, const std::string& src); +template +std::string toString(const typename FIELD::ValueType& in); -} // namespace scl::math +} // namespace scl::math::ff -#endif // SCL_MATH_FF_OPS_H +#endif // SCL_MATH_FIELDS_FF_OPS_H diff --git a/include/scl/math/ops_gmp_ff.h b/include/scl/math/fields/ff_ops_gmp.h similarity index 85% rename from include/scl/math/ops_gmp_ff.h rename to include/scl/math/fields/ff_ops_gmp.h index 11aebde..2be0aaa 100644 --- a/include/scl/math/ops_gmp_ff.h +++ b/include/scl/math/fields/ff_ops_gmp.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,8 +15,8 @@ * along with this program. If not, see . */ -#ifndef SCL_MATH_OPS_GMP_FF_H -#define SCL_MATH_OPS_GMP_FF_H +#ifndef SCL_MATH_FIELDS_FF_OPS_GMP_H +#define SCL_MATH_FIELDS_FF_OPS_GMP_H #include #include @@ -29,7 +29,7 @@ #include "scl/util/str.h" -namespace scl::math { +namespace scl::math::ff { #define SCL_BITS_PER_LIMB static_cast(mp_bits_per_limb) #define SCL_BYTES_PER_LIMB sizeof(mp_limb_t) @@ -64,7 +64,7 @@ struct RedParams { * @param rp reduction parameters. */ template -void MontyIn(mp_limb_t* out, const RedParams rp) { +void montyIn(mp_limb_t* out, const RedParams rp) { mp_limb_t qp[N + 1]; mp_limb_t shift[2 * N] = {0}; // multiply val by 2^{w * N} @@ -80,7 +80,7 @@ void MontyIn(mp_limb_t* out, const RedParams rp) { * @param rp reduction parameters. */ template -void MontyRedc(mp_limb_t* out, const RedParams rp) { +void montyRedc(mp_limb_t* out, const RedParams rp) { // q = val * rp.mc // TODO: This can be optimized a bit since q is reduced modulo 2^N below mp_limb_t q[2 * N]; @@ -110,12 +110,12 @@ void MontyRedc(mp_limb_t* out, const RedParams rp) { * modulo \p mod. The function assumes that \p out has been zeroed. */ template -void MontyInFromInt(mp_limb_t* out, const int value, const RedParams rp) { +void montyInFromInt(mp_limb_t* out, const int value, const RedParams rp) { out[0] = std::abs(value); if (value < 0) { mpn_sub_n(out, rp.prime, out, N); } - MontyIn(out, rp); + montyIn(out, rp); } /** @@ -126,7 +126,7 @@ void MontyInFromInt(mp_limb_t* out, const int value, const RedParams rp) { * @param rp reduction parameters. */ template -void MontyModAdd(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { +void montyModAdd(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { auto carry = mpn_add_n(out, out, op, N); if (carry || mpn_cmp(out, rp.prime, N) >= 0) { mpn_sub_n(out, out, rp.prime, N); @@ -141,7 +141,7 @@ void MontyModAdd(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { * @param rp reduction parameters. */ template -void MontyModSub(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { +void montyModSub(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { auto carry = mpn_sub_n(out, out, op, N); if (carry) { mpn_add_n(out, out, rp.prime, N); @@ -155,9 +155,9 @@ void MontyModSub(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { * @param rp reduction parameters. */ template -void MontyModNeg(mp_limb_t* out, const RedParams rp) { +void montyModNeg(mp_limb_t* out, const RedParams rp) { mp_limb_t t[N] = {0}; - MontyModSub(t, out, rp); + montyModSub(t, out, rp); std::copy(t, t + N, out); } @@ -172,7 +172,7 @@ void MontyModNeg(mp_limb_t* out, const RedParams rp) { * multiplication. */ template -void MontyModMul(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { +void montyModMul(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { mp_limb_t u[N + 1] = {0}; for (std::size_t i = 0; i < N; ++i) { @@ -198,17 +198,17 @@ void MontyModMul(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { * @param rp reduction parameters. */ template -void MontyModSqr(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { +void montyModSqr(mp_limb_t* out, const mp_limb_t* op, const RedParams rp) { mp_limb_t res[2 * N]; mpn_sqr(res, op, N); - MontyRedc(res, rp); + montyRedc(res, rp); std::copy(res, res + N, out); } /** * @brief Checks if a bit has been set. */ -inline bool TestBit(const mp_limb_t* v, std::size_t pos) { +inline bool testBit(const mp_limb_t* v, std::size_t pos) { auto limb = pos / SCL_BITS_PER_LIMB; auto bit = pos % SCL_BITS_PER_LIMB; return ((v[limb] >> bit) & 1) == 1; @@ -223,15 +223,15 @@ inline bool TestBit(const mp_limb_t* v, std::size_t pos) { * @param rp reduction parameters. */ template -void MontyModExp(mp_limb_t* out, +void montyModExp(mp_limb_t* out, const mp_limb_t* base, const mp_limb_t* exp, const RedParams rp) { auto n = mpn_sizeinbase(exp, N, 2); for (std::size_t i = n; i-- > 0;) { - MontyModSqr(out, out, rp); - if (TestBit(exp, i)) { - MontyModMul(out, base, rp); + montyModSqr(out, out, rp); + if (testBit(exp, i)) { + montyModMul(out, base, rp); } } } @@ -248,7 +248,7 @@ void MontyModExp(mp_limb_t* out, * prime_minus_2 argument is assumed to be \f$rp.prime - 2\f$. */ template -void MontyModInv(mp_limb_t* out, +void montyModInv(mp_limb_t* out, const mp_limb_t* op, const mp_limb_t* prime_minus_2, const RedParams rp) { @@ -256,7 +256,7 @@ void MontyModInv(mp_limb_t* out, throw std::invalid_argument("0 not invertible modulo prime"); } - MontyModExp(out, op, prime_minus_2, rp); + montyModExp(out, op, prime_minus_2, rp); } /** @@ -265,7 +265,7 @@ void MontyModInv(mp_limb_t* out, * @return a value x such that R(x, 0) <==> R(lhs, rhs). */ template -int CompareValues(const mp_limb_t* lhs, const mp_limb_t* rhs) { +int compareValues(const mp_limb_t* lhs, const mp_limb_t* rhs) { return mpn_cmp(lhs, rhs, N); } @@ -277,7 +277,7 @@ int CompareValues(const mp_limb_t* lhs, const mp_limb_t* rhs) { * @param rp reduction parameters. */ template -void MontyFromBytes(mp_limb_t* out, +void montyFromBytes(mp_limb_t* out, const unsigned char* src, const RedParams rp) { for (int i = N - 1; i >= 0; --i) { @@ -286,7 +286,7 @@ void MontyFromBytes(mp_limb_t* out, } } - MontyIn(out, rp); + montyIn(out, rp); } /** @@ -297,12 +297,12 @@ void MontyFromBytes(mp_limb_t* out, * @param rp reduction parameters. */ template -void MontyToBytes(unsigned char* dest, +void montyToBytes(unsigned char* dest, const mp_limb_t* src, const RedParams rp) { mp_limb_t padded[2 * N] = {0}; std::copy(src, src + N, padded); - MontyRedc(padded, rp); + montyRedc(padded, rp); std::size_t c = 0; for (int i = N - 1; i >= 0; --i) { @@ -320,7 +320,7 @@ void MontyToBytes(unsigned char* dest, * This method is used handle a string representation of a number with leading * zeros. */ -std::size_t FindFirstNonZero(const std::string& s); +std::size_t findFirstNonZero(const std::string& s); /** * @brief Convert a value in Montgomery representation to a string. @@ -330,25 +330,25 @@ std::size_t FindFirstNonZero(const std::string& s); * @return \p val as a string. */ template -std::string MontyToString(const mp_limb_t* val, const RedParams rp) { +std::string montyToString(const mp_limb_t* val, const RedParams rp) { mp_limb_t padded[2 * N] = {0}; std::copy(val, val + N, padded); - MontyRedc(padded, rp); + montyRedc(padded, rp); - static const char* kHexChars = "0123456789abcdef"; + static const char* hex_chars = "0123456789abcdef"; std::stringstream ss; for (int i = N - 1; i >= 0; --i) { const auto v = padded[i]; for (int j = SCL_BYTES_PER_LIMB - 1; j >= 0; --j) { const auto vv = v >> (j * 8); - ss << kHexChars[(vv & 0xF0) >> 4]; - ss << kHexChars[vv & 0x0F]; + ss << hex_chars[(vv & 0xF0) >> 4]; + ss << hex_chars[vv & 0x0F]; } } auto s = ss.str(); // trim leading 0s - auto n = FindFirstNonZero(s); + auto n = findFirstNonZero(s); if (n > 0) { s = s.substr(n, s.length() - 1); } @@ -368,7 +368,7 @@ std::string MontyToString(const mp_limb_t* val, const RedParams rp) { * @param rp reduction parameters used to convert out into Montgomery form. */ template -void MontyFromString(mp_limb_t* out, +void montyFromString(mp_limb_t* out, const std::string& str, const RedParams rp) { if (str.length()) { @@ -391,15 +391,15 @@ void MontyFromString(mp_limb_t* out, for (int i = 0; i < n && c >= 0; i += m) { auto end = std::min(n, i + m); out[c--] = - util::FromHexString(std::string(beg + i, beg + end)); + util::fromHexString(std::string(beg + i, beg + end)); } - MontyIn(out, rp); + montyIn(out, rp); } } #undef SCL_BITS_PER_LIMB #undef SCL_BYTES_PER_LIMB -} // namespace scl::math +} // namespace scl::math::ff -#endif // SCL_MATH_OPS_GMP_FF_H +#endif // SCL_MATH_FIELDS_FF_OPS_GMP_H diff --git a/include/scl/math/fields/mersenne127.h b/include/scl/math/fields/mersenne127.h index f37d5ee..2b7d4aa 100644 --- a/include/scl/math/fields/mersenne127.h +++ b/include/scl/math/fields/mersenne127.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,7 +21,7 @@ #include #include -namespace scl::math { +namespace scl::math::ff { /** * @brief The field \f$\mathbb{F}_p\f$ with \f$p=2^{127}-1\f$. @@ -48,6 +48,6 @@ struct Mersenne127 { constexpr static const std::size_t BIT_SIZE = 127; }; -} // namespace scl::math +} // namespace scl::math::ff #endif // SCL_MATH_FIELDS_MERSENNE127_H diff --git a/include/scl/math/fields/mersenne61.h b/include/scl/math/fields/mersenne61.h index 5a7523b..b66a39a 100644 --- a/include/scl/math/fields/mersenne61.h +++ b/include/scl/math/fields/mersenne61.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,7 +21,7 @@ #include #include -namespace scl::math { +namespace scl::math::ff { /** * @brief The field \f$\mathbb{F}_p\f$ with \f$p=2^{61}-1\f$. @@ -48,6 +48,6 @@ struct Mersenne61 { constexpr static const std::size_t BIT_SIZE = 61; }; -} // namespace scl::math +} // namespace scl::math::ff #endif // SCL_MATH_FIELDS_MERSENNE61_H diff --git a/include/scl/math/fields/secp256k1_field.h b/include/scl/math/fields/secp256k1_field.h new file mode 100644 index 0000000..1799994 --- /dev/null +++ b/include/scl/math/fields/secp256k1_field.h @@ -0,0 +1,55 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_MATH_FIELDS_SECP256K1_FIELD_H +#define SCL_MATH_FIELDS_SECP256K1_FIELD_H + +#include +#include + +#include + +namespace scl::math::ff { + +/** + * @brief The Field over which secp256k1 is defined. + */ +struct Secp256k1Field { + /** + * @brief Field elements are stored as 4 limb numbers internally. + */ + using ValueType = std::array; + + /** + * @brief Name of the secp256k1 field. + */ + constexpr static const char* NAME = "secp256k1_field"; + + /** + * @brief Byte size of a secp256k1 field element. + */ + constexpr static const std::size_t BYTE_SIZE = 4 * sizeof(mp_limb_t); + + /** + * @brief Bit size of a secp256k1 field element. + */ + constexpr static const std::size_t BIT_SIZE = 8 * BYTE_SIZE; +}; + +} // namespace scl::math::ff + +#endif // SCL_MATH_FIELDS_SECP256K1_FIELD_H diff --git a/include/scl/math/fields/secp256k1_scalar.h b/include/scl/math/fields/secp256k1_scalar.h new file mode 100644 index 0000000..d0bcc62 --- /dev/null +++ b/include/scl/math/fields/secp256k1_scalar.h @@ -0,0 +1,55 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_MATH_FIELDS_SECP256K1_SCALAR_H +#define SCL_MATH_FIELDS_SECP256K1_SCALAR_H + +#include +#include + +#include + +namespace scl::math::ff { + +/** + * @brief Finite field modulo a Secp256k1 prime order sub-group. + */ +struct Secp256k1Scalar { + /** + * @brief Internal type of elements. + */ + using ValueType = std::array; + + /** + * @brief Name of the field. + */ + constexpr static const char* NAME = "secp256k1_order"; + + /** + * @brief Size of an element in bytes. + */ + constexpr static const std::size_t BYTE_SIZE = 4 * sizeof(mp_limb_t); + + /** + * @brief Size of an element in bits. + */ + constexpr static const std::size_t BIT_SIZE = 8 * BYTE_SIZE; +}; + +} // namespace scl::math::ff + +#endif // SCL_MATH_FIELDS_SECP256K1_SCALAR_H diff --git a/include/scl/math/fp.h b/include/scl/math/fp.h index 363081c..fff237b 100644 --- a/include/scl/math/fp.h +++ b/include/scl/math/fp.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -31,15 +31,15 @@ namespace scl::math { /** * @brief Select a suitable Finite Field based on a provided bitlevel. */ -template +template struct FieldSelector { - static_assert(Bits > 0 && Bits < 128, "Bits not in range [1, 127]"); + static_assert(BITS > 0 && BITS < 128, "Bits not in range [1, 127]"); /** * @brief The field. */ - using Field = - std::conditional_t; + using Field = std:: + conditional_t; }; #undef SCL_IN_RANGE @@ -60,8 +60,8 @@ struct FieldSelector { * @see Mersenne61 * @see Mersenne127 */ -template -using Fp = FF::Field>; +template +using Fp = FF::Field>; } // namespace scl::math diff --git a/include/scl/math/la.h b/include/scl/math/la.h deleted file mode 100644 index 63b2eba..0000000 --- a/include/scl/math/la.h +++ /dev/null @@ -1,305 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_MATH_LA_H -#define SCL_MATH_LA_H - -#include "scl/math/mat.h" -#include "scl/math/vec.h" - -namespace scl::math { - -/** - * @brief Swap two rows of a matrix in-place. - * @param A the matrix - * @param k the first row - * @param h the second row - */ -template -void SwapRows(Mat& A, std::size_t k, std::size_t h) { - if (k != h) { - T temp; - for (std::size_t i = 0; i < A.Cols(); ++i) { - temp = A(h, i); - A(h, i) = A(k, i); - A(k, i) = temp; - } - } -} - -/** - * @brief Multiply a row in a matrix by a constant. - * @param A the matrix - * @param row the row - * @param m the constant - */ -template -void MultiplyRow(Mat& A, std::size_t row, const T& m) { - for (std::size_t j = 0; j < A.Cols(); ++j) { - A(row, j) *= m; - } -} - -/** - * @brief Add a mutliple of one row to another in a matrix. - * @param A the matrix - * @param dst the row that is mutated - * @param op the row that is added to \p dst - * @param m the multiple - */ -template -void AddRows(Mat& A, std::size_t dst, std::size_t op, const T& m) { - for (std::size_t j = 0; j < A.Cols(); ++j) { - A(dst, j) += A(op, j) * m; - } -} - -/** - * @brief Bring a matrix into reduced row echelon form in-place. - * @param A the matrix to bring into RREF - */ -template -void RowReduceInPlace(Mat& A) { - std::size_t n = A.Rows(); - std::size_t m = A.Cols(); - std::size_t r = 0; - std::size_t c = 0; - const T zero; - - while (r < n && c < m) { - // find pivot in current column - auto pivot = r; - while (pivot < n && A(pivot, c) == zero) { - pivot++; - } - - if (pivot == n) { - // this column was all 0, so go to next one - c++; - } else { - SwapRows(A, pivot, r); - - // make leading coefficient of this row 1. - auto pv = A(r, c).Inverse(); - MultiplyRow(A, r, pv); - - // finally, for each row that is not r, subtract a multiple of row r. - for (std::size_t k = 0; k < n; ++k) { - if (k == r) { - continue; - } - // skip row if leading coefficient of that row is 0. - auto t = A(k, c); - if (t != zero) { - AddRows(A, k, r, -t); - } - } - r++; - c++; - } - } -} - -/** - * @brief Finds the position of a pivot in a column, if any. - * @param A a RREF matrix - * @param col the column - * @return The index of a pivot in \p col, or -1 if non exists - * @note No validation is performed on any of the arguments. - */ -template -int GetPivotInColumn(const Mat& A, int col) { - T zero; - int i = A.Rows(); - while (i-- > 0) { - if (A(i, col) != zero) { - for (int k = 0; k < col - 1; ++k) { - if (A(i, k) != zero) { - return -1; - } - } - return i; - } - } - return -1; -} - -/** - * @brief Finds the first non-zero row, starting from the "bottom" of a matrix - * - * Starting from the bottom, finds the first row that is non-zero in a matrix. - * This function is used to determine rows that can be skipped when performing - * back substitution. - * - * @param A the matrix - * @return The first non-zero row - */ -template -std::size_t FindFirstNonZeroRow(const Mat& A) { - std::size_t nzr = A.Rows(); - const auto m = A.Cols(); - while (nzr-- > 0) { - bool non_zero = false; - for (std::size_t j = 0; j < m; ++j) { - if (A(nzr, j) != T{}) { - non_zero = true; - break; - } - } - if (non_zero) { - break; - } - } - return nzr; -} - -/** - * @brief Extract a solution from a matrix in RREF. - * @param A the matrix - * @return the solution. - * - * Given a matrix \p A in RREF that represents a system of equations, this - * function extracts a solution for said system. The system is assumed to be - * consistent, but may contain free variables. In those cases, the free - * variables are given the value 1 (in the field). - */ -template -Vec ExtractSolution(const Mat& A) { - const auto n = A.Rows(); - const auto m = A.Cols(); - - Vec x(m - 1); - auto i = FindFirstNonZeroRow(A); - // we remove (n - i) rows, which means setting the corresponding variables to - // 0. - int c = m - 2 - (n - i - 1); - for (; c >= 0; c--) { - const auto p = GetPivotInColumn(A, c); - if (p == -1) { - // a free variable just gets set to 1. - x[c] = T{1}; - } else { - T sum; - for (std::size_t j = p + 1; j < n; ++j) { - sum += A(i, j) * x[j]; - } - x[c] = A(i, m - 1) - sum; - i--; - } - } - return x; -} // LCOV_EXCL_LINE - -/** - * @brief Check if a linear system has a solution. - * @param A an augmented matrix in RREF - * @param unique_only indicates if only unique solutions should be considered - * @return true if \f$Ax = b\f$ has a solution and false otherwise. - * - * Determines if a linear system given by an augmented matrix in RREF has a - * solution, aka. is consistent. if called with \p unique_only set to - * true, then only systems with a unique solution are considered. - */ -template -bool HasSolution(const Mat& A, bool unique_only) { - auto n = A.Rows(); - auto m = A.Cols(); - T zero; - for (std::size_t i = 0; i < n; ++i) { - bool all_zero = true; - for (std::size_t j = 0; j < m - 1; ++j) { - all_zero &= A(i, j) == zero; - } - // No solution is the case when Rank(A) != Rank(A') where A' is A without - // the last column (the augmentation). I.e., when row(A', i) == 0, but - // row(A, i) != 0. - if (unique_only) { - if (all_zero) { - return false; - } - } else { - if (all_zero && A(i, m - 1) != zero) { - return false; - } - } - } - return true; -} - -/** - * @brief Creates an augmented matrix from two matrices. - * @param A the first matrix - * @param B the second matrix - * @return A matrix aug [A | B]. - */ -template -Mat CreateAugmentedMatrix(const Mat& A, const Mat& B) { - auto n = A.Rows(); - auto m = A.Cols(); - auto k = B.Cols(); - Mat aug(n, m + k); - for (std::size_t i = 0; i < n; ++i) { - for (std::size_t j = 0; j < m; ++j) { - aug(i, j) = A(i, j); - } - for (std::size_t j = m; j < m + k; ++j) { - aug(i, j) = B(i, j - m); - } - } - return aug; -} - -/** - * @brief Create an augmented matrix from a matrix and a vector - * @param A a matrix - * @param b a vector - * @return A matrix aug [A | b]. - */ -template -Mat CreateAugmentedMatrix(const Mat& A, const Vec& b) { - return CreateAugmentedMatrix(A, b.ToColumnMatrix()); -} - -/** - * @brief Solves a linear system of equations \f$Ax = b\f$. - * @param x where to store the solution - * @param A the matrix of coefficients - * @param b the equation values - * @return true if a unique solution was found and false otherwise - * @throws std::invalid_argument if the number of rows in \p A does not - * match the size of \p b. - */ -template -bool SolveLinearSystem(Vec& x, const Mat& A, const Vec& b) { - if (A.Rows() != b.Size()) { - throw std::invalid_argument("malformed system of equations"); - } - - auto aug = CreateAugmentedMatrix(A, b); - - RowReduceInPlace(aug); - if (!HasSolution(aug, true)) { - return false; - } - - x = ExtractSolution(aug); - return true; -} - -} // namespace scl::math - -#endif // SCL_MATH_LA_H diff --git a/include/scl/math/lagrange.h b/include/scl/math/lagrange.h index ec8c394..f225827 100644 --- a/include/scl/math/lagrange.h +++ b/include/scl/math/lagrange.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,7 +21,7 @@ #include #include -#include "scl/math/vec.h" +#include "scl/math/vector.h" namespace scl::math { @@ -40,9 +40,9 @@ namespace scl::math { * auto nodes = ... // 0, 1, 2, 3, 4, 5 * auto ys = ... // f(0), f(2), f(3), f(4), f(5) * - * auto basis = ComputeLagrangeBasis(nodes, 7); + * auto basis = computeLagrangeBasis(nodes, 7); * - * auto f7 = ys.Dot(basis); // f(7) + * auto f7 = ys.dot(basis); // f(7) * @endcode * *

The nodes provided must be pairwise invertible. That is, for every @@ -52,12 +52,12 @@ namespace scl::math { * @see https://en.wikipedia.org/wiki/Lagrange_polynomial */ template -Vec ComputeLagrangeBasis(const math::Vec& nodes, const T& x) { - const auto n = nodes.Size(); +Vector computeLagrangeBasis(const math::Vector& nodes, const T& x) { + const auto n = nodes.size(); std::vector b; b.reserve(n); for (std::size_t i = 0; i < n; ++i) { - auto ell = T::One(); + auto ell = T::one(); const auto xi = nodes[i]; for (std::size_t j = 0; j < n; ++j) { if (i != j) { @@ -74,11 +74,11 @@ Vec ComputeLagrangeBasis(const math::Vec& nodes, const T& x) { * @brief Computes a lagrange basis for a set of nodes. * @param nodes the set of nodes. * @param x the evaluation point x. - * @see ComputeLagrangeBasis + * @see computeLagrangeBasis */ template -Vec ComputeLagrangeBasis(const math::Vec& nodes, int x) { - return ComputeLagrangeBasis(nodes, T{x}); +Vector computeLagrangeBasis(const math::Vector& nodes, int x) { + return computeLagrangeBasis(nodes, T{x}); } } // namespace scl::math diff --git a/include/scl/math/mat.h b/include/scl/math/mat.h deleted file mode 100644 index e46f4dc..0000000 --- a/include/scl/math/mat.h +++ /dev/null @@ -1,627 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_MATH_MAT_H -#define SCL_MATH_MAT_H - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "scl/math/ff.h" -#include "scl/math/lagrange.h" -#include "scl/math/ops.h" -#include "scl/math/vec.h" -#include "scl/util/prg.h" -#include "scl/util/traits.h" - -namespace scl::math { - -template -class Vec; - -/** - * @brief Matrix. - */ -template -class Mat : Print> { - public: - /** - * @brief The type of the matrix elements. - */ - using ValueType = Elem; - - /** - * @brief The type of a dimension (row or column count). - */ - using SizeType = std::uint32_t; - - /** - * @brief Read a matrix from a stream of bytes. - * @param n the number of rows - * @param m the number of columns - * @param src the bytes - * @return a matrix. - */ - static Mat Read(std::size_t n, std::size_t m, const unsigned char* src); - - /** - * @brief Create a Matrix and populate it with random elements. - * @param n the number of rows - * @param m the number of columns - * @param prg the prg used to generate random elements - * @return a Matrix with random elements. - */ - static Mat Random(std::size_t n, std::size_t m, util::PRG& prg); - - /** - * @brief Create an N-by-M Vandermonde matrix. - * @param n the number of rows. - * @param m the number of columns. - * @param xs vector containing the x values to use. - * @return a Vandermonde matrix. - * - * Let \p xs be a list \f$(x_1, x_2, \dots, x_n)\f$ where \f$x_i \neq x_j\f$ - * for all \f$i\neq j\f$. A Vandermonde matrix, is the \f$n\times m\f$ - * matrix - * - * \f$ - * V = - * \begin{bmatrix} - * 1 & x_1 & x_1^2 & \dots & x_1^{m-1} \\ - * 1 & x_2 & x_2^2 & \dots & x_2^{m-1} \\ - * \ldots \\ - * 1 & x_n & x_n^2 & \dots & x_n^{m-1} - * \end{bmatrix} - * \f$ - * - * @see https://en.wikipedia.org/wiki/Vandermonde_matrix - */ - static Mat Vandermonde(std::size_t n, - std::size_t m, - const Vec& xs); - - /** - * @brief Create an N-by-M Vandermonde matrix. - * @param n the number of rows. - * @param m the number of columns. - * @return a Vandermonde matrix. - * - * This function returns a Vandermonde matrix using \f$(1, 2, \dots, n + 1)\f$ - * as the set of x values. - * - * @see Mat::Vandermonde. - */ - static Mat Vandermonde(std::size_t n, std::size_t m) { - return Mat::Vandermonde(n, m, Vec::Range(1, n + 1)); - } - - /** - * @brief Create an N-by-M hyper-invertible matrix. - * @param n the number of rows. - * @param m the number of columns. - * @return a Hyperinvertible matrix. - * - * A hyper-invertible matrix is a matrix where every square sub-matrix is - * invertible. - */ - static Mat HyperInvertible(std::size_t n, std::size_t m); - - /** - * @brief Create a matrix from a vector. - * - * The idea of this method is to allow one to write - * - * - * Mat::FromVector(2, 2, {1, 2, 3, 4}) - * - * - * in order to create the matrix [[1, 2], [3, 4]]. - * - * @param n the number of rows - * @param m the number of columns - * @param vec the elements of the matrix - * @return a Matrix. - */ - static Mat FromVector(std::size_t n, - std::size_t m, - const std::vector& vec) { - if (vec.size() != n * m) { - throw std::invalid_argument("invalid dimensions"); - } - return Mat(n, m, vec); - } - - /** - * @brief Construct an n-by-n identity matrix. - */ - static Mat Identity(std::size_t n) { - Mat I(n); - for (std::size_t i = 0; i < n; ++i) { - I(i, i) = Elem(1); - } - return I; - } // LCOV_EXCL_LINE - - /** - * @brief Construct an empty 0-by-0 matrix. - */ - Mat() : m_rows(0), m_cols(0) {} - - /** - * @brief Create an N-by-M matrix with default initialized values. - * @param n the number of rows - * @param m the number of columns - */ - explicit Mat(std::size_t n, std::size_t m) { - if (n == 0 || m == 0) { - throw std::invalid_argument("n or m cannot be 0"); - } - std::vector v(n * m); - m_rows = n; - m_cols = m; - m_values = v; - } - - /** - * @brief Create a square matrix with default initialized values. - * @param n the dimensions of the matrix - */ - explicit Mat(std::size_t n) : Mat(n, n){}; - - /** - * @brief The number of rows of this matrix. - */ - SizeType Rows() const { - return m_rows; - } - - /** - * @brief The number of columns of this matrix. - */ - SizeType Cols() const { - return m_cols; - } - - /** - * @brief Provides mutable access to a matrix element. - * @param row the row of the element being queried - * @param column the column of the element being queried - * @return an element of the matrix. - */ - Elem& operator()(std::size_t row, std::size_t column) { - return m_values[m_cols * row + column]; - } - - /** - * @brief Provides read-only access to a matrix element. - * @param row the row of the element being queried - * @param column the column of the element being queried - * @return an element of the matrix. - */ - Elem operator()(std::size_t row, std::size_t column) const { - return m_values[m_cols * row + column]; - } - - /** - * @brief Add this matrix with another matrix of the same dimensions. - * @param other the other matrix - * @return The entry-wise sum of this matrix and \p other. - * @throws std::illegal_argument if the dimensions of this and \p other are - * not equal. - */ - Mat Add(const Mat& other) const { - Mat copy(m_rows, m_cols, m_values); - return copy.AddInPlace(other); - } - - /** - * @brief Add this matrix with another matrix of the same dimensions in-place. - * @param other the other matrix - * @return The entry-wise sum of this matrix and \p other. - * @throws std::illegal_argument if the dimensions of this and \p other are - * not equal. - */ - Mat& AddInPlace(const Mat& other) { - EnsureCompatible(other); - auto n = m_values.size(); - for (std::size_t i = 0; i < n; ++i) { - m_values[i] += other.m_values[i]; - } - return *this; - } - - /** - * @brief Subtract this matrix with another matrix of the same dimensions. - * @param other the other matrix - * @return the difference of this and \p other - * @throws std::illegal_argument if the dimensions of this and \p do not - * match. - */ - Mat Subtract(const Mat& other) const { - Mat copy(m_rows, m_cols, m_values); - return copy.SubtractInPlace(other); - } - - /** - * @brief Subtract this matrix with another matrix of the same dimensions - * in-place. - * @param other the other matrix - * @return the difference of this and \p other - * @throws std::illegal_argument if the dimensions of this and \p do not - * match. - */ - Mat& SubtractInPlace(const Mat& other) { - EnsureCompatible(other); - auto n = m_values.size(); - for (std::size_t i = 0; i < n; i++) { - m_values[i] -= other.m_values[i]; - } - return *this; - } - - /** - * @brief Multiply this matrix with another matrix of the same dimensions. - * @param other the other matrix - * @return the entry-wise product of this and \p other - * @throws std::illegal_argument if the dimensions of this and \p do not - * match. - */ - Mat MultiplyEntryWise(const Mat& other) const { - Mat copy(m_rows, m_cols, m_values); - return copy.MultiplyEntryWiseInPlace(other); - } - - /** - * @brief Multiply this matrix with another matrix of the same dimensions. - * @param other the other matrix - * @return the product of this and \p other - * @throws std::illegal_argument if the dimensions of this and \p do not - * match. - */ - Mat& MultiplyEntryWiseInPlace(const Mat& other) { - EnsureCompatible(other); - auto n = m_values.size(); - for (std::size_t i = 0; i < n; i++) { - m_values[i] *= other.m_values[i]; - } - return *this; - } - - /** - * @brief Performs a matrix multiplication. - * @param other the other matrix - * @return the matrix product of this and \p other. - * @throws std::illegal_argument if the dimensions of the inputs are - * incompatible. - */ - Mat Multiply(const Mat& other) const; - - /** - * @brief Performs a matrix vector product. - * @param vector the vector. - * @return the result vector. - * - * This computes \f$A\cdot x\f$ where \f$A\f$ is a \f$n\times m\f$ matrix and - * \f$x\f$ a length \f$m\f$ vector. The return value is a vector \f$y\f$ of - * length \f$n\f$. - */ - Vec Multiply(const Vec& vector) const; - - /** - * @brief Multiply this matrix with a scalar - * @param scalar the scalar - * @return this scaled by \p scalar. - */ - template < - typename Scalar, - std::enable_if_t::value, bool> = true> - Mat ScalarMultiply(const Scalar& scalar) const { - Mat copy(m_rows, m_cols, m_values); - return copy.ScalarMultiplyInPlace(scalar); - } - - /** - * @brief Multiply this matrix with a scalar in-place. - * @param scalar the scalar - * @return this scaled by \p scalar. - */ - template < - typename Scalar, - std::enable_if_t::value, bool> = true> - Mat& ScalarMultiplyInPlace(const Scalar& scalar) { - for (auto& v : m_values) { - v *= scalar; - } - return *this; - } - - /** - * @brief Check if this matrix is square. - */ - bool IsSquare() const { - return Rows() == Cols(); - } - - /** - * @brief Transpose this matrix. - * @return the transpose of this. - */ - Mat Transpose() const; - - /** - * @brief Resize this matrix. - * @param rows the new row count - * @param cols the new column count - */ - Mat& Resize(std::size_t rows, std::size_t cols) { - if (rows * cols != Rows() * Cols()) { - throw std::invalid_argument("cannot resize matrix"); - } - m_rows = rows; - m_cols = cols; - return *this; - } - - /** - * @brief Returns true if this matrix is the identity matrix. - */ - bool IsIdentity() const; - - /** - * @brief Return a string representation of this matrix. - */ - std::string ToString() const; - - /** - * @brief Test if this matrix is equal to another. - */ - bool Equals(const Mat& other) const { - if (Rows() != other.Rows() || Cols() != other.Cols()) { - return false; - } - - bool equal = true; - for (std::size_t i = 0; i < m_values.size(); i++) { - equal &= m_values[i] == other.m_values[i]; - } - return equal; - } - - /** - * @brief Write this matrix to a buffer. - * @param dest where to write the matrix. - * - * This function just writes the content of the matrix. - */ - void Write(unsigned char* dest) const; - - /** - * @brief The size of a matrix when serialized in bytes. - */ - std::size_t ByteSize() const { - return Cols() * Rows() * Elem::ByteSize(); - } - - private: - Mat(std::size_t r, std::size_t c, std::vector v) - : m_rows(r), m_cols(c), m_values(v){}; - - void EnsureCompatible(const Mat& other) { - if (m_rows != other.m_rows || m_cols != other.m_cols) { - throw std::invalid_argument("incompatible matrices"); - } - } - - SizeType m_rows; - SizeType m_cols; - std::vector m_values; - - friend class Vec; -}; - -template -Mat Mat::Read(std::size_t n, - std::size_t m, - const unsigned char* src) { - const auto* ptr = src; - auto total = n * m; - - // write all elements now that we know we'll not exceed the maximum read size. - std::vector elements; - elements.reserve(total); - for (std::size_t i = 0; i < total; i++) { - elements.emplace_back(Elem::Read(ptr)); - ptr += Elem::ByteSize(); - } - return Mat(n, m, elements); -} - -template -void Mat::Write(unsigned char* dest) const { - for (const auto& v : m_values) { - v.Write(dest); - dest += Elem::ByteSize(); - } -} - -template -Mat Mat::Random(std::size_t n, std::size_t m, util::PRG& prg) { - std::size_t nelements = n * m; - return Mat(n, m, Vec::Random(nelements, prg).ToStlVector()); -} - -template -Mat Mat::Vandermonde(std::size_t n, - std::size_t m, - const Vec& xs) { - if (xs.Size() != n) { - throw std::invalid_argument("|xs| != number of rows"); - } - - Mat v(n, m); - for (std::size_t i = 0; i < n; ++i) { - v(i, 0) = Elem(1); - for (std::size_t j = 1; j < m; ++j) { - v(i, j) = v(i, j - 1) * xs[i]; - } - } - return v; -} // LCOV_EXCL_LINE - -template -Mat Mat::HyperInvertible(std::size_t n, std::size_t m) { - Mat him(n, m); - - const auto vs = Vec::Range(1, m + 1); - - for (std::size_t i = 0; i < n; ++i) { - const auto r = ComputeLagrangeBasis(vs, -i); - for (std::size_t j = 0; j < m; ++j) { - him(i, j) = r[j]; - } - } - return him; -} - -template -Mat Mat::Multiply(const Mat& other) const { - if (Cols() != other.Rows()) { - throw std::invalid_argument("matmul: this->Cols() != that->Rows()"); - } - const auto n = Rows(); - const auto p = Cols(); - const auto m = other.Cols(); - - Mat result(n, m); - for (std::size_t i = 0; i < n; i++) { - for (std::size_t k = 0; k < p; k++) { - for (std::size_t j = 0; j < m; j++) { - result(i, j) += operator()(i, k) * other(k, j); - } - } - } - return result; -} // LCOV_EXCL_LINE - -template -Vec Mat::Multiply(const Vec& vector) const { - if (Cols() != vector.Size()) { - throw std::invalid_argument("matmul: this->Cols() != vec.Size()"); - } - - std::vector result; - result.reserve(Rows()); - - for (std::size_t i = 0; i < Rows(); ++i) { - auto b = m_values.begin() + i * Cols(); - auto e = m_values.begin() + (i + 1) * Cols(); - result.emplace_back(UncheckedInnerProd(b, e, vector.begin())); - } - - return result; -} - -template -Mat Mat::Transpose() const { - Mat t(Cols(), Rows()); - for (std::size_t i = 0; i < Rows(); i++) { - for (std::size_t j = 0; j < Cols(); j++) { - t(j, i) = operator()(i, j); - } - } - return t; -} - -template -bool Mat::IsIdentity() const { - if (!IsSquare()) { - return false; - } - - bool is_ident = true; - for (std::size_t i = 0; i < Rows(); ++i) { - for (std::size_t j = 0; j < Cols(); ++j) { - if (i == j) { - is_ident &= operator()(i, j) == Elem{1}; - } else { - is_ident &= operator()(i, j) == Elem{0}; - } - } - } - return is_ident; -} - -template -std::string Mat::ToString() const { - // this method converts a matrix to something like - // [ a0 a1 a2 ... ] - // [ b0 b1 .. ] - // Each column is aligned according to the largest element in that column. - - const auto n = Rows(); - const auto m = Cols(); - - if (!(n && m)) { - return "[ EMPTY MATRIX ]"; - } - - // convert all elements to strings and find the widest string in each column - // since that will be used to properly align the final output. - std::vector elements; - elements.reserve(n * m); - std::vector fills; - fills.reserve(m); - for (std::size_t j = 0; j < m; j++) { - auto first = operator()(0, j).ToString(); - auto fill = first.size(); - elements.emplace_back(first); - for (std::size_t i = 1; i < n; i++) { - auto next = operator()(i, j).ToString(); - auto next_fill = next.size(); - if (next_fill > fill) { - fill = next_fill; - } - elements.push_back(next); - } - fills.push_back(fill + 1); - } - - std::stringstream ss; - ss << "\n"; - for (std::size_t i = 0; i < n; i++) { - ss << "["; - for (std::size_t j = 0; j < m; j++) { - ss << std::setfill(' ') << std::setw(fills[j]) << elements[j * n + i] - << " "; - } - ss << "]"; - if (i < n - 1) { - ss << "\n"; - } - } - return ss.str(); -} - -} // namespace scl::math - -#endif // SCL_MATH_MAT_H diff --git a/include/scl/math/math.h b/include/scl/math/math.h index 0210713..5badf70 100644 --- a/include/scl/math/math.h +++ b/include/scl/math/math.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,14 +21,26 @@ #include "scl/math/ec.h" #include "scl/math/ff.h" #include "scl/math/fp.h" -#include "scl/math/mat.h" +#include "scl/math/matrix.h" #include "scl/math/number.h" -#include "scl/math/vec.h" +#include "scl/math/vector.h" #include "scl/math/z2k.h" /** - * @brief Utilities related to math. + * @brief Maths! */ -namespace scl::math {} // namespace scl::math +namespace scl::math { + +/** + * @brief Functionality related to finite fields. + */ +namespace ff {} + +/** + * @brief Functionality related to elliptic curves. + */ +namespace ec {} + +} // namespace scl::math #endif // SCL_MATH_MATH_H diff --git a/include/scl/math/matrix.h b/include/scl/math/matrix.h new file mode 100644 index 0000000..4efc1a8 --- /dev/null +++ b/include/scl/math/matrix.h @@ -0,0 +1,968 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_MATH_MATRIX_H +#define SCL_MATH_MATRIX_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "scl/math/ff.h" +#include "scl/math/lagrange.h" +#include "scl/math/vector.h" +#include "scl/serialization/serializer.h" +#include "scl/util/prg.h" + +namespace scl { +namespace math { + +template +class Vector; + +/** + * @brief Matrix. + */ +template +class Matrix final { + public: + friend struct seri::Serializer>; + + /** + * @brief The type of the matrix elements. + */ + using ValueType = ELEMENT; + + /** + * @brief Create a Matrix and populate it with random elements. + * @param n the number of rows + * @param m the number of columns + * @param prg the prg used to generate random elements + * @return a Matrix with random elements. + */ + static Matrix random(std::size_t n, std::size_t m, util::PRG& prg); + + /** + * @brief Create an N-by-M Vandermonde matrix. + * @param n the number of rows. + * @param m the number of columns. + * @param xs vector containing the x values to use. + * @return a Vandermonde matrix. + * + * Let \p xs be a list \f$(x_1, x_2, \dots, x_n)\f$ where \f$x_i \neq x_j\f$ + * for all \f$i\neq j\f$. A Vandermonde matrix, is the \f$n\times m\f$ + * matrix + * + * \f$ + * V = + * \begin{bmatrix} + * 1 & x_1 & x_1^2 & \dots & x_1^{m-1} \\ + * 1 & x_2 & x_2^2 & \dots & x_2^{m-1} \\ + * \ldots \\ + * 1 & x_n & x_n^2 & \dots & x_n^{m-1} + * \end{bmatrix} + * \f$ + * + * @see https://en.wikipedia.org/wiki/Vandermonde_matrix + */ + static Matrix vandermonde(std::size_t n, + std::size_t m, + const Vector& xs); + + /** + * @brief Create an N-by-M Vandermonde matrix. + * @param n the number of rows. + * @param m the number of columns. + * @return a Vandermonde matrix. + * + * This function returns a Vandermonde matrix using \f$(1, 2, \dots, n + 1)\f$ + * as the set of x values. + * + * @see Mat::Vandermonde. + */ + static Matrix vandermonde(std::size_t n, std::size_t m) { + return Matrix::vandermonde(n, m, Vector::range(1, n + 1)); + } + + /** + * @brief Create an N-by-M hyper-invertible matrix. + * @param n the number of rows. + * @param m the number of columns. + * @return a Hyperinvertible matrix. + * + * A hyper-invertible matrix is a matrix where every square sub-matrix is + * invertible. + */ + static Matrix hyperInvertible(std::size_t n, std::size_t m); + + /** + * @brief Create a matrix from a vector. + * + * The idea of this method is to allow one to write + * + * + * Mat::FromVector(2, 2, {1, 2, 3, 4}) + * + * + * in order to create the matrix [[1, 2], [3, 4]]. + * + * @param n the number of rows + * @param m the number of columns + * @param vec the elements of the matrix + * @return a Matrix. + */ + static Matrix fromVector(std::size_t n, + std::size_t m, + const std::vector& vec) { + if (vec.size() != n * m) { + throw std::invalid_argument("invalid dimensions"); + } + return Matrix(n, m, vec); + } + + /** + * @brief Construct an n-by-n identity matrix. + */ + static Matrix identity(std::size_t n) { + Matrix I(n); + for (std::size_t i = 0; i < n; ++i) { + I(i, i) = ELEMENT(1); + } + return I; + } // LCOV_EXCL_LINE + + /** + * @brief Construct an empty 0-by-0 matrix. + */ + Matrix() : m_rows(0), m_cols(0) {} + + /** + * @brief Create an N-by-M matrix with default initialized values. + * @param n the number of rows + * @param m the number of columns + */ + explicit Matrix(std::size_t n, std::size_t m) { + if (n == 0 || m == 0) { + throw std::invalid_argument("n or m cannot be 0"); + } + std::vector v(n * m); + m_rows = n; + m_cols = m; + m_values = v; + } + + /** + * @brief Create a square matrix with default initialized values. + * @param n the dimensions of the matrix + */ + explicit Matrix(std::size_t n) : Matrix(n, n){}; + + /** + * @brief The number of rows of this matrix. + */ + std::size_t rows() const { + return m_rows; + } + + /** + * @brief The number of columns of this matrix. + */ + std::size_t cols() const { + return m_cols; + } + + /** + * @brief Provides mutable access to a matrix element. + * @param row the row of the element being queried + * @param column the column of the element being queried + * @return an element of the matrix. + */ + ELEMENT& operator()(std::size_t row, std::size_t column) { + return m_values[m_cols * row + column]; + } + + /** + * @brief Provides read-only access to a matrix element. + * @param row the row of the element being queried + * @param column the column of the element being queried + * @return an element of the matrix. + */ + ELEMENT operator()(std::size_t row, std::size_t column) const { + return m_values[m_cols * row + column]; + } + + /** + * @brief Add this matrix with another matrix of the same dimensions. + * @param other the other matrix + * @return The entry-wise sum of this matrix and \p other. + * @throws std::illegal_argument if the dimensions of this and \p other are + * not equal. + */ + Matrix add(const Matrix& other) const { + Matrix copy(m_rows, m_cols, m_values); + return copy.addInPlace(other); + } + + /** + * @brief Add this matrix with another matrix of the same dimensions in-place. + * @param other the other matrix + * @return The entry-wise sum of this matrix and \p other. + * @throws std::illegal_argument if the dimensions of this and \p other are + * not equal. + */ + Matrix& addInPlace(const Matrix& other) { + ensureCompatible(other); + auto n = m_values.size(); + for (std::size_t i = 0; i < n; ++i) { + m_values[i] += other.m_values[i]; + } + return *this; + } + + /** + * @brief Subtract this matrix with another matrix of the same dimensions. + * @param other the other matrix + * @return the difference of this and \p other + * @throws std::illegal_argument if the dimensions of this and \p do not + * match. + */ + Matrix subtract(const Matrix& other) const { + Matrix copy(m_rows, m_cols, m_values); + return copy.subtractInPlace(other); + } + + /** + * @brief Subtract this matrix with another matrix of the same dimensions + * in-place. + * @param other the other matrix + * @return the difference of this and \p other + * @throws std::illegal_argument if the dimensions of this and \p do not + * match. + */ + Matrix& subtractInPlace(const Matrix& other) { + ensureCompatible(other); + auto n = m_values.size(); + for (std::size_t i = 0; i < n; i++) { + m_values[i] -= other.m_values[i]; + } + return *this; + } + + /** + * @brief Multiply this matrix with another matrix of the same dimensions. + * @param other the other matrix + * @return the entry-wise product of this and \p other + * @throws std::illegal_argument if the dimensions of this and \p do not + * match. + */ + Matrix multiplyEntryWise(const Matrix& other) const { + Matrix copy(m_rows, m_cols, m_values); + return copy.multiplyEntryWiseInPlace(other); + } + + /** + * @brief Multiply this matrix with another matrix of the same dimensions. + * @param other the other matrix + * @return the product of this and \p other + * @throws std::illegal_argument if the dimensions of this and \p do not + * match. + */ + Matrix& multiplyEntryWiseInPlace(const Matrix& other) { + ensureCompatible(other); + auto n = m_values.size(); + for (std::size_t i = 0; i < n; i++) { + m_values[i] *= other.m_values[i]; + } + return *this; + } + + /** + * @brief Performs a matrix multiplication. + * @param other the other matrix + * @return the matrix product of this and \p other. + * @throws std::illegal_argument if the dimensions of the inputs are + * incompatible. + */ + Matrix multiply(const Matrix& other) const; + + /** + * @brief Performs a matrix vector product. + * @param vector the vector. + * @return the result vector. + * + * This computes \f$A\cdot x\f$ where \f$A\f$ is a \f$n\times m\f$ matrix and + * \f$x\f$ a length \f$m\f$ vector. The return value is a vector \f$y\f$ of + * length \f$n\f$. + */ + Vector multiply(const Vector& vector) const; + + /** + * @brief Multiply this matrix with a scalar + * @param scalar the scalar + * @return this scaled by \p scalar. + */ + template + requires requires(ELEMENT a, ELEMENT b) { (a) * (b); } + Matrix scalarMultiply(const SCALAR& scalar) const { + Matrix copy(m_rows, m_cols, m_values); + return copy.scalarMultiplyInPlace(scalar); + } + + /** + * @brief Multiply this matrix with a scalar in-place. + * @param scalar the scalar + * @return this scaled by \p scalar. + */ + template + requires requires(ELEMENT a, ELEMENT b) { (a) * (b); } + Matrix& scalarMultiplyInPlace(const SCALAR& scalar) { + for (auto& v : m_values) { + v *= scalar; + } + return *this; + } + + /** + * @brief Check if this matrix is square. + */ + bool isSquare() const { + return rows() == cols(); + } + + /** + * @brief Transpose this matrix. + * @return the transpose of this. + */ + Matrix transpose() const; + + /** + * @brief Resize this matrix. + * @param new_rows the new row count + * @param new_cols the new column count + */ + Matrix& resize(std::size_t new_rows, std::size_t new_cols) { + if (new_rows * new_cols != rows() * cols()) { + throw std::invalid_argument("cannot resize matrix"); + } + m_rows = new_rows; + m_cols = new_cols; + return *this; + } + + /** + * @brief Returns true if this matrix is the identity matrix. + */ + bool isIdentity() const; + + /** + * @brief Compute the inverse of this matrix, if possible. + * @throws std::invalid_argument if the matrix is not square. + * + * This function computes the inverse of a matrix using Guassian + * elimination. No checks are done as to whether such an inverse exists. + */ + Matrix invert() const; + + /** + * @brief Return a string representation of this matrix. + */ + std::string toString() const; + + /** + * @brief Write a string representation of this matrix to a stream. + */ + friend std::ostream& operator<<(std::ostream& os, const Matrix& matrix) { + return os << matrix.toString(); + } + + /** + * @brief Test if this matrix is equal to another. + */ + bool equals(const Matrix& other) const { + if (rows() != other.rows() || cols() != other.cols()) { + return false; + } + + bool equal = true; + for (std::size_t i = 0; i < m_values.size(); i++) { + equal &= m_values[i] == other.m_values[i]; + } + return equal; + } + + /** + * @brief The size of a matrix when serialized in bytes. + */ + std::size_t byteSize() const { + return cols() * rows() * ELEMENT::byteSize(); + } + + private: + Matrix(std::size_t r, std::size_t c, std::vector v) + : m_rows(r), m_cols(c), m_values(v){}; + + void ensureCompatible(const Matrix& other) { + if (m_rows != other.m_rows || m_cols != other.m_cols) { + throw std::invalid_argument("incompatible matrices"); + } + } + + std::size_t m_rows; + std::size_t m_cols; + std::vector m_values; + + friend class Vector; +}; + +template +Matrix Matrix::random(std::size_t n, + std::size_t m, + util::PRG& prg) { + std::size_t nelements = n * m; + return Matrix(n, m, Vector::random(nelements, prg).toStlVector()); +} + +template +Matrix Matrix::vandermonde(std::size_t n, + std::size_t m, + const Vector& xs) { + if (xs.size() != n) { + throw std::invalid_argument("|xs| != number of rows"); + } + + Matrix v(n, m); + for (std::size_t i = 0; i < n; ++i) { + v(i, 0) = ELEMENT(1); + for (std::size_t j = 1; j < m; ++j) { + v(i, j) = v(i, j - 1) * xs[i]; + } + } + return v; +} // LCOV_EXCL_LINE + +template +Matrix Matrix::hyperInvertible(std::size_t n, std::size_t m) { + Matrix him(n, m); + + const auto vs = Vector::range(1, m + 1); + + for (std::size_t i = 0; i < n; ++i) { + const auto r = computeLagrangeBasis(vs, -i); + for (std::size_t j = 0; j < m; ++j) { + him(i, j) = r[j]; + } + } + return him; +} + +template +Matrix Matrix::multiply(const Matrix& other) const { + if (cols() != other.rows()) { + throw std::invalid_argument("matmul: this->cols() != that->rows()"); + } + const auto n = rows(); + const auto p = cols(); + const auto m = other.cols(); + + Matrix result(n, m); + for (std::size_t i = 0; i < n; i++) { + for (std::size_t k = 0; k < p; k++) { + for (std::size_t j = 0; j < m; j++) { + result(i, j) += operator()(i, k) * other(k, j); + } + } + } + return result; +} // LCOV_EXCL_LINE + +template +Vector Matrix::multiply(const Vector& vector) const { + if (cols() != vector.size()) { + throw std::invalid_argument("matmul: this->cols() != vec.size()"); + } + + std::vector result; + result.reserve(rows()); + + for (std::size_t i = 0; i < rows(); ++i) { + auto b = m_values.begin() + i * cols(); + auto e = m_values.begin() + (i + 1) * cols(); + result.emplace_back(innerProd(b, e, vector.begin())); + } + + return result; +} + +template +Matrix Matrix::transpose() const { + Matrix t(cols(), rows()); + for (std::size_t i = 0; i < rows(); i++) { + for (std::size_t j = 0; j < cols(); j++) { + t(j, i) = operator()(i, j); + } + } + return t; +} + +template +bool Matrix::isIdentity() const { + if (!isSquare()) { + return false; + } + + bool is_ident = true; + for (std::size_t i = 0; i < rows(); ++i) { + for (std::size_t j = 0; j < cols(); ++j) { + if (i == j) { + is_ident &= operator()(i, j) == ELEMENT{1}; + } else { + is_ident &= operator()(i, j) == ELEMENT{0}; + } + } + } + return is_ident; +} + +/** + * @brief Swap two rows of a matrix in-place. + * @param A the matrix + * @param k the first row + * @param h the second row + */ +template +void swapRows(Matrix& A, std::size_t k, std::size_t h) { + if (k != h) { + ELEMENT temp; + for (std::size_t i = 0; i < A.cols(); ++i) { + temp = A(h, i); + A(h, i) = A(k, i); + A(k, i) = temp; + } + } +} + +/** + * @brief Multiply a row in a matrix by a constant. + * @param A the matrix + * @param row the row + * @param m the constant + */ +template +void multiplyRow(Matrix& A, std::size_t row, const ELEMENT& m) { + for (std::size_t j = 0; j < A.cols(); ++j) { + A(row, j) *= m; + } +} + +/** + * @brief Add a mutliple of one row to another in a matrix. + * @param A the matrix + * @param dst the row that is mutated + * @param op the row that is added to \p dst + * @param m the multiple + */ +template +void addRows(Matrix& A, + std::size_t dst, + std::size_t op, + const ELEMENT& m) { + for (std::size_t j = 0; j < A.cols(); ++j) { + A(dst, j) += A(op, j) * m; + } +} + +/** + * @brief Bring a matrix into reduced row echelon form in-place. + * @param A the matrix to bring into RREF + */ +template +void rowReduceInPlace(Matrix& A) { + std::size_t n = A.rows(); + std::size_t m = A.cols(); + std::size_t r = 0; + std::size_t c = 0; + const ELEMENT zero; + + while (r < n && c < m) { + // find pivot in current column + auto pivot = r; + while (pivot < n && A(pivot, c) == zero) { + pivot++; + } + + if (pivot == n) { + // this column was all 0, so go to next one + c++; + } else { + swapRows(A, pivot, r); + + // make leading coefficient of this row 1. + auto pv = A(r, c).inverse(); + multiplyRow(A, r, pv); + + // finally, for each row that is not r, subtract a multiple of row r. + for (std::size_t k = 0; k < n; ++k) { + if (k == r) { + continue; + } + // skip row if leading coefficient of that row is 0. + auto t = A(k, c); + if (t != zero) { + addRows(A, k, r, -t); + } + } + r++; + c++; + } + } +} + +/** + * @brief Finds the position of a pivot in a column, if any. + * @param A a RREF matrix + * @param col the column + * @return The index of a pivot in \p col, or -1 if non exists + * @note No validation is performed on any of the arguments. + */ +template +int getPivotInColumn(const Matrix& A, int col) { + const auto zero = ELEMENT::zero(); + int i = A.rows(); + while (i-- > 0) { + if (A(i, col) != zero) { + for (int k = 0; k < col - 1; ++k) { + if (A(i, k) != zero) { + return -1; + } + } + return i; + } + } + return -1; +} + +/** + * @brief Finds the first non-zero row, starting from the "bottom" of a matrix + * + * Starting from the bottom, finds the first row that is non-zero in a matrix. + * This function is used to determine rows that can be skipped when performing + * back substitution. + * + * @param A the matrix + * @return The first non-zero row + */ +template +std::size_t findFirstNonZeroRow(const Matrix& A) { + const auto zero = ELEMENT::zero(); + std::size_t nzr = A.rows(); + const auto m = A.cols(); + while (nzr-- > 0) { + bool non_zero = false; + for (std::size_t j = 0; j < m; ++j) { + if (A(nzr, j) != zero) { + non_zero = true; + break; + } + } + if (non_zero) { + break; + } + } + return nzr; +} + +/** + * @brief Extract a solution from a matrix in RREF. + * @param A the matrix + * @return the solution. + * + * Given a matrix \p A in RREF that represents a system of equations, this + * function extracts a solution for said system. The system is assumed to be + * consistent, but may contain free variables. In those cases, the free + * variables are given the value 1 (in the field). + */ +template +Vector extractSolution(const Matrix& A) { + const auto n = A.rows(); + const auto m = A.cols(); + + Vector x(m - 1); + auto i = findFirstNonZeroRow(A); + // we remove (n - i) rows, which means setting the corresponding variables to + // 0. + int c = m - 2 - (n - i - 1); + for (; c >= 0; c--) { + const auto p = getPivotInColumn(A, c); + if (p == -1) { + // a free variable just gets set to 1. + x[c] = ELEMENT{1}; + } else { + auto sum = ELEMENT::zero(); + for (std::size_t j = p + 1; j < n; ++j) { + sum += A(i, j) * x[j]; + } + x[c] = A(i, m - 1) - sum; + i--; + } + } + return x; +} // LCOV_EXCL_LINE + +/** + * @brief Check if a linear system has a solution. + * @param A an augmented matrix in RREF + * @param unique_only indicates if only unique solutions should be considered + * @return true if \f$Ax = b\f$ has a solution and false otherwise. + * + * Determines if a linear system given by an augmented matrix in RREF has a + * solution, aka. is consistent. if called with \p unique_only set to + * true, then only systems with a unique solution are considered. + */ +template +bool hasSolution(const Matrix& A, bool unique_only) { + auto n = A.rows(); + auto m = A.cols(); + ELEMENT zero; + for (std::size_t i = 0; i < n; ++i) { + bool all_zero = true; + for (std::size_t j = 0; j < m - 1; ++j) { + all_zero &= A(i, j) == zero; + } + // No solution is the case when Rank(A) != Rank(A') where A' is A without + // the last column (the augmentation). I.e., when row(A', i) == 0, but + // row(A, i) != 0. + if (unique_only) { + if (all_zero) { + return false; + } + } else { + if (all_zero && A(i, m - 1) != zero) { + return false; + } + } + } + return true; +} + +/** + * @brief Creates an augmented matrix from two matrices. + * @param A the first matrix + * @param B the second matrix + * @return A matrix aug [A | B]. + */ +template +Matrix createAugmentedMatrix(const Matrix& A, + const Matrix& B) { + auto n = A.rows(); + auto m = A.cols(); + auto k = B.cols(); + Matrix aug(n, m + k); + for (std::size_t i = 0; i < n; ++i) { + for (std::size_t j = 0; j < m; ++j) { + aug(i, j) = A(i, j); + } + for (std::size_t j = m; j < m + k; ++j) { + aug(i, j) = B(i, j - m); + } + } + return aug; +} + +/** + * @brief Create an augmented matrix from a matrix and a vector + * @param A a matrix + * @param b a vector + * @return A matrix aug [A | b]. + */ +template +Matrix createAugmentedMatrix(const Matrix& A, + const Vector& b) { + return createAugmentedMatrix(A, b.toColumnMatrix()); +} + +/** + * @brief Solves a linear system of equations \f$Ax = b\f$. + * @param x where to store the solution + * @param A the matrix of coefficients + * @param b the equation values + * @return true if a unique solution was found and false otherwise + * @throws std::invalid_argument if the number of rows in \p A does not + * match the size of \p b. + */ +template +bool solveLinearSystem(Vector& x, + const Matrix& A, + const Vector& b) { + if (A.rows() != b.size()) { + throw std::invalid_argument("malformed system of equations"); + } + + auto aug = createAugmentedMatrix(A, b); + + rowReduceInPlace(aug); + if (!hasSolution(aug, true)) { + return false; + } + + x = extractSolution(aug); + return true; +} + +template +Matrix Matrix::invert() const { + if (!isSquare()) { + throw std::invalid_argument("cannot invert non-square matrix"); + } + + const std::size_t n = cols(); + const Matrix id = identity(n); + + auto aug = createAugmentedMatrix(*this, id); + rowReduceInPlace(aug); + + Matrix inv(n, n); + for (std::size_t i = 0; i < n; ++i) { + for (std::size_t j = 0; j < n; ++j) { + inv(i, j) = aug(i, n + j); + } + } + + return inv; +} + +template +std::string Matrix::toString() const { + // this method converts a matrix to something like + // [ a0 a1 a2 ... ] + // [ b0 b1 .. ] + // Each column is aligned according to the largest element in that column. + + const auto n = rows(); + const auto m = cols(); + + if (!(n && m)) { + return "[ EMPTY MATRIX ]"; + } + + // convert all elements to strings and find the widest string in each column + // since that will be used to properly align the final output. + std::vector elements; + elements.reserve(n * m); + std::vector fills; + fills.reserve(m); + for (std::size_t j = 0; j < m; j++) { + auto first = operator()(0, j).toString(); + auto fill = first.size(); + elements.emplace_back(first); + for (std::size_t i = 1; i < n; i++) { + auto next = operator()(i, j).toString(); + auto next_fill = next.size(); + if (next_fill > fill) { + fill = next_fill; + } + elements.push_back(next); + } + fills.push_back(fill + 1); + } + + std::stringstream ss; + ss << "\n"; + for (std::size_t i = 0; i < n; i++) { + ss << "["; + for (std::size_t j = 0; j < m; j++) { + ss << std::setfill(' ') << std::setw(fills[j]) << elements[j * n + i] + << " "; + } + ss << "]"; + if (i < n - 1) { + ss << "\n"; + } + } + return ss.str(); +} + +} // namespace math + +namespace seri { + +/** + * @brief Serializer specialization for a math::Mat. + */ +template +struct Serializer> { + private: + // type used to denote a dimension + using DimType = std::uint32_t; + + // serializer for vector + using S_vec = Serializer>; + + // serializer for the dimension + using S_dim = Serializer; + + public: + /** + * @brief Size of a matrix. + * @param mat the matrix. + */ + static std::size_t sizeOf(const math::Matrix& mat) { + return S_vec::sizeOf(mat.m_values) + 2 * sizeof(DimType); + } + + /** + * @brief Write a matrix to a buffer. + * @param mat the matrix. + * @param buf the buffer. + */ + static std::size_t write(const math::Matrix& mat, + unsigned char* buf) { + std::size_t offset = 0; + offset = S_dim::write(static_cast(mat.rows()), buf); + offset += S_dim::write(static_cast(mat.cols()), buf + offset); + S_vec::write(mat.m_values, buf + offset); + return sizeOf(mat); + } + + /** + * @brief Read a matrix from a buffer. + * @param mat where to store the matrix after reading. + * @param buf the buffer. + * @return the number of bytes read. + */ + static std::size_t read(math::Matrix& mat, + const unsigned char* buf) { + DimType rows; + DimType cols; + std::size_t offset = 0; + offset = S_dim::read(rows, buf); + offset += S_dim::read(cols, buf + offset); + std::vector elements; + S_vec::read(elements, buf + offset); + mat = math::Matrix(rows, cols, std::move(elements)); + return sizeOf(mat); + } +}; + +} // namespace seri +} // namespace scl + +#endif // SCL_MATH_MATRIX_H diff --git a/include/scl/math/number.h b/include/scl/math/number.h index 8f1d752..9c9e4b0 100644 --- a/include/scl/math/number.h +++ b/include/scl/math/number.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -23,52 +23,44 @@ #include -#include "scl/math/ops.h" +#include "scl/serialization/serializer.h" #include "scl/util/prg.h" -namespace scl::math { +namespace scl { +namespace math { class Number; /** * @brief Compute the least common multiple of two numbers. - * @param a the first number. - * @param b the second number. * @return \f$lcm(a, b)\f$. */ -Number LCM(const Number& a, const Number& b); +Number lcm(const Number& a, const Number& b); /** * @brief Compute the greatest common divisor of two numbers. - * @param a the first number. - * @param b the second number. * @return \f$gcd(a, b)\f$. */ -Number GCD(const Number& a, const Number& b); +Number gcd(const Number& a, const Number& b); /** * @brief Compute the modular inverse of a number. - * @param val the value to invert. - * @param mod the modulus. * @return \f$val^{-1} \mod mod \f$. * @throws std::logic_error if \p val is not invertible. * @throws std::invalid_argument if \p mod is 0. */ -Number ModInverse(const Number& val, const Number& mod); +Number modInverse(const Number& val, const Number& mod); /** * @brief Compute a modular exponentiation. - * @param base the base. - * @param exp the exponent. - * @param mod the modulus. * @return \f$base^{exp} \mod mod\f$. */ -Number ModExp(const Number& base, const Number& exp, const Number& mod); +Number modExp(const Number& base, const Number& exp, const Number& mod); /** * @brief Arbitrary precision integer. */ -class Number final : Print { +class Number final { public: /** * @brief Generate a random Number. @@ -76,7 +68,7 @@ class Number final : Print { * @param prg a prg for generating the random number. * @return a random Number. */ - static Number Random(std::size_t bits, util::PRG& prg); + static Number random(std::size_t bits, util::PRG& prg); /** * @brief Generate a random prime. @@ -84,21 +76,21 @@ class Number final : Print { * @param prg a prg for generating the random prime. * @return a random prime. */ - static Number RandomPrime(std::size_t bits, util::PRG& prg); + static Number randomPrime(std::size_t bits, util::PRG& prg); /** * @brief Read a Number from a string * @param str the string * @return a Number. */ - static Number FromString(const std::string& str); + static Number fromString(const std::string& str); /** * @brief Read a number from a buffer. * @param buf the buffer. * @return a Number. */ - static Number Read(const unsigned char* buf); + static Number read(const unsigned char* buf); /** * @brief Construct a Number from an int. @@ -292,9 +284,9 @@ class Number final : Print { Number operator^(const Number& number) const; /** - * @brief operator |= - * @param number - * @return + * @brief operator |= for Number. + * @param number the other number. + * @return this OR'ed with \p number. */ Number& operator|=(const Number& number) { *this = *this | number; @@ -302,16 +294,16 @@ class Number final : Print { } /** - * @brief operator | - * @param number - * @return + * @brief operator | for Number. + * @param number the other Number. + * @return a number equal to *this OR'ed with \p number. */ Number operator|(const Number& number) const; /** - * @brief operator &= - * @param number - * @return + * @brief operator &= for Number. + * @param number the other Number. + * @return this AND'ed with \p number. */ Number& operator&=(const Number& number) { *this = *this & number; @@ -319,15 +311,15 @@ class Number final : Print { } /** - * @brief operator & - * @param number - * @return + * @brief operator & for Number. + * @param number the other Number. + * @return a number equal to this AND'ed with \p number. */ Number operator&(const Number& number) const; /** - * @brief operator ~ - * @return + * @brief operator ~ for Number + * @return the bitwise negation of this Number. */ Number operator~() const; @@ -342,59 +334,59 @@ class Number final : Print { * @param number the other number * @return a int indicating the relationship between this and \p number. */ - int Compare(const Number& number) const; + int compare(const Number& number) const; /** * @brief Equality of two numbers. */ friend bool operator==(const Number& lhs, const Number& rhs) { - return lhs.Compare(rhs) == 0; + return lhs.compare(rhs) == 0; } /** * @brief In-equality of two numbers. */ friend bool operator!=(const Number& lhs, const Number& rhs) { - return lhs.Compare(rhs) != 0; + return lhs.compare(rhs) != 0; } /** * @brief Strictly less-than of two numbers. */ friend bool operator<(const Number& lhs, const Number& rhs) { - return lhs.Compare(rhs) < 0; + return lhs.compare(rhs) < 0; } /** * @brief Less-than-or-equal of two numbers. */ friend bool operator<=(const Number& lhs, const Number& rhs) { - return lhs.Compare(rhs) <= 0; + return lhs.compare(rhs) <= 0; } /** * @brief Strictly greater-than of two numbers. */ friend bool operator>(const Number& lhs, const Number& rhs) { - return lhs.Compare(rhs) > 0; + return lhs.compare(rhs) > 0; } /** * @brief Greater-than-or-equal of two numbers. */ friend bool operator>=(const Number& lhs, const Number& rhs) { - return lhs.Compare(rhs) >= 0; + return lhs.compare(rhs) >= 0; } /** * @brief Get the size of this number in bytes. */ - std::size_t ByteSize() const; + std::size_t byteSize() const; /** * @brief Get the size of this Number in bits. */ - std::size_t BitSize() const; + std::size_t bitSize() const; /** * @brief Test whether a particular bit of this Number is set. @@ -406,35 +398,42 @@ class Number final : Print { * @param index the index of the bit * @return true if the bit at \p index is set and false otherwise. */ - bool TestBit(std::size_t index) const; + bool testBit(std::size_t index) const; /** * @brief Test if this Number is odd. * @return true if this Number is odd. */ - bool Odd() const { - return TestBit(0); + bool odd() const { + return testBit(0); } /** * @brief Test if this Number is even. * @return true if this Number is even. */ - bool Even() const { - return !Odd(); + bool even() const { + return !odd(); } /** * @brief Write this number to a buffer. * @param buf the buffer. */ - void Write(unsigned char* buf) const; + void write(unsigned char* buf) const; /** * @brief Return a string representation of this Number. * @return a string. */ - std::string ToString() const; + std::string toString() const; + + /** + * @brief Write a string representation of this Number to a stream. + */ + friend std::ostream& operator<<(std::ostream& os, const Number& number) { + return os << number.toString(); + } /** * @brief STL swap implementation for Number. @@ -447,14 +446,60 @@ class Number final : Print { private: mpz_t m_value; - friend Number LCM(const Number& a, const Number& b); - friend Number GCD(const Number& a, const Number& b); - friend Number ModInverse(const Number& val, const Number& mod); - friend Number ModExp(const Number& base, + friend Number lcm(const Number& a, const Number& b); + friend Number gcd(const Number& a, const Number& b); + friend Number modInverse(const Number& val, const Number& mod); + friend Number modExp(const Number& base, const Number& exp, const Number& mod); }; -} // namespace scl::math +} // namespace math + +namespace seri { + +/** + * @brief Serializer specialization for math::Number. + */ +template <> +struct Serializer { + /** + * @brief Get the serialized size of a math::Number. + * @param number the number. + * @return the serialized size of a math::Number. + * + * A math::Number is writte as size_and_sign | number where + * size_and_sign is a 4 byte value containing the byte size of + * the number and its sign. + */ + static std::size_t sizeOf(const math::Number& number) { + return number.byteSize() + sizeof(std::uint32_t); + } + + /** + * @brief Write a number to a buffer. + * @param number the number. + * @param buf the buffer. + * @return the number of bytes written. + */ + static std::size_t write(const math::Number& number, unsigned char* buf) { + number.write(buf); + return sizeOf(number); + } + + /** + * @brief Read a math::Number from a buffer. + * @param number the number. + * @param buf the buffer. + * @return the number of bytes read. + */ + static std::size_t read(math::Number& number, const unsigned char* buf) { + number = math::Number::read(buf); + return sizeOf(number); + } +}; + +} // namespace seri +} // namespace scl #endif // SCL_MATH_NUMBER_H diff --git a/include/scl/math/ops.h b/include/scl/math/ops.h deleted file mode 100644 index 5e92227..0000000 --- a/include/scl/math/ops.h +++ /dev/null @@ -1,123 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_MATH_OPS_H -#define SCL_MATH_OPS_H - -#include - -namespace scl::math { - -/** - * @brief Provides binary + and -, and unary - * - operators. - * - * Requires that the type \p T implements the += and - * -= operators, and a Negate function. - */ -template -struct Add { - /** - * @brief Add two elements and return their sum. - */ - friend T operator+(const T& lhs, const T& rhs) { - T temp(lhs); - return temp += rhs; - } - - /** - * @brief Subtract two elements and return their difference. - */ - friend T operator-(const T& lhs, const T& rhs) { - T temp(lhs); - return temp -= rhs; - } - - /** - * @brief Return the negation of an element. - */ - friend T operator-(const T& elem) { - T temp(elem); - return temp.Negate(); - } -}; - -/** - * @brief Provides * and / operators. - * - * Requires that \p T implements the *= and /= - * operators. - */ -template -struct Mul { - /** - * @brief Multiply two elements and return their product. - */ - friend T operator*(const T& lhs, const T& rhs) { - T temp(lhs); - return temp *= rhs; - } - - /** - * @brief Divide two elements and return their quotient. - */ - friend T operator/(const T& lhs, const T& rhs) { - T temp(lhs); - return temp /= rhs; - } -}; - -/** - * @brief Provides == and != operators. - * - * Requires that \p implements an Equal(T) function. - */ -template -struct Eq { - /** - * @brief Compare two elements for equality. - */ - friend bool operator==(const T& lhs, const T& rhs) { - return lhs.Equal(rhs); - } - - /** - * @brief Compare two elements for inequality. - */ - friend bool operator!=(const T& lhs, const T& rhs) { - return !(lhs == rhs); - } -}; - -/** - * @brief Provides << syntax for printing to a string. - * - * Requires that \p implements a ToString() function. - */ -template -struct Print { - /** - * @brief Write a string representation of an element to a stream. - */ - friend std::ostream& operator<<(std::ostream& os, const T& r) { - return os << r.ToString(); - } -}; - -} // namespace scl::math - -#endif // SCL_MATH_OPS_H diff --git a/include/scl/math/poly.h b/include/scl/math/poly.h index a36418c..64f11c3 100644 --- a/include/scl/math/poly.h +++ b/include/scl/math/poly.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,14 +20,14 @@ #include -#include "scl/math/vec.h" +#include "scl/math/vector.h" namespace scl::math { /** - * @brief Polynomials over finite fields. + * @brief Polynomials over rings. */ -template +template class Polynomial { public: /** @@ -35,7 +35,7 @@ class Polynomial { * @param coefficients the coefficients * */ - static Polynomial Create(const Vec& coefficients); + static Polynomial create(const Vector& coefficients); /** * @brief Construct a constant polynomial with constant term 0. @@ -46,14 +46,14 @@ class Polynomial { * @brief Construct a constant polynomial. * @param constant the constant term of the polynomial */ - Polynomial(const T& constant) : m_coefficients({constant}) {} + Polynomial(const RING& constant) : m_coefficients({constant}) {} /** * @brief Evaluate this polynomial on a supplied point. * @param x the point to evaluate this polynomial on * @return f(x) where \p x is the supplied point and f this polynomial. */ - T Evaluate(const T& x) const { + RING evaluate(const RING& x) const { auto it = m_coefficients.rbegin(); auto end = m_coefficients.rend(); auto y = *it++; @@ -66,14 +66,14 @@ class Polynomial { /** * @brief Access coefficients, with the constant term at position 0. */ - T& operator[](std::size_t idx) { + RING& operator[](std::size_t idx) { return m_coefficients[idx]; } /** * @brief Access coefficients, with the constant term at position 0. */ - T operator[](std::size_t idx) const { + RING operator[](std::size_t idx) const { return m_coefficients[idx]; } @@ -81,57 +81,57 @@ class Polynomial { * @brief Get the coefficients of this polynomial. * @return the coefficients. */ - Vec Coefficients() const { + Vector coefficients() const { return m_coefficients; } /** * @brief Add two polynomials. */ - Polynomial Add(const Polynomial& q) const; + Polynomial add(const Polynomial& q) const; /** * @brief Subtraction two polynomials. */ - Polynomial Subtract(const Polynomial& q) const; + Polynomial subtract(const Polynomial& q) const; /** * @brief Multiply two polynomials. */ - Polynomial Multiply(const Polynomial& q) const; + Polynomial multiply(const Polynomial& q) const; /** * @brief Divide two polynomials. * @return A pair \f$(q, r)\f$ such that \f$\mathtt{this} = p * q + r\f$. */ - std::array Divide(const Polynomial& q) const; + std::array divide(const Polynomial& q) const; /** * @brief Returns true if this is the 0 polynomial. */ - bool IsZero() const { - return Degree() == 0 && ConstantTerm() == T(0); + bool isZero() const { + return degree() == 0 && constantTerm() == RING(0); } /** * @brief Get the constant term of this polynomial. */ - T ConstantTerm() const { + RING constantTerm() const { return operator[](0); } /** * @brief Get the leading term of this polynomial. */ - T LeadingTerm() const { - return operator[](Degree()); + RING leadingTerm() const { + return operator[](degree()); } /** * @brief Degree of this polynomial. */ - std::size_t Degree() const { - return m_coefficients.Size() - 1; + std::size_t degree() const { + return m_coefficients.size() - 1; } /** @@ -151,49 +151,50 @@ class Polynomial { * @param polynomial_name the string to use for the name of the polynomial * @param variable_name the string to use for the variable name */ - std::string ToString(const char* polynomial_name, + std::string toString(const char* polynomial_name, const char* variable_name) const; /** * @brief Get a string representation of this polynomial. * @note Equivalent to ToString("f", "x"). */ - std::string ToString() const { - return ToString("f", "x"); + std::string toString() const { + return toString("f", "x"); } /** * @brief Write a string representation of this polynomial to a stream. */ friend std::ostream& operator<<(std::ostream& os, const Polynomial& p) { - return os << p.ToString(); + return os << p.toString(); } private: - Polynomial(const Vec& coefficients) : m_coefficients(coefficients){}; + Polynomial(const Vector& coefficients) : m_coefficients(coefficients){}; - Vec m_coefficients; + Vector m_coefficients; }; -template -Polynomial Polynomial::Create(const Vec& coefficients) { +template +Polynomial Polynomial::create(const Vector& coefficients) { auto it = coefficients.rbegin(); auto end = coefficients.rend(); - auto cutoff = coefficients.Size(); - T zero; + auto cutoff = coefficients.size(); + RING zero; for (; it != end; ++it) { if (*it != zero) { break; } --cutoff; } - const auto c = Vec(coefficients.begin(), coefficients.begin() + cutoff); + const auto c = + Vector(coefficients.begin(), coefficients.begin() + cutoff); - if (c.Empty()) { - return Polynomial{}; + if (c.empty()) { + return Polynomial{}; } - return Polynomial{c}; + return Polynomial{c}; } /** @@ -202,65 +203,65 @@ Polynomial Polynomial::Create(const Vec& coefficients) { * @param n the size of the final Vec * @return A scl::Vec of length \p n with coefficients of p and zeros. */ -template -Vec PadCoefficients(const Polynomial& p, std::size_t n) { - Vec c(n); +template +Vector padCoefficients(const Polynomial& p, std::size_t n) { + Vector c(n); for (std::size_t i = 0; i < n; ++i) { - if (i <= p.Degree()) { + if (i <= p.degree()) { c[i] = p[i]; } } return c; } // LCOV_EXCL_LINE -template -Polynomial Polynomial::Add(const Polynomial& q) const { - const auto this_larger = Degree() > q.Degree(); - const auto n = (this_larger ? Degree() : q.Degree()) + 1; - const auto pp = PadCoefficients(*this, n); - const auto qp = PadCoefficients(q, n); - const auto c = pp.Add(qp); - return Polynomial::Create(c); +template +Polynomial Polynomial::add(const Polynomial& q) const { + const auto this_larger = degree() > q.degree(); + const auto n = (this_larger ? degree() : q.degree()) + 1; + const auto pp = padCoefficients(*this, n); + const auto qp = padCoefficients(q, n); + const auto c = pp.add(qp); + return Polynomial::create(c); } -template -Polynomial Polynomial::Subtract(const Polynomial& q) const { - const auto this_larger = Degree() > q.Degree(); - const auto n = (this_larger ? Degree() : q.Degree()) + 1; - const auto pp = PadCoefficients(*this, n); - const auto qp = PadCoefficients(q, n); - const auto c = pp.Subtract(qp); - return Polynomial::Create(c); +template +Polynomial Polynomial::subtract(const Polynomial& q) const { + const auto this_larger = degree() > q.degree(); + const auto n = (this_larger ? degree() : q.degree()) + 1; + const auto pp = padCoefficients(*this, n); + const auto qp = padCoefficients(q, n); + const auto c = pp.subtract(qp); + return Polynomial::create(c); } -template -Polynomial Polynomial::Multiply(const Polynomial& q) const { - Vec c(Degree() + q.Degree() + 1); - for (std::size_t i = 0; i <= Degree(); ++i) { - for (std::size_t j = 0; j <= q.Degree(); ++j) { +template +Polynomial Polynomial::multiply(const Polynomial& q) const { + Vector c(degree() + q.degree() + 1); + for (std::size_t i = 0; i <= degree(); ++i) { + for (std::size_t j = 0; j <= q.degree(); ++j) { c[i + j] += operator[](i) * q[j]; } } - return Polynomial::Create(c); + return Polynomial::create(c); } /** * @brief Divide the leading terms of two polynomials. * @note assumes that deg(p) >= deg(q). */ -template -Polynomial DivideLeadingTerms(const Polynomial& p, - const Polynomial& q) { - const auto deg_out = p.Degree() - q.Degree(); - Vec c(deg_out + 1); - c[deg_out] = p.LeadingTerm() / q.LeadingTerm(); - return Polynomial::Create(c); +template +Polynomial divideLeadingTerms(const Polynomial& p, + const Polynomial& q) { + const auto deg_out = p.degree() - q.degree(); + Vector c(deg_out + 1); + c[deg_out] = p.leadingTerm() / q.leadingTerm(); + return Polynomial::create(c); } -template -std::array, 2> Polynomial::Divide( - const Polynomial& q) const { - if (q.IsZero()) { +template +std::array, 2> Polynomial::divide( + const Polynomial& q) const { + if (q.isZero()) { throw std::invalid_argument("division by 0"); } @@ -268,20 +269,20 @@ std::array, 2> Polynomial::Divide( Polynomial p; Polynomial r = *this; - while (!r.IsZero() && r.Degree() >= q.Degree()) { - const auto t = DivideLeadingTerms(r, q); - p = p.Add(t); - r = r.Subtract(t.Multiply(q)); + while (!r.isZero() && r.degree() >= q.degree()) { + const auto t = divideLeadingTerms(r, q); + p = p.add(t); + r = r.subtract(t.multiply(q)); } return {p, r}; } -template -std::string Polynomial::ToString(const char* polynomial_name, - const char* variable_name) const { +template +std::string Polynomial::toString(const char* polynomial_name, + const char* variable_name) const { std::stringstream ss; ss << polynomial_name << "(" << variable_name << ") = " << m_coefficients[0]; - for (std::size_t i = 1; i < m_coefficients.Size(); i++) { + for (std::size_t i = 1; i < m_coefficients.size(); i++) { ss << " + " << m_coefficients[i] << variable_name; if (i > 1) { ss << "^" << i; diff --git a/include/scl/math/vec.h b/include/scl/math/vector.h similarity index 58% rename from include/scl/math/vec.h rename to include/scl/math/vector.h index 033abfd..342ba1e 100644 --- a/include/scl/math/vec.h +++ b/include/scl/math/vector.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,8 +15,8 @@ * along with this program. If not, see . */ -#ifndef SCL_MATH_VEC_H -#define SCL_MATH_VEC_H +#ifndef SCL_MATH_VECTOR_H +#define SCL_MATH_VECTOR_H #include #include @@ -26,22 +26,24 @@ #include #include +#include "scl/serialization/serializer.h" #include "scl/util/prg.h" -#include "scl/util/traits.h" -namespace scl::math { +namespace scl { +namespace math { -template -class Mat; +template +class Matrix; /** - * @brief Computes an unchecked inner product between two iterators. - * @param xb start of the first iterator - * @param xe end of the first iterator - * @param yb start of the second iterator + * @brief Computes an inner product between two iterators. + * @param xb start of the first iterator. + * @param xe end of the first iterator. + * @param yb start of the second iterator. + * @return the inner product. */ -template -T UncheckedInnerProd(I0 xb, I0 xe, I1 yb) { +template +T innerProd(IT0 xb, IT0 xe, IT1 yb) { T v; while (xb != xe) { v += *xb++ * *yb++; @@ -55,13 +57,15 @@ T UncheckedInnerProd(I0 xb, I0 xe, I1 yb) { * This class is a thin wrapper around std::vector meant only to provide some * functionality that makes it behave like other classes present in SCUtil. */ -template -class Vec { +template +class Vector final { public: + friend struct seri::Serializer>; + /** * @brief The type of vector elements. */ - using ValueType = Elem; + using ValueType = ELEMENT; /** * @brief The type of a vector size. @@ -71,30 +75,23 @@ class Vec { /** * @brief Iterator type. */ - using iterator = typename std::vector::iterator; + using iterator = typename std::vector::iterator; /** * @brief Const iterator type. */ - using const_iterator = typename std::vector::const_iterator; + using const_iterator = typename std::vector::const_iterator; /** * @brief Reverse iterator type. */ - using reverse_iterator = typename std::vector::reverse_iterator; + using reverse_iterator = typename std::vector::reverse_iterator; /** * @brief Reverse const iterator type. */ using const_reverse_iterator = - typename std::vector::const_reverse_iterator; - - /** - * @brief Read a vec from a stream of bytes. - * @param n the number of elements to read - * @param src the buffer to read from - */ - static Vec Read(std::size_t n, const unsigned char* src); + typename std::vector::const_reverse_iterator; /** * @brief Create a Vec and populate it with random elements. @@ -102,7 +99,7 @@ class Vec { * @param prg a PRG used to generate random elements * @return a Vec with random elements. */ - static Vec Random(std::size_t n, util::PRG& prg); + static Vector random(std::size_t n, util::PRG& prg); /** * @brief Create a vector with values in a range. @@ -110,45 +107,45 @@ class Vec { * @param end the end value, exclusive * @return a vector with values [start, start + 1, ..., end - 1]. */ - static Vec Range(std::size_t start, std::size_t end); + static Vector range(std::size_t start, std::size_t end); /** * @brief Create a vector with values in a range. * @param end the end value, exclusive. * @return a vector with values [0, ..., end - 1]. */ - static Vec Range(std::size_t end) { - return Range(0, end); + static Vector range(std::size_t end) { + return range(0, end); } /** * @brief Default constructor that creates an empty Vec. */ - Vec() {} + Vector() {} /** * @brief Construct a new Vec of explicit size. * @param n the size */ - explicit Vec(std::size_t n) : m_values(n) {} + explicit Vector(std::size_t n) : m_values(n) {} /** * @brief Construct a vector from an initializer_list. * @param values an initializer_list */ - Vec(std::initializer_list values) : m_values(values) {} + Vector(std::initializer_list values) : m_values(values) {} /** * @brief Construct a vector from an STL vector. * @param values an STL vector */ - Vec(const std::vector& values) : m_values(values) {} + Vector(const std::vector& values) : m_values(values) {} /** * @brief Move construct a vector from an STL vector. * @param values an STL vector */ - Vec(std::vector&& values) : m_values(std::move(values)) {} + Vector(std::vector&& values) : m_values(std::move(values)) {} /** * @brief Construct a Vec from a pair of iterators. @@ -156,34 +153,34 @@ class Vec { * @param last iterator pointing to the one past last element * @tparam It iterator type */ - template - explicit Vec(It first, It last) : m_values(first, last) {} + template + explicit Vector(IT first, IT last) : m_values(first, last) {} /** * @brief The size of the Vec. */ - SizeType Size() const { + SizeType size() const { return m_values.size(); } /** * @brief Check if this Vec is empty. */ - bool Empty() const { - return Size() == 0; + bool empty() const { + return size() == 0; } /** * @brief Mutable access to vector elements. */ - Elem& operator[](std::size_t idx) { + ELEMENT& operator[](std::size_t idx) { return m_values[idx]; } /** * @brief Read only access to vector elements. */ - Elem operator[](std::size_t idx) const { + ELEMENT operator[](std::size_t idx) const { return m_values[idx]; } @@ -192,16 +189,16 @@ class Vec { * @param other the other vector * @return the sum of this and \p other. */ - Vec Add(const Vec& other) const; + Vector add(const Vector& other) const; /** * @brief Add two Vec objects entry-wise in-place. * @param other the other vector * @return the sum of this and \p other, assigned to this. */ - Vec& AddInPlace(const Vec& other) { - EnsureCompatible(other); - for (std::size_t i = 0; i < Size(); i++) { + Vector& addInPlace(const Vector& other) { + ensureCompatible(other); + for (std::size_t i = 0; i < size(); i++) { m_values[i] += other.m_values[i]; } return *this; @@ -212,16 +209,16 @@ class Vec { * @param other the other vector * @return the difference of this and \p other. */ - Vec Subtract(const Vec& other) const; + Vector subtract(const Vector& other) const; /** * @brief Subtract two Vec objects entry-wise in-place. * @param other the other vector * @return the difference of this and \p other, assigned to this. */ - Vec& SubtractInPlace(const Vec& other) { - EnsureCompatible(other); - for (std::size_t i = 0; i < Size(); i++) { + Vector& subtractInPlace(const Vector& other) { + ensureCompatible(other); + for (std::size_t i = 0; i < size(); i++) { m_values[i] -= other.m_values[i]; } return *this; @@ -232,16 +229,16 @@ class Vec { * @param other the other vector * @return the product of this and \p other. */ - Vec MultiplyEntryWise(const Vec& other) const; + Vector multiplyEntryWise(const Vector& other) const; /** * @brief Multiply two Vec objects entry-wise in-place. * @param other the other vector * @return the product of this and \p other, assigned to this. */ - Vec& MultiplyEntryWiseInPlace(const Vec& other) { - EnsureCompatible(other); - for (std::size_t i = 0; i < Size(); i++) { + Vector& multiplyEntryWiseInPlace(const Vector& other) { + ensureCompatible(other); + for (std::size_t i = 0; i < size(); i++) { m_values[i] *= other.m_values[i]; } return *this; @@ -252,17 +249,17 @@ class Vec { * @param other the other vector * @return the dot (or inner) product of this and \p other. */ - Elem Dot(const Vec& other) const { - EnsureCompatible(other); - return UncheckedInnerProd(begin(), end(), other.begin()); + ELEMENT dot(const Vector& other) const { + ensureCompatible(other); + return innerProd(begin(), end(), other.begin()); } /** * @brief Compute the sum over entries of this vector. * @return the sum of the entries of this vector. */ - Elem Sum() const { - Elem sum; + ELEMENT sum() const { + ELEMENT sum; for (const auto& v : m_values) { sum += v; } @@ -274,16 +271,17 @@ class Vec { * @param scalar the scalar * @return a scaled version of this vector. */ - template < - typename Scalar, - std::enable_if_t::value, bool> = true> - Vec ScalarMultiply(const Scalar& scalar) const { - std::vector r; - r.reserve(Size()); + template + requires requires(const ELEMENT& e, const SCALAR& s) { + { (e) * (s) } -> std::convertible_to; + } + Vector scalarMultiply(const SCALAR& scalar) const { + std::vector r; + r.reserve(size()); for (const auto& v : m_values) { r.emplace_back(scalar * v); } - return Vec(r); + return Vector(r); } /** @@ -291,10 +289,11 @@ class Vec { * @param scalar the scalar * @return a scaled version of this vector. */ - template < - typename Scalar, - std::enable_if_t::value, bool> = true> - Vec& ScalarMultiplyInPlace(const Scalar& scalar) { + template + requires requires(ELEMENT& e, const SCALAR& s) { + { e *= s } -> std::convertible_to; + } + Vector& scalarMultiplyInPlace(const SCALAR& scalar) { for (auto& v : m_values) { v *= scalar; } @@ -306,47 +305,47 @@ class Vec { * @param other the other vector * @return true if this vector is equal to \p other and false otherwise. */ - bool Equals(const Vec& other) const; + bool equals(const Vector& other) const; /** * @brief Operator == overload for Vec. */ - friend bool operator==(const Vec& left, const Vec& right) { - return left.Equals(right); + friend bool operator==(const Vector& left, const Vector& right) { + return left.equals(right); } /** * @brief Operator != overload for Vec. */ - friend bool operator!=(const Vec& left, const Vec& right) { + friend bool operator!=(const Vector& left, const Vector& right) { return !(left == right); } /** * @brief Convert this vector into a 1-by-N row matrix. */ - Mat ToRowMatrix() const { - return Mat{1, Size(), m_values}; + Matrix toRowMatrix() const { + return Matrix{1, size(), m_values}; } /** * @brief Convert this vector into a N-by-1 column matrix. */ - Mat ToColumnMatrix() const { - return Mat{Size(), 1, m_values}; + Matrix toColumnMatrix() const { + return Matrix{size(), 1, m_values}; } /** * @brief Convert this Vec object to an std::vector. */ - std::vector& ToStlVector() { + std::vector& toStlVector() { return m_values; } /** * @brief Convert this Vec object to a const std::vector. */ - const std::vector& ToStlVector() const { + const std::vector& toStlVector() const { return m_values; } @@ -356,48 +355,42 @@ class Vec { * @param end the end index, exclusive * @return a sub-vector. */ - Vec SubVector(std::size_t start, std::size_t end) const { + Vector subVector(std::size_t start, std::size_t end) const { if (start > end) { throw std::logic_error("invalid range"); } - return Vec(begin() + start, begin() + end); + return Vector(begin() + start, begin() + end); } /** * @brief Extract a sub-vector. * - * This method is equivalent to Vec#SubVector(0, end). + * This method is equivalent to subVector(0, end). * * @param end the end index, exclusive * @return a sub-vector. */ - Vec SubVector(std::size_t end) const { - return SubVector(0, end); + Vector subVector(std::size_t end) const { + return subVector(0, end); } /** * @brief Return a string representation of this vector. */ - std::string ToString() const; + std::string toString() const; /** * @brief Write a string representation of this vector to a stream. */ - friend std::ostream& operator<<(std::ostream& os, const Vec& v) { - return os << v.ToString(); + friend std::ostream& operator<<(std::ostream& os, const Vector& v) { + return os << v.toString(); } - /** - * @brief Write this Vec to a buffer. - * @param dest the buffer to write this Vec to - */ - void Write(unsigned char* dest) const; - /** * @brief Returns the number of bytes that Write writes. */ - std::size_t ByteSize() const { - return Elem::ByteSize(); + std::size_t byteSize() const { + return size() * ELEMENT::byteSize(); } /** @@ -485,132 +478,157 @@ class Vec { } private: - void EnsureCompatible(const Vec& other) const { - if (Size() != other.Size()) { + void ensureCompatible(const Vector& other) const { + if (size() != other.size()) { throw std::invalid_argument("Vec sizes mismatch"); } } - std::vector m_values; + std::vector m_values; }; -template -Vec Vec::Read(std::size_t n, const unsigned char* src) { - std::vector r; - r.reserve(n); - std::size_t offset = 0; - while (n-- > 0) { - r.emplace_back(Elem::Read(src + offset)); - offset += Elem::ByteSize(); - } - return Vec(r); -} - -template -Vec Vec::Range(std::size_t start, std::size_t end) { +template +Vector Vector::range(std::size_t start, std::size_t end) { if (start > end) { throw std::invalid_argument("invalid range"); } if (start == end) { - return Vec{}; + return Vector{}; } - std::vector v; + std::vector v; v.reserve(end - start); for (std::size_t i = start; i < end; ++i) { - v.emplace_back(Elem{(int)i}); + v.emplace_back(ELEMENT{(int)i}); } - return Vec(v); + return Vector(v); } -template -Vec Vec::Random(std::size_t n, util::PRG& prg) { - auto buf = std::make_unique(n * Elem::ByteSize()); - prg.Next(buf.get(), n * Elem::ByteSize()); +template +Vector Vector::random(std::size_t n, util::PRG& prg) { + auto buf = std::make_unique(n * ELEMENT::byteSize()); + prg.next(buf.get(), n * ELEMENT::byteSize()); - std::vector elements; + std::vector elements; elements.reserve(n); for (std::size_t i = 0; i < n; ++i) { - elements.emplace_back(Elem::Read(buf.get() + i * Elem::ByteSize())); + elements.emplace_back(ELEMENT::read(buf.get() + i * ELEMENT::byteSize())); } - return Vec(elements); + return Vector(elements); } -template -Vec Vec::Add(const Vec& other) const { - EnsureCompatible(other); - std::vector r; - auto n = Size(); +template +Vector Vector::add(const Vector& other) const { + ensureCompatible(other); + std::vector r; + auto n = size(); r.reserve(n); for (std::size_t i = 0; i < n; i++) { r.emplace_back(m_values[i] + other.m_values[i]); } - return Vec(r); + return Vector(r); } -template -Vec Vec::Subtract(const Vec& other) const { - EnsureCompatible(other); - std::vector r; - auto n = Size(); +template +Vector Vector::subtract(const Vector& other) const { + ensureCompatible(other); + std::vector r; + auto n = size(); r.reserve(n); for (std::size_t i = 0; i < n; i++) { r.emplace_back(m_values[i] - other.m_values[i]); } - return Vec(r); + return Vector(r); } -template -Vec Vec::MultiplyEntryWise(const Vec& other) const { - EnsureCompatible(other); - std::vector r; - auto n = Size(); +template +Vector Vector::multiplyEntryWise( + const Vector& other) const { + ensureCompatible(other); + std::vector r; + auto n = size(); r.reserve(n); for (std::size_t i = 0; i < n; i++) { r.emplace_back(m_values[i] * other.m_values[i]); } - return Vec(r); + return Vector(r); } -template -bool Vec::Equals(const Vec& other) const { - if (Size() != other.Size()) { +template +bool Vector::equals(const Vector& other) const { + if (size() != other.size()) { return false; } bool equal = true; - for (std::size_t i = 0; i < Size(); i++) { + for (std::size_t i = 0; i < size(); i++) { equal &= m_values[i] == other.m_values[i]; } return equal; } -template -std::string Vec::ToString() const { - if (Empty()) { +template +std::string Vector::toString() const { + if (empty()) { return "[ EMPTY VECTOR ]"; } std::stringstream ss; ss << "["; std::size_t i = 0; - for (; i < Size() - 1; i++) { + for (; i < size() - 1; i++) { ss << m_values[i] << ", "; } ss << m_values[i] << "]"; return ss.str(); } -template -void Vec::Write(unsigned char* dest) const { - for (const auto& v : m_values) { - v.Write(dest); - dest += Elem::ByteSize(); +} // namespace math + +namespace seri { // namespace seri + +/** + * @brief Serializer specialization for math::Vec. + */ +template +struct Serializer> { + private: + using S_vec = Serializer>; + + public: + /** + * @brief Size of a vector. + * @param vec the vector. + */ + static std::size_t sizeOf(const math::Vector& vec) { + return S_vec::sizeOf(vec.m_values); } -} -} // namespace scl::math + /** + * @brief Write a math::Vec to a buffer. + * @param vec the vector. + * @param buf the buffer. + */ + static std::size_t write(const math::Vector& vec, + unsigned char* buf) { + return S_vec::write(vec.m_values, buf); + } + + /** + * @brief Read a math::Vec from a buf. + * @param vec the vector. + * @param buf the buffer. + * @return the number of bytes read. + */ + static std::size_t read(math::Vector& vec, + const unsigned char* buf) { + return S_vec::read(vec.m_values, buf); + } +}; + +} // namespace seri +} // namespace scl -#endif // SCL_MATH_VEC_H +#endif // SCL_MATH_VECTOR_H diff --git a/include/scl/math/z2k.h b/include/scl/math/z2k.h index 30c2012..728da9b 100644 --- a/include/scl/math/z2k.h +++ b/include/scl/math/z2k.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,10 +18,10 @@ #ifndef SCL_MATH_Z2K_H #define SCL_MATH_Z2K_H +#include #include -#include "scl/math/ops.h" -#include "scl/math/z2k_ops.h" +#include "scl/math/z2k/z2k_ops.h" #include "scl/util/prg.h" namespace scl::math { @@ -35,84 +35,67 @@ namespace scl::math { * used. When elements of Z2k are serialized, they are padded to the nearest * byte (so Z2k<6> and Z2k<8> take up the same amount of space). */ -template -class Z2k final : Add>, - Mul>, - Eq>, - Print> { +template +class Z2k final { public: /** * @brief The raw type of a Z2k element. */ using ValueType = - std::conditional_t<(Bits <= 64), std::uint64_t, __uint128_t>; - - /** - * @brief The bit size of the ring. Identical to BitSize(). - */ - constexpr static std::size_t SpecifiedBitSize() { - return Bits; - } + std::conditional_t<(BITS <= 64), std::uint64_t, __uint128_t>; /** * @brief The number of bytes needed to store a ring element. */ - constexpr static std::size_t ByteSize() { - return (Bits - 1) / 8 + 1; + constexpr static std::size_t byteSize() { + return (BITS - 1) / 8 + 1; } /** - * @brief The bit size of the ring. Identical to SpecifiedBitSize(). + * @brief The bit size of the ring. */ - constexpr static std::size_t BitSize() { - return SpecifiedBitSize(); + constexpr static std::size_t bitSize() { + return BITS; } /** * @brief A short string representation of this ring. */ - constexpr static const char* Name() { + constexpr static const char* name() { return "Z2k"; } /** - * @brief Read a ring from a buffer. - * @param src the buffer - * @return a ring element. - * @note This method reads exactly ByteSize() bytes of \p src. + * @brief Read a Z2k element from a buffer. */ - static Z2k Read(const unsigned char* src) { + static Z2k read(const unsigned char* src) { Z2k e; - Z2kFromBytes(e.m_value, src); + z2k::fromBytes(e.m_value, src); return e; } /** * @brief Create a random element. - * @param prg a prg used to generate the random element - * @return a random element. */ - static Z2k Random(util::PRG& prg) { - unsigned char buffer[ByteSize()]; - prg.Next(buffer, ByteSize()); - return Z2k::Read(buffer); + static Z2k random(util::PRG& prg) { + unsigned char buffer[byteSize()]; + prg.next(buffer, byteSize()); + return Z2k::read(buffer); } /** * @brief Create a ring element from a string. - * @param str the string - * @return a ring element. */ - static Z2k FromString(const std::string& str) { + static Z2k fromString(const std::string& str) { Z2k e; - Z2kFromString(e.m_value, str); + z2k::convertIn(e.m_value, str); return e; } // LCOV_EXCL_LINE /** * @brief Get the additive identity of this ring. */ - static Z2k Zero() { + static Z2k zero() { static Z2k zero; return zero; } @@ -120,21 +103,20 @@ class Z2k final : Add>, /** * @brief Get the multiplicative identity of this ring. */ - static Z2k One() { + static Z2k one() { static Z2k one(1); return one; } /** * @brief Create a new ring element from a value. - * @param value the value */ explicit constexpr Z2k(const ValueType& value) : m_value(value) {} /** * @brief Create a new ring element equal to 0. */ - explicit constexpr Z2k() : m_value(0) {} + constexpr Z2k() : m_value(0) {} /** * @brief Destructor. Does nothing. @@ -143,108 +125,190 @@ class Z2k final : Add>, /** * @brief Add another element to this. - * @param other the other element - * @return this incremented by \p other. */ Z2k& operator+=(const Z2k& other) { - Z2kAdd(m_value, other.m_value); + z2k::add(m_value, other.m_value); return *this; } + /** + * @brief Add two Z2k elements together. + */ + friend Z2k operator+(const Z2k& lhs, const Z2k& rhs) { + Z2k tmp(lhs); + return tmp += rhs; + } + + /** + * @brief Pre-increment this element. + */ + Z2k& operator++() { + return *this += one(); + } + + /** + * @brief Post-increment this element. + */ + friend Z2k operator++(Z2k& e, int) { + Z2k tmp(e); + ++e; + return tmp; + } + /** * @brief Subtract another element from this. - * @param other the other element - * @return this decremented by \p other. */ Z2k& operator-=(const Z2k& other) { - Z2kSubtract(m_value, other.m_value); + z2k::subtract(m_value, other.m_value); return *this; } + /** + * @brief Subtract two Z2k elements from each other. + */ + friend Z2k operator-(const Z2k& lhs, const Z2k& rhs) { + Z2k tmp(lhs); + return tmp -= rhs; + } + + /** + * @brief Pre-decrement this element. + */ + Z2k& operator--() { + return *this -= one(); + } + + /** + * @brief Post-decrement this element. + */ + friend Z2k operator--(Z2k& e, int) { + Z2k tmp(e); + --e; + return tmp; + } + /** * @brief Multiply another element to this. - * @param other the other element - * @return this scaled by \p other. */ Z2k& operator*=(const Z2k& other) { - Z2kMultiply(m_value, other.m_value); + z2k::multiply(m_value, other.m_value); return *this; } + /** + * @brief Multiply two Z2k elements together. + */ + friend Z2k operator*(const Z2k& lhs, const Z2k& rhs) { + Z2k tmp(lhs); + return tmp *= rhs; + } + /** * @brief Divide this element by another. - * @param other the other element - * @return this element set to this / \p other. * @throws std::invalid_argument if \p other is not invertible. */ Z2k& operator/=(const Z2k& other) { - Z2kMultiply(m_value, other.Inverse().m_value); + z2k::multiply(m_value, other.inverse().m_value); return *this; } + /** + * @brief Divide two Z2k elements. + * @throws std::invalid_argument if \p other is not invertible. + */ + friend Z2k operator/(const Z2k& lhs, const Z2k& rhs) { + Z2k tmp(lhs); + return tmp /= rhs; + } + /** * @brief Negates this element. */ - Z2k& Negate() { - Z2kNegate(m_value); + Z2k& negate() { + z2k::negate(m_value); return *this; } /** * @brief Compute the negation of this element. */ - Z2k Negated() const { + Z2k negated() const { Z2k copy(m_value); - return copy.Negate(); + return copy.negate(); + } + + /** + * @brief Negate a Z2k element. + */ + friend Z2k operator-(const Z2k& e) { + return e.negated(); } /** * @brief Inverts this element. * @throws std::invalid_argument if this element is not invertible. */ - Z2k& Invert() { - Z2kInvert(m_value); + Z2k& invert() { + z2k::invert(m_value); return *this; } /** * @brief Compute the inverse of this element. + * @throws std::invalid_argument if this element is not invertible. */ - Z2k Inverse() const { + Z2k inverse() const { Z2k copy(m_value); - return copy.Invert(); + return copy.invert(); } /** * @brief Return the least significant bit of this element. - * - * This value can be used to determine if an element is invertible or not. In - * particular, an element x is invertible if x.Lsb() == - * 1. That is, if it is odd. */ - unsigned Lsb() const { - return Z2kLsb(m_value); + unsigned lsb() const { + return z2k::lsb(m_value); } /** * @brief Check if this element is equal to another element. */ - bool Equal(const Z2k& other) const { - return Z2kEqual(m_value, other.m_value); + bool equal(const Z2k& other) const { + return z2k::equal(m_value, other.m_value); + } + + /** + * @brief Equality operator for Z2k elements. + */ + friend bool operator==(const Z2k& lhs, const Z2k& rhs) { + return lhs.equal(rhs); + } + + /** + * @brief In-equality operator for Z2k elements. + */ + friend bool operator!=(const Z2k& lhs, const Z2k& rhs) { + return !(lhs == rhs); } /** * @brief Return a string representation of this element. */ - std::string ToString() const { - return Z2kToString(m_value); + std::string toString() const { + return z2k::toString(m_value); + } + + /** + * @brief Write a string representation of this element to stream. + */ + friend std::ostream& operator<<(std::ostream& os, const Z2k& e) { + return os << e.toString(); } /** * @brief Write this element to a buffer. */ - void Write(unsigned char* dest) const { - Z2kToBytes(m_value, dest); + void write(unsigned char* dest) const { + z2k::toBytes(m_value, dest); } private: diff --git a/include/scl/math/z2k_ops.h b/include/scl/math/z2k/z2k_ops.h similarity index 82% rename from include/scl/math/z2k_ops.h rename to include/scl/math/z2k/z2k_ops.h index e552270..6dfbbb9 100644 --- a/include/scl/math/z2k_ops.h +++ b/include/scl/math/z2k/z2k_ops.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,8 +15,8 @@ * along with this program. If not, see . */ -#ifndef SCL_MATH_Z2K_OPS_H -#define SCL_MATH_Z2K_OPS_H +#ifndef SCL_MATH_Z2K_Z2K_OPS_H +#define SCL_MATH_Z2K_Z2K_OPS_H #include #include @@ -25,13 +25,13 @@ #include "scl/util/str.h" -namespace scl::math { +namespace scl::math::z2k { /** * @brief Add two values modulo a power of 2 without normalization. */ template -void Z2kAdd(T& dst, const T& op) { +void add(T& dst, const T& op) { dst += op; } @@ -39,7 +39,7 @@ void Z2kAdd(T& dst, const T& op) { * @brief Subtract two values modulo a power of 2 without normalization. */ template -void Z2kSubtract(T& dst, const T& op) { +void subtract(T& dst, const T& op) { dst -= op; } @@ -47,7 +47,7 @@ void Z2kSubtract(T& dst, const T& op) { * @brief Multiply two values modulo a power of 2 without normalization. */ template -void Z2kMultiply(T& dst, const T& op) { +void multiply(T& dst, const T& op) { dst *= op; } @@ -55,7 +55,7 @@ void Z2kMultiply(T& dst, const T& op) { * @brief Negate a value modulo a power of 2 without normalization. */ template -void Z2kNegate(T& v) { +void negate(T& v) { v = -v; } @@ -63,7 +63,7 @@ void Z2kNegate(T& v) { * @brief Get the least significant bit of a value. */ template -unsigned Z2kLsb(T& v) { +unsigned lsb(T& v) { return v & 1; } @@ -77,8 +77,8 @@ unsigned Z2kLsb(T& v) { * @param v the value to invert */ template = true> -void Z2kInvert(T& v) { - if (!Z2kLsb(v)) { +void invert(T& v) { + if (!lsb(v)) { throw std::invalid_argument("value not invertible modulo 2^K"); } @@ -98,7 +98,7 @@ void Z2kInvert(T& v) { * @brief Compute equality modulo a power of 2. */ template = true> -bool Z2kEqual(const T& a, const T& b) { +bool equal(const T& a, const T& b) { return (a & SCL_MASK(T, K)) == (b & SCL_MASK(T, K)); } @@ -106,7 +106,7 @@ bool Z2kEqual(const T& a, const T& b) { * @brief Read a value from a buffer and truncate it to a power of 2. */ template = true> -void Z2kFromBytes(T& v, const unsigned char* src) { +void fromBytes(T& v, const unsigned char* src) { v = *(const T*)src; v &= SCL_MASK(T, K); } @@ -115,8 +115,8 @@ void Z2kFromBytes(T& v, const unsigned char* src) { * @brief Write a value modulo a power of 2 to a buffer. */ template = true> -void Z2kToBytes(const T& v, unsigned char* dest) { - // normalization is deferred until elements are written somewhere, so we v +void toBytes(const T& v, unsigned char* dest) { + // normalization is deferred until elements are written somewhere, so v // needs to be normalized before we can write it. auto w = v & SCL_MASK(T, K); std::memcpy(dest, (unsigned char*)&w, (K - 1) / 8 + 1); @@ -128,8 +128,8 @@ void Z2kToBytes(const T& v, unsigned char* dest) { * @param str the string */ template = true> -void Z2kFromString(T& v, const std::string& str) { - v = util::FromHexString(str); +void convertIn(T& v, const std::string& str) { + v = util::fromHexString(str); v &= SCL_MASK(T, K); } @@ -138,11 +138,13 @@ void Z2kFromString(T& v, const std::string& str) { * @param v the value */ template = true> -std::string Z2kToString(const T& v) { +std::string toString(const T& v) { auto w = v & SCL_MASK(T, K); - return util::ToHexString(w); + return util::toHexString(w); } -} // namespace scl::math +#undef SCL_MASK -#endif // SCL_MATH_Z2K_OPS_H +} // namespace scl::math::z2k + +#endif // SCL_MATH_Z2K_Z2K_OPS_H diff --git a/include/scl/net/channel.h b/include/scl/net/channel.h index 9d606f4..507f4c9 100644 --- a/include/scl/net/channel.h +++ b/include/scl/net/channel.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,79 +18,46 @@ #ifndef SCL_NET_CHANNEL_H #define SCL_NET_CHANNEL_H -#include -#include -#include -#include -#include - -#include "scl/net/config.h" +#include "scl/coro/task.h" #include "scl/net/packet.h" -#include "scl/util/traits.h" namespace scl::net { /** - * @brief Channel interface. - * - * Channel defines the interface for a channel between two peers, as well - * as a number of convenience methods for sending and receiving different kinds - * of data. To implement an actual channel, subclass Channel and implement - * the four virtual methods. - * - * @see InMemoryChannel - * @see TcpChannel + * @brief Peer-to-peer communication channel interface. */ class Channel { public: - virtual ~Channel(){}; + virtual ~Channel() {} /** * @brief Close connection to remote. */ - virtual void Close() = 0; - - /** - * @brief Send data to the remote party. - * @param src the data to send - * @param n the number of bytes to send - */ - virtual void Send(const unsigned char* src, std::size_t n) = 0; + virtual void close() = 0; /** - * @brief Receive data from the remote party. - * @param dst where to store the received data - * @param n how much data to receive - * @return how many bytes were received. + * @brief Send a data packet on the channel. + * @param packet the packet to send. */ - virtual std::size_t Recv(unsigned char* dst, std::size_t n) = 0; + virtual coro::Task send(Packet&& packet) = 0; /** - * @brief Check if there is something to receive on this channel. - * @return true if this channel has data and false otherwise. + * @brief Send a data packet on the channel. + * @param packet the packet to send. */ - virtual bool HasData() = 0; + virtual coro::Task send(const Packet& packet) = 0; /** - * @brief Send a Packet on this channel. - * @param packet the packet. - * - * The default implementation of this function makes two calls to Send. One - * for sending the size of the packet, and one for sendin the content. + * @brief Receive a data packet from on the channel. + * @return the received packet. */ - virtual void Send(const Packet& packet); + virtual coro::Task recv() = 0; /** - * @brief Receive a Packet from this channel. - * @param block whether to block until the packet has been received. - * @return a packet. May return nothing when \p block is true. - * - * The default implementation of this function will call Recv and HasData in - * case \p block is true. When receiving a packet in blocking mode, Recv is - * called immidiately. Otherwise, HasData will be called first to determine if - * there's any data available. + * @brief Check if there is something to receive on this channel. + * @return true if this channel has data and false otherwise. */ - virtual std::optional Recv(bool block = true); + virtual coro::Task hasData() = 0; }; } // namespace scl::net diff --git a/include/scl/net/config.h b/include/scl/net/config.h index 7a7afc6..244fcb9 100644 --- a/include/scl/net/config.h +++ b/include/scl/net/config.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -48,7 +48,7 @@ struct Party { std::string hostname; /** - * @brief The base port. + * @brief The port. */ std::size_t port; }; @@ -66,7 +66,7 @@ class NetworkConfig { * @param id the identity of this party * @param filename the filename */ - static NetworkConfig Load(std::size_t id, const std::string& filename); + static NetworkConfig load(std::size_t id, const std::string& filename); /** * @brief Create a network config where all parties are running locally. @@ -80,7 +80,7 @@ class NetworkConfig { * @param size the size of the network * @param port_base the base port */ - static NetworkConfig Localhost(std::size_t id, + static NetworkConfig localhost(std::size_t id, std::size_t size, std::size_t port_base); @@ -89,8 +89,8 @@ class NetworkConfig { * @param id the identity of this party * @param size the size of the network */ - static NetworkConfig Localhost(std::size_t id, std::size_t size) { - return NetworkConfig::Localhost(id, size, DEFAULT_PORT_OFFSET); + static NetworkConfig localhost(std::size_t id, std::size_t size) { + return NetworkConfig::localhost(id, size, DEFAULT_PORT_OFFSET); }; /** @@ -100,7 +100,7 @@ class NetworkConfig { */ NetworkConfig(std::size_t id, const std::vector& parties) : m_id(id), m_parties(parties) { - Validate(); + validate(); }; /** @@ -112,41 +112,36 @@ class NetworkConfig { /** * @brief Gets the identity of this party. */ - std::size_t Id() const { + std::size_t id() const { return m_id; }; /** * @brief Gets the size of the network. */ - std::size_t NetworkSize() const { + std::size_t networkSize() const { return m_parties.size(); }; /** * @brief Get a list of connection information for parties in this network. */ - std::vector Parties() const { + std::vector parties() const { return m_parties; }; /** * @brief Get information about a party. */ - Party GetParty(unsigned id) const { + Party party(unsigned id) const { return m_parties[id]; }; - /** - * @brief Return a string representation of this network config. - */ - std::string ToString() const; - private: - void Validate(); - std::size_t m_id; std::vector m_parties; + + void validate(); }; } // namespace scl::net diff --git a/include/scl/net/loopback.h b/include/scl/net/loopback.h new file mode 100644 index 0000000..9129890 --- /dev/null +++ b/include/scl/net/loopback.h @@ -0,0 +1,140 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_NET_LOOPBACK_H +#define SCL_NET_LOOPBACK_H + +#include +#include +#include + +#include "scl/coro/task.h" +#include "scl/net/channel.h" + +namespace scl::net { + +/** + * @brief A loopback channel. + * + * This Channel implementation defines a channel which connects to in-memory + * buffers. This channel is useful as a channel used by a party that talks with + * itself, or as a channel between two parties that are simply run in-memory. + */ +class LoopbackChannel final : public Channel { + public: + /** + * @brief The type of the internal buffer. + */ + using Buffer = std::deque; + + /** + * @brief Create a pair of connected channels. + * + * This creates two channels ch0 and ch1 such that + * any packet sent on ch0 can be received on ch1, + * and vice versa. + */ + static std::array, 2> createPaired() { + auto buf0 = std::make_shared(); + auto buf1 = std::make_shared(); + return {std::make_shared(buf0, buf1), + std::make_shared(buf1, buf0)}; + } + + /** + * @brief Create a loopback channel that connects to itself. + * + * A channel created with this function will receive anything that it sends. + */ + static std::shared_ptr create() { + auto buf = std::make_shared(); + return std::make_shared(buf, buf); + } + + /** + * @brief Construct a new loopback channel that is not connected to anything. + */ + LoopbackChannel() {} + + /** + * @brief Construct a new loopback channel from a pair of buffers. + */ + LoopbackChannel(std::shared_ptr in, std::shared_ptr out) + : m_in(in), m_out(out) {} + + /** + * @brief Close the channel. Does nothing. + */ + void close() override {} + + /** + * @brief Send a packet on this channel. + * + * This function will complete immediately once awaited. + */ + coro::Task send(Packet&& packet) override { + m_out->emplace_back(packet); + co_return; + } + + /** + * @brief Send a packet on this channel. + * + * This function will complete immediately once awaited. + */ + coro::Task send(const Packet& packet) override { + m_out->push_back(packet); + co_return; + } + + /** + * @brief Receive a packet on this channel. + * + * This function may suspend if there is no data yet. + */ + coro::Task recv() override { + // suspend in case there is no packets yet + co_await [this]() { return !this->m_in->empty(); }; + auto packet = m_in->front(); + m_in->pop_front(); + co_return packet; + } + + /** + * @brief Check if there are data available for receiving. + */ + coro::Task hasData() override { + co_return !m_in->empty(); + } + + /** + * @brief Get the size of the next packet. + * + * This function does not perform any checks on the incoming buffer. + */ + std::size_t getNextPacketSize() const { + return m_in->front().size(); + } + + private: + std::shared_ptr m_in; + std::shared_ptr m_out; +}; + +} // namespace scl::net + +#endif // SCL_NET_LOOPBACK_H diff --git a/include/scl/net/mem_channel.h b/include/scl/net/mem_channel.h deleted file mode 100644 index 8e06127..0000000 --- a/include/scl/net/mem_channel.h +++ /dev/null @@ -1,87 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_NET_MEM_CHANNEL_H -#define SCL_NET_MEM_CHANNEL_H - -#include -#include - -#include "scl/net/channel.h" -#include "scl/net/shared_deque.h" - -namespace scl::net { - -/** - * @brief Channel that communicates through in-memory buffers. - */ -class MemoryBackedChannel final : public Channel { - private: - using Buffer = SharedDeque>; - - public: - /** - * @brief Create a pair of paired channels. - * - * This method returns a pair of channels that shared their buffers such that - * what is sent on one can be retrieved on the other. - */ - static std::array, 2> CreatePaired() { - auto buf0 = std::make_shared(); - auto buf1 = std::make_shared(); - auto chl0 = std::make_shared(buf0, buf1); - auto chl1 = std::make_shared(buf1, buf0); - return {chl0, chl1}; - }; - - /** - * @brief Create a channel that sends to itself. - */ - static std::shared_ptr CreateLoopback() { - auto buf = std::make_shared(); - return std::make_shared(buf, buf); - } - - /** - * @brief Create a new channel that sends and receives on in-memory buffers. - * @param in_buffer the buffer to read incoming messages from - * @param out_buffer the buffer to put outgoing messages - */ - MemoryBackedChannel(std::shared_ptr in_buffer, - std::shared_ptr out_buffer) - : m_in(std::move(in_buffer)), m_out(std::move(out_buffer)){}; - - ~MemoryBackedChannel(){}; - - void Send(const unsigned char* src, std::size_t n) override; - std::size_t Recv(unsigned char* dst, std::size_t n) override; - - bool HasData() override { - return m_in->Size() > 0 || !m_overflow.empty(); - }; - - void Close() override{}; - - private: - std::shared_ptr m_in; - std::shared_ptr m_out; - std::vector m_overflow; -}; - -} // namespace scl::net - -#endif // SCL_NET_MEM_CHANNEL_H diff --git a/include/scl/net/net.h b/include/scl/net/net.h index a6d1e14..f9ea0e9 100644 --- a/include/scl/net/net.h +++ b/include/scl/net/net.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,10 +19,9 @@ #define SCL_NET_NET_H #include "scl/net/config.h" -#include "scl/net/mem_channel.h" +#include "scl/net/loopback.h" #include "scl/net/network.h" #include "scl/net/tcp_channel.h" -#include "scl/net/threaded_sender.h" /** * @brief %Network utilities. diff --git a/include/scl/net/network.h b/include/scl/net/network.h index 098232a..7f7ac04 100644 --- a/include/scl/net/network.h +++ b/include/scl/net/network.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -29,192 +29,188 @@ #include "scl/net/channel.h" #include "scl/net/config.h" -#include "scl/net/mem_channel.h" +#include "scl/net/loopback.h" #include "scl/net/sys_iface.h" #include "scl/net/tcp_channel.h" #include "scl/net/tcp_utils.h" namespace scl::net { +struct MockNetwork; + /** * @brief A Network. + * + *

A Network is effectively a list of Channel's with a bunch of helper + * functions and is the main interface that an MPC protocol will use to + * communicate with other parties. + * + *

Below is shown how a typical Beaver multiplication might be carried out + * using a Network object to communicate: + * + * @code + * Network nw = ... // initialize in some way + * + * for (int i = 0; i < nw.size(); i++) { + * Packet pkt = getDataToSend(); + * co_await nw.party(i)->send(pkt); + * } + * + * for (int i = 0; i < nw.size(); i++) { + * auto recvd = co_await nw.party(i)->recv(); + * processReceivedData(recvd); + * } + * @endcode */ class Network { public: /** * @brief Create a network using a network config. - * @param config the network configuration to use - * @tparam ChannelT the channel type + * @param config the network configuration to use. * - * This creates a network where parties are connected using private - * peer-to-peer channels. + * Creates a new network where the connection information about the parties of + * the network is read from a provided config. In the resulting network, the + * local party is connected to itself with a LoopbackChannel, and to everyone + * else with a TcpChannel. */ - template - static Network Create(const NetworkConfig& config); + static coro::Task create(const NetworkConfig& config); /** * @brief Create a new network. - * @param channels a list of peer-to-peer channels - * @param my_id the ID of the local party + * @param channels the list of channels in the network. + * @param id the ID of the local party */ - Network(const std::vector>& channels, - std::size_t my_id) - : m_channels(channels), m_id(my_id){}; + Network(const std::vector>& channels, std::size_t id) + : m_channels(channels), m_id(id){}; Network() = default; /** - * @brief Get a channel to a particular party. - * @param id the id of the party + * @brief Get a communication channel to some party. + * @param id the id of the party. + * @return the channel to party \p id. */ - Channel* Party(unsigned id) { + Channel* party(unsigned id) { return m_channels[id].get(); } /** * @brief Get the next party according to its ID. + * @return channel to the party with ID (myId() + 1) % size(). */ - Channel* Next() { - const auto next_id = m_id == Size() - 1 ? 0 : m_id + 1; + Channel* next() { + const auto next_id = m_id == size() - 1 ? 0 : m_id + 1; return m_channels[next_id].get(); } /** * @brief Get the previous party according to its ID. + * @return channel to the party with ID (myId() - 1) % size(). */ - Channel* Previous() { - const auto prev_id = m_id == 0 ? Size() - 1 : m_id - 1; + Channel* previous() { + const auto prev_id = m_id == 0 ? size() - 1 : m_id - 1; return m_channels[prev_id].get(); } /** * @brief Get the other party in the network. + * @return channel to the party with ID 1 - myId(). + * @throws std::logic error if the network contains more than two parties. * - * If the network has more than two parties then this method throws an - * std::logic_error as the concept of "other" party is ambigious in that case. + * This function is only meaningful for a two-party network. */ - Channel* Other() { - if (Size() != 2) { + Channel* other() { + if (size() != 2) { throw std::logic_error("other party ambiguous for more than 2 parties"); } return m_channels[1 - m_id].get(); } /** - * @brief The size of the network. + * @brief Get the channel to the local party. + * @return the channel to the party with ID equal to myId(). */ - std::size_t Size() const { - return m_channels.size(); - }; + Channel* me() { + return party(myId()); + } /** - * @brief The ID of the local party. + * @brief Send a packet to all parties on this network. + * @param packet the packet. + * + * This function is equivalent to + * @code + * for (int i = 0; i < size(); i++) { + * co_await party(i)->send(packet); + * } + * @endcode */ - std::size_t MyId() const { - return m_id; - }; + coro::Task send(const Packet& packet) { + for (std::size_t i = 0; i < size(); i++) { + co_await party(i)->send(packet); + } + } /** - * @brief Closes all channels in the network. + * @brief Receive data from a subset of parties. + * @param t the minimum number of parties to receive data from. + * @return list of received packets. + * + * Attempts to receive data from all parties, but stops when a Packet has been + * received from at least t parties. The return value is a std::vector of + * size() std::optional elements. Positions with no values correspond to + * parties that did not send anything. Thus, the return value will have at + * least \p t positions with values. */ - void Close() { - for (auto& c : m_channels) { - c->Close(); + coro::Task>> recv(std::size_t t) { + std::vector> recvs; + recvs.reserve(size()); + for (std::size_t i = 0; i < size(); i++) { + recvs.emplace_back(party(i)->recv()); } - }; - - private: - std::vector> m_channels; - std::size_t m_id; -}; + co_return co_await coro::batch(std::move(recvs), t); + } -/** - * @brief A fake network. Useful for testing. - */ -struct FakeNetwork { /** - * @brief Create a fake network of some size for a specific party. - * @param id the ID of the party owning the fake network - * @param n the size of the network - * @return a FakeNetwork. + * @brief Receive data from all parties on the network. + * @return list of received packets. */ - static FakeNetwork Create(unsigned id, std::size_t n); + coro::Task> recv() { + std::vector> recvs; + recvs.reserve(size()); + for (std::size_t i = 0; i < size(); i++) { + recvs.emplace_back(party(i)->recv()); + } + co_return co_await coro::batch(std::move(recvs)); + } /** - * @brief The ID of the party owning this fake network. + * @brief The number of parties in this network. */ - unsigned id; + std::size_t size() const { + return m_channels.size(); + }; /** - * @brief The network object held by the local party. + * @brief The ID of the local party. */ - Network my_network; + std::size_t myId() const { + return m_id; + }; /** - * @brief Channels that send data to the local party. - * - * The channel on index i != id of this list can be used to send - * data to the local party. The channel on index id is a - * nullptr. + * @brief Closes all channels in the network. */ - std::vector> incoming; -}; - -/** - * @brief Create a fully connected network that resides in memory. - * @param n the size of the network - * @return a fully connected network. - * - * This function creates a list of networks of \p n parties where each pair of - * parties are connected to eachother by a InMemoryChannel. - */ -std::vector CreateMemoryBackedNetwork(std::size_t n); - -template -Network Network::Create(const NetworkConfig& config) { - std::vector> channels(config.NetworkSize()); - - // connect to ourselves. - channels[config.Id()] = MemoryBackedChannel::CreateLoopback(); - - // This thread runs a server which accepts connections from all parties with - // an ID strictly greater than ours. - std::thread connector([&channels, &config]() { - const auto id = config.Id(); - - // the number of connections we should listen for. - const auto m = config.NetworkSize() - id - 1; - - if (m > 0) { - auto port = config.GetParty(id).port; - auto server_socket = CreateServerSocket<>((int)port, (int)m); - - for (std::size_t i = id + 1; i < config.NetworkSize(); ++i) { - auto conn = AcceptConnection(server_socket); - std::shared_ptr channel = - std::make_shared(conn.socket); - - auto p = channel->Recv().value(); - channels[p.Read()] = channel; - } - SysIFace::Close(server_socket); + void close() { + for (auto& c : m_channels) { + c->close(); } - }); - - Packet p; - for (std::size_t i = 0; i < config.Id(); ++i) { - const auto party = config.GetParty(i); - auto socket = ConnectAsClient<>(party.hostname, (int)party.port); - std::shared_ptr channel = std::make_shared(socket); - p << (unsigned)config.Id(); - channel->Send(p); - channels[i] = channel; - p.ResetWritePtr(); - } + }; - connector.join(); - return Network{channels, config.Id()}; -} + private: + std::vector> m_channels; + std::size_t m_id; +}; } // namespace scl::net diff --git a/include/scl/net/packet.h b/include/scl/net/packet.h index 55be7e5..fd5005c 100644 --- a/include/scl/net/packet.h +++ b/include/scl/net/packet.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,10 +20,11 @@ #include #include +#include #include +#include "scl/serialization/serializable.h" #include "scl/serialization/serializer.h" -#include "scl/serialization/serializers.h" namespace scl::net { @@ -37,17 +38,17 @@ namespace scl::net { * @code * net::Packet p; * p << (int)1234; // write an int to the packet. - * // p.Write((int)1234); // same as above. + * // p.write((int)1234); // same as above. * p << math::FF{42}; // write a FF element. * - * std::cout << p.Size() << std::endl; // size of a packet. This will be the + * std::cout << p.size() << std::endl; // size of a packet. This will be the * // size in bytes of the content written * // to it, so with the values written * // above, it will be something like * // sizeof(int) + math::FF::ByteSize() * - * auto v = p.Read(); // read the integer written above - * auto f = p.Read>(); // read the FF element. + * auto v = p.read(); // read the integer written above + * auto f = p.read>(); // read the FF element. * @endcode * *

Packet is essentially a struct with the following format @@ -63,8 +64,8 @@ namespace scl::net { */ class Packet { private: - // the internal buffer is resized using realloc(3) so it has to be allocated - // using malloc(3) and deallocated using free(3), as opposed to new/delete. + // the internal buffer is resized using realloc so it has to be allocated + // using malloc and deallocated using free, as opposed to new/delete. struct FreeDeleter { void operator()(void* p) const { std::free(p); @@ -88,12 +89,34 @@ class Packet { * @brief Construct a new packet. * @param initial_size the initial amount of bytes to allocate. */ - Packet(std::size_t initial_size = 2048) + Packet(std::size_t initial_size = 1024) : m_buffer(static_cast(std::malloc(initial_size))), m_cap(initial_size), m_read_ptr(0), m_write_ptr(0) {} + /** + * @brief Copy constructor. + * @param packet the packet to copy. + */ + Packet(const Packet& packet) + : m_buffer(static_cast(std::malloc(packet.m_cap))), + m_cap(packet.m_cap), + m_read_ptr(packet.m_read_ptr), + m_write_ptr(packet.m_write_ptr) { + std::memcpy(m_buffer.get(), packet.m_buffer.get(), packet.m_write_ptr); + } + + /** + * @brief Copy assignment. + * @param packet the packet to copy into this. + * @return this. + */ + Packet& operator=(Packet packet) { + swap(*this, packet); + return *this; + }; + /** * @brief Read an object from the packet. * @tparam T the type of the object to read. @@ -102,10 +125,10 @@ class Packet { * util::Serializer. A specialization of util::Serialization for \p T must * therefore exist. */ - template - T Read() { + template + T read() { T v; - const auto sz = seri::Serializer::Read(v, Get() + m_read_ptr); + const auto sz = seri::Serializer::read(v, get() + m_read_ptr); m_read_ptr += sz; return v; } // LCOV_EXCL_LINE @@ -113,16 +136,31 @@ class Packet { /** * @brief Write an object to this packet. * @param obj the object to read. + * @return the number of bytes written. * * This function writes \p obj using an util::Serializer. Calling this * function may also result in the internal buffer being resized. */ - template - void Write(const T& obj) { - const auto sz = seri::Serializer::SizeOf(obj); - ReserveSpace(sz); - seri::Serializer::Write(obj, Get() + m_write_ptr); + template + std::size_t write(const T& obj) { + const auto sz = seri::Serializer::sizeOf(obj); + reserveSpace(sz); + seri::Serializer::write(obj, get() + m_write_ptr); m_write_ptr += sz; + return sz; + } + + /** + * @brief Append the content of another packet to this. + * @param obj the other packet. + * @return the number of bytes written to this packet. + */ + std::size_t write(const Packet& obj) { + const auto sz = obj.size(); + reserveSpace(sz); + std::copy(obj.get(), obj.get() + sz, get() + m_write_ptr); + m_write_ptr += sz; + return sz; } /** @@ -130,28 +168,35 @@ class Packet { */ template friend Packet& operator<<(Packet& packet, const T& thing) { - packet.Write(thing); + packet.write(thing); return packet; } /** * @brief The size of a packet. */ - SizeType Size() const { + SizeType size() const { return m_write_ptr; } + /** + * @brief Get the number of unread bytes of this packet. + */ + SizeType remaining() const { + return size() - m_read_ptr; + } + /** * @brief Get a raw const pointer to the content of this packet. */ - const unsigned char* Get() const { + const unsigned char* get() const { return m_buffer.get(); } /** * @brief Get a raw pointer to the conte of this packet. */ - unsigned char* Get() { + unsigned char* get() { return m_buffer.get(); } @@ -163,7 +208,7 @@ class Packet { * new_write_ptr > Size() may result in reading outside the * internal buffer, so don't do this. */ - void SetWritePtr(std::ptrdiff_t new_write_ptr) { + void setWritePtr(std::ptrdiff_t new_write_ptr) { m_write_ptr = new_write_ptr; m_read_ptr = std::min(m_write_ptr, m_read_ptr); } @@ -174,8 +219,8 @@ class Packet { * Since calls to Read are stateful, this function can be used to re-read the * content of a packet. */ - void ResetWritePtr() { - SetWritePtr(0); + void resetWritePtr() { + setWritePtr(0); } /** @@ -185,7 +230,7 @@ class Packet { * used skipping objects, or re-reading objects. Only valid for * new_read_ptr < Size(). */ - void SetReadPtr(std::ptrdiff_t new_read_ptr) { + void setReadPtr(std::ptrdiff_t new_read_ptr) { m_read_ptr = new_read_ptr; } @@ -195,12 +240,50 @@ class Packet { * Like reads, writes are also stateful. This function allows reusing an * existing packet. */ - void ResetReadPtr() { - SetReadPtr(0); + void resetReadPtr() { + setReadPtr(0); + } + + /** + * @brief Compare two packets with respect to their content. + * @return true if the two packets contain the same data, false otherwise. + */ + friend bool operator==(const Packet& lhs, const Packet& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + + const std::size_t n = lhs.size(); + for (std::size_t i = 0; i < n; i++) { + if (lhs.m_buffer[i] != rhs.m_buffer[i]) { + return false; + } + } + return true; + } + + /** + * @brief Swap the content of two packets. + * @param first the first packet. + * @param second the second packet. + */ + friend void swap(Packet& first, Packet& second) { + using std::swap; + swap(first.m_buffer, second.m_buffer); + swap(first.m_cap, second.m_cap); + swap(first.m_read_ptr, second.m_read_ptr); + swap(first.m_write_ptr, second.m_write_ptr); } private: - void ResizeBuffer(std::size_t new_size) { + Buffer m_buffer; + std::size_t m_cap; + std::ptrdiff_t m_read_ptr; + std::ptrdiff_t m_write_ptr; + + // Resize the internal buffer to some new size. This may move the entire + // packet to somewhere else in memory. + void resizeBuffer(std::size_t new_size) { auto* buf = m_buffer.release(); auto* buf_new = std::realloc(buf, new_size); if (buf_new == nullptr) { @@ -213,18 +296,16 @@ class Packet { m_cap = new_size; } - void ReserveSpace(std::size_t obj_size) { + // reserves enough space for some object. This function calculates the minimum + // amount of space needed to store an object of some size, and then tries to + // resize the internal buffer to twice that size. + void reserveSpace(std::size_t obj_size) { const auto min_size = obj_size + m_write_ptr; if (min_size > m_cap) { const auto new_size = std::max(min_size, 2 * m_cap); - ResizeBuffer(new_size); + resizeBuffer(new_size); } } - - Buffer m_buffer; - std::size_t m_cap; - std::ptrdiff_t m_read_ptr; - std::ptrdiff_t m_write_ptr; }; } // namespace scl::net diff --git a/include/scl/net/shared_deque.h b/include/scl/net/shared_deque.h deleted file mode 100644 index 57dbe30..0000000 --- a/include/scl/net/shared_deque.h +++ /dev/null @@ -1,126 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_NET_SHARED_DEQUE_H -#define SCL_NET_SHARED_DEQUE_H - -#include -#include -#include - -namespace scl::net { - -/** - * @brief A simple thread safe double-ended queue. - * - * Based on https://codereview.stackexchange.com/q/238347 - */ -template > -class SharedDeque { - public: - /** - * @brief Remove the top element from the queue. - */ - void PopFront(); - - /** - * @brief Read the top element from the queue. - */ - T& Peek(); - - /** - * @brief Remove and return the top element from the queue. - */ - T Pop(); - - /** - * @brief Insert an item to the back of the queue. - */ - void PushBack(const T& item); - - /** - * @brief Move an item to the back of the queue. - */ - void PushBack(T&& item); - - /** - * @brief Number of elements currently in the queue. - */ - std::size_t Size(); - - private: - std::deque m_deck; - std::mutex m_mutex; - std::condition_variable m_cond; -}; - -template -void SharedDeque::PopFront() { - std::unique_lock lock(m_mutex); - while (m_deck.empty()) { - m_cond.wait(lock); - } - m_deck.pop_front(); -} - -template -T& SharedDeque::Peek() { - std::unique_lock lock(m_mutex); - while (m_deck.empty()) { - m_cond.wait(lock); - } - return m_deck.front(); -} - -template -T SharedDeque::Pop() { - std::unique_lock lock(m_mutex); - while (m_deck.empty()) { - m_cond.wait(lock); - } - auto x = m_deck.front(); - m_deck.pop_front(); - return x; -} - -template -void SharedDeque::PushBack(const T& item) { - std::unique_lock lock(m_mutex); - m_deck.push_back(item); - lock.unlock(); - m_cond.notify_one(); -} - -template -void SharedDeque::PushBack(T&& item) { - std::unique_lock lock(m_mutex); - m_deck.push_back(std::move(item)); - lock.unlock(); - m_cond.notify_one(); -} - -template -std::size_t SharedDeque::Size() { - std::unique_lock lock(m_mutex); - auto size = m_deck.size(); - lock.unlock(); - return size; -} - -} // namespace scl::net - -#endif // SCL_NET_SHARED_DEQUE_H diff --git a/include/scl/net/sys_iface.h b/include/scl/net/sys_iface.h index aca715a..328934f 100644 --- a/include/scl/net/sys_iface.h +++ b/include/scl/net/sys_iface.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,12 +21,13 @@ #include #include +#include #include #include #include #include -namespace scl::net { +namespace scl::net::details { /** * @brief System call wrapper. @@ -38,21 +39,28 @@ struct SysIFace { /** * @brief See man 3 errno. */ - static auto GetError() { + static auto getError() { return errno; } /** * @brief See man 2 socket. */ - static auto Socket(int domain, int type, int protocol) { + static auto socket(int domain, int type, int protocol) { return ::socket(domain, type, protocol); } + /** + * @brief See man 2 fcntl. + */ + static auto fcntl(int fd, int cmd, int flags) { + return ::fcntl(fd, cmd, flags); + } + /** * @brief See man 2 setsockopt. */ - static auto SetSockOpt(int sockfd, + static auto setSockOpt(int sockfd, int level, int optname, const void* optval, @@ -63,49 +71,49 @@ struct SysIFace { /** * @brief See man htons. */ - static auto HostToNet(short hostshort) { + static auto hostToNet(short hostshort) { return ::htons(hostshort); } /** * @brief See man 2 bind. */ - static auto Bind(int sockfd, const struct sockaddr* addr, socklen_t addrlen) { + static auto bind(int sockfd, const struct sockaddr* addr, socklen_t addrlen) { return ::bind(sockfd, addr, addrlen); } /** * @brief See man 2 listen. */ - static auto Listen(int sockfd, int backlog) { + static auto listen(int sockfd, int backlog) { return ::listen(sockfd, backlog); } /** * @brief See man 2 accept. */ - static auto Accept(int sockfd, struct sockaddr* addr, socklen_t* addrlen) { + static auto accept(int sockfd, struct sockaddr* addr, socklen_t* addrlen) { return ::accept(sockfd, addr, addrlen); } /** * @brief See man 3 inet_pton. */ - static auto AddrToBin(int af, const char* src, void* dst) { + static auto addrToBin(int af, const char* src, void* dst) { return ::inet_pton(af, src, dst); } /** * @brief See man 3 inet_ntoa. */ - static auto NetToAddr(struct in_addr inp) { + static auto netToAddr(struct in_addr inp) { return ::inet_ntoa(inp); } /** * @brief See man 2 connect. */ - static auto Connect(int sockfd, + static auto connect(int sockfd, const struct sockaddr* addr, socklen_t addrlen) { return ::connect(sockfd, addr, addrlen); @@ -114,32 +122,32 @@ struct SysIFace { /** * @brief See man 2 poll. */ - static auto Poll(struct pollfd* fds, nfds_t nfds, int timeout) { + static auto poll(struct pollfd* fds, nfds_t nfds, int timeout) { return ::poll(fds, nfds, timeout); } /** * @brief See man 2 close. */ - static auto Close(int fd) { + static auto close(int fd) { return ::close(fd); } /** * @brief See man 2 read. */ - static auto Read(int fd, void* buf, size_t count) { + static auto read(int fd, void* buf, size_t count) { return ::read(fd, buf, count); } /** * @brief See man 2 write. */ - static auto Write(int fd, const void* buf, size_t count) { + static auto write(int fd, const void* buf, size_t count) { return ::write(fd, buf, count); } }; -} // namespace scl::net +} // namespace scl::net::details #endif // SCL_NET_SYS_IFACE_H diff --git a/include/scl/net/tcp_channel.h b/include/scl/net/tcp_channel.h index f509404..6183a11 100644 --- a/include/scl/net/tcp_channel.h +++ b/include/scl/net/tcp_channel.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,120 +19,195 @@ #define SCL_NET_TCP_CHANNEL_H #include -#include -#include +#include #include +#include "scl/coro/future.h" #include "scl/net/channel.h" #include "scl/net/config.h" #include "scl/net/sys_iface.h" +#include "scl/net/tcp_utils.h" namespace scl::net { /** - * @brief A channel between two peers utilizing TCP. + * @brief A channel implementation using TCP. */ -template +template class TcpChannel final : public Channel { public: /** - * @brief Wrap a socket in a TCP channel. + * @brief Create a new TcpChannel. * @param socket the socket. */ - TcpChannel(int socket) : m_alive(true), m_socket(socket){}; + TcpChannel(SocketType socket) : m_alive(true), m_socket(socket) {} /** - * @brief Tells whether this channel is alive or not. + * @brief Check if this channel is alive. + * + * The channel is considered alive upon construction, and dead after the first + * call to close(). */ - bool Alive() const { + bool alive() const { return m_alive; - }; + } + + /** + * @brief Close this channel. + * + * Calling close() will result in all future calls to alive() returning + * false. This function can be called multiple times. + */ + void close() override; + + /** + * @brief Send a packet on this channel. + * @param packet the packet to send. + * + * This function will attempt to send \p packet on the underlying socket. If + * this would result in the call blocking, then the function is suspended and + * scheduled to run later through the supplied scheduler. + */ + coro::Task send(Packet&& packet) override; - void Send(const unsigned char* src, std::size_t n) override; - std::size_t Recv(unsigned char* dst, std::size_t n) override; - bool HasData() override; - void Close() override; + /** + * @brief Send a packet on this channel. + * @param packet the packet to send. + * + * This function will attempt to send \p packet on the underlying socket. If + * this would result in the call blocking, then the function is suspended and + * scheduled to run later through the supplied scheduler. + */ + coro::Task send(const Packet& packet) override; + + /** + * @brief Recv a packet on this channel. + * @return the received packet. + * + * This function will suspend execution if not enough data is ready yet. To + * check if it's possible to receive something on the channel, use hasData(). + */ + coro::Task recv() override; + + /** + * @brief Check if this channel has data ready for recovering. + * @return true if there's data to receive and false otherwise. + */ + coro::Task hasData() override; private: bool m_alive; - int m_socket; + SocketType m_socket; }; -template -void TcpChannel::Send(const unsigned char* src, std::size_t n) { - std::size_t rem = n; - std::size_t offset = 0; +template +void TcpChannel::close() { + if (m_alive) { + // ensures that we only attempt to close the socket once, even if closing + // the somehow socket fails. + m_alive = false; - while (rem > 0) { - auto sent = Sys::Write(m_socket, src + offset, rem); - - if (sent < 0) { - throw std::system_error(Sys::GetError(), + if (SYS::close(m_socket) < 0) { + throw std::system_error(SYS::getError(), std::generic_category(), - "write failed"); + "close failed"); } - - rem -= sent; - offset += sent; } } -template -std::size_t TcpChannel::Recv(unsigned char* dst, std::size_t n) { - std::size_t rem = n; - std::size_t offset = 0; - - while (rem > 0) { - auto recv = Sys::Read(m_socket, dst + offset, rem); +template +coro::Task TcpChannel::send(Packet&& packet) { + co_await send(packet); +} - if (recv == 0) { - break; - } +template +coro::Task TcpChannel::send(const Packet& packet) { + // Write the packet size to a buffer. + const Packet::SizeType packet_size = packet.size(); + const auto packet_size_size = sizeof(Packet::SizeType); + unsigned char packet_size_buf[packet_size_size] = {0}; + std::memcpy(packet_size_buf, &packet_size, packet_size_size); + + // assume writing the packet size won't block. It probably wont. + if (SYS::write(m_socket, packet_size_buf, packet_size_size) < 0) { + throw std::system_error(SYS::getError(), + std::generic_category(), + "writing packet size failed"); + } - if (recv < 0) { - throw std::system_error(Sys::GetError(), - std::generic_category(), - "read failed"); + // Write content of packet. This may block, in which case a RetrySend + // awaitable is created and this coroutine is suspended. + std::size_t rem = packet.size(); + const unsigned char* data = packet.get(); + while (rem > 0) { + const auto written = SYS::write(m_socket, data, rem); + if (written < 0) { + const auto err = SYS::getError(); + if (err == EAGAIN || err == EWOULDBLOCK) { + co_await [socket = m_socket]() { + return details::pollSocket(socket, POLLOUT); + }; + } else { + throw std::system_error(err, std::generic_category(), "send failed"); + } } - rem -= recv; - offset += recv; + rem -= written; + data += written; } - - return n - rem; } -template -bool TcpChannel::HasData() { - struct pollfd fds { - m_socket, POLLIN, 0 - }; +namespace details { - auto r = Sys::Poll(&fds, 1, 0); +// Helper coroutine for reading some amount of bytes from a socket into a +// buffer. If the read would block, then the call is suspended using the +// provided scheduler. +template +coro::Task recvInto(SocketType socket, + unsigned char* dst, + std::size_t nbytes) { + std::size_t rem = nbytes; + while (rem > 0) { + const auto read = SYS::read(socket, dst, rem); + if (read < 0) { + const auto err = SYS::getError(); + if (err == EAGAIN || err == EWOULDBLOCK) { + co_await + [socket = socket]() { return pollSocket(socket, POLLIN); }; + } else { + throw std::system_error(err, std::generic_category(), "recv failed"); + } + } - if (r < 0) { - throw std::system_error(Sys::GetError(), - std::generic_category(), - "poll failed"); + rem -= read; + dst += read; } - - return r > 0 && fds.revents == POLLIN; } -template -void TcpChannel::Close() { - if (!m_alive) { - return; - } +} // namespace details - m_alive = false; +template +coro::Task TcpChannel::recv() { + unsigned char packet_size_buf[sizeof(Packet::SizeType)] = {0}; - if (Sys::Close(m_socket) < 0) { - throw std::system_error(Sys::GetError(), - std::generic_category(), - "close failed"); - } + // read size of the packet. + co_await details::recvInto(m_socket, + packet_size_buf, + sizeof(Packet::SizeType)); + Packet::SizeType packet_size; + std::memcpy(&packet_size, packet_size_buf, sizeof(Packet::SizeType)); + + Packet packet(packet_size); + co_await details::recvInto(m_socket, packet.get(), packet_size); + packet.setWritePtr(packet_size); + + co_return packet; +} + +template +coro::Task TcpChannel::hasData() { + co_return details::pollSocket(m_socket, POLLIN); } } // namespace scl::net diff --git a/include/scl/net/tcp_utils.h b/include/scl/net/tcp_utils.h index c4119cb..d9ae1e8 100644 --- a/include/scl/net/tcp_utils.h +++ b/include/scl/net/tcp_utils.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,11 +18,13 @@ #ifndef SCL_NET_TCP_UTILS_H #define SCL_NET_TCP_UTILS_H +#include #include #include #include #include +#include #include #include @@ -35,6 +37,8 @@ namespace scl::net { */ using SocketType = int; +namespace details { + /** * @brief A connection. */ @@ -58,11 +62,11 @@ struct Connection { * @return A socket. */ template -SocketType CreateServerSocket(int port, int backlog) { - SocketType ssock = Sys::Socket(AF_INET, SOCK_STREAM, 0); +SocketType createServerSocket(int port, int backlog) { + SocketType ssock = Sys::socket(AF_INET, SOCK_STREAM, 0); if (ssock < 0) { - throw std::system_error(Sys::GetError(), + throw std::system_error(Sys::getError(), std::generic_category(), "could not acquire server socket"); } @@ -70,27 +74,27 @@ SocketType CreateServerSocket(int port, int backlog) { int opt = 1; auto options = SO_REUSEADDR | SO_REUSEPORT; - if (Sys::SetSockOpt(ssock, SOL_SOCKET, options, &opt, sizeof(opt)) < 0) { - throw std::system_error(Sys::GetError(), + if (Sys::setSockOpt(ssock, SOL_SOCKET, options, &opt, sizeof(opt)) < 0) { + throw std::system_error(Sys::getError(), std::generic_category(), "could not set socket options"); } struct sockaddr_in addr; addr.sin_family = AF_INET; - addr.sin_addr.s_addr = Sys::HostToNet(INADDR_ANY); - addr.sin_port = Sys::HostToNet(port); + addr.sin_addr.s_addr = Sys::hostToNet(INADDR_ANY); + addr.sin_port = Sys::hostToNet(port); struct sockaddr* addr_ptr = (struct sockaddr*)&addr; - if (Sys::Bind(ssock, addr_ptr, sizeof(addr)) < 0) { - throw std::system_error(Sys::GetError(), + if (Sys::bind(ssock, addr_ptr, sizeof(addr)) < 0) { + throw std::system_error(Sys::getError(), std::generic_category(), "could not bind socket"); } - if (Sys::Listen(ssock, backlog)) { - throw std::system_error(Sys::GetError(), + if (Sys::listen(ssock, backlog)) { + throw std::system_error(Sys::getError(), std::generic_category(), "could not listen on socket"); } @@ -105,19 +109,19 @@ SocketType CreateServerSocket(int port, int backlog) { * @return An accepted connection */ template -Connection AcceptConnection(SocketType server_socket) { - auto sa = std::make_unique(); +Connection acceptConnection(SocketType server_socket) { + struct sockaddr sa; auto addrsize = sizeof(struct sockaddr_in); - SocketType sock = Sys::Accept(server_socket, sa.get(), (socklen_t*)&addrsize); + SocketType sock = Sys::accept(server_socket, &sa, (socklen_t*)&addrsize); if (sock < 0) { - throw std::system_error(Sys::GetError(), + throw std::system_error(Sys::getError(), std::generic_category(), "could not accept connection"); } - const auto* p = (struct sockaddr_in*)sa.get(); - std::string hostname = Sys::NetToAddr(p->sin_addr); + const auto* p = (struct sockaddr_in*)&sa; + std::string hostname = Sys::netToAddr(p->sin_addr); return {sock, hostname}; } @@ -129,35 +133,71 @@ Connection AcceptConnection(SocketType server_socket) { * @param port the port of the remote peer * @return A socket. */ -template -SocketType ConnectAsClient(const std::string& hostname, int port) { +template +SocketType connectAsClient(const std::string& hostname, int port) { using namespace std::chrono_literals; - SocketType sock = Sys::Socket(AF_INET, SOCK_STREAM, 0); + SocketType sock = SYS::socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { - throw std::system_error(Sys::GetError(), + throw std::system_error(SYS::getError(), std::generic_category(), "could not acquire socket"); } struct sockaddr_in addr; addr.sin_family = AF_INET; - addr.sin_port = Sys::HostToNet(port); + addr.sin_port = SYS::hostToNet(port); - int err = Sys::AddrToBin(AF_INET, hostname.c_str(), &(addr.sin_addr)); + int err = SYS::addrToBin(AF_INET, hostname.c_str(), &(addr.sin_addr)); if (err == 0) { throw std::runtime_error("invalid hostname"); } - while (Sys::Connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) { - std::this_thread::sleep_for(300ms); + if (SYS::connect(sock, (struct sockaddr*)&addr, sizeof(addr)) < 0) { + throw std::system_error(SYS::getError(), + std::generic_category(), + "could not connect"); } return sock; } +template +void markSocketNonBlocking(SocketType socket) { + auto flags = SYS::fcntl(socket, F_GETFL, 0); + if (flags == -1) { + throw std::system_error(SYS::getError(), + std::generic_category(), + "could not read current flags of socket"); + } + + if (SYS::fcntl(socket, F_SETFL, flags | O_NONBLOCK) == -1) { + throw std::system_error(SYS::getError(), + std::generic_category(), + "could not set O_NONBLOCK on socket"); + } +} + +template +bool pollSocket(SocketType socket, short event) { + struct pollfd fds { + socket, POLLIN, 0 + }; + + auto r = SYS::poll(&fds, 1, 0); + + if (r < 0) { + throw std::system_error(SYS::getError(), + std::generic_category(), + "poll failed"); + } + + return r > 0 && fds.revents == event; +} + +} // namespace details } // namespace scl::net #endif // SCL_NET_TCP_UTILS_H diff --git a/include/scl/net/threaded_sender.h b/include/scl/net/threaded_sender.h deleted file mode 100644 index d3f50f9..0000000 --- a/include/scl/net/threaded_sender.h +++ /dev/null @@ -1,74 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_NET_THREADED_SENDER_H -#define SCL_NET_THREADED_SENDER_H - -#include -#include - -#include "scl/net/channel.h" -#include "scl/net/shared_deque.h" -#include "scl/net/tcp_channel.h" - -namespace scl::net { - -/** - * @brief A decorator for TcpChannel which does Send calls in a separate thread - * - * The purpose of this class is to avoid situations where calls to Send may - * block, for example if we're trying to send more that what can fit in the TCP - * window. - */ -class ThreadedSenderChannel final : public Channel { - public: - /** - * @brief Create a new threaded sender channel. - * @param socket an open socket used to construct scl::TcpChannel - */ - ThreadedSenderChannel(int socket); - - /** - * @brief Destroying a ThreadedSenderChannel closes the connection. - */ - ~ThreadedSenderChannel() { - Close(); - }; - - void Close() override; - - void Send(const unsigned char* src, std::size_t n) override { - m_send_buffer.PushBack({src, src + n}); - }; - - std::size_t Recv(unsigned char* dst, std::size_t n) override { - return m_channel.Recv(dst, n); - }; - - bool HasData() override { - return m_channel.HasData(); - }; - - private: - TcpChannel<> m_channel; - SharedDeque> m_send_buffer; - std::future m_sender; -}; - -} // namespace scl::net - -#endif // SCL_NET_THREADED_SENDER_H diff --git a/include/scl/protocol/base.h b/include/scl/protocol/base.h index 721b8cb..a19545f 100644 --- a/include/scl/protocol/base.h +++ b/include/scl/protocol/base.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -25,109 +25,123 @@ #include "scl/net/network.h" #include "scl/protocol/env.h" +#include "scl/protocol/result.h" namespace scl::proto { /** - * @brief Base class for protocols. + * @brief Interface for protocols. * - * Protocol provides a simple interface for any protocol that produce - * output. Protocol defines only two methods. + * A class implementing this interface defines the code that a party runs in an + * interactive protocol. An example is given below, showing how we might + * implement a classical secure multiplication protocol using a multiplication + * triple. * - *

Protocol::Run, which takes a Network (so that the protocol can - * communicate). The output of Protocol::Run is either another protocol or - * nullptr. in the former case, the return value represents the - * next logical step of the protocol, while the latter case indicates that the - * protocol has terminated. + * @code + * template + * class BeaverMul final : public proto::Protocol { + * public: + * BeaverMul(SHARE x, SHARE y, Triple triple) + * : m_x(x), m_y(y), m_triple(triple) {} * - *

Protocol::Output which returns the output of the protocol. The output is - * something of type std::any and so can essentially be any - * object. It is the user's job to know what the concrete type of the output is. + * coro::Task run(proto::Env& env) const override { + * net::Packet packet; + * + * packet << m_x - m_triple.a; // [e] = [x] - [a] + * packet << m_y - m_triple.b; // [d] = [y] - [b] + * + * co_await env.network.party(0)->send(packet); + * co_await env.network.party(1)->send(packet); + * + * net::Packet packet0 = co_await env.network.party(0)->recv(); + * net::Packet packet1 = co_await env.network.party(1)->recv(); + * + * const SHARE e0 = packet0.read(); + * const SHARE d0 = packet0.read(); + * const SHARE e1 = packet1.read(); + * const SHARE d1 = packet1.read(); + * + * const SHARE e = e0 + e1; + * const SHARE d = d0 + d1; + * + * // [z] = ed + e[b] + d[a] + [c]. Only party 0 adds constants. + * SHARE z = e * m_triple.b + d * m_triple.a + m_triple.c; + * if (env.network.myId() == 0) { + * z += e * d; + * } + * + * co_return proto::ProtocolResult::done(z); + * } + * + * private: + * SHARE m_x; + * SHARE m_y; + * Triple m_triple; + * }; + * @endcode + * + * It is possible to chain multiple protocols together by returning a pointer to + * the next protocol. It is also possible to compose protocol objects, e.g., as + * shown below. However, care has to be taken in handling cases where two + * protocols both attempt to read from the same channel. + * + * @code + * struct SimpleProtocol final : public proto::Protocol { + * coro::Task run(proto::Env& env) const override { + * // ... do stuff + * co_return proto::ProtocolResult::done(some_value); + * } + * }; + * + * class Composed final : public proto::Protocol { + * public: + * Composed(SimpleProtocol&& protocol1, SimpleProtocol&& protocol2) + * : m_protocol1(std::move(protocol1)), m_protocol2(std::move(protocol2)) {} + * + * coro::Task run(proto::Env& env) const override { + * // batch the two protocols. + * std::vector> protocols; + * protocols.emplace_back(m_protocol1.run()); + * protocols.emplace_back(m_protocol2.run()); + * + * // ask the coroutine runtime to run them in any random order. + * std::vector results = + * co_await coro::batch(std::move(protocols)); + * + * co_return results; + * } + * + * private: + * SimpleProtocol m_protocol1; + * SimpleProtocol m_protocol2; + * }; + * @endcode + * + * Each Protocol is associated with a name, defaulting to the value of + * Protocol::DEFAULT_NAME. The name is used only in the simulator to group + * measurements when generating a result. */ struct Protocol { + virtual ~Protocol() {} + /** * @brief Default protocol name. */ constexpr static const char* DEFAULT_NAME = "UNNAMED"; - virtual ~Protocol(){}; /** * @brief Run the protocol. - * @param env the protocol environment. - * @return next protocol to run, or nullptr if we're done. */ - virtual std::unique_ptr Run(Env& env) = 0; + virtual coro::Task run(Env& env) const = 0; /** - * @brief A name for this protocol. - * @return the protocol name. - * - * Override this method to provide a unique name for a protocol. The name - * serves as a way to distinguish two Protocol implementations from each - * other. The default value is Protocol::kDefaultName. + * @brief The protocol's name. */ - virtual std::string Name() const { + virtual std::string name() const { return Protocol::DEFAULT_NAME; } - - /** - * @brief Output produced by running the protocol. - * @return the output. - */ - virtual std::any Output() const { - return {}; - } }; -/** - * @brief Evaluate a protocol. - * @param protocol the protocol. - * @param output_cb a callback for consuming protocol output. - * @param env the protocol environment. - */ -template -void Evaluate(std::unique_ptr protocol, - Callback output_cb, - Env& env) { - std::shared_ptr next = std::move(protocol); - std::shared_ptr prev = next; - - while (next != nullptr) { - next = next->Run(env); - if (prev->Output().has_value()) { - output_cb(prev->Output()); - } - prev = next; - } -} - -/** - * @brief Evaluate a protocol. - * @param protocol the protocol. - * @param network the network to evaluate the protocol with. - * @param output_cb a callback for consuming protocol output. - */ -template -void Evaluate(std::unique_ptr protocol, - net::Network& network, - Callback output_cb) { - Env ctx{network, - std::make_unique(), - std::make_unique()}; - Evaluate(std::move(protocol), output_cb, ctx); -} - -/** - * @brief Evalate a protocol, discarding all outputs generated. - * @param protocol the protocol to evaluate. - * @param network the network to use. - */ -inline void Evaluate(std::unique_ptr protocol, - net::Network& network) { - const auto sink = [](auto output) { (void)output; }; - Evaluate(std::move(protocol), network, sink); -} - } // namespace scl::proto #endif // SCL_PROTOCOL_BASE_H diff --git a/include/scl/protocol/clock.h b/include/scl/protocol/clock.h new file mode 100644 index 0000000..ba8e700 --- /dev/null +++ b/include/scl/protocol/clock.h @@ -0,0 +1,61 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_PROTOCOL_CLOCK_H +#define SCL_PROTOCOL_CLOCK_H + +#include "scl/util/time.h" + +namespace scl::proto { + +/** + * @brief A clock interface. + */ +struct Clock { + virtual ~Clock() {} + + /** + * @brief Read the current value of the clock. + */ + virtual util::Time::Duration read() const = 0; +}; + +/** + * @brief A clock implementation based on real time. + */ +class RealtimeClock final : public Clock { + public: + /** + * @brief Create a new RealtimeClock. + */ + RealtimeClock() : m_clock_start(util::Time::now()) {} + + /** + * @brief Read the current value of the clock. + * @return the amount of time elapsed since this clock was created. + */ + util::Time::Duration read() const override { + return util::Time::now() - m_clock_start; + } + + private: + util::Time::TimePoint m_clock_start; +}; + +} // namespace scl::proto + +#endif // SCL_PROTOCOL_CLOCK_H diff --git a/include/scl/protocol/env.h b/include/scl/protocol/env.h index 8755fbe..e67507d 100644 --- a/include/scl/protocol/env.h +++ b/include/scl/protocol/env.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -22,8 +22,10 @@ #include #include #include +#include #include "scl/net/network.h" +#include "scl/protocol/clock.h" #include "scl/util/time.h" namespace scl::proto { @@ -36,101 +38,29 @@ namespace scl::proto { * its total running time, as well as the ability to work with threads. */ struct Env { - /** - * @brief Interface for the environment's threading context. - * - * At the moment, the only threading related functionality that the - * environment supports is the ability to sleep the current thread for a - * provided amount of time. - */ - struct Thread { - virtual ~Thread(){}; - - /** - * @brief Put this thread to sleep. - * @param ms the time to sleep this thread for, in milliseconds - */ - virtual void Sleep(std::size_t ms) = 0; - }; - - /** - * @brief Interface for the environment's clock context. - * - * This interface essentially models a "stopwatch" of sorts. The idea is that - * it will start ticking when a protocol starts. The protocol can check the - * current elapsed time at any point, and mark checkpoints. - */ - struct Clock { - virtual ~Clock(){}; - - /** - * @brief Read the current value of the clock. - */ - virtual util::Time::Duration Read() const = 0; - - /** - * @brief Record a checkpoint with an associated message. - */ - virtual void Checkpoint(const std::string& message) = 0; - }; - /** * @brief The network. */ net::Network network; /** - * @brief Clock. + * @brief Clock used to tell for how long the protocol has been running. */ std::unique_ptr clock; - - /** - * @brief Threading context. - */ - std::unique_ptr thread_ctx; -}; - -/** - * @brief A protocol clock which operates with real time. - */ -class RealTimeClock final : public Env::Clock { - public: - RealTimeClock() : m_init_time(util::Time::Now()) {} - ~RealTimeClock() {} - - /** - * @brief Get the current time. - */ - util::Time::Duration Read() const { - return util::Time::Now() - m_init_time; - } - - /** - * @brief Print the current time to stdout. - */ - void Checkpoint(const std::string& message) { - auto ms = std::chrono::duration(Read()).count(); - std::cout << message << " @ " << ms << " ms\n"; - } - - private: - util::Time::TimePoint m_init_time; }; /** - * @brief A protocol thread context which uses STL thread. + * @brief Create an environment from a network. + * @param network the network. + * @return an Env object. + * + * The returned environemnt uses the RealTimeClock and StlThreadContext for the + * environment's clock and thread context, respectively. */ -class StlThreadContext final : public Env::Thread { - public: - ~StlThreadContext() {} - - /** - * @brief Sleep the current thread using std::this_thread::sleep_for. - */ - void Sleep(std::size_t ms) override { - std::this_thread::sleep_for(std::chrono::milliseconds(ms)); - } -}; +inline Env createDefaultEnv(net::Network network) { + return Env{.network = std::move(network), + .clock = std::make_unique()}; +} } // namespace scl::proto diff --git a/include/scl/protocol/eval.h b/include/scl/protocol/eval.h new file mode 100644 index 0000000..ba376dc --- /dev/null +++ b/include/scl/protocol/eval.h @@ -0,0 +1,103 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_PROTOCOL_EVAL_H +#define SCL_PROTOCOL_EVAL_H + +#include +#include + +#include "scl/coro/task.h" +#include "scl/protocol/base.h" +#include "scl/protocol/env.h" +#include "scl/protocol/result.h" + +namespace scl::proto { + +/** + * @brief Evaluate a protocol. + * @param protocol the protocol to evaluate. + * @param env the environment to use. + * @param output_callback a callable that receives protocol outputs. + * + * This function will evaluate a protocol, passing any outputs that + * the protocol produces to @p output_callback. + */ +template +coro::Task evaluate(std::unique_ptr protocol, + Env& env, + CALLBACK output_callback) { + while (protocol) { + ProtocolResult result = co_await protocol->run(env); + + if (result.next_protocol) { + protocol = std::move(result.next_protocol); + } + + if (result.result.has_value()) { + output_callback(result.result); + } + } + + co_return; +} + +/** + * @brief Evaluate a protocol. + * @tparam RESULT the type of the protocol's output. + * @param protocol the protocol to evaluate. + * @param env the environment to use. + * + * This function evaluates a protocol and returns the result that it + * produces. If the protocol terminates, but produces no output, then + * an std::logic_error is thrown. Similarly, if the result produced + * cannot be std::any_cast to something of type + * RESULT, then an error is thrown. + */ +template +coro::Task evaluate(std::unique_ptr protocol, Env& env) { + while (protocol) { + ProtocolResult result = co_await protocol->run(env); + + if (result.next_protocol) { + protocol = std::move(result.next_protocol); + } else { + if (result.result.has_value()) { + co_return std::any_cast(result.result); + } else { + throw std::logic_error("Protocol did not produce any result"); + } + } + } +} + +/** + * @brief Evaluate a protocol that produces no result. + * @param protocol the protocol to evaluate. + * @param env the environment to use. + */ +template <> +inline coro::Task evaluate(std::unique_ptr protocol, Env& env) { + while (protocol) { + ProtocolResult result = co_await protocol->run(env); + protocol = std::move(result.next_protocol); + } +} + +} // namespace scl::proto + +#endif // SCL_PROTOCOL_EVAL_H diff --git a/include/scl/protocol/protocol.h b/include/scl/protocol/protocol.h index ddb547c..2112160 100644 --- a/include/scl/protocol/protocol.h +++ b/include/scl/protocol/protocol.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,6 +19,9 @@ #define SCL_PROTOCOL_PROTOCOL_H #include "scl/protocol/base.h" +#include "scl/protocol/env.h" +#include "scl/protocol/eval.h" +#include "scl/protocol/result.h" /** * @brief %Protocol utilities. diff --git a/include/scl/protocol/result.h b/include/scl/protocol/result.h new file mode 100644 index 0000000..8c880f2 --- /dev/null +++ b/include/scl/protocol/result.h @@ -0,0 +1,87 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_PROTOCOL_RESULT_H +#define SCL_PROTOCOL_RESULT_H + +#include +#include + +namespace scl::proto { + +struct Protocol; + +/** + * @brief The Result of running a Protocol. + * + * All Protocols must return a ProtocolResult object indicating (1) the next + * protocol to run, and (2) the output produced by the protocol. Either can be + * empty, which gives rise to the four constructions present on this class. + */ +struct ProtocolResult { + /** + * @brief Create a protocol result without any next steps or output. + * + * Result returned by a final protocol that produced no output. + */ + static ProtocolResult done() { + return {.next_protocol = nullptr, .result = {}}; + } + + /** + * @brief Create a protocol result without any next steps and an output. + * + * Result returned by a final protocol that produced some output. + */ + static ProtocolResult done(std::any output) { + return ProtocolResult{.next_protocol = nullptr, + .result = std::move(output)}; + } + + /** + * @brief Create a protocol result with a next step. + * + * Result returned by an intermediary protocol that produced no output. + */ + static ProtocolResult next(std::unique_ptr next) { + return ProtocolResult{.next_protocol = std::move(next), .result = {}}; + } + + /** + * @brief Create a protocol result with a next step and output. + * + * Result returned by an intermediary protocol that produced some output. + */ + static ProtocolResult next(std::unique_ptr next, std::any output) { + return ProtocolResult{.next_protocol = std::move(next), + .result = std::move(output)}; + } + + /** + * @brief The next protocol to run. A nullptr value indicates no next step. + */ + std::unique_ptr next_protocol; + + /** + * @brief The output of the protocol. + */ + std::any result; +}; + +} // namespace scl::proto + +#endif // SCL_PROTOCOL_RESULT_H diff --git a/include/scl/scl.h b/include/scl/scl.h index 7484cf1..e46e6f6 100644 --- a/include/scl/scl.h +++ b/include/scl/scl.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,6 +18,15 @@ #ifndef SCL_SCL_H #define SCL_SCL_H +#include "scl/coro/coroutine.h" +#include "scl/math/math.h" +#include "scl/net/net.h" +#include "scl/protocol/protocol.h" +#include "scl/serialization/serialization.h" +#include "scl/simulation/simulation.h" +#include "scl/ss/ss.h" +#include "scl/util/util.h" + /** * @brief Main namespace. */ diff --git a/include/scl/serialization/math_serializers.h b/include/scl/serialization/math_serializers.h deleted file mode 100644 index 444b0cc..0000000 --- a/include/scl/serialization/math_serializers.h +++ /dev/null @@ -1,210 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_SERIALIZATION_MATH_SERIALIZERS_H -#define SCL_SERIALIZATION_MATH_SERIALIZERS_H - -#include "scl/math/ff.h" -#include "scl/math/mat.h" -#include "scl/math/number.h" -#include "scl/math/vec.h" -#include "scl/serialization/serializer.h" - -namespace scl::seri { - -/** - * @brief Serializer specialization for math::FF types. - */ -template -struct Serializer> { - /** - * @brief Determine the size of an math::FF value. - * @param ignored an element, which is ignored. - * - * The size of an math::FF element can be determined from its type alone, so - * the argument is ignored. - */ - static constexpr std::size_t SizeOf(const math::FF& ignored) { - (void)ignored; - return math::FF::ByteSize(); - } - - /** - * @brief Write an math::FF element to a buffer. - * @param elem the element. - * @param buf the buffer. - * - * Calls math::FF::Write. - */ - static std::size_t Write(const math::FF& elem, unsigned char* buf) { - elem.Write(buf); - return SizeOf(elem); - } - - /** - * @brief Read an math::FF element from a buffer. - * @param elem output variable holding the read element after reading. - * @param buf the buffer. - * @return the number of bytes read. - * - * Calls math::FF::Read() and returns math::FF::ByteSize(); - */ - static std::size_t Read(math::FF& elem, const unsigned char* buf) { - elem = math::FF::Read(buf); - return math::FF::ByteSize(); - } -}; - -/** - * @brief Serializer specialization for math::Vec. - */ -template -struct Serializer> { - private: - using SizeType = typename math::Vec::SizeType; - static constexpr auto SIZE_TYPE_SIZE = sizeof(SizeType); - - public: - /** - * @brief Size of a vector. - * @param vec the vector. - */ - static std::size_t SizeOf(const math::Vec& vec) { - return vec.ByteSize() + SIZE_TYPE_SIZE; - } - - /** - * @brief Write a math::Vec to a buffer. - * @param vec the vector. - * @param buf the buffer. - */ - static std::size_t Write(const math::Vec& vec, unsigned char* buf) { - const auto sz = vec.Size(); - std::memcpy(buf, &sz, SIZE_TYPE_SIZE); - vec.Write(buf + SIZE_TYPE_SIZE); - return SizeOf(vec); - } - - /** - * @brief Read a math::Vec from a buf. - * @param vec the vector. - * @param buf the buffer. - * @return the number of bytes read. - */ - static std::size_t Read(math::Vec& vec, const unsigned char* buf) { - SizeType sz; - std::memcpy(&sz, buf, SIZE_TYPE_SIZE); - vec = math::Vec::Read(sz, buf + SIZE_TYPE_SIZE); - return SizeOf(vec); - } -}; - -/** - * @brief Serializer specialization for a math::Mat. - */ -template -struct Serializer> { - private: - using SizeType = typename math::Mat::SizeType; - static constexpr auto SIZE_TYPE_SIZE = sizeof(SizeType); - - public: - /** - * @brief Size of a matrix. - * @param mat the matrix. - * - * The size of a matrix is the determined as the size of the content plus two - * times SIZE_TYPE_SIZE. - */ - static std::size_t SizeOf(const math::Mat& mat) { - return mat.ByteSize() + 2 * SIZE_TYPE_SIZE; - } - - /** - * @brief Write a matrix to a buffer. - * @param mat the matrix. - * @param buf the buffer. - */ - static std::size_t Write(const math::Mat& mat, unsigned char* buf) { - const auto c = mat.Cols(); - std::memcpy(buf, &c, SIZE_TYPE_SIZE); - const auto r = mat.Rows(); - std::memcpy(buf + SIZE_TYPE_SIZE, &r, SIZE_TYPE_SIZE); - mat.Write(buf + 2 * SIZE_TYPE_SIZE); - return SizeOf(mat); - } - - /** - * @brief Read a matrix from a buffer. - * @param mat where to store the matrix after reading. - * @param buf the buffer. - * @return the number of bytes read. - */ - static std::size_t Read(math::Mat& mat, const unsigned char* buf) { - SizeType r; - SizeType c; - std::memcpy(&c, buf, SIZE_TYPE_SIZE); - std::memcpy(&r, buf + SIZE_TYPE_SIZE, SIZE_TYPE_SIZE); - mat = math::Mat::Read(r, c, buf + 2 * SIZE_TYPE_SIZE); - return SizeOf(mat); - } -}; - -/** - * @brief Serializer specialization for math::Number. - */ -template <> -struct Serializer { - /** - * @brief Get the serialized size of a math::Number. - * @param number the number. - * @return the serialized size of a math::Number. - * - * A math::Number is writte as size_and_sign | number where - * size_and_sign is a 4 byte value containing the byte size of - * the number and its sign. - */ - static std::size_t SizeOf(const math::Number& number) { - return number.ByteSize() + sizeof(std::uint32_t); - } - - /** - * @brief Write a number to a buffer. - * @param number the number. - * @param buf the buffer. - * @return the number of bytes written. - */ - static std::size_t Write(const math::Number& number, unsigned char* buf) { - number.Write(buf); - return SizeOf(number); - } - - /** - * @brief Read a math::Number from a buffer. - * @param number the number. - * @param buf the buffer. - * @return the number of bytes read. - */ - static std::size_t Read(math::Number& number, const unsigned char* buf) { - number = math::Number::Read(buf); - return SizeOf(number); - } -}; - -} // namespace scl::seri - -#endif // SCL_SERIALIZATION_MATH_SERIALIZERS_H diff --git a/include/scl/serialization/serializable.h b/include/scl/serialization/serializable.h new file mode 100644 index 0000000..0c4ca00 --- /dev/null +++ b/include/scl/serialization/serializable.h @@ -0,0 +1,38 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_SERIALIZATION_SERIALIZABLE_H +#define SCL_SERIALIZATION_SERIALIZABLE_H + +#include "scl/serialization/serializer.h" + +namespace scl::seri { + +/** + * @brief Requirements for a type to be serializable in SCL. + */ +template +concept Serializable = + requires(T v, const unsigned char* in, unsigned char* out) { + { Serializer::sizeOf(v) } -> std::same_as; + { Serializer::write(v, out) } -> std::same_as; + { Serializer::read(v, in) } -> std::same_as; + }; + +} // namespace scl::seri + +#endif // SCL_SERIALIZATION_SERIALIZABLE_H diff --git a/include/scl/serialization/serializers.h b/include/scl/serialization/serialization.h similarity index 68% rename from include/scl/serialization/serializers.h rename to include/scl/serialization/serialization.h index 3987067..14cfb5a 100644 --- a/include/scl/serialization/serializers.h +++ b/include/scl/serialization/serialization.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,14 +15,15 @@ * along with this program. If not, see . */ -#ifndef SCL_SERIALIZATION_SERIALIZERS_H -#define SCL_SERIALIZATION_SERIALIZERS_H +#ifndef SCL_SERIALIZATION_SERIALIZATION_H +#define SCL_SERIALIZATION_SERIALIZATION_H -#include "scl/serialization/math_serializers.h" +#include "scl/serialization/serializable.h" +#include "scl/serialization/serializer.h" /** - * @brief Serialization functionality. + * @brief Code for serializing and deserializing things. */ -namespace scl::seri {} // namespace scl::seri +namespace scl::seri {} -#endif // SCL_SERIALIZATION_SERIALIZERS_H +#endif // SCL_SERIALIZATION_SERIALIZATION_H diff --git a/include/scl/serialization/serializer.h b/include/scl/serialization/serializer.h index 58a4a95..9109520 100644 --- a/include/scl/serialization/serializer.h +++ b/include/scl/serialization/serializer.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,65 +19,67 @@ #define SCL_SERIALIZATION_SERIALIZER_H #include +#include #include #include #include -#include "scl/util/traits.h" - namespace scl::seri { /** - * @brief Serializer. + * @brief Serializer type. + * + *

Serializer's are used throughout SCL whenever data has to be converted + * to/from a binary format. Deciding how this conversion is done is Serializer's + * job. * - *

A Serializer is meant to provide functionality for writing to, and reading - * from, a binary format, and is used for example when sending data across a - * network. To make the type \p T serializable, a specialization of Serializer - * must be provided which defines the following three methods: + *

SCL contains Serializer implementations for most of its types, as well as + * couple of a "standard" types (such as trivially copyable objects). Adding a + * new Serializer is easy, and can be done by providing a specialization of the + * Serializer type, which must include a couple of methods. An example is shown + * below: * - *

    + * @code + * template <> + * struct scl::seri::Serializer { + * // should return the binary size of an object + * static std::size_t sizeOf(const MyType& obj) { + * return binarySize(obj); + * } * - *
  • static std::size_t SizeOf(const T& obj). This function - * should return the binary size of obj. More precisely, it - * should return the size of the memory needed to store an object of type \p - * T.
  • + * // writes the object to a buffer. The buffer can be assumed to have space + * // for at least sizeOf(obj) bytes. The function must return the number of + * // bytes written to buf. + * static std::size_t write(const MyType& obj, unsigned char* buf) { + * writeMyTypeToBuf(obj, buf); + * return sizeOf(obj); + * } * - *
  • static std::size_t Write(const T& obj, unsigned char* buf) - * this function should write obj in whatever binary format is - * appropriate to the buffer buf. The function can assume that - * buf points to at least SizeOf(obj) bytes of memory. - * The return value of this function should be the number of bytes written to - * buf.
  • + * // reads an object from a buffer and assigns it to the first argument. + * // The return value must be the number of bytes read from buf. + * static std::size_t read(MyType& obj, const unsigned char* buf) { + * obj = readMyTypeFromBuf(buf); + * return sizeOf(obj); + * } + * }; + * @endcode * - *
  • static std::size_t Read(T& obj, const unsigned char* buf) - * this function should read an object of type \p T from buf and - * assign it to obj. The return value should be the number of bytes - * read from buf.
  • - *
+ * @see Serializable. */ template struct Serializer; /** * @brief Serializer specialization for trivially copyable types. - * - * This Serializer reads and writes types that are trivially copyable. In a - * nutshell, this includes all types that can be constructed via a call to - * std::memcpy. - * * @see https://en.cppreference.com/w/cpp/named_req/TriviallyCopyable */ template struct Serializer::value>> { /** * @brief Determine the size of an object. - * @param ignored the object, which is ignored. - * - * The size of a trivially copyable object is assumed to be decidable from the - * type itself. This function returns sizeof(T). + * @return the size of something of the trivially copyable type. */ - static constexpr std::size_t SizeOf(const T& ignored) { - (void)ignored; + static constexpr std::size_t sizeOf(const T& /* ignored */) { return sizeof(T); } @@ -86,9 +88,9 @@ struct Serializer::value>> { * @param obj the object to write. * @param buf the buffer to write the object to. */ - static constexpr std::size_t Write(const T& obj, unsigned char* buf) { + static constexpr std::size_t write(const T& obj, unsigned char* buf) { std::memcpy(buf, &obj, sizeof(T)); - return sizeof(T); + return sizeOf(obj); } /** @@ -97,32 +99,76 @@ struct Serializer::value>> { * @param buf the buffer to read from. * @return the number of bytes read. Equal to SizeOf(T). */ - static constexpr std::size_t Read(T& obj, const unsigned char* buf) { + static constexpr std::size_t read(T& obj, const unsigned char* buf) { std::memcpy(&obj, buf, sizeof(T)); - return SizeOf(obj); + return sizeOf(obj); + } +}; + +/** + * @brief Type used to carry information about the size of an STL vector. + */ +using StlVecSizeType = std::uint32_t; + +/** + * @brief Serializer specialization for STL vector of bytes. + */ +template <> +struct Serializer> { + public: + /** + * @brief Determine the size of the byte vector. + * @param data the vector. + * @return the size of \p data in bytes. + */ + static std::size_t sizeOf(const std::vector& data) { + return Serializer::sizeOf(data.size()) + data.size(); + } + + /** + * @brief Write a byte vector to a buffer. + * @param data the vector. + * @param buf the buffer. + * @return the number of bytes written to \p buf. + */ + static std::size_t write(const std::vector& data, + unsigned char* buf) { + const auto offset = Serializer::write(data.size(), buf); + std::memcpy(buf + offset, data.data(), data.size()); + return sizeOf(data); + } + + /** + * @brief Read a byte vector from a buffer. + * @param data where to store the byte vector read. + * @param buf the buffer to read from. + * @return the number of bytes read. + */ + static std::size_t read(std::vector& data, + const unsigned char* buf) { + StlVecSizeType size = 0; + const auto offset = Serializer::read(size, buf); + data.resize(size); + std::memcpy(data.data(), buf + offset, size); + return sizeOf(data); } }; /** - * @brief Serializer specialization for one dimensional std::vector - * types. + * @brief Serializer specialization for generic std::vector types. */ template struct Serializer> { - private: - using VectorSizeType = typename std::vector::size_type; - constexpr static auto SIZE_SIZE = sizeof(VectorSizeType); - public: /** * @brief Determine the byte size of a vector. * @param vec the vector. * @return the size of \p vec when written using this Serializer. */ - static std::size_t SizeOf(const std::vector& vec) { - auto size = SIZE_SIZE; + static std::size_t sizeOf(const std::vector& vec) { + auto size = Serializer::sizeOf(vec.size()); for (const auto& v : vec) { - size += Serializer::SizeOf(v); + size += Serializer::sizeOf(v); } return size; } @@ -133,11 +179,10 @@ struct Serializer> { * @param buf the buffer where \p vec is written to. * @return the number of bytes written to buf. */ - static std::size_t Write(const std::vector& vec, unsigned char* buf) { - Serializer::Write(vec.size(), buf); - auto offset = SIZE_SIZE; + static std::size_t write(const std::vector& vec, unsigned char* buf) { + auto offset = Serializer::write(vec.size(), buf); for (const auto& v : vec) { - offset += Serializer::Write(v, buf + offset); + offset += Serializer::write(v, buf + offset); } return offset; } @@ -151,15 +196,14 @@ struct Serializer> { * This function reads a size from \p buf and uses it to reserve * space in \p vec. Elements are then read one by one from \p buf. */ - static std::size_t Read(std::vector& vec, const unsigned char* buf) { - typename std::vector::size_type size; - Serializer::Read(size, buf); - vec.reserve(size); - auto offset = SIZE_SIZE; + static std::size_t read(std::vector& vec, const unsigned char* buf) { + StlVecSizeType size = 0; + auto offset = Serializer::read(size, buf); + vec.resize(size); for (std::size_t i = 0; i < size; ++i) { T v; - offset += Serializer::Read(v, buf + offset); - vec.emplace_back(v); + offset += Serializer::read(v, buf + offset); + vec[i] = std::move(v); } return offset; } diff --git a/include/scl/simulation/buffer.h b/include/scl/simulation/buffer.h deleted file mode 100644 index 50ee758..0000000 --- a/include/scl/simulation/buffer.h +++ /dev/null @@ -1,74 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_SIMULATION_BUFFER_H -#define SCL_SIMULATION_BUFFER_H - -#include - -namespace scl::sim { - -/** - * @brief Peer-to-peer channel interface for passing data in simulations. - * - * ChannelBuffer is interface of any communication channel between two peers in - * a simulation. Besides allowing reads and writes, a ChannelBuffer must also - * support Prepare-Commit-Rollback logic. This is needed because in case of - * simulation failures, it is necessary to undo any reads and writes performed - * on a channel. - */ -struct ChannelBuffer { - virtual ~ChannelBuffer(){}; - - /** - * @brief Read data from the channel. - * @param data the data to write. - * @param n the number of bytes to read. - */ - virtual void Read(unsigned char* data, std::size_t n) = 0; - - /** - * @brief Write data to the channel. - * @param data the data to write. - * @param n the number of bytes to write. - */ - virtual void Write(const unsigned char* data, std::size_t n) = 0; - - /** - * @brief Get the amount of bytes that can be read from this channel. - */ - virtual std::size_t Size() = 0; - - /** - * @brief Prepare reads and writes. - */ - virtual void Prepare() = 0; - - /** - * @brief Commit reads and writes. - */ - virtual void Commit() = 0; - - /** - * @brief Rollback reads and writes. - */ - virtual void Rollback() = 0; -}; - -} // namespace scl::sim - -#endif // SCL_SIMULATION_BUFFER_H diff --git a/src/scl/net/threaded_sender.cc b/include/scl/simulation/cancellation.h similarity index 54% rename from src/scl/net/threaded_sender.cc rename to include/scl/simulation/cancellation.h index 37bf0ab..7c840b8 100644 --- a/src/scl/net/threaded_sender.cc +++ b/include/scl/simulation/cancellation.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,27 +15,24 @@ * along with this program. If not, see . */ -#include "scl/net/threaded_sender.h" +#ifndef SCL_SIMULATION_CANCELLATION_H +#define SCL_SIMULATION_CANCELLATION_H -#include +#include +#include +#include -scl::net::ThreadedSenderChannel::ThreadedSenderChannel(int socket) - : m_channel(TcpChannel(socket)) { - m_sender = std::async(std::launch::async, [&]() { - while (true) { - auto data = m_send_buffer.Peek(); - if (!m_channel.Alive()) { - break; - } - m_channel.Send(data.data(), data.size()); - m_send_buffer.PopFront(); - } - }); -} +#include "scl/util/bitmap.h" -void scl::net::ThreadedSenderChannel::Close() { - m_channel.Close(); - unsigned char stop_signal = 1; - Send(&stop_signal, 1); - m_sender.wait(); -} +namespace scl::sim::details { + +/** + * @brief Exception used to signal that a coroutine has been cancelled. + */ +struct CancellationException final : public std::runtime_error { + CancellationException() : std::runtime_error("cancelled") {} +}; + +} // namespace scl::sim::details + +#endif // SCL_SIMULATION_CANCELLATION_H diff --git a/include/scl/simulation/channel.h b/include/scl/simulation/channel.h index 372a287..afac22b 100644 --- a/include/scl/simulation/channel.h +++ b/include/scl/simulation/channel.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,132 +20,71 @@ #include +#include "scl/coro/task.h" #include "scl/net/channel.h" #include "scl/simulation/channel_id.h" #include "scl/simulation/context.h" +#include "scl/simulation/transport.h" -namespace scl::sim { +namespace scl::sim::details { /** - * @brief Simulate a net::Channel::Close call on a channel. - * @param ctx a simulation context. - * @param id the ID of the channel making the call. - * @return the event generated. - * - * This function simply generates a CLOSE event for the current - * time of the running party. - */ -std::shared_ptr SimulateClose(std::shared_ptr ctx, - ChannelId id); - -/** - * @brief Simulate a net::Channel::Send call on a channel. - * @param ctx a simulation context. - * @param id the ID of the channel making the call. - * @param src a pointer to the data being sent. - * @param n the number of bytes being sent. - * @return the event generated. - * - * This function aims to simulate Sending data over a network, under the - * assumption that this operation is instant. This involves generating a - * SEND event with the current time of the party, minus the time it - * took to write the data in \p src unto the underlying ChannelBuffer. In order - * to determine when the \p n bytes are going to be received, this function also - * records a write operation on the context with the time in the - * SEND event for the number \p n of bytes sent. - */ -std::shared_ptr SimulateSend(std::shared_ptr ctx, - ChannelId id, - const unsigned char* src, - std::size_t n); - -/** - * @brief Simulate a net::Channel::Recv call on a channel. - * @param ctx a simulation context. - * @param id the ID of the channel making the call. - * @param dst destination for the received data. - * @param n the number of bytes to receive. - * @return the event generated. - * @throws scl::SimulationFailure in case the call could not be simulated - * - *

This function fails in the case when less than \p n bytes are available on - * the underlying ChannelBuffer. Otherwise, the function reads the requested - * number of bytes, and computes the time this data would be received. Finally, - * the function creates a RECV event with the adjusted time. - * - *

The time in the RECV event is adjusted by going through the - * recorded write operations for the sending channel - */ -std::shared_ptr SimulateRecv(std::shared_ptr ctx, - ChannelId id, - unsigned char* dst, - std::size_t n); - -/** - * @brief Simulate a net::Channel::HasData call on a channel. - * @param ctx a simulation context - * @param id the ID of the channel making the call - * @return the event generated and the whether there was data available. - * @throws scl::SimulationFailure in case the call could not be simulated - * - *

This function simulates the case where id.local checks if - * id.remote sent it data. This is done by checking if there are - * unhandled write operations by the remote party that took place before this - * function was called. - * - *

This function may fail if the last time recorded by the remote party is - * earlier than the time when this function was called. In this case, it is not - * possible to determine if there are data available. + * @brief Channel implementation used during simulations. */ -std::pair> SimulateHasData( - std::shared_ptr ctx, - ChannelId id); - -/** - * @brief Channel implementation used in simulations. - * - * SimulatedChannel wraps a SimulationContext and a ChannelId and calls out to - * sim::SimulateClose, sim::SimulateSend, sim::SimulateRecv or - * sim::SimulateHasData, which performs the actual simulation of the methods in - * the Channel interface. - */ -class Channel final : public net::Channel { +class SimulatedChannel final : public net::Channel { public: /** - * @brief Construct a new Channel for simulations. - * @param id the ID of the channel - * @param ctx a simulation context object + * @brief Construct a SimulatedChannel. + * @param cid the ID of this channel. + * @param context a context object for this channel. + * @param transport the transport to use for moving data. */ - Channel(ChannelId id, std::shared_ptr ctx) : m_id(id), m_ctx(ctx){}; + SimulatedChannel(ChannelId cid, + GlobalContext::LocalContext context, + std::shared_ptr transport) + : m_cid(cid), m_context(context), m_transport(transport) {} - void Close() override { - m_ctx->AddEvent(m_id.local, SimulateClose(m_ctx, m_id)); - } - - void Send(const unsigned char* src, std::size_t n) override { - m_ctx->AddEvent(m_id.local, SimulateSend(m_ctx, m_id, src, n)); - } + /** + * @brief Closes the channel. + * + * Creates a EventType::CLOSE event. + */ + void close() override; - std::size_t Recv(unsigned char* dst, std::size_t n) override { - m_ctx->AddEvent(m_id.local, SimulateRecv(m_ctx, m_id, dst, n)); - return n; - } + /** + * @brief Sends data on the channel. + * + * Creates a EventType::SEND event. + */ + coro::Task send(net::Packet&& packet) override; - bool HasData() override { - const auto r = SimulateHasData(m_ctx, m_id); - m_ctx->AddEvent(m_id.local, std::get<1>(r)); - return std::get<0>(r); - } + /** + * @brief Sends data on the channel. + * + * Creates a EventType::SEND event. + */ + coro::Task send(const net::Packet& packet) override; - void Send(const net::Packet& packet) override; + /** + * @brief Receives data on the channel. + * + * Creates a EventType::RECV event. + */ + coro::Task recv() override; - std::optional Recv(bool block = true) override; + /** + * @brief Checks if there is data available on this channel. + * + * Creates a EventType::HAS_DATA event. + */ + coro::Task hasData() override; private: - ChannelId m_id; - std::shared_ptr m_ctx; + ChannelId m_cid; + GlobalContext::LocalContext m_context; + std::shared_ptr m_transport; }; -} // namespace scl::sim +} // namespace scl::sim::details #endif // SCL_SIMULATION_CHANNEL_H diff --git a/include/scl/simulation/channel_id.h b/include/scl/simulation/channel_id.h index 85bf449..ccbf2f2 100644 --- a/include/scl/simulation/channel_id.h +++ b/include/scl/simulation/channel_id.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -36,14 +36,6 @@ namespace scl::sim { * in a std::map. */ struct ChannelId { - /** - * @brief Construct a new channel ID. - * @param local ID of the this (i.e., the local) party - * @param remote ID of the remote party - */ - ChannelId(std::size_t local, std::size_t remote) - : local(local), remote(remote){}; - /** * @brief ID of this party. */ @@ -55,14 +47,10 @@ struct ChannelId { std::size_t remote; /** - * @brief Flip the direction of the channel. - * - * Flip turns a channel ID {i, j} into a channel ID {j, - * i}. This is used when a party i needs the ID of the - * channel it should read from when receving data from j. + * @brief Flips the view of the this ID. */ - ChannelId Flip() const { - return ChannelId{remote, local}; + ChannelId flip() const { + return {remote, local}; } /** @@ -84,7 +72,7 @@ struct ChannelId { * @brief Print operator for ChannelId. */ friend std::ostream& operator<<(std::ostream& os, const ChannelId& cid) { - return os << "ChannelId{" << cid.local << ", " << cid.remote << "}"; + return os << "{local=" << cid.local << ", remote=" << cid.remote << "}"; } }; diff --git a/include/scl/simulation/config.h b/include/scl/simulation/config.h index 203c4f4..062172d 100644 --- a/include/scl/simulation/config.h +++ b/include/scl/simulation/config.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -86,24 +86,24 @@ class ChannelConfig { /** * @brief Create a simulation config with default values. */ - static ChannelConfig Default(); + static ChannelConfig defaultConfig(); /** * @brief Create a simulation config for a loopback connection. */ - static ChannelConfig Loopback(); + static ChannelConfig loopback(); /** * @brief The network type of the channel. */ - NetworkType Type() const { + NetworkType type() const { return m_type; } /** * @brief Bandwidth in Bits/s. */ - std::size_t Bandwidth() const { + std::size_t bandwidth() const { return m_bandwidth; } @@ -124,14 +124,14 @@ class ChannelConfig { /** * @brief Package loss in percentage. */ - double PackageLoss() const { + double packetLoss() const { return m_package_loss; } /** * @brief TCP window size. */ - std::size_t WindowSize() const { + std::size_t windowSize() const { return m_window_size; }; @@ -175,8 +175,8 @@ class ChannelConfig::Builder { /** * @brief Build the simulation config. */ - ChannelConfig Build() const { - Validate(); + ChannelConfig build() const { + validate(); return ChannelConfig{ m_type.value_or(ChannelConfig::DEFAULT_NETWORK_TYPE), m_bandwidth.value_or(ChannelConfig::DEFAULT_BANDWIDTH), @@ -191,7 +191,7 @@ class ChannelConfig::Builder { * @param type the network type. * @return the builder. */ - Builder& Type(NetworkType type) { + Builder& type(NetworkType type) { m_type = type; return *this; } @@ -201,7 +201,7 @@ class ChannelConfig::Builder { * @param bandwidth bandwidth in bits/s. * @return the builder. */ - Builder& Bandwidth(std::size_t bandwidth) { + Builder& bandwidth(std::size_t bandwidth) { m_bandwidth = bandwidth; return *this; } @@ -231,7 +231,7 @@ class ChannelConfig::Builder { * @param percentage the percent of packages being lost * @return the builder. */ - Builder& PackageLoss(double percentage) { + Builder& packetLoss(double percentage) { m_package_loss = percentage; return *this; } @@ -241,7 +241,7 @@ class ChannelConfig::Builder { * @param window_size of the TCP window in bytes * @return the builder. */ - Builder& WindowSize(std::size_t window_size) { + Builder& windowSize(std::size_t window_size) { m_window_size = window_size; return *this; } @@ -255,7 +255,7 @@ class ChannelConfig::Builder { std::optional m_window_size; // Validate the config settings before creating the actual SimulationConfig. - void Validate() const; + void validate() const; }; /** @@ -270,7 +270,7 @@ struct NetworkConfig { /** * @brief Returns the configuration of a particular channel. */ - virtual ChannelConfig Get(ChannelId channel_id) = 0; + virtual ChannelConfig get(ChannelId channel_id) = 0; }; /** @@ -282,9 +282,9 @@ struct NetworkConfig { * channels are configured according to ChannelConfig::Loopback. */ struct SimpleNetworkConfig final : public NetworkConfig { - ChannelConfig Get(ChannelId channel_id) override { - static auto config = ChannelConfig::Default(); - static auto lo = ChannelConfig::Loopback(); + ChannelConfig get(ChannelId channel_id) override { + static auto config = ChannelConfig::defaultConfig(); + static auto lo = ChannelConfig::loopback(); return (channel_id.local == channel_id.remote) ? lo : config; } diff --git a/include/scl/simulation/context.h b/include/scl/simulation/context.h index 785c7fe..d7b5514 100644 --- a/include/scl/simulation/context.h +++ b/include/scl/simulation/context.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,322 +18,277 @@ #ifndef SCL_SIMULATION_CONTEXT_H #define SCL_SIMULATION_CONTEXT_H +#include #include #include -#include #include #include -#include "scl/simulation/buffer.h" +#include "scl/simulation/cancellation.h" #include "scl/simulation/channel_id.h" #include "scl/simulation/config.h" #include "scl/simulation/event.h" -#include "scl/simulation/mem_channel_buffer.h" +#include "scl/simulation/hook.h" +#include "scl/util/bitmap.h" namespace scl::sim { +class SimulationContext; + +namespace details { + /** - * @brief Context for simulations. + * @brief Global context object for a simulation. + * + * GlobalContext keeps track of the events that the parties in the simulation + * generates, the timestamps of when a party sends data on a channel, and the + * local clocks of each party. */ -class Context { - private: - enum class State { PREPARE, COMMIT, ROLLBACK }; - - public: +struct GlobalContext { /** - * @brief Provides a read-only view of a Context. + * @brief Create a new global context for a simulation. + * @param number_of_parties the number of parties in the simulation. + * @param network_config the network configuration to use. + * @param hooks the hooks that should be run when an event is created. */ - class View; + static GlobalContext create(std::size_t number_of_parties, + std::unique_ptr network_config, + std::vector hooks); /** - * @brief A write operation on the channel. + * @brief The number of parties. */ - struct WriteOp { - /** - * @brief Construct a new WriteOp. - * @param amount the amount of data in the write operation. - * @param time the time of the write operation. - */ - WriteOp(std::size_t amount, util::Time::Duration time) - : amount(amount), time(time) {} - /** - * @brief The amount of data written. - */ - std::size_t amount; - - /** - * @brief When the data was written. - */ - util::Time::Duration time; - }; + std::size_t number_of_parties; /** - * @brief Create a new SimulationContext. - * @tparam ChannelBufferT the type of the channel buffer to use - * @param number_of_parties the number of parties in the simulation - * @param config config for the simulated network - * @return a pointer to a SimulationConfig. - * - * This factory method handles the non-trivial setup related to the network - * buffers. A specialization exists for each of the ChannelBuffer - * implementations that currently exist in SCL. + * @brief The network configuration for the simulation. */ - template - static std::shared_ptr Create(std::size_t number_of_parties, - std::shared_ptr config); + std::unique_ptr network_config; /** - * @brief Construct a new simulation context. - * @param config a config describing the simulated network - * - * This constructor simply sets the network config for the context but - * otherwise performs no initialization whatsoever. Use Create instead. + * @brief The simulation traces. */ - Context(std::shared_ptr config) - : m_network_config(config), m_nparties(0) {} + std::vector traces; /** - * @brief Get the config for a channel. - * @param channel_id the ID of the channel - * @return a SimulatedNetworkConfig for the channel. - */ - ChannelConfig ChannelConfiguration(ChannelId channel_id) const { - return m_network_config->Get(channel_id); - } - - /** - * @brief Get the number of parties in the simulation. + * @brief Current unhandled packets in the network. + * + * This is a mapping from a channel to timestamps of calls to send on the + * channel that have not yet been received. */ - std::size_t NumberOfParties() const { - return m_nparties; - } + std::unordered_map> sends; /** - * @brief Get the channel buffer for a particular channel. - * @param id the ID of the channel - * @return the channel buffer. + * @brief The local clocks for each party. */ - std::shared_ptr Buffer(ChannelId id) { - return m_buffers[id]; - } + std::vector clocks; /** - * @brief Add a write operation. - * @param id the identifier of the channel that the write occured on. - * @param n the number of bytes written. - * @param time the time the write happened. + * @brief Map of parties currently in the process of receiving data. */ - void AddWrite(ChannelId id, std::size_t n, util::Time::Duration time) { - m_writes[id].emplace(n, time); - } + std::vector recv_map; /** - * @brief Check if a channel has any unprocessed writes on it. - * @param id the identifier for the channel. - * @return true if the channel has unprocessed writes. False otherwise. + * @brief Map used to indicate which parties have been stopped. */ - bool HasWrite(ChannelId id) const { - return !(m_writes.find(id) == m_writes.end() || m_writes.at(id).empty()); - } + mutable util::Bitmap cancellation_map; /** - * @brief Get the next write on a channel. - * @param id the identifier of the channel. - * @return a write operation. - * - * This method does not check if there are any writes. + * @brief Hooks. */ - WriteOp& NextWrite(ChannelId id) { - return m_writes[id].front(); - } + std::vector hooks; /** - * @brief Delete a write operation. - * @param id the identifier of the channel. + * @brief A local version of a GlobalContext. * - * This method is meant to be called after a write operation has had all its - * data processed. In a nutshell, when op.amount == 0. + * LocalContext provides a local mutable "view" of the GlobalContext for a + * particular party. */ - void DeleteWrite(ChannelId id) { - m_writes[id].pop(); - } + class LocalContext { + public: + /** + * @brief Add an event to this party's simulation trace. + * @param event the event. + */ + void recordEvent(std::shared_ptr event); - /** - * @brief Add an event. - * @param id the ID of the party adding the event - * @param event the event - */ - void AddEvent(std::size_t id, std::shared_ptr event) { - m_traces[id].emplace_back(event); - } + /** + * @brief Indicate that this party is sending data to another party. + * @param receiver the ID of the receiving party. + * @param timestamp when the data was sent. + */ + void send(std::size_t receiver, util::Time::Duration timestamp) { + const ChannelId id{.local = m_id, .remote = receiver}; + m_gctx.sends[id].push_back(timestamp); + } - /** - * @brief Get all simulation traces. - */ - std::vector Trace() const { - return m_traces; - } + /** + * @brief Receives an amount of bytes. + * @param sender the ID of sending party. + * @param nbytes the amount of bytes that this party wishes to receive. + * @param timestamp the local time of the receiving party. + * @return \p timestamp adjusted with an appropriate delay. + * + *

The return value is \p timestamp adjusted to account for any delay + * that this party would incur in receiving \p nbytes. + */ + util::Time::Duration recv(std::size_t sender, + std::size_t nbytes, + util::Time::Duration timestamp); - /** - * @brief Get the simulation trace of a particular party. - */ - SimulationTrace Trace(std::size_t id) const { - return m_traces[id]; - } + /** + * @brief Indicate that this party has started receiving data. + */ + void recvStart(std::size_t id); - /** - * @brief Check if a party has terminated. - * @param id the ID of the party. - * @return true if the party has terminated, and otherwise false. - */ - bool HasTerminated(std::size_t id) const { - if (Trace(id).empty()) { - return false; - } - const auto t = Trace(id).back()->EventType(); - return t == sim::Event::Type::STOP || t == sim::Event::Type::KILLED; - } + /** + * @brief Indicate that this party has stopped receiving data. + */ + void recvDone(std::size_t id); - /** - * @brief Remove and return the last event added by a party. - */ - std::shared_ptr PopLastEvent(std::size_t id) { - auto evt = m_traces[id].back(); - m_traces[id].pop_back(); - return evt; - } + /** + * @brief Check of a party is in the process of receiving from us. + */ + bool receiving(std::size_t receiver) const; - /** - * @brief Get the latest timestamp of a particular party. - */ - util::Time::Duration LatestTimestamp(std::size_t id) const { - return Trace(id).back()->Timestamp(); - } + /** + * @brief Check if a party has terminated. + */ + bool dead(std::size_t id) const; - /** - * @brief Find the ID of a suitable next party to run in the simulation. - * @param current the last party to run - * @return the ID of the next party to run, or none if the simulation is done. - */ - std::optional NextToRun(std::optional current = {}); + /** + * @brief Returns the amount of elapsed so far. + * + * The amount of elapsed time is defined as the current running time + * (defined as the timestamp on the last event produced by this party) plus + * the time elapsed since the startClock was called. + */ + util::Time::Duration elapsedTime() const; - /** - * @brief Add a candidate party to run next. - */ - void AddCandidateToRun(std::size_t id) { - m_next_party_cand.emplace_back(id); - }; + /** + * @brief Get the current time of some other party in the protocol. + * @param other_party the ID of the other party. + */ + util::Time::Duration currentTimeOf(std::size_t other_party) const; - /** - * @brief Update the checkpoint value to the current time. - */ - void UpdateCheckpoint() { - m_checkpoint = util::Time::Now(); - } + /** + * @brief Start the clock for this party. + * + * This internally sets the timestamp used to compute the elapsed time. + * Thus, this function should be called whenever the party starts doing + * "real work". E.g., just before a send or receive call on a simulated + * channel returns. + */ + void startClock(); - /** - * @brief Compute the time since the last time Checkpoint was called. - */ - util::Time::Duration Checkpoint(std::size_t id); + /** + * @brief Get the timestamp of the most recent event. + * + * Does not check if an event exists. + */ + util::Time::Duration lastEventTimestamp() const; - /** - * @brief Get the value of the current checkpoint. - */ - util::Time::TimePoint ReadCurrentCheckpoint() const { - return m_checkpoint; - } + /** + * @brief Get a limited version of this context object. + */ + SimulationContext getContext() const; - /** - * @brief Prepare a party for running. - */ - void Prepare(std::size_t id); + private: + friend struct GlobalContext; - /** - * @brief Commit all the events and network data that a party generated. - */ - void Commit(std::size_t id); + std::size_t m_id; + GlobalContext& m_gctx; - /** - * @brief Rollback changes that a party made. - */ - void Rollback(std::size_t id); + LocalContext(std::size_t id, GlobalContext& global) + : m_id(id), m_gctx(global) {} + }; /** - * @brief Obtain a View of this context. + * @brief Get a local party's view of this context. + * @param party_id the ID of the party. + * @return a view of this context for party \p party_id. */ - View GetView(); - - private: - std::shared_ptr m_network_config; - - std::size_t m_nparties; - - std::vector m_traces; - std::size_t m_trace_index; - - std::unordered_map> m_buffers; - - State m_state = State::COMMIT; - - std::unordered_map> m_writes; - std::unordered_map> m_writes_backup; - - util::Time::TimePoint m_checkpoint; - - std::vector m_next_party_cand; + LocalContext view(std::size_t party_id) { + return LocalContext(party_id, *this); + } }; /** - * @brief Create a simulation context with in-memory channels. + * @brief Output a global context object to a stream. */ -template <> -std::shared_ptr Context::Create( - std::size_t number_of_parties, - std::shared_ptr config); +std::ostream& operator<<(std::ostream& os, const GlobalContext& global_ctx); + +} // namespace details /** - * @brief View of a context. + * @brief A view of the current context object of the simulation. * - * View provides a read-only view of certain parts of the current Context. + * SimulationContext provides a view of the current simulation context with + * minor options for mutability. This object is passed to a hook and allows + * reacting when different events are produced. */ -class Context::View { +class SimulationContext { public: /** - * @brief Get the trace of a party. - * @param id the ID of the party. + * @brief Get the trace of a particular party. + */ + SimulationTrace trace(std::size_t party_id) const { + return m_gctx.traces[party_id]; + } + + /** + * @brief Get the running time of a party. + */ + util::Time::Duration currentTimeOf(std::size_t party_id) const { + return m_gctx.view(m_id).currentTimeOf(party_id); + } + + /** + * @brief Check if a party is still running, or if it is dead. */ - SimulationTrace Trace(std::size_t id) const { - return m_ctx.Trace(id); + bool dead(std::size_t party_id) const { + return m_gctx.view(m_id).dead(party_id); } /** - * @brief Check if a party has terminated. - * @param id the ID of the party. - * @return true if the party has terminated, and otherwise false. + * @brief Stop a party. */ - bool HasTerminated(std::size_t id) const { - return m_ctx.HasTerminated(id); + void cancel(std::size_t party_id) const { + if (party_id != m_id) { + m_gctx.cancellation_map.set(party_id, true); + } else { + throw details::CancellationException(); + } } /** - * @brief Get the total number of parties in the simulation. + * @brief Stop the simulation. */ - std::size_t NumberOfParties() const { - return m_ctx.NumberOfParties(); + void cancelSimulation() const { + for (std::size_t i = 0; i < m_gctx.number_of_parties; i++) { + m_gctx.cancellation_map.set(i, true); + } + cancel(m_id); } private: - friend Context; + friend class details::GlobalContext::LocalContext; - View(const Context& ctx) : m_ctx(ctx) {} + std::size_t m_id; + details::GlobalContext& m_gctx; - const Context& m_ctx; + SimulationContext(std::size_t id, details::GlobalContext& context) + : m_id(id), m_gctx(context) {} }; -inline Context::View Context::GetView() { - return Context::View(*this); +namespace details { + +inline SimulationContext GlobalContext::LocalContext::getContext() const { + return SimulationContext(m_id, m_gctx); } +} // namespace details + } // namespace scl::sim #endif // SCL_SIMULATION_CONTEXT_H diff --git a/include/scl/simulation/env.h b/include/scl/simulation/env.h deleted file mode 100644 index ecb87de..0000000 --- a/include/scl/simulation/env.h +++ /dev/null @@ -1,102 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_SIMULATION_ENV_H -#define SCL_SIMULATION_ENV_H - -#include - -#include "scl/protocol/env.h" -#include "scl/simulation/context.h" -#include "scl/simulation/event.h" - -namespace scl::sim { - -/** - * @brief A ProtocolEnvironment::Clock implementation for simulated protocols. - */ -class Clock final : public proto::Env::Clock { - public: - /** - * @brief Create a new clock for simulations. - * @param ctx a simulation context. Used to read the current time of the party - * @param id the ID of the party - */ - Clock(std::shared_ptr ctx, std::size_t id) : m_ctx(ctx), m_id(id){}; - - /** - * @brief Get the total elapsed time of this party. - * - * This method will compute a running time based on the current checkpoint in - * the simulation context, and offset this with the timestamp from the last - * event that the party generated. - */ - util::Time::Duration Read() const override { - const auto now = util::Time::Now(); - const auto ts = m_ctx->LatestTimestamp(m_id); - return now - m_ctx->ReadCurrentCheckpoint() + ts; - } - - /** - * @brief Mark a checkpoint. - */ - void Checkpoint(const std::string& message) override { - m_ctx->AddEvent(m_id, std::make_shared(Read(), message)); - } - - private: - std::shared_ptr m_ctx; - std::size_t m_id; -}; - -/** - * @brief A ProtocolEnvironment::Thread implementation for simulated protocols. - */ -class ThreadCtx final : public proto::Env::Thread { - public: - /** - * @brief Create a new thread context for simulations. - * @param ctx a simulation context - * @param id the ID of the party - */ - ThreadCtx(std::shared_ptr ctx, std::size_t id) - : m_ctx(ctx), m_id(id){}; - - /** - * @brief Simulate a sleep for this party. - * @param ms the time to sleep this party in milliseconds - * - * The main thread of this party is put to sleep by generating a - * SLEEP event with the time t0 + ms, where - * t0 is the current time of the party. - */ - void Sleep(std::size_t ms) override { - const auto now = m_ctx->Checkpoint(m_id); - const auto event = std::make_shared(Event::Type::SLEEP, - now, - std::chrono::milliseconds(ms)); - m_ctx->AddEvent(m_id, event); - } - - private: - std::shared_ptr m_ctx; - std::size_t m_id; -}; - -} // namespace scl::sim - -#endif // SCL_SIMULATION_ENV_H diff --git a/include/scl/simulation/event.h b/include/scl/simulation/event.h index 183704e..7d88618 100644 --- a/include/scl/simulation/event.h +++ b/include/scl/simulation/event.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -33,346 +33,305 @@ namespace scl::sim { /** - * @brief An event generated during simulation. + * @brief Event types. */ -class Event { - public: +enum class EventType { /** - * @brief Type of the event. - */ - enum class Type { - /** - * @brief Event indicating that a party started running. - */ - START, - - /** - * @brief Event indicating that a party stopped running. - */ - STOP, + * @brief Event generated when a party starts running. + */ + START, - /** - * @brief Event indicating a network Send operation. - */ - SEND, + /** + * @brief Event generated when a party stops running. + */ + STOP, - /** - * @brief Event indicating a network Recv operation. - */ - RECV, + /** + * @brief Event generated when a party is forcibly stopped. + */ + KILLED, - /** - * @brief Event indicating that a party queried a channel for data. - */ - HAS_DATA, + /** + * @brief Event generated when a party was cancelled by the manager. + */ + CANCELLED, - /** - * @brief Event indicating a party closed a connection. - */ - CLOSE, + /** + * @brief Event generated when a channel is closed. + */ + CLOSE, - /** - * @brief Event indicating a party put its thread to sleep. - */ - SLEEP, + /** + * @brief Event generated when data is sent on a channel. + */ + SEND, - /** - * @brief Event indicating that a party produces output. - */ - OUTPUT, + /** + * @brief Event generated when data is received on a channel. + */ + RECV, - /** - * @brief Event made at the start of a protocol segment. - */ - SEGMENT_BEGIN, + /** + * @brief Event generated when a channel is queried for the presence of data. + */ + HAS_DATA, - /** - * @brief Event made at the end of a protocol segment. - */ - SEGMENT_END, + /** + * @brief Event generated when a party sleeps. + */ + SLEEP, - /** - * @brief A checkpoint recorded by the protocol. - */ - CHECKPOINT, + /** + * @brief Event generated when a party produces output. + */ + OUTPUT, - /** - * @brief Event made when a party sends a net::Packet. - */ - PACKET_SEND, + /** + * @brief Event generated at the start of a protocol. + */ + PROTOCOL_BEGIN, - /** - * @brief Event made when a party receives a net::Packet. - */ - PACKET_RECV, + /** + * @brief Event generated at the end of a protocol. + */ + PROTOCOL_END, - /** - * @brief Event made when a party is stopped prematurely. - */ - KILLED - }; +}; +/** + * @brief An event in a simulation. + */ +struct Event { /** - * @brief Construct a new measurement. - * @param type the type of the measurement - * @param timestamp the timepoint for this measurement - * @param offset an offset for the timestamp + * @brief Create an event indicating the party started running. */ - Event(Type type, util::Time::Duration timestamp, util::Time::Duration offset) - : m_type(type), m_timestamp(timestamp), m_offset(offset){}; + static std::shared_ptr start(); /** - * @brief Construct a new measurement. - * @param type the type of the measurement - * @param timestamp the timepoint for this measurement. + * @brief Create an event indicating the party stopped running. + * @param timestamp the time the party stopped running at. */ - Event(Type type, util::Time::Duration timestamp) - : Event(type, timestamp, util::Time::Duration::zero()) {} - - virtual ~Event(){}; + static std::shared_ptr stop(util::Time::Duration timestamp); /** - * @brief Get the adjusted time of this event. - * - * This will return the adjusted (i.e., "real-time") time of the event. The - * un-adjusted timestamp is Time() - Offset(). + * @brief Create an event indicating the party was killed by an exception. + * @param timestamp the time the party was stopped. + * @param reason a message describing the reason for the kill. */ - util::Time::Duration Timestamp() const { - return m_timestamp + m_offset; - } + static std::shared_ptr killed(util::Time::Duration timestamp, + const std::string& reason); /** - * @brief Get the type of this event. + * @brief Create an event indicating the party was stopped. + * @param timestamp the time the party was stopped. */ - Type EventType() const { - return m_type; - } + static std::shared_ptr cancelled(util::Time::Duration timestamp); /** - * @brief Get the offset of the timestamp of this event. + * @brief Create an event indicating that a channel was closed. + * @param timestamp the time the channel was closed. + * @param channel_id the ID of the channel. */ - util::Time::Duration Offset() const { - return m_offset; - } + static std::shared_ptr closeChannel(util::Time::Duration timestamp, + ChannelId channel_id); - private: - Type m_type; - util::Time::Duration m_timestamp; - util::Time::Duration m_offset; -}; + /** + * @brief Create an event indicating that some data was sent on a channel. + * @param timestamp the time the data was sent. + * @param channel_id the ID of the channel. + * @param amount the amount of bytes sent. + */ + static std::shared_ptr sendData(util::Time::Duration timestamp, + ChannelId channel_id, + std::size_t amount); -/** - * @brief Events related to a network channel. - */ -class NetworkEvent : public Event { - public: /** - * @brief Construct a new network measurement. - * @param type the type of the measurement - * @param timestamp the time of the measurement - * @param id the ID of the channel + * @brief Create an event indicating that some data was received on a channel. + * @param timestamp the time the data was received. + * @param channel_id the ID of the channel. + * @param amount the amount of bytes received. */ - NetworkEvent(Type type, util::Time::Duration timestamp, ChannelId id) - : Event(type, timestamp), m_id(id) {} + static std::shared_ptr recvData(util::Time::Duration timestamp, + ChannelId channel_id, + std::size_t amount); /** - * @brief Construct a new network measurement with an offset. - * @param type the type of the measurement - * @param timestamp the time of the measurement - * @param offset an offset - * @param id the ID of the channel + * @brief Create an event indicating that a channel was queried for the + * presence of data. + * @param timestamp the time of the query. + * @param channel_id the ID of the channel. */ - NetworkEvent(Type type, - util::Time::Duration timestamp, - util::Time::Duration offset, - ChannelId id) - : Event(type, timestamp, offset), m_id(id) {} + static std::shared_ptr hasData(util::Time::Duration timestamp, + ChannelId channel_id); /** - * @brief Get the ID of the local party in this network event. + * @brief Create an event indicating that the party slept. + * @param timestamp the time the party went to sleep. + * @param sleep_duration the duration of the sleep. */ - std::size_t LocalParty() const { - return m_id.local; - } + static std::shared_ptr sleep(util::Time::Duration timestamp, + util::Time::Duration sleep_duration); /** - * @brief Get the ID of the remote party in this network event. + * @brief Create an event indicating that the party produced an output. */ - std::size_t RemoteParty() const { - return m_id.remote; - } + static std::shared_ptr output(util::Time::Duration timestamp); - private: - ChannelId m_id; -}; + /** + * @brief Create an event indicating that a protocol began. + * @param timestamp the starting time of the protocol. + * @param protocol_name the name of the protocol. + */ + static std::shared_ptr protocolBegin(util::Time::Duration timestamp, + const std::string& protocol_name); -/** - * @brief Events related to data transfers on the network. - */ -class NetworkDataEvent : public NetworkEvent { - public: /** - * @brief Create a new network data event. - * @param type the type of the event. - * @param timestamp when the event took place. - * @param id the ID of the channel. - * @param amount the amount of data sent or received. + * @brief Create an event indicating that a protocol ended. + * @param timestamp the finishing time of the protocol. + * @param protocol_name the name of the protocol. */ - NetworkDataEvent(Type type, - util::Time::Duration timestamp, - ChannelId id, - std::size_t amount) - : NetworkEvent(type, timestamp, id), m_amount(amount) {} + static std::shared_ptr protocolEnd(util::Time::Duration timestamp, + const std::string& protocol_name); /** - * @brief Create a new network data event. - * @param type the type of the event. - * @param timestamp when the event took place. - * @param offset an offset to \p timestamp. - * @param id the ID of the channel. - * @param amount the amount of data sent or received. + * @brief Constructor. */ - NetworkDataEvent(Type type, - util::Time::Duration timestamp, - util::Time::Duration offset, - ChannelId id, - std::size_t amount) - : NetworkEvent(type, timestamp, offset, id), m_amount(amount) {} + Event(EventType type, util::Time::Duration timestamp) + : type(type), timestamp(timestamp) {} + + virtual ~Event() {} /** - * @brief Get the amount of data sent or received. + * @brief The event type. */ - std::size_t DataAmount() const { - return m_amount; - } + EventType type; - private: - std::size_t m_amount; + /** + * @brief The event timestamp. + */ + util::Time::Duration timestamp; }; /** - * @brief An event created when a party calls the packet recv function on a - * channel. + * @brief An event relating to a channel. */ -class PacketRecvEvent final : public NetworkDataEvent { - public: +struct ChannelEvent : public Event { /** - * @brief Create a new network data event. - * @param timestamp when the event took place. - * @param offset an offset to \p timestamp. - * @param id the ID of the channel. - * @param amount the amount of data sent or received. - * @param blocking whether the Recv call was blocking. + * @brief Constructor */ - PacketRecvEvent(util::Time::Duration timestamp, - util::Time::Duration offset, - ChannelId id, - std::size_t amount, - bool blocking) - : NetworkDataEvent(Type::PACKET_RECV, timestamp, offset, id, amount), - m_blocking(blocking) {} + ChannelEvent(EventType type, + util::Time::Duration timestamp, + ChannelId channel_id) + : Event(type, timestamp), channel_id(channel_id) {} + ~ChannelEvent() {} /** - * @brief True if the call was blocking and false otherwise. + * @brief The ID of the channel this event was created for. */ - bool Blocking() const { - return m_blocking; - } - - private: - bool m_blocking; + ChannelId channel_id; }; /** - * @brief Event created when a party calls HasData on a channel. + * @brief An event relating to a channel send or receive action. */ -class HasDataEvent final : public NetworkEvent { - public: +struct ChannelDataEvent final : public ChannelEvent { /** - * @brief Construct a new HasDataEvent. - * @param timestamp the time the event happened. - * @param id the ID of the channel. - * @param had_data whether data was available. + * @brief Constructor. */ - HasDataEvent(util::Time::Duration timestamp, ChannelId id, bool had_data) - : NetworkEvent(Type::HAS_DATA, timestamp, id), m_had_data(had_data) {} + ChannelDataEvent(EventType type, + util::Time::Duration timestamp, + ChannelId channel_id, + std::size_t amount) + : ChannelEvent(type, timestamp, channel_id), amount(amount) {} /** - * @brief Whether the call that generated this event had data. + * @brief The amount of data in this event. */ - bool HadData() const { - return m_had_data; - } - - private: - bool m_had_data; + std::size_t amount; }; /** - * @brief An event taken at the start or end of Protocol::Run. + * @brief An event relating to a sleep. */ -class SegmentEvent final : public Event { - public: +struct SleepEvent final : public Event { /** - * @brief Construct a new segment event. - * @param type the type. Either SEGMENT_BEGIN or SEGMENT_END - * @param timestamp the time of the event - * @param name the name of the segment. + * @brief Constructor. */ - SegmentEvent(Type type, util::Time::Duration timestamp, std::string name) - : Event(type, timestamp), m_name(std::move(name)){}; - + SleepEvent(EventType type, + util::Time::Duration timestamp, + util::Time::Duration sleep_duration) + : Event(type, timestamp + sleep_duration), + sleep_duration(sleep_duration) {} /** - * @brief Get the name of this segment. + * @brief The sleep duration. */ - std::string Name() const { - return m_name; - } - - private: - std::string m_name; + util::Time::Duration sleep_duration; }; /** - * @brief An event created when a protocol calls - * env.clock.Checkpoint(). + * @brief A protocol event. */ -class CheckpointEvent final : public Event { - public: +struct ProtocolEvent final : public Event { + /** + * @brief Constructor. + */ + ProtocolEvent(EventType type, + util::Time::Duration timestamp, + const std::string& protocol_name) + : Event(type, timestamp), protocol_name(protocol_name) {} /** - * @brief Create a new checkpoint event. - * @param timestamp the time of the event. - * @param id the id of the checkpoint. + * @brief The name of the protocol. */ - CheckpointEvent(util::Time::Duration timestamp, const std::string& id) - : Event(Event::Type::CHECKPOINT, timestamp), m_id(id) {} + std::string protocol_name; +}; +/** + * @brief A kill event. + */ +struct KillEvent final : public Event { /** - * @brief Get the checkpoint id. + * @brief Constructor. */ - std::string Id() const { - return m_id; - } + KillEvent(util::Time::Duration timestamp, const std::string& reason) + : Event(EventType::KILLED, timestamp), reason(reason) {} - private: - std::string m_id; + /** + * @brief The message giving a reason for the kill. + */ + std::string reason; }; /** * @brief Pretty print an event type. */ -std::ostream& operator<<(std::ostream& os, Event::Type type); +std::ostream& operator<<(std::ostream& stream, EventType type); + +/** + * @brief Pretty print an event. + */ +std::ostream& operator<<(std::ostream& stream, const Event* event); /** - * @brief Pretty print a measurement to a stream. + * @brief Pretty print an event. */ -std::ostream& operator<<(std::ostream& os, const Event* m); +inline std::ostream& operator<<(std::ostream& stream, + std::shared_ptr event) { + return stream << event.get(); +} /** - * @brief A simulation trace is simply a vector of measurements. + * @brief The execution trace of a simulation is a list of the events it + * generated. */ using SimulationTrace = std::vector>; +/** + * @brief Write a trace to an output stream. + * @param stream the stream. + * @param trace the trace. + */ +void writeTrace(std::ostream& stream, const SimulationTrace& trace); + } // namespace scl::sim #endif // SCL_SIMULATION_EVENT_H diff --git a/include/scl/simulation/hook.h b/include/scl/simulation/hook.h new file mode 100644 index 0000000..ca88ff0 --- /dev/null +++ b/include/scl/simulation/hook.h @@ -0,0 +1,112 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_SIMULATION_HOOK_H +#define SCL_SIMULATION_HOOK_H + +#include + +#include "scl/simulation/event.h" + +namespace scl::sim { + +class SimulationContext; + +/** + * @brief Interface for hooks. + * + * A hook is a piece of a code that is run in response to an event, and can + * therefore be used to add custom logging or simulation termination. + * + * @code + * struct MyHook final : public Hook { + * void run(std::size_t party_id, const ReadOnlyGlobalContext& ctx) override + * { + * auto event = (ProtocolEvent*)ctx.trace(party_id)->back(); + * std::cout << "Party " << party_id + * << " finished running " event->protocol_name; + * << std::endl; + * } + * }; + * + * // elsewhere + * + * Manager* manager = // create a Manager object + * manager->addHook(sim::EventType::PROTOCOL_END); + * @endcode + * + * The hooks are run right after the triggering event has been added to the + * party's event trace. It is therefore safe to assume that + * ctx.traces[party_id] is not empty. + * + * A party, or the simulation as a whole, can be stopped through the + * SimulationContext object that the hook receives. This is useful + * to e.g., terminate the simulation when a particular party finishes. + * + * @code + * struct MyHook final : public Hook { + * void run(std::size_t party_id, const ReadOnlyGlobalContext& ctx) override { + * // stop the other party + * ctx.cancel(1 - party); + * } + * }; + * + * // elsewhere + * + * Manager* manager = // ... + * // call the hook when a party finishes the simulation. The hook will then + * // cancel the other party, which must still be running. + * manager->addHook(sim::EventType::STOP); + * @endcode + * + * Terminating the calling party (the party indicated by the + * party_id argument) on any of the following events + *

    + *
  • sim::EventType::STOP + *
  • sim::EventType::KILLED + *
  • sim::EventType::CANCELLED + *
+ * is undefined behaviour. + * + * @see sim::Manager::addHook + */ +struct Hook { + virtual ~Hook() {} + + /** + * @brief Function to run. + */ + virtual void run(std::size_t party_id, const SimulationContext& ctx) = 0; +}; + +/** + * @brief A hook and trigger event. + */ +struct TriggerAndHook { + /** + * @brief The event to trigger the hook on. + */ + std::optional trigger; + /** + * @brief The hook. + */ + std::unique_ptr hook; +}; + +} // namespace scl::sim + +#endif // SCL_SIMULATION_HOOK_H diff --git a/include/scl/simulation/manager.h b/include/scl/simulation/manager.h index 00b1df4..b02c447 100644 --- a/include/scl/simulation/manager.h +++ b/include/scl/simulation/manager.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,40 +21,68 @@ #include #include #include +#include #include "scl/protocol/base.h" #include "scl/simulation/config.h" -#include "scl/simulation/context.h" +#include "scl/simulation/event.h" +#include "scl/simulation/hook.h" namespace scl::sim { /** * @brief Manager for a simulation. * - * The role of a Manager object is to describe the different parameters that - * goes into simulation, such as how the network behaves, how to handle outputs - * and for how many replications to run. + * A Manager manages certain aspects of a protocol simulation: + *
    + *
  • The number of replications in the simulation. + *
  • The protocol to simulate. + *
  • What to do with the protocol(s) output. + *
  • What network to use. + *
  • When to terminate protocol(s). + *
  • What to do when a protocol finishes. + *
+ * + * Manager only requires implementing the Manager::protocol function that is + * responsible for constructing the protocols to be simulated. Everything else + * has sensible defaults. + * + *

The Manager::protocol function

+ * + * This is one of two required function and specifies which protocol to + * simulate. The return value is an STL vector of proto::Protocol objects to + * simulate. The length of this vector is assumed by the simulator to define the + * number of parties present in the protocol. The vector is allowed to contain + * nullptr values. (These will simply correspond to parties that + * are not running any code.) + * + *

Handling simulation outputs

+ * + * The other of the required functions. Each run of the simulator produces a + * list of traces (one per party). The Manager::handleSimulatorOutput function + * decides what to do with said traces. + * + *

Hooks

+ * + * Manager::addHook makes it possible to specify "hooks" that the simulator will + * run before and after a protocols' proto::Protocol::run function is + * called. Each hook is called with the ID of the protocol, corresponding to the + * protocol's index in the vector that Manager::protocol returned, as well as a + * "read-only" view of the simulators context. + * + *

Handling protocol outputs

+ * + * Any output produced by a protocol will be passed to Manager::handleOutput, + * and customizing this function therefore allows us to e.g., check correctness + * of a protocol. */ class Manager { public: - /** - * @brief Construct a new manager. - * @param replications the number of replications to simulate. - */ - Manager(std::size_t replications) : m_replications(replications) {} - /** * @brief Destructor. */ virtual ~Manager() {} - /** - * @brief Get the number of replications. - */ - std::size_t Replications() const { - return m_replications; - } - /** * @brief Return a fresh instance of the protocol to simulate. * @@ -64,20 +92,25 @@ class Manager { * important that objects returned by this function are independent of objects * previously returned by calling this function. */ - virtual std::vector> Protocol() = 0; + virtual std::vector> protocol() = 0; + + /** + * @brief Handle the output of a simulation. + * @param party_id the ID of the party that ran in the simulation. + * @param trace the simulation trace produced by the simulator. + */ + virtual void handleSimulatorOutput(std::size_t party_id, + const SimulationTrace& trace) = 0; /** * @brief Handle the output produced by some party. - * @param replication the replication that the output was produced in. * @param party_id the ID of the party who produced the output. * @param output the output. * * The default implementation simply discards the output. */ - virtual void HandleOutput(std::size_t replication, - std::size_t party_id, - const std::any& output) { - (void)replication; + virtual void handleProtocolOutput(std::size_t party_id, + const std::any& output) { (void)party_id; (void)output; } @@ -87,61 +120,87 @@ class Manager { * * The default is to return a SimpleNetworkConfig instance. */ - virtual std::shared_ptr NetworkConfiguration() { - return std::make_shared(); + virtual std::unique_ptr networkConfiguration() const { + return std::make_unique(); } /** - * @brief Decide whether to terminate a party. - * @param party_id the ID of the party. - * @param view a view of the simulation context. + * @brief Add a new hook. + * @tparam HOOK the hook. + * @tparam HOOK_ARGS argument pack for the arguments passed to the hook. + * @param trigger the event type to trigger the hook on. + * @param args arguments to pass to the constructor of the hook. * - *

Under normal circumstances, a party is terminated when its Run function - * returns nullptr. This function can be used to terminate a - * party prematurely, e.g., after it has been running for a certain amount of - * time. + * Use this function to add sim::Hooks to the simulation. The + * hook to add is specified by the \p HOOK template argument, and the hook is + * constructed by the addHook function in a manner similar to how + * std::make_unique works. The added hook will be run every time an event of + * type \p trigger is generated. + */ + template + void addHook(EventType trigger, HOOK_ARGS&&... args) { + static_assert(std::is_base_of_v); + m_hooks.emplace_back(TriggerAndHook{ + trigger, + std::make_unique(std::forward(args)...)}); + } + + /** + * @brief Add a new hook. + * @tparam HOOK the hook. + * @tparam HOOK_ARGS argument pack for the arguments passed to the hook. + * @param args arguments to pass to the constructor of the hook. * - *

The default implementation never terminates parties prematurely. + * Use this function to add sim::Hooks to the simulation. The + * hook to add is specified by the \p HOOK template argument, and the hook is + * constructed by the addHook function in a manner similar to how + * std::make_unique works. The added hook will be run for all events. */ - virtual bool Terminate(std::size_t party_id, const Context::View& view) { - (void)party_id; - (void)view; - return false; + template + void addHook(HOOK_ARGS&&... args) { + static_assert(std::is_base_of_v); + m_hooks.emplace_back(TriggerAndHook{ + {}, + std::make_unique(std::forward(args)...)}); } private: - std::size_t m_replications; + friend void simulate(std::unique_ptr manager); + std::vector m_hooks; }; /** - * @brief A simple simulation manager which allows running a protocol once. + * @brief Manager that outputs traces to a stream. + * + * Writes simulation traces to a provided stream as a json object of the form: + * + * @code + * { + * "replication": , + * "party_id": , + * "trace": + * } + * @endcode */ -class SingleReplicationManager final : public Manager { +class ManagerWithOutputToStream : public Manager { public: /** - * @brief Construct a new SingleReplicationManager. - * @param protocol the protocol to run - */ - SingleReplicationManager( - std::vector> protocol) - : Manager(1), m_protocol(std::move(protocol)), m_used(false) {} - - /** - * @brief Get the protocol to simulate. - * @throws std::logic_error if this function is called more than once. + * @brief Create a new ManagerWithOutputToStream. + * @param stream the stream to write the output to. */ - std::vector> Protocol() { - if (m_used) { - throw std::logic_error( - "Protocol called twice on SingleReplicationManager"); - } - m_used = true; - return std::move(m_protocol); + ManagerWithOutputToStream(std::ostream& stream) : m_stream(stream) {} + + void handleSimulatorOutput(std::size_t party_id, + const SimulationTrace& trace) override { + m_stream << "{"; + m_stream << "\"party_id\":" << party_id << ","; + m_stream << "\"trace\":"; + writeTrace(m_stream, trace); + m_stream << "}" << std::endl; } private: - std::vector> m_protocol; - bool m_used; + std::ostream& m_stream; }; } // namespace scl::sim diff --git a/include/scl/simulation/mem_channel_buffer.h b/include/scl/simulation/mem_channel_buffer.h deleted file mode 100644 index 6dc256c..0000000 --- a/include/scl/simulation/mem_channel_buffer.h +++ /dev/null @@ -1,120 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_SIMULATION_MEM_CHANNEL_BUFFER_H -#define SCL_SIMULATION_MEM_CHANNEL_BUFFER_H - -#include -#include -#include -#include - -#include "scl/simulation/buffer.h" - -namespace scl::sim { - -/** - * @brief A channel buffer backed by in-memory vectors. - * - * MemoryBackedChannelBuffer works much the same as MemoryBackedChannel, in that - * it internally just holds two vectors. One for reading, and one for - * writing. The difference is that MemoryBackedChannelBuffer allows for writes - * and reads to be rolled back. - */ -class MemoryBackedChannelBuffer final : public ChannelBuffer { - // type of the internal buffer - using BufferT = std::vector; - - public: - /** - * @brief Create a channel buffer connected to itself. - */ - static std::shared_ptr CreateLoopback() { - auto buf = std::make_shared(); - return std::make_shared(buf, buf); - } - - /** - * @brief Create a pair of paired channels. - */ - static std::array, 2> CreatePaired() { - auto buf0 = std::make_shared(); - auto buf1 = std::make_shared(); - return {std::make_shared(buf0, buf1), - std::make_shared(buf1, buf0)}; - } - - /** - * @brief Create a Memory backed ChannelBuffer. - * @param write_buffer buffer for storing writes - * @param read_buffer buffer for storing reads - */ - MemoryBackedChannelBuffer(std::shared_ptr write_buffer, - std::shared_ptr read_buffer) - : m_write_buf(write_buffer), - m_read_buf(read_buffer), - m_write_ptr(0), - m_read_ptr(0){}; - - ~MemoryBackedChannelBuffer() {} - - std::size_t Size() override { - return m_read_buf->size() - m_read_ptr; - } - - void Read(unsigned char* data, std::size_t n) override { - const auto m = (BufferT::difference_type)m_read_ptr; - const auto n_ = (BufferT::difference_type)n; - std::copy(m_read_buf->begin() + m, m_read_buf->begin() + m + n_, data); - m_read_ptr += n; - } - - void Write(const unsigned char* data, std::size_t n) override { - m_write_buf->insert(m_write_buf->end(), data, data + n); - } - - void Prepare() override { - m_write_ptr = m_write_buf->size(); - m_read_ptr = 0; - } - - void Commit() override { - // erase the data that was read since Prepare and reset write/read ptr. - auto m = (BufferT::difference_type)m_read_ptr; - m_read_buf->erase(m_read_buf->begin(), m_read_buf->begin() + m); - - m_read_ptr = 0; - m_write_ptr = m_write_buf->size(); - } - - void Rollback() override { - // erase data written since Prepare and reset the read ptr. - m_write_buf->resize(m_write_ptr); - m_read_ptr = 0; - } - - private: - std::shared_ptr m_write_buf; - std::shared_ptr m_read_buf; - - std::size_t m_write_ptr; - std::size_t m_read_ptr; -}; - -} // namespace scl::sim - -#endif // SCL_SIMULATION_MEM_CHANNEL_BUFFER_H diff --git a/include/scl/simulation/result.h b/include/scl/simulation/result.h deleted file mode 100644 index 4a538ee..0000000 --- a/include/scl/simulation/result.h +++ /dev/null @@ -1,208 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_SIMULATION_RESULT_H -#define SCL_SIMULATION_RESULT_H - -#include -#include -#include -#include -#include - -#include "scl/simulation/channel_id.h" -#include "scl/simulation/event.h" -#include "scl/simulation/measurement.h" - -namespace scl::sim { - -/** - * @brief The simulation result of a party. - * - *

Result holds measurements related to the execution of a single party from - * a simulation. A Result holds three types of information: Measurements related - * to execution time, measurements relating to the amount of data sent and - * received, and the original simulation trace(s). - * - *

The main API of Result consists of the functions ExecutionTime(), which - * provides measurements for the exeuction time of a party, and - * TransferAmounts(), which provide measurements for the amount of data sent and - * received by the party. - * - *

For both ExecutionTime() and TransferAmounts(), it is possible to specify - * a "segment" when querying for measurements, by supplying a - * std::string with the name of the segment. The name supplied must - * match the name of a proto::Protocol from the protocol being simulated. - * - *

For TransferAmounts(), it is also possible to query for data sent or - * received on a particular channel. - */ -class Result { - public: - /** - * @brief Type of a segment name. - */ - using SegmentName = std::optional; - - /** - * @brief Struct containing a measurement for a particular protocol segment. - */ - struct SegmentMeasurement { - /** - * @brief Measurement related to execution time. - */ - TimeMeasurement duration_m; - /** - * @brief Measurement relating to data sent/received. - */ - SendRecvMeasurement send_recv_m; - /** - * @brief Measurements related to individual channels. - */ - std::unordered_map channels_m; - }; - - /** - * @brief Create a simulation result from a list of simulation traces. - * @param traces the simulation traces. - * @return a list of Results; one per party. - * - *

This function is used by Simulate() to create its return value after - * running a simulation. The input to this function is a list of traces - * traces where traces[i][j] is trace from i'th - * replication of party j. - * - *

Internally, this function will collect and aggregate all traces created - * when simulation a party, and output a Result object for each party. - */ - static std::vector Create( - const std::vector>& traces); - - /** - * @brief Get the execution time. - * @param name the segment. None if the total time should be returned. - * @return a sim::TimeMeasurement. - */ - TimeMeasurement ExecutionTime(const SegmentName& name = {}) const { - return m_measurements.at(name).duration_m; - } - - /** - * @brief Get the amount of data transferred. - * @param name the segment. None if the total amount should be returned. - * @return a SendRecvMeasurement. - */ - SendRecvMeasurement TransferAmounts(const SegmentName& name = {}) const { - return m_measurements.at(name).send_recv_m; - } - - /** - * @brief Get the amount of data transferred on a particular channel. - * @param id the ID of the channel. - * @param name the segment. None if the total amount should be returned. - * @return a SendRecvMeasurement. - */ - SendRecvMeasurement TransferAmounts(std::size_t id, - const SegmentName& name = {}) const { - return m_measurements.at(name).channels_m.at(id); - } - - /** - * @brief Get the list of remote parties that this party interacted with. - * @param name the segment. None if all interactions should be returned. - * @return A list of party IDs. - * - * This returns a list of the IDs of parties that this party either sent data - * to, or received data from. - */ - std::vector Interactions(const SegmentName& name = {}) const; - - /** - * @brief Get the segment names of the protocol simulation. - * @return The names of all segments in the simulated protocol. - * - * If any of the simulated proto::Protocol segments did not specify a name, - * then the return value of this function will include - * proto::Protocol::kDefaultName. - */ - std::vector SegmentNames() const { - return m_segment_names; - } - - /** - * @brief Write a trace to a stream. - * @param stream the stream to write the trace to. - * @param replication the simulation replication - * @param name the segment. None if the entire trace should be written. - */ - void WriteTrace(std::ostream& stream, - std::size_t replication, - const SegmentName& name = {}) const; - - /** - * @brief Write the simulation result to a stream. - * @param stream the stream to write the result to. - */ - void Write(std::ostream& stream) const; - - /** - * @brief Get the simulation trace from a particular replication. - * @param replication the replication. - * @return the simulation trace from a replication. - */ - SimulationTrace Trace(std::size_t replication) const { - return m_traces[replication]; - } - - /** - * @brief Get the measurement associated with a checkpoint. - * @param key the string identifying the checkpoint. - * @return the time measurement. - */ - TimeMeasurement Checkpoint(const std::string& key) const { - return m_checkpoints.at(key); - } - - private: - static Result Create(const std::vector& traces); - - Result( - const std::vector& traces, - const std::unordered_map& measurements, - const std::unordered_map& checkpoints, - const std::vector& segment_names) - : m_traces(traces), - m_measurements(measurements), - m_checkpoints(checkpoints), - m_segment_names(segment_names){}; - - // The raw simulation trace - std::vector m_traces; - - // per-segment measurements - std::unordered_map m_measurements; - - // user made checkpoints - std::unordered_map m_checkpoints; - - // segment names - std::vector m_segment_names; -}; - -} // namespace scl::sim - -#endif // SCL_SIMULATION_RESULT_H diff --git a/include/scl/simulation/runtime.h b/include/scl/simulation/runtime.h new file mode 100644 index 0000000..2059083 --- /dev/null +++ b/include/scl/simulation/runtime.h @@ -0,0 +1,95 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_SIMULATION_RUNTIME_H +#define SCL_SIMULATION_RUNTIME_H + +#include +#include + +#include "scl/coro/runtime.h" +#include "scl/simulation/context.h" +#include "scl/simulation/event.h" + +namespace scl::sim::details { + +/** + * @brief Runtime implementation used in the simulator. + */ +class SimulatorRuntime final : public coro::Runtime { + constexpr static std::size_t MANAGER_PID = -1; + + struct Coro { + std::coroutine_handle<> coroutine; + std::function predicate; + std::size_t party_id; + }; + + public: + /** + * @brief Construct a new simulator runtime. + */ + SimulatorRuntime(GlobalContext& ctx) + : m_ctx(ctx), m_current_pid(MANAGER_PID) {} + + ~SimulatorRuntime() {} + + using coro::Runtime::schedule; + + /** + * @brief Schedule a coroutine to run for a particular party. + * @param coroutine the coroutine. + * @param id the id of the party. + * + * This function is used when scheduling the initial batch of protocols. Each + * protocol run gets assigned a party id using this function, and the ID is + * then used throughout the execution in order correctly manipulate the + * context. + */ + void scheduleWithId(std::coroutine_handle<> coroutine, std::size_t id) { + m_tq.emplace_back( + coroutine, + []() { return true; }, + id); + } + + void schedule(std::coroutine_handle<> coroutine, + std::function&& predicate) override; + + void schedule(std::coroutine_handle<> coroutine, + util::Time::Duration delay) override; + + void deschedule(std::coroutine_handle<> coroutine) override; + + std::coroutine_handle<> next() override; + + bool taskQueueEmpty() const override { + return m_tq.empty(); + } + + private: + GlobalContext& m_ctx; + + std::size_t m_current_pid; + std::list m_tq; + + void removeCancelledCoros(); +}; + +} // namespace scl::sim::details + +#endif // SCL_SIMULATION_RUNTIME_H diff --git a/include/scl/simulation/simulation.h b/include/scl/simulation/simulation.h index 3943091..21232a4 100644 --- a/include/scl/simulation/simulation.h +++ b/include/scl/simulation/simulation.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,20 +19,10 @@ #define SCL_SIMULATION_SIMULATION_H #include "scl/simulation/config.h" -#include "scl/simulation/result.h" #include "scl/simulation/simulator.h" /** - * @brief Utilities for simulating protocol executions. - * - *

SCL supports running a protocol, as defined via. the proto::Protocol - * interface, to be run using both a real network, as well as a simulated one. - * This will allow the user of SCL to implement a protocol once and then - * run it under either a real network, where all the parties are connected - * through pair-wise TCP channels, or a "fake" one, where parties are simply - * emulated. - * - *

Let's consider a simple example. + * @brief Protocol simulation. */ namespace scl::sim {} // namespace scl::sim diff --git a/include/scl/simulation/simulator.h b/include/scl/simulation/simulator.h index f0c610e..bddc37b 100644 --- a/include/scl/simulation/simulator.h +++ b/include/scl/simulation/simulator.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,12 +18,7 @@ #ifndef SCL_SIMULATION_SIMULATOR_H #define SCL_SIMULATION_SIMULATOR_H -#include -#include -#include #include -#include -#include #include #include @@ -31,56 +26,14 @@ #include "scl/simulation/config.h" #include "scl/simulation/event.h" #include "scl/simulation/manager.h" -#include "scl/simulation/result.h" namespace scl::sim { -/** - * @brief Exception used to signal that a simulation failed. - * - * In certain cases it is not possible to determine whether data is ready on a - * channel (e.g., if the receiver is chronologically ahead of the sender). This - * exception is used in these cases to "gracefully" interrupt the running party. - */ -struct SimulationFailure final : public std::runtime_error { - /** - * @brief Construct a new simulation failure exception. - */ - SimulationFailure(const char* msg) : std::runtime_error(msg){}; - SimulationFailure() : SimulationFailure("simulation failed") {} -}; - -/** - * @brief Compute the expected time that some bytes would be received. - * @param config a simulation config, detailing the network conditions - * @param n the number of bytes to receive - * @return the time it took to send \p n bytes. - * - * This function is used throughout the simulation to compute how long it takes - * for a number of bytes to arrive over the network used in the simulation. The - * number of bytes is specified by the second argument \p n while the network - * conditions (bandwidth, latency, overhead, etc...) is specified by \p config. - */ -util::Time::Duration ComputeRecvTime(const ChannelConfig& config, - std::size_t n); - /** * @brief Simulate the execution of a protocol. * @param manager a simulation manager. - * @return the simulation result. - */ -std::vector Simulate(std::unique_ptr manager); - -/** - * @brief Simulate a protocol for a single replication. - * @param protocol the protocol. - * @return the simulation result. */ -inline std::vector Simulate( - std::vector> protocol) { - return Simulate( - std::make_unique(std::move(protocol))); -} +void simulate(std::unique_ptr manager); } // namespace scl::sim diff --git a/include/scl/simulation/transport.h b/include/scl/simulation/transport.h new file mode 100644 index 0000000..1bb3ea7 --- /dev/null +++ b/include/scl/simulation/transport.h @@ -0,0 +1,116 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_SIMULATION_TRANSPORT_H +#define SCL_SIMULATION_TRANSPORT_H + +#include +#include +#include +#include + +#include "scl/coro/task.h" +#include "scl/net/packet.h" +#include "scl/simulation/channel_id.h" +#include "scl/simulation/context.h" + +namespace scl::sim::details { + +/** + * @brief Transport layer for a simulated network. + * + * Transport provides the functionality used when a simulated channel + * sends or receives data. A Transport is shared between all parties + * on the network, which allows it to e.g., only store one copy of a + * packet even if it sent to multiple parties. + */ +class Transport final { + private: + // represents either an actual packet or an index to a packet. If + // the variant is a packet, then it's because the packet was move'ed + // to the receiver, whereas an index indicates that it was copied. + using PktOrIdx = std::variant; + + // An indirect packet transfer. The count indicates how many other + // parties are waiting to receive the packet, and is incremented + // when the packet is sent and decremented when the packet is + // received. It is essentially a reference counter. + struct PktAndCount { + net::Packet packet; + std::size_t count; + }; + + public: + /** + * @brief Send a packet on the transport. + * @param cid the channel ID of the sending channel. + * @param packet the packet to send. + * + * This function will attempt to directly move the packet to the + * receiver. + */ + void send(ChannelId cid, net::Packet&& packet); + + /** + * @brief Send a packet on the transport. + * @param cid the channel ID of the sending channel. + * @param packet the packet. + * + * This function will attempt to only store one copy of the packet, + * even if it is being sent to multiple parties. A copy of the + * packet will happen when it is initially sent, and then once per + * subsequent receive of the packet. + */ + void send(ChannelId cid, const net::Packet& packet); + + /** + * @brief Check if there's data for a channel on this transport. + */ + bool hasData(ChannelId cid) const; + + /** + * @brief Receive data on a channel. + * @param cid the ID of the receiving channel. + * + * This function should only be called if there is data to be had on + * the channel. Calling it in other cases is undefined behavior. + */ + net::Packet recv(ChannelId cid); + + /** + * @brief Performs some clean-up on the transport. + * + * This function will trim the internal lists of sent packets if no + * more receivers are expected. Clean-up is performed as an explicit + * separate step, because it might invalidate existing pointers (and + * thus might not be "free" in terms of required computing). + */ + void cleanUp(GlobalContext& ctx); + + private: + // tracks p2p channels between parties + std::unordered_map> m_channels; + + // tracks packets that are potentially sent to more than one + // party. Each entry is a packet and a list of channels that the + // packet is sent on. + std::vector m_packets; +}; + +} // namespace scl::sim::details + +#endif // SCL_SIMULATION_TRANSPORT_H diff --git a/include/scl/ss/additive.h b/include/scl/ss/additive.h index 6ebb744..23e9248 100644 --- a/include/scl/ss/additive.h +++ b/include/scl/ss/additive.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,7 +20,7 @@ #include -#include "scl/math/vec.h" +#include "scl/math/vector.h" #include "scl/util/prg.h" namespace scl::ss { @@ -36,15 +36,15 @@ namespace scl::ss { * \f$(x_1,x_2,\dots,x_n)\f$ of values such that \f$x=\sum_i x_i\f$. * *

An additive secret-sharing output by this function is a math::Vec object, - * and so reconstructing the secret is simply shares.Sum(). + * and so reconstructing the secret is simply shares.sum(). */ template -math::Vec AdditiveShare(const T& secret, std::size_t n, util::PRG& prg) { +math::Vector additiveShare(const T& secret, std::size_t n, util::PRG& prg) { std::vector shares; shares.reserve(n); - auto sum = T::Zero(); + auto sum = T::zero(); for (std::size_t i = 0; i < n - 1; ++i) { - const auto s = T::Random(prg); + const auto s = T::random(prg); shares.emplace_back(s); sum += s; } diff --git a/include/scl/ss/feldman.h b/include/scl/ss/feldman.h index 8ae9a41..897dbe7 100644 --- a/include/scl/ss/feldman.h +++ b/include/scl/ss/feldman.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -24,26 +24,76 @@ #include #include "scl/math/lagrange.h" -#include "scl/math/vec.h" +#include "scl/math/vector.h" #include "scl/ss/shamir.h" #include "scl/util/prg.h" namespace scl::ss { /** - * @brief A Feldman secret-sharing. + * @brief A verifiable secret share for Feldman VSSS. */ -template +template +struct FeldmanShare { + /** + * @brief The group that commitments live in. + */ + using Group = GROUP; + + /** + * @brief The field that shares live in. + */ + using Field = typename GROUP::ScalarField; + + /** + * @brief The share. + */ + Field share; + + /** + * @brief The commitments. + */ + math::Vector commitments; +}; + +/** + * @brief A verifiable secret-sharing suitable for Feldman VSSS. + * + * This struct captures a set of secret shares produced by the Feldman + * verifiable secret sharing schemes. In this scheme, a secret is shared into + * \f$n\f$ shares and \f$t+1\f$ commitments. The share held by a party is one of + * the \f$n\f$ shares, and all \f$t+1\f$ commitments. + */ +template struct FeldmanSharing { + /** + * @brief The group that commitments live in. + */ + using Group = GROUP; + + /** + * @brief The field that shares live in. + */ + using Field = typename GROUP::ScalarField; + /** * @brief The shares. */ - math::Vec shares; + math::Vector shares; /** * @brief The commitments. */ - math::Vec commitments; + math::Vector commitments; + + /** + * @brief Get a particular party's share. + * @param party_id the ID of the party. + * @return \p party_id's share. + */ + FeldmanShare getShare(std::size_t party_id) const { + return {shares[party_id], commitments}; + } }; /** @@ -54,58 +104,62 @@ struct FeldmanSharing { * @param prg a PRG for creating randomness. * @return a Feldman secret-sharing. */ -template -FeldmanSharing FeldmanShare(const typename G::ScalarField& secret, - std::size_t t, - std::size_t n, - util::PRG& prg) { - const auto shares = ShamirShare(secret, t, n, prg); - - std::vector comm; +template +FeldmanSharing feldmanSecretShare( + const typename FeldmanSharing::Field& secret, + std::size_t t, + std::size_t n, + util::PRG& prg) { + const auto shares = shamirSecretShare(secret, t, n, prg); + + std::vector comm; comm.reserve(t + 1); - const auto gen = G::Generator(); - for (std::size_t i = 0; i < t + 1; ++i) { + const auto gen = GROUP::generator(); + comm.emplace_back(secret * gen); + for (std::size_t i = 0; i < t; ++i) { comm.emplace_back(shares[i] * gen); } - return {shares, math::Vec{comm}}; + return {shares, math::Vector{comm}}; } /** - * @brief A Feldman secret-share and the owner's index. + * @brief Verify a share given a set of commitments. + * @param share the share to verify. + * @param share_index the index (e.g., party ID) of the share. + * @return true if the provided share is valid for that index, and false + * otherwise. + * + * This function checks if a provided share is consistent with a set of + * commitments. */ -template -struct ShareAndIndex { - /** - * @brief The index. - */ - std::size_t index; - - /** - * @brief The share. - */ - typename G::ScalarField share; -}; +template +bool feldmanVerify(const FeldmanShare& share, std::size_t share_index) { + using F = typename GROUP::ScalarField; + const auto ns = math::Vector::range(share.commitments.size()); + const auto lb = math::computeLagrangeBasis(ns, share_index); + const auto v = + math::innerProd(lb.begin(), lb.end(), share.commitments.begin()); + return v == GROUP::generator() * share.share; +} /** * @brief Verify a share given a set of commitments. - * @param share_and_index the secret-share and its index. - * @param commits a set of commitments. + * @param share the share to verify. + * @param commitments the commitments to verify against. + * @param share_index the index (e.g., party ID) of the share. * @return true if the provided share is valid for that index, and false * otherwise. * * This function checks if a provided share is consistent with a set of * commitments. */ -template -bool FeldmanVerify(const ShareAndIndex& share_and_index, - const math::Vec& commits) { - const auto ns = - math::Vec::Range(1, commits.Size() + 1); - const auto lb = math::ComputeLagrangeBasis(ns, share_and_index.index); - const auto v = - math::UncheckedInnerProd(lb.begin(), lb.end(), commits.begin()); - return v == G::Generator() * share_and_index.share; +template +bool feldmanVerify( + const typename FeldmanShare::Field& share, + const math::Vector::Group>& commitments, + std::size_t share_index) { + return feldmanVerify({share, commitments}, share_index); } } // namespace scl::ss diff --git a/include/scl/ss/pedersen.h b/include/scl/ss/pedersen.h new file mode 100644 index 0000000..be398ac --- /dev/null +++ b/include/scl/ss/pedersen.h @@ -0,0 +1,291 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_SS_PEDERSEN_H +#define SCL_SS_PEDERSEN_H + +#include + +#include "scl/math/array.h" +#include "scl/math/lagrange.h" +#include "scl/math/poly.h" +#include "scl/math/vector.h" +#include "scl/ss/shamir.h" +#include "scl/util/prg.h" + +namespace scl::ss { + +/** + * @brief A secret share in the Pedersen VSS scheme. + * + * A PedersenShare for party \f$i\in\{0,\dots,n-1\}\f$ is a tuple + * \f$(a,r,\mathbf{A})\f$ where \f$\mathbf{A}\f$ is a vector of Pedersen + * commitments over a group \f$\mathbb{G}\f$, and \f$(a,r)\f$ are elements of + * \f$\mathbb{Z}_{ord(\mathbb{G})}\f$ corresponding to the \f$i\f$'th opening. + * That is, \f$aG+rH=\mathbf{A}[i]\f$, for suitable values \f$G,H\f$. The vector + * of commitments only explicitly lists the \f$t+1\f$ first commitments (where + * \f$t\f$ is the privacy threshold), but the rest can be computed easily, e.g., + * via \ref computeCommitmentForIndex + */ +template +struct PedersenShare { + /** + * @brief The group that commitments live in. + */ + using Group = GROUP; + + /** + * @brief The field that shares live in. + */ + using Field = typename GROUP::ScalarField; + + /** + * @brief The secret share and randomness. + */ + math::Array share; + + /** + * @brief The commitments. + */ + math::Vector commitments; + + /** + * @brief Get the commitment randomness of this share. + */ + Field getRand() const { + return share[1]; + } + + /** + * @brief Get the share part of this share. + */ + Field getShare() const { + return share[0]; + } +}; + +/** + * @brief A secret sharing for the Pedersen VSS scheme. + */ +template +struct PedersenSharing { + /** + * @brief The group that commitments live in. + */ + using Group = GROUP; + + /** + * @brief The field that shares live in. + */ + using Field = typename GROUP::ScalarField; + + /** + * @brief The shares. + */ + math::Vector> shares; + + /** + * @brief The share commitments. + */ + math::Vector commitments; + + /** + * @brief Get the share of a particular party. + * @param party_id the ID of the party. + * @return \p party_id's Pedersen share. + */ + PedersenShare getShare(std::size_t party_id) const { + return {shares[party_id], commitments}; + } +}; + +/** + * @brief Verifiably secret share a value using Pedersen VSS scheme. + * @param secret the secret. + * @param t the privacy threshold. + * @param n the number of shares to create. + * @param prg a PRG to use for creating randomness. + * @param h a curve point used in the commitments. + * @param randomness the random value to use for the secret. + * @return a PedersenSharing of \p secret. + */ +template +PedersenSharing pedersenSecretShare( + const typename PedersenSharing::Field& secret, + std::size_t t, + std::size_t n, + util::PRG& prg, + const typename PedersenSharing::Group& h, + const typename PedersenSharing::Field& randomness) { + using F = typename PedersenSharing::Field; + using G = typename PedersenSharing::Group; + + const math::Array s = {{secret, randomness}}; + const auto shares = shamirSecretShare(s, t, n, prg); + + std::vector comm; + comm.reserve(t + 1); + const auto gen = G::generator(); + comm.emplace_back(secret * gen + randomness * h); + for (std::size_t i = 0; i < t; ++i) { + comm.emplace_back(shares[i][0] * gen + shares[i][1] * h); + } + + return {shares, comm}; +} + +/** + * @brief Verifiably secret share a value using Pedersen VSS scheme. + * @param secret the secret. + * @param t the privacy threshold. + * @param n the number of shares to create. + * @param prg a PRG to use for creating randomness. + * @param h a curve point used in the commitments. + * @return a PedersenSharing of \p secret. + */ +template +PedersenSharing pedersenSecretShare( + const typename PedersenSharing::Field& secret, + std::size_t t, + std::size_t n, + util::PRG& prg, + const typename PedersenSharing::Group& h) { + using F = typename PedersenSharing::Field; + const auto rand = F::random(prg); + return pedersenSecretShare(secret, t, n, prg, h, rand); +} + +/** + * @brief Compute the commitment for a particular index. + * @param commitments the commitments of a Pedersen secret share. + * @param share_index the index of the share. + * @return the commitment of the share at \p share_index. + */ +template +GROUP computeCommitmentForIndex(const math::Vector& commitments, + std::size_t share_index) { + if (share_index < commitments.size()) { + return commitments[share_index]; + } + + using Field = typename PedersenShare::Field; + using Group = typename PedersenShare::Group; + + const auto ns = math::Vector::range(commitments.size()); + const auto lb = math::computeLagrangeBasis(ns, share_index); + return math::innerProd(lb.begin(), lb.end(), commitments.begin()); +} + +/** + * @brief Verify a Pedersen secret share. + * @param share the share to verify. + * @param share_index the evaluation index of the share. + * @param h the curve point used in the commitments. + * @return true if the share is valid and false otherwise. + */ +template +bool pedersenVerify(const PedersenShare share, + std::size_t share_index, + const typename PedersenShare::Group& h) { + using Group = typename PedersenShare::Group; + return computeCommitmentForIndex(share.commitments, share_index) == + share.getShare() * Group::generator() + share.getRand() * h; +} + +/** + * @brief Verify a Pedersen secret share. + * @param share the share and randomness to verify. + * @param commitments the share commitments. + * @param share_index the evaluation index of the share. + * @param h the curve point used in the commitments. + * @return true if the share is valid and false otherwise. + */ +template +bool pedersenVerify( + const math::Array::Field, 2>& share, + const math::Vector::Group>& commitments, + std::size_t share_index, + const typename PedersenShare::Group& h) { + return pedersenVerify({share, commitments}, share_index, h); +} + +/** + * @brief Apply a matrix to a vector of shares. + * @param begin a beginning iterator to a list of shares. + * @param end an end iterator to a list of shares. + * @param matrix the matrix. + * @return \p shares after multiplying with \p matrix. + * + * This function is useful if one wishes to randomize a vector of shares using + * e.g., a Vandermonde matrix, as in DN07. + */ +template +std::vector> apply( + const IT begin, + const IT end, + const math::Matrix::Field>& matrix) { + // stupid case + if (begin == end) { + return {}; + } + + using Group = typename PedersenShare::Group; + + const std::size_t n = matrix.rows(); + const std::size_t p = matrix.cols(); + const std::size_t m = begin->commitments.size(); + + // multiply matrix from left + std::vector> shares_out(n); + for (auto& share_out : shares_out) { + share_out.commitments = math::Vector(m); + } + + std::size_t i; + std::size_t k; + std::size_t j; + + for (i = 0; i < n; i++) { + auto b = begin; + for (k = 0; k < p; k++) { + shares_out[i].share += b->share * matrix(i, k); + for (j = 0; j < m; j++) { + shares_out[i].commitments[j] += matrix(i, k) * b->commitments[j]; + } + b++; + } + } + + return shares_out; +} + +/** + * @brief Apply a matrix to a vector of shares. + * @param shares the shares. + * @param matrix the matrix. + * @return \p shares after multiplying with \p matrix. + */ +template +std::vector> apply( + const std::vector>& shares, + const math::Matrix::Field>& matrix) { + return apply(shares.begin(), shares.end(), matrix); +} + +} // namespace scl::ss + +#endif // SCL_SS_PEDERSEN_H diff --git a/include/scl/ss/shamir.h b/include/scl/ss/shamir.h index c69b214..560a390 100644 --- a/include/scl/ss/shamir.h +++ b/include/scl/ss/shamir.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -26,10 +26,10 @@ #include #include -#include "scl/math/la.h" #include "scl/math/lagrange.h" +#include "scl/math/matrix.h" #include "scl/math/poly.h" -#include "scl/math/vec.h" +#include "scl/math/vector.h" #include "scl/util/prg.h" namespace scl::ss { @@ -49,21 +49,22 @@ namespace scl::ss { * points in which \f$f\f$ is evaluated is called the alphas. */ template -math::Vec ShamirShare(const T& secret, - std::size_t t, - std::size_t n, - util::PRG& prg) { - auto c = math::Vec::Random(t + 1, prg); +math::Vector shamirSecretShare(const T& secret, + std::size_t t, + std::size_t n, + util::PRG& prg) { + auto c = math::Vector::random(t + 1, prg); c[0] = secret; - const auto p = math::Polynomial::Create(c); + const auto p = math::Polynomial::create(c); std::vector shares; shares.reserve(n); + auto x = T::one(); for (std::size_t i = 1; i <= n; ++i) { - shares.emplace_back(p.Evaluate(T{(int)i})); + shares.emplace_back(p.evaluate(x++)); } - return math::Vec(shares); + return math::Vector(shares); } /** @@ -78,11 +79,11 @@ math::Vec ShamirShare(const T& secret, * \f$\alpha_i=\mathtt{alphas}[i]\f$ and returns \f$f(x)\f$. */ template -T ShamirRecoverP(const math::Vec& shares, - const math::Vec& alphas, +T shamirRecoverP(const math::Vector& shares, + const math::Vector& alphas, const T& x) { - const auto lb = math::ComputeLagrangeBasis(alphas, x); - return math::UncheckedInnerProd(shares.begin(), shares.end(), lb.begin()); + const auto lb = math::computeLagrangeBasis(alphas, x); + return math::innerProd(shares.begin(), shares.end(), lb.begin()); } /** @@ -96,67 +97,61 @@ T ShamirRecoverP(const math::Vec& shares, * obtained from ss::ShamirShare. */ template -T ShamirRecoverP(const math::Vec& shares) { - return ShamirRecoverP(shares, - math::Vec::Range(1, shares.Size() + 1), - T::Zero()); +T shamirRecoverP(const math::Vector& shares) { + return shamirRecoverP(shares, + math::Vector::range(1, shares.size() + 1), + T{}); } /** * @brief Recover a Shamir secret-shared secret with error detection. * @param shares the shares. * @param alphas the alphas. + * @param t the number of shares that might contain errors. + * @param d the degree of the sharing. * @param x the evaluation point. * @return a value. * @throws std::logic_error if the provided shares are not consistent. - * - * Let \f$n=\mathtt{shares.size()}\f$ and \f$t=(n-1)/2\f$. This function - * interpolates a polynomial \f$f\f$ running through \f$(s_i,\alpha_i)\f$ where - * \f$s_i=\mathtt{shares}[i]\f$, \f$\alpha_i=\mathtt{alphas}[i]\f$ for - * \f$i=1,\dots,t\f$. Note that this implies that \f$f\f$ has degree - * \f$t\f$. The interpolated polynomial must be consistent with the remaining - * shares and alphas, that is \f$f(\alpha_i)=s_i\f$ for \f$i=t+1,\dots,n\f$. If - * this is the case, then \f$f(x)\f$ is returned, otherwise an - * std::logic_error is thrown. */ template -T ShamirRecoverD(const math::Vec& shares, - const math::Vec& alphas, +T shamirRecoverD(const math::Vector& shares, + const math::Vector& alphas, + std::size_t t, + std::size_t d, const T& x) { - const std::size_t t = (shares.Size() - 1) / 2; - const std::size_t n = 2 * t + 1; - const auto ns = alphas.SubVector(t + 1); + if (shares.size() < d + t || alphas.size() < d + t) { + throw std::logic_error("not enough shares provided to detect errors"); + } + + const std::size_t m = d + 1; + const auto ns = alphas.subVector(d + 1); - for (std::size_t i = t + 1; i < n; ++i) { - // Shares are indexed starting from 1. - auto lb = math::ComputeLagrangeBasis(ns, alphas[i]); - auto yi = math::UncheckedInnerProd(shares.begin(), - shares.begin() + t + 1, - lb.begin()); + for (std::size_t i = m; i < d + t; ++i) { + auto lb = math::computeLagrangeBasis(ns, alphas[i]); + auto yi = + math::innerProd(shares.begin(), shares.begin() + m, lb.begin()); if (yi != shares[i]) { throw std::logic_error("error detected during recovery"); } } - auto lb = math::ComputeLagrangeBasis(ns, x); - return math::UncheckedInnerProd(shares.begin(), - shares.begin() + t + 1, - lb.begin()); + auto lb = math::computeLagrangeBasis(ns, x); + return math::innerProd(shares.begin(), shares.begin() + m, lb.begin()); } /** * @brief Recover a Shamir secret-shared secret with error detection. * @param shares the shares. + * @param t the degree of the sharing. * @return a value. * * This function is identical to ss::ShamirRecoverD with * \f$\mathtt{alphas}=(1,\dots,\mathtt{shares.size()}+1)\f$ and \f$x=0\f$. */ template -T ShamirRecoverD(const math::Vec& shares) { - const std::size_t t = (shares.Size() - 1) / 2; +T shamirRecoverD(const math::Vector& shares, std::size_t t) { const std::size_t n = 2 * t + 1; - return ShamirRecoverD(shares, math::Vec::Range(1, n + 1), T::Zero()); + return shamirRecoverD(shares, math::Vector::range(1, n + 1), t, t, T{}); } /** @@ -205,14 +200,14 @@ struct ErrorCorrectedSecret { *

This function can correct up to \f$t\f$ errors in the supplied shares. */ template -ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares, - const math::Vec& alphas) { - const std::size_t t = (shares.Size() - 1) / 3; +ErrorCorrectedSecret shamirRecoverC(const math::Vector& shares, + const math::Vector& alphas) { + const std::size_t t = (shares.size() - 1) / 3; const std::size_t n = 3 * t + 1; - math::Mat A(n); - math::Vec b(n); - math::Vec x(n); + math::Matrix A(n); + math::Vector b(n); + math::Vector x(n); int e; for (std::size_t k = 0; k <= t; ++k) { e = t - k; // NOLINT @@ -231,19 +226,19 @@ ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares, } } - if (SolveLinearSystem(x, A, b)) { + if (solveLinearSystem(x, A, b)) { break; } } - math::Vec cE{x.begin(), x.begin() + e + 1}; + math::Vector cE{x.begin(), x.begin() + e + 1}; cE[e] = T(1); - auto E = math::Polynomial::Create(cE); - auto Q = math::Polynomial::Create(math::Vec{x.begin() + e, x.end()}); - auto qr = Q.Divide(E); + auto E = math::Polynomial::create(cE); + auto Q = math::Polynomial::create(math::Vector{x.begin() + e, x.end()}); + auto qr = Q.divide(E); - if (!qr[1].IsZero()) { + if (!qr[1].isZero()) { throw std::logic_error("could not correct shares"); } @@ -259,8 +254,8 @@ ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares, * \f$\mathtt{alphas}=(1,\dots,\mathtt{shares.size()}+1)\f$. */ template -ErrorCorrectedSecret ShamirRecoverC(const math::Vec& shares) { - return ShamirRecoverC(shares, math::Vec::Range(1, shares.Size() + 1)); +ErrorCorrectedSecret shamirRecoverC(const math::Vector& shares) { + return shamirRecoverC(shares, math::Vector::range(1, shares.size() + 1)); } } // namespace scl::ss diff --git a/include/scl/ss/ss.h b/include/scl/ss/ss.h index 0a28c5d..0b6cce6 100644 --- a/include/scl/ss/ss.h +++ b/include/scl/ss/ss.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,14 +20,14 @@ #include "scl/ss/additive.h" #include "scl/ss/feldman.h" +#include "scl/ss/pedersen.h" #include "scl/ss/shamir.h" /** * @brief Secret sharing utilities. * - *

The scl::ss namespace contains a small collection of functionalities - * related to secret-sharing. Currently, SCL only provides support for Shamir, - * Feldman and Additive secret-sharing. + * The scl::ss namespace contains a small collection of functionalities + * related to secret-sharing. */ namespace scl::ss {} // namespace scl::ss diff --git a/include/scl/util/bitmap.h b/include/scl/util/bitmap.h new file mode 100644 index 0000000..53f3ad0 --- /dev/null +++ b/include/scl/util/bitmap.h @@ -0,0 +1,277 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_UTIL_BITMAP_H +#define SCL_UTIL_BITMAP_H + +#include +#include +#include +#include +#include + +#include "scl/serialization/serializer.h" + +namespace scl { +namespace util { + +/** + * @brief A simple bitmap. + * + * The Bitmap class holds bits. It serves some of the same functionality as + * std::vector. The implementation of Bitmap stores bits + * packed in objects of type Bitmap::BlockType, current unsigned + * char. As a consequence, Bitmap always stores a multiple of + * sizeof(Bitmap::BlockType) * 8 bits. Any unset bits are + * guaranteed to be 0. + */ +class Bitmap { + public: + /** + * @brief The internal block type. + */ + using BlockType = unsigned char; + + /** + * @brief Number of bits that each block stores. + */ + constexpr static std::size_t BITS_PER_BLOCK = sizeof(BlockType) * 8; + + private: + using ContainerType = std::vector; + + public: + /** + * @brief Create a Bitmap from an std::vector. + * @param bool_vec the std::vector. + * @return a Bitmap. + */ + static Bitmap fromStdVecBool(const std::vector& bool_vec) { + Bitmap bm(bool_vec.size()); + for (std::size_t i = 0; i < bool_vec.size(); ++i) { + bm.set(i, bool_vec[i]); + } + return bm; + } // LCOV_EXCL_LINE + + /** + * @brief Construct a Bitmap with some initial size. + * @param initial_size the initial size. + */ + Bitmap(std::size_t initial_size) + : m_bits(ContainerType(bytesRequired(initial_size), 0)) {} + + /** + * @brief Construct an empty Bitmap. + */ + Bitmap() : Bitmap(0) {} + + /** + * @brief Check the bit at some position. + * @param index the bit position. + * @return true if the bit at position \p index is set and 0 otherwise. + */ + bool at(std::size_t index) const { + const std::size_t block = index / BITS_PER_BLOCK; + const std::size_t block_index = index & (BITS_PER_BLOCK - 1); + return ((m_bits[block] >> block_index) & 1) == 1; + } + + /** + * @brief Set the bit at some position. + * @param index the position of the bit to set. + * @param b the value to set. + */ + void set(std::size_t index, bool b) { + const std::size_t block = index / BITS_PER_BLOCK; + const std::size_t block_index = index & (BITS_PER_BLOCK - 1); + m_bits[block] ^= (-(b ? 1 : 0) ^ m_bits[block]) & (1 << block_index); + } + + /** + * @brief Count the number of bits set in this Bitmap. + * @return the population count of this Bitmap. + */ + std::size_t count() const { + // https://stackoverflow.com/a/698108 + static const unsigned char lut[] = + {0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4}; + std::size_t count = 0; + for (const auto& b : m_bits) { + count += lut[b & 0x0f] + lut[b >> 4]; + } + return count; + } + + /** + * @brief Get the number of blocks this Bitmap uses. + * @return the number of BlockType elements used by this Bitmap. + */ + std::size_t numberOfBlocks() const { + return m_bits.size(); + } + + /** + * @brief Check if two bitmaps contain the same content. + * @param bm0 the first Bitmap. + * @param bm1 the second Bitmap. + * @return true if \p bm0 and \p bm1 are equal, false otherwise. + */ + friend bool operator==(const Bitmap& bm0, const Bitmap& bm1) { + return bm0.m_bits == bm1.m_bits; + } + + /** + * @brief Check if two bitmaps are different. + * @param bm0 the first Bitmap. + * @param bm1 the second Bitmap. + * @return false if \p bm0 and \p bm1 are equal, true otherwise. + */ + friend bool operator!=(const Bitmap& bm0, const Bitmap& bm1) { + return !(bm0 == bm1); + } + + /** + * @brief Write this bitmap to a stream. + * @param os the stream. + * @param m the bitmap. + */ + friend std::ostream& operator<<(std::ostream& os, const Bitmap& m) { + for (const BlockType& block : m.m_bits) { + os << std::bitset(block); + } + return os; + } + + /** + * @brief Compute the XOR of two bitmaps. + * @param bm0 the first bitmap. + * @param bm1 the other bitmap. + */ + friend Bitmap operator^(const Bitmap& bm0, const Bitmap& bm1) { + validateSizes(bm0, bm1); + Bitmap bm; + bm.m_bits.resize(bm0.numberOfBlocks()); + for (std::size_t i = 0; i < bm.m_bits.size(); i++) { + bm.m_bits[i] = bm0.m_bits[i] ^ bm1.m_bits[i]; + } + return bm; + } + + /** + * @brief Compute the AND of two bitmaps. + * @param bm0 the first bitmap. + * @param bm1 the other bitmap. + */ + friend Bitmap operator&(const Bitmap& bm0, const Bitmap& bm1) { + validateSizes(bm0, bm1); + Bitmap bm; + bm.m_bits.resize(bm0.numberOfBlocks()); + for (std::size_t i = 0; i < bm.m_bits.size(); i++) { + bm.m_bits[i] = bm0.m_bits[i] & bm1.m_bits[i]; + } + return bm; + } + + /** + * @brief Compute the OR of two bitmaps. + * @param bm0 the first bitmap. + * @param bm1 the other bitmap. + */ + friend Bitmap operator|(const Bitmap& bm0, const Bitmap& bm1) { + validateSizes(bm0, bm1); + Bitmap bm; + bm.m_bits.resize(bm0.numberOfBlocks()); + for (std::size_t i = 0; i < bm.m_bits.size(); i++) { + bm.m_bits[i] = bm0.m_bits[i] | bm1.m_bits[i]; + } + return bm; + } + + /** + * @brief Compute the negation of a bitmap. + * @param bm0 the bitmap. + */ + friend Bitmap operator~(const Bitmap& bm0) { + Bitmap bm; + bm.m_bits.resize(bm0.numberOfBlocks()); + for (std::size_t i = 0; i < bm.m_bits.size(); i++) { + bm.m_bits[i] = ~bm0.m_bits[i]; + } + return bm; + } + + private: + ContainerType m_bits; + + static constexpr std::size_t bytesRequired(std::size_t bits) { + return bits == 0 ? 1 : (bits - 1) / (BITS_PER_BLOCK) + 1; + } + + static void validateSizes(const util::Bitmap& bm0, const util::Bitmap& bm1) { + if (bm0.numberOfBlocks() != bm1.numberOfBlocks()) { + throw std::logic_error("bitmaps are different sizes"); + } + } + + friend scl::seri::Serializer; +}; + +} // namespace util + +namespace seri { + +/** + * @brief Serializer for util::Bitmap types. + */ +template <> +struct Serializer { + /** + * @brief Get serialized size of a util::Bitmap. + * @param bm the util::Bitmap. + * @return the size in bytes of the \p bm. + */ + static std::size_t sizeOf(const util::Bitmap& bm) { + return Serializer::sizeOf(bm.m_bits); + } + + /** + * @brief Write a util::Bitmap to a buffer. + * @param bm the util::Bitmap. + * @param buf the buffer. + * @return the number of bytes written. + */ + static std::size_t write(const util::Bitmap& bm, unsigned char* buf) { + return Serializer::write(bm.m_bits, buf); + } + + /** + * @brief Read a util::Bitmap from a buffer. + * @param bm the util::Bitmap that will store the result. + * @param buf the buffer to read the util::Bitmap from. + * @return the number of bytes read from \p buf. + */ + static std::size_t read(util::Bitmap& bm, const unsigned char* buf) { + return Serializer::read(bm.m_bits, buf); + } +}; + +} // namespace seri + +} // namespace scl + +#endif // SCL_UTIL_BITMAP_H diff --git a/include/scl/util/cmdline.h b/include/scl/util/cmdline.h index b3183d9..0ebed58 100644 --- a/include/scl/util/cmdline.h +++ b/include/scl/util/cmdline.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,14 +19,15 @@ #define SCL_UTIL_CMDLINE_H #include +#include #include -#include #include -#include #include #include #include #include +#include +#include namespace scl::util { @@ -38,10 +39,10 @@ namespace scl::util { * * @code * auto p = ProgramOptions::Parser("some description") - * .Add(ProgramArg::Required("foo", "int", "foo description")) - * .Add(ProgramArg::Optional("bar", "bool", "123")) - * .Add(ProgramFlag("flag")) - * .Parse(argc, argv); + * .add(ProgramArg::required("foo", "int", "foo description")) + * .add(ProgramArg::optional("bar", "bool", "123")) + * .add(ProgramFlag("flag")) + * .parse(argc, argv); * @endcode * * The above snippet will parse the argv argument vector passed to @@ -58,8 +59,8 @@ class ProgramOptions { * @param name the name of the argument. * @return true if the argument was set, false otherwise. */ - bool Has(std::string_view name) const { - return mArgs.find(name) != mArgs.end(); + bool has(std::string_view name) const { + return m_args.find(name) != m_args.end(); } /** @@ -67,8 +68,8 @@ class ProgramOptions { * @param name the name of the flag. * @return true if the flag was set, false otherwise. */ - bool FlagSet(std::string_view name) const { - return mFlags.find(name) != mFlags.end(); + bool flagSet(std::string_view name) const { + return m_flags.find(name) != m_flags.end(); } /** @@ -76,8 +77,8 @@ class ProgramOptions { * @param name the name of the argument. * @return the value of the argument, as is. */ - std::string_view Get(std::string_view name) const { - return mArgs.at(name); + std::string_view get(std::string_view name) const { + return m_args.at(name); } /** @@ -92,24 +93,24 @@ class ProgramOptions { * object. */ template - T Get(std::string_view name) const; + T get(std::string_view name) const; private: ProgramOptions( const std::unordered_map& args, const std::unordered_map& flags) - : mArgs(args), mFlags(flags){}; + : m_args(args), m_flags(flags){}; - std::unordered_map mArgs; - std::unordered_map mFlags; + std::unordered_map m_args; + std::unordered_map m_flags; }; /** * @brief Specialization of CmdArgs::Get for bool. */ template <> -inline bool ProgramOptions::Get(std::string_view name) const { - const auto v = mArgs.at(name); +inline bool ProgramOptions::get(std::string_view name) const { + const auto v = m_args.at(name); return v == "1" || v == "true"; } @@ -117,17 +118,17 @@ inline bool ProgramOptions::Get(std::string_view name) const { * @brief Specialization for CmdArgs::Get for int. */ template <> -inline int ProgramOptions::Get(std::string_view name) const { - return std::stoi(mArgs.at(name).data()); +inline int ProgramOptions::get(std::string_view name) const { + return std::stoi(m_args.at(name).data()); } /** * @brief Specialization of CmdArgs::Get for std::size_t. */ template <> -inline std::size_t ProgramOptions::Get( +inline std::size_t ProgramOptions::get( std::string_view name) const { - return std::stoul(mArgs.at(name).data()); + return std::stoul(m_args.at(name).data()); } /** @@ -140,7 +141,7 @@ struct ProgramArg { * @param type_hint a string describing the expected type. E.g., "int". * @param description a short description. */ - static ProgramArg Required(std::string_view name, + static ProgramArg required(std::string_view name, std::string_view type_hint, std::string_view description = "") { return ProgramArg{true, name, type_hint, description, {}}; @@ -153,7 +154,7 @@ struct ProgramArg { * @param default_value an optional default value. * @param description a short description. */ - static ProgramArg Optional(std::string_view name, + static ProgramArg optional(std::string_view name, std::string_view type_hint, std::optional default_value, std::string_view description = "") { @@ -163,7 +164,7 @@ struct ProgramArg { /** * @brief Whether this argument is required. */ - bool required; + bool is_required; /** * @brief The name of this argument. @@ -221,17 +222,14 @@ class ProgramOptions::Parser { * @brief Create a command-line argument parser. * @param description a short description of the program. */ - Parser(std::string_view description = "") : mDescription(description) {} + Parser(std::string_view description = "") : m_description(description) {} /** * @brief Define an argument. * @param def an argument definition. */ - Parser& Add(const ProgramArg& def) { - if (Exists(def)) { - PrintHelp("duplicate argument definition"); - } - mArgs.emplace_back(def); + Parser& add(const ProgramArg& def) { + m_args.emplace_back(def); return *this; } @@ -239,11 +237,8 @@ class ProgramOptions::Parser { * @brief Define a flag argument. * @param flag a flag definition. */ - Parser& Add(const ProgramFlag& flag) { - if (Exists(flag)) { - PrintHelp("duplicate argument definition"); - } - mFlags.emplace_back(flag); + Parser& add(const ProgramFlag& flag) { + m_flags.emplace_back(flag); return *this; } @@ -251,67 +246,91 @@ class ProgramOptions::Parser { * @brief Parse arguments. * @param argc the number of arguments. * @param argv the arguments. + * @return the program options, or an error message. * * The \p argc and \p argv are assumed to be the inputs to a programs main * function. */ - ProgramOptions Parse(int argc, char* argv[]); + std::variant parseArguments(int argc, + char* argv[]); + + /** + * @brief Parse arguments. + * @param argc the number of arguments. + * @param argv the arguments. + * @param exit_on_error whether to std::exit when parsing fails + * @return a set of program options. + */ + ProgramOptions parse(int argc, char* argv[], bool exit_on_error = true) { + auto opts = parseArguments(argc, argv); + if (opts.index() == 0) { + return std::get(opts); + } + auto error_msg = std::get(opts); + printHelp(error_msg); + if (exit_on_error) { + std::exit(error_msg.empty() ? 0 : 1); + } else { + throw std::runtime_error(error_msg.empty() ? "no error" : "error"); + } + } /** * @brief Print a help string to stdout. */ - void Help() const { - ArgListLong(std::cout); + void help() const { + argListLong(std::cout); } private: + std::string_view m_description; + std::string_view m_program_name; + + std::vector m_args; + std::vector m_flags; + template - bool Exists(const T& arg_or_flag) const; - void ArgListShort(std::ostream& stream, std::string_view program_name) const; - void ArgListLong(std::ostream& stream) const; + bool exists(const T& arg_or_flag) const; + void argListShort(std::ostream& stream, std::string_view program_name) const; + void argListLong(std::ostream& stream) const; - bool IsArg(std::string_view name) const; - bool IsFlag(std::string_view name) const; + bool isArg(std::string_view name) const; + bool isFlag(std::string_view name) const; template - void ForEachOptional(const std::list& list, P pred) const { + void forEachOptional(const std::vector& list, P pred) const { std::for_each(list.begin(), list.end(), [&](const auto e) { - if (!e.required) { + if (!e.is_required) { pred(e); } }); } template - void ForEachRequired(const std::list& list, P pred) const { + void forEachRequired(const std::vector& list, P pred) const { std::for_each(list.begin(), list.end(), [&](const auto e) { - if (e.required) { + if (e.is_required) { pred(e); } }); } - void PrintHelp(std::string_view error_msg = ""); - - std::string_view mDescription; - std::string_view mProgramName; - - std::list mArgs; - std::list mFlags; + void printHelp(std::string_view error_msg = ""); }; template -bool ProgramOptions::Parser::Exists(const T& arg_or_flag) const { - const auto exists_a = std::any_of(mArgs.begin(), mArgs.end(), [&](auto a) { +bool ProgramOptions::Parser::exists(const T& arg_or_flag) const { + const auto exists_a = std::any_of(m_args.begin(), m_args.end(), [&](auto a) { return a.name == arg_or_flag.name; }); if (exists_a) { return true; } - const auto exists_f = std::any_of(mFlags.begin(), mFlags.end(), [&](auto a) { - return a.name == arg_or_flag.name; - }); + const auto exists_f = + std::any_of(m_flags.begin(), m_flags.end(), [&](auto a) { + return a.name == arg_or_flag.name; + }); return exists_f; } diff --git a/include/scl/util/digest.h b/include/scl/util/digest.h index 0c06356..e440831 100644 --- a/include/scl/util/digest.h +++ b/include/scl/util/digest.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -32,17 +32,17 @@ namespace scl::util { * * This type is effectively std::array. */ -template -using Digest = std::array; +template +using Digest = std::array; /** * @brief Convert a digest to a string. * @param digest the digest * @return a hex representation of the digest. */ -template -std::string DigestToString(const D& digest) { - return ToHexString(digest.begin(), digest.end()); +template +std::string digestToString(const DIGEST& digest) { + return toHexString(digest.begin(), digest.end()); } } // namespace scl::util diff --git a/include/scl/util/hash.h b/include/scl/util/hash.h index 1c147f7..dc2153e 100644 --- a/include/scl/util/hash.h +++ b/include/scl/util/hash.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -31,8 +31,8 @@ namespace scl::util { * This type defults to one of the three instantiations of SHA3 that SCL * provides. */ -template -using Hash = Sha3; +template +using Hash = Sha3; } // namespace scl::util diff --git a/include/scl/util/iuf_hash.h b/include/scl/util/iuf_hash.h index e1fbef5..7b016ea 100644 --- a/include/scl/util/iuf_hash.h +++ b/include/scl/util/iuf_hash.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -25,17 +25,20 @@ #include #include -#include "scl/serialization/serializers.h" +#include "scl/serialization/serializer.h" namespace scl::util { /** - * @brief IUF interface for hash functions. - * @tparam HashImpl hash implementation + * @brief IUF (Init-Update-Finalize) interface for hash functions. + * @tparam HASH hash implementation. + * + * IUFHash provides a CRTP style interface for a hash function implementation + * * @see Sha3 * @see Sha256 */ -template +template struct IUFHash { /** * @brief Update the hash function with a set of bytes. @@ -43,8 +46,8 @@ struct IUFHash { * @param n the number of bytes. * @return the updated Hash object. */ - IUFHash& Update(const unsigned char* bytes, std::size_t n) { - static_cast(this)->Hash(bytes, n); + IUFHash& update(const unsigned char* bytes, std::size_t n) { + static_cast(this)->hash(bytes, n); return *this; }; @@ -53,8 +56,8 @@ struct IUFHash { * @param data a vector of bytes. * @return the updated Hash object. */ - IUFHash& Update(const std::vector& data) { - return Update(data.data(), data.size()); + IUFHash& update(const std::vector& data) { + return update(data.data(), data.size()); }; /** @@ -63,8 +66,8 @@ struct IUFHash { * @return the updated Hash object. */ template - IUFHash& Update(const std::array& data) { - return Update(data.data(), N); + IUFHash& update(const std::array& data) { + return update(data.data(), N); } /** @@ -72,8 +75,8 @@ struct IUFHash { * @param string the string. * @return the updated Hash object. */ - IUFHash& Update(std::string_view string) { - return Update(reinterpret_cast(string.data()), + IUFHash& update(std::string_view string) { + return update(reinterpret_cast(string.data()), string.size()); } @@ -83,20 +86,20 @@ struct IUFHash { * @return the updated Hash object. */ template - IUFHash& Update(const T& data) { + IUFHash& update(const T& data) { using Sr = seri::Serializer; - const auto size = Sr::SizeOf(data); + const auto size = Sr::sizeOf(data); const auto buf = std::make_unique(size); - Sr::Write(data, buf.get()); - return Update(buf.get(), size); + Sr::write(data, buf.get()); + return update(buf.get(), size); } /** * @brief Finalize and return the digest. * @return a digest. */ - auto Finalize() { - auto digest = static_cast(this)->Write(); + auto finalize() { + auto digest = static_cast(this)->write(); return digest; }; }; diff --git a/include/scl/simulation/measurement.h b/include/scl/util/measurement.h similarity index 59% rename from include/scl/simulation/measurement.h rename to include/scl/util/measurement.h index 2aba9d6..d9ab171 100644 --- a/include/scl/simulation/measurement.h +++ b/include/scl/util/measurement.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,8 +15,8 @@ * along with this program. If not, see . */ -#ifndef SCL_SIMULATION_MEASUREMENT_H -#define SCL_SIMULATION_MEASUREMENT_H +#ifndef SCL_UTIL_MEASUREMENT_H +#define SCL_UTIL_MEASUREMENT_H #include #include @@ -25,7 +25,7 @@ #include "scl/util/time.h" -namespace scl::sim { +namespace scl::util { /** * @brief Measurement from a simulation. @@ -37,7 +37,7 @@ class Measurement { * @brief Add a sample to this measurement. * @param sample the sample. */ - void AddSample(const T& sample) { + void addSample(const T& sample) { m_samples.emplace_back(sample); } @@ -45,15 +45,81 @@ class Measurement { * @brief Read-only access to the samples in this measurement. * @return the samples. */ - std::vector Samples() const { + std::vector samples() const { return m_samples; } + /** + * @brief Begin iterator to the samples. + */ + auto begin() const { + return m_samples.begin(); + } + + /** + * @brief End iterator to the samples. + */ + auto end() const { + return m_samples.end(); + } + + /** + * @brief Get the mean of the measurements. + */ + T mean() const { + T sum = zero(); + for (const auto& v : m_samples) { + sum += v; + } + return sum / size(); + } + + /** + * @brief Get the variance of the measurement. + */ + T var() const { + // exit early to avoid a division by 0 later. + if (size() <= 1) { + return zero(); + } + + const T mu = mean(); + T sum = zero(); + for (const auto& v : m_samples) { + sum += square(v - mu); + } + return sum / (size() - 1); + } + + /** + * @brief Get the median of the measurement. + */ + T median() const { + if (empty()) { + return zero(); + } + + const std::size_t half = size() / 2; + + if (size() % 2 == 1) { + return m_samples[half]; + } + + return (m_samples[half] + m_samples[half - 1]) / 2; + } + + /** + * @brief Get the sample standard deviation of the measurements. + */ + T stddev() const { + return sqrt(var()); + } + /** * @brief The size of this measurement, defined as the number of samples. * @return number of samples. */ - std::size_t Size() const { + std::size_t size() const { return m_samples.size(); } @@ -61,21 +127,21 @@ class Measurement { * @brief Check whether this measurement is empty. * @return true if this measurement has zero samples, and false otherwise. */ - bool Empty() const { + bool empty() const { return m_samples.empty(); } private: std::vector m_samples; + + // helpers for the stddev() function. + T square(T val) const; + T sqrt(T val) const; + T zero() const; }; /** * @brief A measurement for time related observations. - * - * This type holds measurements related to time. In particular, measurements - * concerning the execution time of protocols and protocol segments. The data - * type is util::Time::Duration, which is essentially - * std::chrono::duration. */ using TimeMeasurement = Measurement; @@ -86,9 +152,6 @@ std::ostream& operator<<(std::ostream& os, const TimeMeasurement& m); /** * @brief A measurement for data related observations. - * - * This type holds measurements related to data transfer amounts. That is, the - * amount of data that is being sent and received in some context. */ using DataMeasurement = Measurement; @@ -99,10 +162,6 @@ std::ostream& operator<<(std::ostream& os, const DataMeasurement& m); /** * @brief A measurement for data sent and received. - * - * This wraps two DataMeasurements: One for the data being sent, and one for - * data being received. This struct thus models e.g., the data that a particular - * party sends in a segment, or the data being sent on a channel. */ struct SendRecvMeasurement { /** @@ -116,6 +175,6 @@ struct SendRecvMeasurement { DataMeasurement recv; }; -} // namespace scl::sim +} // namespace scl::util -#endif // SCL_SIMULATION_MEASUREMENT_H +#endif // SCL_UTIL_MEASUREMENT_H diff --git a/include/scl/util/merkle.h b/include/scl/util/merkle.h index 01afffd..ab304f8 100644 --- a/include/scl/util/merkle.h +++ b/include/scl/util/merkle.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,7 +20,9 @@ #include +#include "scl/util/bitmap.h" #include "scl/util/digest.h" +#include "scl/util/merkle_proof.h" namespace scl::util { @@ -29,66 +31,55 @@ namespace scl::util { * @tparam H a hash function. * @tparam T the leaf data type. */ -template +template struct MerkleTree { /** * @brief The digest type nodes. */ - using DigestType = typename H::DigestType; + using DigestType = typename HASH::DigestType; /** - * @brief Compute a Merkle tree hash. - * @param data the date to hash. - * @return the root hash. + * @brief The proof type. */ - static DigestType Hash(const std::vector& data); + using Proof = MerkleProof; /** - * @brief A Merkle tree proof. + * @brief Compute a Merkle tree hash. + * @param data the date to hash. + * @return the root hash. */ - struct Proof { - /** - * @brief The path from a particular leaf to the root. - */ - std::vector path; - - /** - * @brief A vector describing whether at the left or right element for each - * element in a path. - */ - std::vector direction; - }; + static DigestType hash(const std::vector& data); /** * @brief Create a proof that a particular index is part of a Merkle tree. */ - static Proof Prove(const std::vector& data, std::size_t index); + static Proof prove(const std::vector& data, std::size_t index); /** * @brief Verify a Merkle tree proof. - * @param value the statement. + * @param leaf the statement. * @param root the tree root. * @param proof the proof * @return true if the */ - static bool Verify(const T& value, + static bool verify(const LEAF& leaf, const DigestType& root, const Proof& proof); private: - static std::vector HashLeafs(const std::vector& data); + static std::vector hashLeafs(const std::vector& data); }; -template -auto MerkleTree::HashLeafs(const std::vector& data) +template +auto MerkleTree::hashLeafs(const std::vector& data) -> std::vector { std::vector digests; auto sz = data.size(); digests.reserve(sz); for (const auto& d : data) { - H hash; - digests.emplace_back(hash.Update(d).Finalize()); + HASH hash; + digests.emplace_back(hash.update(d).finalize()); } // duplicate the last hash in case there's an odd number of leafs. @@ -100,9 +91,9 @@ auto MerkleTree::HashLeafs(const std::vector& data) return digests; } // LCOV_EXCL_LINE -template -auto MerkleTree::Hash(const std::vector& data) -> DigestType { - std::vector digests = HashLeafs(data); +template +auto MerkleTree::hash(const std::vector& data) -> DigestType { + std::vector digests = hashLeafs(data); auto sz = digests.size(); @@ -111,8 +102,8 @@ auto MerkleTree::Hash(const std::vector& data) -> DigestType { for (std::size_t i = 0; i < sz; i += 2) { const auto left = digests[i]; const auto right = digests[i + 1]; - H hash; - digests[j] = hash.Update(left).Update(right).Finalize(); + HASH hash; + digests[j] = hash.update(left).update(right).finalize(); j++; } @@ -128,10 +119,10 @@ auto MerkleTree::Hash(const std::vector& data) -> DigestType { return digests[0]; } -template -auto MerkleTree::Prove(const std::vector& data, std::size_t index) - -> Proof { - std::vector digests = HashLeafs(data); +template +auto MerkleTree::prove(const std::vector& data, + std::size_t index) -> Proof { + std::vector digests = hashLeafs(data); std::vector path; std::vector direction; @@ -143,8 +134,8 @@ auto MerkleTree::Prove(const std::vector& data, std::size_t index) const auto left = digests[i]; const auto right = digests[i + 1]; - H hash; - digests[j] = hash.Update(left).Update(right).Finalize(); + HASH hash; + digests[j] = hash.update(left).update(right).finalize(); if (i == index) { path.emplace_back(right); @@ -167,22 +158,22 @@ auto MerkleTree::Prove(const std::vector& data, std::size_t index) } } - return {path, direction}; + return {path, Bitmap::fromStdVecBool(direction)}; } -template -bool MerkleTree::Verify(const T& value, - const DigestType& root, - const Proof& proof) { +template +bool MerkleTree::verify(const LEAF& leaf, + const DigestType& root, + const Proof& proof) { const auto [h, d] = proof; - auto digest = H{}.Update(value).Finalize(); + auto digest = HASH{}.update(leaf).finalize(); for (std::size_t i = 0; i < h.size(); ++i) { - H hash; - if (d[i]) { - digest = hash.Update(h[i]).Update(digest).Finalize(); + HASH hash; + if (d.at(i)) { + digest = hash.update(h[i]).update(digest).finalize(); } else { - digest = hash.Update(digest).Update(h[i]).Finalize(); + digest = hash.update(digest).update(h[i]).finalize(); } } diff --git a/include/scl/util/merkle_proof.h b/include/scl/util/merkle_proof.h new file mode 100644 index 0000000..70a86f2 --- /dev/null +++ b/include/scl/util/merkle_proof.h @@ -0,0 +1,95 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_UTIL_MERKLE_PROOF_H +#define SCL_UTIL_MERKLE_PROOF_H + +#include + +#include "scl/serialization/serializer.h" +#include "scl/util/bitmap.h" + +namespace scl { +namespace util { + +/** + * @brief A Merkle tree proof. + */ +template +struct MerkleProof { + /** + * @brief The path from a particular leaf to the root. + */ + std::vector path; + + /** + * @brief A vector describing whether at the left or right element for each + * element in a path. + */ + Bitmap direction; +}; + +} // namespace util + +namespace seri { + +/** + * @brief Serializer for MerkleProof. + */ +template +struct Serializer> { + /** + * @brief Determines the size in bytes of a merkle proof. + * @param proof the proof. + * @return the size of \p proof in bytes. + */ + static std::size_t sizeOf(const util::MerkleProof& proof) { + return Serializer>::sizeOf(proof.path) + + Serializer::sizeOf(proof.direction); + } + + /** + * @brief Write a merkle proof to a buffer. + * @param proof the proof. + * @param buf the buffer. + * @return the number of bytes written to \p buf. + */ + static std::size_t write(const util::MerkleProof& proof, + unsigned char* buf) { + buf += Serializer>::write(proof.path, buf); + buf += Serializer::write(proof.direction, buf); + return sizeOf(proof); + } + + /** + * @brief Read a merkle proof from a buffer. + * @param proof the proof. + * @param buf the buffer. + * @return the number of bytes read from \p buf. + */ + static std::size_t read(util::MerkleProof& proof, + const unsigned char* buf) { + buf += Serializer>::read(proof.path, buf); + buf += Serializer::read(proof.direction, buf); + return sizeOf(proof); + } +}; + +} // namespace seri +} // namespace scl + +#endif // SCL_UTIL_MERKLE_PROOF_H diff --git a/include/scl/util/prg.h b/include/scl/util/prg.h index c19c0cb..07cb060 100644 --- a/include/scl/util/prg.h +++ b/include/scl/util/prg.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,13 +19,14 @@ #define SCL_UTIL_PRG_H #include +#include #include #include -#include +#include +#include #include #include -#include /** * @brief 64 bit nonce which is prepended to the counter in the PRG. @@ -69,41 +70,41 @@ class PRG { /** * @brief Size of the seed. */ - static constexpr std::size_t SeedSize() { + static constexpr std::size_t seedSize() { return BLOCK_SIZE; - }; + } /** * @brief Create a new PRG with seed 0. */ - static PRG Create(); + static PRG create(); /** * @brief Create a new PRG with a provided seed. * @param seed the seed. * @param seed_len length of the seed */ - static PRG Create(const unsigned char* seed, std::size_t seed_len); + static PRG create(const unsigned char* seed, std::size_t seed_len); /** * @brief Create a new PRG from a provided seed. * @param seed the seed. */ - static PRG Create(const std::string& seed); + static PRG create(const std::string& seed); /** * @brief Reset the PRG. * * This method allows resetting a PRG object to its initial state. */ - void Reset(); + void reset(); /** * @brief Generate random data and store it in a supplied buffer. * @param buffer the buffer * @param n how many bytes of random data to generate */ - void Next(unsigned char* buffer, std::size_t n); + void next(unsigned char* buffer, std::size_t n); /** * @brief Generate random data and store it in a supplied buffer. @@ -112,9 +113,18 @@ class PRG { * How many bytes of random data to generate is decided based on the output of * buffer.size(). */ - void Next(std::vector& buffer) { - Next(buffer.data(), buffer.size()); - }; + void next(std::vector& buffer) { + next(buffer.data(), buffer.size()); + } + + /** + * @brief Generate random data and store it in a supplied buffer. + * @param buffer the buffer. + */ + template + void next(std::array& buffer) { + Next(buffer.data(), N); + } /** * @brief Generate random data and store in in a supplied buffer. @@ -126,30 +136,30 @@ class PRG { * The capacity of \p buffer is not affected in any way by this method and it * requires that it has room for at least \p n elements. */ - void Next(std::vector& buffer, std::size_t n) { + void next(std::vector& buffer, std::size_t n) { if (buffer.size() < n) { throw std::invalid_argument("n exceeds buffer.size()"); } - Next(buffer.data(), n); - }; + next(buffer.data(), n); + } /** * @brief Generate and return random data. * @param n the number of random bytes to generate * @return the random bytes. */ - std::vector Next(std::size_t n) { + std::vector next(std::size_t n) { auto buffer = std::make_unique(n); - Next(buffer.get(), n); + next(buffer.get(), n); return std::vector(buffer.get(), buffer.get() + n); - }; + } /** * @brief The seed. */ std::array Seed() const { return m_seed; - }; + } private: PRG(std::array seed) : m_seed(seed){}; @@ -158,8 +168,8 @@ class PRG { long m_counter = PRG_INITIAL_COUNTER; BlockType m_state[11]; - void Update(); - void Init(); + void update(); + void init(); }; } // namespace scl::util diff --git a/include/scl/util/sha256.h b/include/scl/util/sha256.h index da29caf..d9a7c5b 100644 --- a/include/scl/util/sha256.h +++ b/include/scl/util/sha256.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,7 +21,6 @@ #include #include #include -#include #include "scl/util/digest.h" #include "scl/util/iuf_hash.h" @@ -43,12 +42,12 @@ class Sha256 final : public IUFHash { * @param bytes a pointer to a number of bytes. * @param nbytes the number of bytes. */ - void Hash(const unsigned char* bytes, std::size_t nbytes); + void hash(const unsigned char* bytes, std::size_t nbytes); /** * @brief Finalize and return the digest. */ - DigestType Write(); + DigestType write(); private: std::array m_chunk; @@ -63,9 +62,9 @@ class Sha256 final : public IUFHash { 0x1f83d9ab, 0x5be0cd19}; - void Transform(); - void Pad(); - DigestType WriteDigest(); + void transform(); + void pad(); + DigestType writeDigest(); }; } // namespace scl::util diff --git a/include/scl/util/sha3.h b/include/scl/util/sha3.h index 7919cfd..4e598db 100644 --- a/include/scl/util/sha3.h +++ b/include/scl/util/sha3.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,10 +18,8 @@ #ifndef SCL_UTIL_SHA3_H #define SCL_UTIL_SHA3_H -#include #include #include -#include #include "scl/util/digest.h" #include "scl/util/iuf_hash.h" @@ -32,32 +30,32 @@ namespace scl::util { * @brief SHA3 hash function. * @tparam DigestSize the output size in bits. Must be either 256, 384 or 512 */ -template -class Sha3 final : public IUFHash> { - static_assert(DigestSize == 256 || DigestSize == 384 || DigestSize == 512, +template +class Sha3 final : public IUFHash> { + static_assert(BITS == 256 || BITS == 384 || BITS == 512, "Invalid SHA3 digest size. Must be 256, 384 or 512"); public: /** * @brief The type of a SHA3 digest. */ - using DigestType = Digest; + using DigestType = Digest; /** * @brief Update the hash function with a set of bytes. * @param bytes a pointer to a number of bytes. * @param nbytes the number of bytes. */ - void Hash(const unsigned char* bytes, std::size_t nbytes); + void hash(const unsigned char* bytes, std::size_t nbytes); /** * @brief Finalize and return the digest. */ - DigestType Write(); + DigestType write(); private: static const std::size_t STATE_SIZE = 25; - static const std::size_t CAPACITY = 2 * DigestSize / (8 * sizeof(uint64_t)); + static const std::size_t CAPACITY = 2 * BITS / (8 * sizeof(uint64_t)); static const std::size_t CUTTOFF = STATE_SIZE - (CAPACITY & (~0x80000000)); uint64_t m_state[STATE_SIZE] = {0}; @@ -71,10 +69,10 @@ class Sha3 final : public IUFHash> { * @brief Keccak function. * @param state the current state */ -void Keccakf(uint64_t state[25]); +void keccakf(uint64_t state[25]); -template -void Sha3::Hash(const unsigned char* bytes, std::size_t nbytes) { +template +void Sha3::hash(const unsigned char* bytes, std::size_t nbytes) { unsigned int old_tail = (8 - m_byte_index) & 7; const unsigned char* p = bytes; @@ -96,7 +94,7 @@ void Sha3::Hash(const unsigned char* bytes, std::size_t nbytes) { m_saved = 0; if (++m_word_index == CUTTOFF) { - Keccakf(m_state); + keccakf(m_state); m_word_index = 0; } } @@ -114,7 +112,7 @@ void Sha3::Hash(const unsigned char* bytes, std::size_t nbytes) { m_state[m_word_index] ^= t; if (++m_word_index == CUTTOFF) { - Keccakf(m_state); + keccakf(m_state); m_word_index = 0; } p += sizeof(uint64_t); @@ -125,12 +123,12 @@ void Sha3::Hash(const unsigned char* bytes, std::size_t nbytes) { } } -template -auto Sha3::Write() -> Sha3::DigestType { +template +auto Sha3::write() -> Sha3::DigestType { uint64_t t = (uint64_t)(((uint64_t)(0x02 | (1 << 2))) << ((m_byte_index)*8)); m_state[m_word_index] ^= m_saved ^ t; m_state[CUTTOFF - 1] ^= 0x8000000000000000ULL; - Keccakf(m_state); + keccakf(m_state); for (std::size_t i = 0; i < STATE_SIZE; ++i) { const unsigned int t1 = (uint32_t)m_state[i]; diff --git a/include/scl/util/sign.h b/include/scl/util/sign.h index 05751d1..13a87c6 100644 --- a/include/scl/util/sign.h +++ b/include/scl/util/sign.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,15 +21,16 @@ #include #include "scl/math/curves/secp256k1.h" +#include "scl/math/ec.h" #include "scl/math/ff.h" namespace scl::util { /** * @brief A signature for some signature scheme. - * @tparam SignatureScheme the signature scheme. + * @tparam SIGNATURE_SCHEME the signature scheme. */ -template +template struct Signature; class ECDSA; @@ -40,14 +41,14 @@ class ECDSA; template <> struct Signature { private: - using ElementType = math::FF; + using ElementType = math::FF; public: /** * @brief The size of an ECDSA signature in bytes. */ - constexpr static std::size_t ByteSize() { - return ElementType::ByteSize() * 2; + constexpr static std::size_t byteSize() { + return ElementType::byteSize() * 2; } /** @@ -55,18 +56,18 @@ struct Signature { * @param buf the buffer. * @return an ECDSA signature. */ - static Signature Read(const unsigned char* buf) { - return {ElementType::Read(buf), - ElementType::Read(buf + ElementType::ByteSize())}; + static Signature read(const unsigned char* buf) { + return {ElementType::read(buf), + ElementType::read(buf + ElementType::byteSize())}; } /** * @brief Write an ECDSA signature to a stream. * @param buf the buffer. */ - void Write(unsigned char* buf) const { - r.Write(buf); - s.Write(buf + ElementType::ByteSize()); + void write(unsigned char* buf) const { + r.write(buf); + s.write(buf + ElementType::byteSize()); } /** @@ -88,7 +89,7 @@ class ECDSA { /** * @brief Public key type. A curve point. */ - using PublicKey = math::EC; + using PublicKey = math::EC; /** * @brief Secret key type. An element modulo the order of the curve. @@ -100,28 +101,28 @@ class ECDSA { * @param secret_key the secret key. * @return A public key. */ - static PublicKey Derive(const SecretKey& secret_key) { - return secret_key * PublicKey::Generator(); + static PublicKey derive(const SecretKey& secret_key) { + return secret_key * PublicKey::generator(); } /** * @brief Sign a message. - * @tparam D a digest type. + * @tparam DIGEST a digest type. * @param secret_key the secret key for signing. * @param digest the digest to sign. * @param prg a PRG used to select the nonce in the signature. * @return an ECDSA signature. */ - template + template static Signature Sign(const SecretKey& secret_key, - const D& digest, + const DIGEST& digest, PRG& prg) { - const auto k = SecretKey::Random(prg); - const auto R = k * PublicKey::Generator(); - const auto rx = ConversionFunc(R); - const auto h = DigestToElement(digest); + const auto k = SecretKey::random(prg); + const auto R = k * PublicKey::generator(); + const auto rx = conversionFunc(R); + const auto h = digestToElement(digest); - return {rx, k.Inverse() * (h + secret_key * rx)}; + return {rx, k.inverse() * (h + secret_key * rx)}; } /** @@ -131,17 +132,17 @@ class ECDSA { * @param digest the digest that was signed. * @return true if the signature is valid and false otherwise. */ - template - static bool Verify(const PublicKey& public_key, + template + static bool verify(const PublicKey& public_key, const Signature& signature, - const D& digest) { - const auto h = DigestToElement(digest); + const DIGEST& digest) { + const auto h = digestToElement(digest); const auto [r, s] = signature; - const auto si = s.Inverse(); - const auto R1 = (h * si) * PublicKey::Generator(); + const auto si = s.inverse(); + const auto R1 = (h * si) * PublicKey::generator(); const auto R2 = (r * si) * public_key; const auto R = R1 + R2; - return !R.PointAtInfinity() && ConversionFunc(R) == r; + return !R.isPointAtInfinity() && conversionFunc(R) == r; } /** @@ -153,11 +154,11 @@ class ECDSA { * \f$R=(r_x, r_y)\f$ and outputs a scalar as \f$r_x \mod p\f$ where \f$p\f$ * is order of a subgroup. */ - static SecretKey ConversionFunc(const PublicKey& R) { - const auto rx_f = R.ToAffine()[0]; - unsigned char rx_bytes[SecretKey::ByteSize()]; - rx_f.Write(rx_bytes); - return SecretKey::Read(rx_bytes); + static SecretKey conversionFunc(const PublicKey& R) { + const auto rx_f = R.toAffine()[0]; + unsigned char rx_bytes[SecretKey::byteSize()]; + rx_f.write(rx_bytes); + return SecretKey::read(rx_bytes); } /** @@ -165,14 +166,14 @@ class ECDSA { * @param digest the digest. * @return a scalar. */ - template - static SecretKey DigestToElement(const D& digest) { - if (digest.size() < SecretKey::ByteSize()) { - unsigned char buf[SecretKey::ByteSize()] = {0}; + template + static SecretKey digestToElement(const DIGEST& digest) { + if (digest.size() < SecretKey::byteSize()) { + unsigned char buf[SecretKey::byteSize()] = {0}; std::copy(digest.begin(), digest.end(), buf); - return SecretKey::Read(buf); + return SecretKey::read(buf); } - return SecretKey::Read(digest.data()); + return SecretKey::read(digest.data()); } }; diff --git a/include/scl/util/str.h b/include/scl/util/str.h index 66554f2..7323f40 100644 --- a/include/scl/util/str.h +++ b/include/scl/util/str.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,9 +18,11 @@ #ifndef SCL_UTIL_STR_H #define SCL_UTIL_STR_H +#include #include #include #include +#include namespace scl::util { @@ -44,7 +46,7 @@ namespace scl::util { * is not permitted. The input is assumed to encode an integer in big endian. */ template -T FromHexString(const std::string& s) { +T fromHexString(const std::string& s) { auto n = s.size(); if (n % 2) { throw std::invalid_argument("odd-length hex string"); @@ -67,7 +69,7 @@ T FromHexString(const std::string& s) { * @brief Convert value into a string. */ template -std::string ToHexString(const T& v) { +std::string toHexString(const T& v) { std::stringstream ss; ss << std::hex << v; return ss.str(); @@ -80,7 +82,7 @@ std::string ToHexString(const T& v) { * @return a hex representation of the digest. */ template -std::string ToHexString(It begin, It end) { +std::string toHexString(It begin, It end) { std::stringstream ss; ss << std::setfill('0') << std::hex; while (begin != end) { @@ -93,7 +95,7 @@ std::string ToHexString(It begin, It end) { * @brief ToHexString specialization for __uint128_t. */ template <> -std::string ToHexString(const __uint128_t& v); +std::string toHexString(const __uint128_t& v); } // namespace scl::util diff --git a/include/scl/util/time.h b/include/scl/util/time.h index b07ac99..6054324 100644 --- a/include/scl/util/time.h +++ b/include/scl/util/time.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,6 +19,7 @@ #define SCL_UTIL_TIME_H #include +#include namespace scl::util { @@ -39,11 +40,19 @@ struct Time { /** * @brief Get the current time as a TimePoint. */ - static TimePoint Now() { + static TimePoint now() { return std::chrono::steady_clock::now(); }; }; +/** + * @brief Convert a timestamp to milliseconds. + */ +inline long double timeToMillis(Time::Duration time) { + using namespace std::chrono; + return duration(time).count(); +} + } // namespace scl::util #endif // SCL_UTIL_TIME_H diff --git a/include/scl/util/traits.h b/include/scl/util/traits.h deleted file mode 100644 index 5805195..0000000 --- a/include/scl/util/traits.h +++ /dev/null @@ -1,62 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_UTIL_TRAITS_H -#define SCL_UTIL_TRAITS_H - -#include -#include -#include - -namespace scl::util { - -/// @cond - -template -struct IsStdVectorImpl : std::false_type {}; - -template -struct IsStdVectorImpl> : std::true_type {}; - -// https://stackoverflow.com/a/35207812 -template -struct HasOperatorMulImpl { - template - static auto Test(TT*) -> decltype(std::declval() * std::declval()); - - template - static auto Test(...) -> std::false_type; - - using Type = typename std::is_same(0))>::type; -}; - -/// @endcond - -/** - * @brief Trait for determining if two types can be multipled. - * @tparam T the first type. - * @tparam V the second type. - * - * This trait evalutes to an std::true_type if T operator*(V) is - * defined. - */ -template -struct HasOperatorMul : HasOperatorMulImpl::Type {}; - -} // namespace scl::util - -#endif // SCL_UTIL_TRAITS_H diff --git a/include/scl/util/util.h b/include/scl/util/util.h index 86a3e0a..61781b3 100644 --- a/include/scl/util/util.h +++ b/include/scl/util/util.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,6 +18,7 @@ #ifndef SCL_UTIL_UTIL_H #define SCL_UTIL_UTIL_H +#include "scl/util/cmdline.h" #include "scl/util/hash.h" #include "scl/util/prg.h" diff --git a/scripts/build_compile_commands_json.sh b/scripts/build_compile_commands_json.sh deleted file mode 100755 index 05432a2..0000000 --- a/scripts/build_compile_commands_json.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/bash - -if ! command -v bear &>/dev/null; then - echo "Cannot generate compile_commands.json because bear is missing" - exit 1 -fi - -if ! [ -f CMakeLists.txt ]; then - echo "No CMakeLists.txt found" - echo "Run this script from the project root" - exit 1 -fi - -if [ $# -eq "0" ] || ! [ -d $1 ]; then - echo "Usage: $0 [path_to_debug_build_directory]" - exit 1 -fi - -build_dir="${1}" -project_root=$(pwd) - -build_compile_commands () { - cd "${build_dir}" - make -s clean - bear -- make -s -j4 - cd $project_root -} - -build_compile_commands -cp "${build_dir}compile_commands.json" . diff --git a/scripts/build_documentation.sh b/scripts/build_documentation.sh deleted file mode 100755 index f963ef9..0000000 --- a/scripts/build_documentation.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/bash - -CONF="doc/DoxyConf" - -if ! command -v doxygen &>/dev/null; then - echo "Cannot build documentation because doxygen is missing" - exit 1 -fi - -if ! [ -f "${CONF}" ]; then - echo "DoxyConf file missing" - echo "Run this script from the project root" - exit 1 -fi - -doxygen ${CONF} . diff --git a/scripts/check_copyright_headers.py b/scripts/check_copyright_headers.py index 4c59201..d4ad6cd 100755 --- a/scripts/check_copyright_headers.py +++ b/scripts/check_copyright_headers.py @@ -2,9 +2,12 @@ import os +YEAR = "2024" +AUTHOR = "Anders Dalskov" + header = """\ /* SCL --- Secure Computation Library - * ---- THIS LINE IS IGNORED ---- + * Copyright (C) {year} {author} * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,21 +24,14 @@ */ """ -copyright_line = " * Copyright (C) 2023" - def check_file(filename, path): - expected_header = header.rstrip().split("\n") + expected_header = header.format(year=YEAR, author=AUTHOR).rstrip().split("\n") with open(path, 'r') as f: lines = f.readlines()[:len(expected_header) + 1] n = 0 good = True for a, b in zip(expected_header, lines): - ## copyright line - if n == 1: - if not b.startswith(copyright_line): - good = False - break - elif a.rstrip() != b.rstrip(): + if a.rstrip() != b.rstrip(): good = False break n += 1 diff --git a/scripts/check_coverage.py b/scripts/check_coverage.py deleted file mode 100755 index 0e4b016..0000000 --- a/scripts/check_coverage.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python3 - -import sys -import re - -line_threshold = 100.0 - -print(f"Line coverage threshold {line_threshold}%") - -summary = open(sys.argv[1]).read().split('\n') -lines = summary[2] -pl = float(re.findall("[0-9]?[0-9][0-9].[0-9]%", lines)[0][:-1]) - -print(f"coverage: {pl}%") - -coverage_met = pl >= line_threshold - -if not coverage_met: - print("Coverage not met :(") - -exit(0 if coverage_met else 1) diff --git a/scripts/check_coverage.sh b/scripts/check_coverage.sh new file mode 100755 index 0000000..5cad43b --- /dev/null +++ b/scripts/check_coverage.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +set -o pipefail + +readarray -t COV <<< $(grep 'Overall coverage rate:' $1 -A 2) + +LINES=$(echo ${COV[1]} | grep -oE '[0-9]+\.[0-9]+') +FUNCS=$(echo ${COV[2]} | grep -oE '[0-9]+\.[0-9]+') + +echo "Line coverage: ${LINES}% (target: ${COV_THRESHOLD_LINES}%)" +echo "Function coverage: ${FUNCS}% (target: ${COV_THRESHOLD_FUNCS}%)" + +awk "BEGIN { if (${LINES} < ${COV_THRESHOLD_LINES}) { print \"line coverage not met\"; exit 1 } }" +awk "BEGIN { if (${FUNCS} < ${COV_THRESHOLD_FUNCS}) { print \"function coverage not met\"; exit 1 } }" diff --git a/scripts/check_formatting.sh b/scripts/check_formatting.sh index 0647985..b6cddd9 100755 --- a/scripts/check_formatting.sh +++ b/scripts/check_formatting.sh @@ -1,3 +1,7 @@ -#!/usr/bin/bash +#!/usr/bin/env bash -find . -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format -n {} \; +set -eo pipefail + +find include/ src/ test/ -type f \( -iname "*.h" -o -iname "*.cc" \) -exec clang-format-15 -n {} \; &> /tmp/checks.txt +cat /tmp/checks.txt +[[ ! -s /tmp/checks.txt ]] diff --git a/src/scl/coro/runtime.cc b/src/scl/coro/runtime.cc new file mode 100644 index 0000000..edc2cf2 --- /dev/null +++ b/src/scl/coro/runtime.cc @@ -0,0 +1,51 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "scl/coro/runtime.h" + +#include + +#include "scl/coro/task.h" + +using namespace scl; + +void coro::details::removeHandle(Runtime* runtime, + std::coroutine_handle<> handle) { + if (runtime != nullptr) { + runtime->deschedule(handle); + } +} + +std::coroutine_handle<> coro::DefaultRuntime::next() { + auto b = m_tq.begin(); + const auto e = m_tq.end(); + while (b != e) { + const auto [coro, pred] = *b; + if (pred()) { + m_tq.erase(b); + return coro; + } + b++; + } + return std::noop_coroutine(); +} + +void coro::DefaultRuntime::deschedule(std::coroutine_handle<> coroutine) { + m_tq.remove_if([&coroutine](const Pair& pair) { + return std::get<0>(pair) == coroutine; + }); +} diff --git a/src/scl/math/secp256k1_curve.cc b/src/scl/math/curves/secp256k1_curve.cc similarity index 54% rename from src/scl/math/secp256k1_curve.cc rename to src/scl/math/curves/secp256k1_curve.cc index 3c310d7..9352141 100644 --- a/src/scl/math/secp256k1_curve.cc +++ b/src/scl/math/curves/secp256k1_curve.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,15 +19,14 @@ #include #include -#include "./secp256k1_helpers.h" +#include "./../fields/secp256k1_helpers.h" +#include "scl/math/curves/ec_ops.h" #include "scl/math/curves/secp256k1.h" #include "scl/math/ec.h" -#include "scl/math/ec_ops.h" -#include "scl/math/fp.h" using namespace scl; -using Curve = math::Secp256k1; +using Curve = math::ec::Secp256k1; using Field = math::FF; using Point = Curve::ValueType; @@ -39,40 +38,44 @@ using Point = Curve::ValueType; #define GET_Y(point) (point)[1] #define GET_Z(point) (point)[2] -static const Field kCurveB(7); - template <> -void math::CurveSetPointAtInfinity(Point& out) { +void math::ec::setPointAtInfinity(Point& out) { out = POINT_AT_INFINITY; } namespace { -bool Valid(const Field& x, const Field& y) { +bool valid(const Field& x, const Field& y) { + static const Field b(7); + + // valid if y^2 == x^3 + b auto lhs = y * y; - auto rhs = x * x * x + kCurveB; + auto rhs = x * x * x + b; return lhs == rhs; } } // namespace template <> -void math::CurveSetAffine(Point& out, const Field& x, const Field& y) { - if (Valid(x, y)) { - out = {x, y, Field::One()}; +void math::ec::setAffine(Point& out, const Field& x, const Field& y) { + if (valid(x, y)) { + out = {x, y, Field::one()}; } else { throw std::invalid_argument("provided (x, y) not on curve"); } } template <> -std::array math::CurveToAffine(const Point& point) { - const auto Z = GET_Z(point).Inverse(); +std::array math::ec::toAffine(const Point& point) { + if (GET_Z(point) == Field::one()) { + return {GET_X(point), GET_Y(point)}; + } + const auto Z = GET_Z(point).inverse(); return {GET_X(point) * Z, GET_Y(point) * Z}; } template <> -bool math::CurveEqual(const Point& in1, const Point& in2) { +bool math::ec::equal(const Point& in1, const Point& in2) { const auto& Z1 = GET_Z(in1); const auto& Z2 = GET_Z(in2); // (X1, Y1, Z1) eqv (X2, Y2, Z2) <==> (X1 * Z2, Y1 * Z2) == (X2 * Z1, Y2 * Z2) @@ -81,17 +84,17 @@ bool math::CurveEqual(const Point& in1, const Point& in2) { } template <> -bool math::CurveIsPointAtInfinity(const Point& point) { - return GET_Z(point) == Field::Zero(); +bool math::ec::isPointAtInfinity(const Point& point) { + return GET_Z(point) == Field::zero(); } template <> -std::string math::CurveToString(const Point& point) { +std::string math::ec::toString(const Point& point) { std::string str; - if (CurveIsPointAtInfinity(point)) { + if (isPointAtInfinity(point)) { str = "EC{POINT_AT_INFINITY}"; } else { - auto ap = CurveToAffine(point); + auto ap = toAffine(point); std::stringstream ss; ss << "EC{" << ap[0] << ", " << ap[1] << "}"; str = ss.str(); @@ -100,77 +103,46 @@ std::string math::CurveToString(const Point& point) { } // LCOV_EXCL_LINE template <> -void math::CurveSetGenerator(Point& out) { +void math::ec::setGenerator(Point& out) { static const Point gen = { - Field::FromString( + Field::fromString( "79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"), - Field::FromString( + Field::fromString( "483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8"), Field{1}}; out = gen; } -template <> -void math::CurveDouble(Point& out) { - // https://eprint.iacr.org/2015/1060.pdf algorithm 9. - - static const auto b3 = Field(3 * 7); - - auto t0 = GET_Y(out) * GET_Y(out); - auto z3 = t0 + t0; - z3 = z3 + z3; - - z3 = z3 + z3; - auto t1 = GET_Y(out) * GET_Z(out); - auto t2 = GET_Z(out) * GET_Z(out); - - t2 = b3 * t2; - auto x3 = t2 * z3; - auto y3 = t0 + t2; - - z3 = t1 * z3; - t1 = t2 + t2; - t2 = t1 + t2; - - t0 = t0 - t2; - y3 = t0 * y3; - y3 = x3 + y3; - - t1 = GET_X(out) * GET_Y(out); - x3 = t0 * t1; - x3 = x3 + x3; - - out[0] = x3; - out[1] = y3; - out[2] = z3; -} +namespace { -template <> -void math::CurveAdd(Point& out, const Point& in) { - // https://eprint.iacr.org/2015/1060.pdf algorithm 7 +void addProj(Field& x1, + Field& y1, + Field& z1, + const Field& x2, + const Field& y2, + const Field& z2) { + static const Field b3(3 * 7); - static const auto b3 = Field(3 * 7); + auto t0 = x1 * x2; + auto t1 = y1 * y2; + auto t2 = z1 * z2; - auto t0 = GET_X(out) * GET_X(in); - auto t1 = GET_Y(out) * GET_Y(in); - auto t2 = GET_Z(out) * GET_Z(in); - - auto t3 = GET_X(out) + GET_Y(out); - auto t4 = GET_X(in) + GET_Y(in); + auto t3 = x1 + y1; + auto t4 = x2 + y2; t3 = t3 * t4; t4 = t0 + t1; t3 = t3 - t4; - t4 = GET_Y(out) + GET_Z(out); + t4 = y1 + z1; - auto x3 = GET_Y(in) + GET_Z(in); + auto x3 = y2 + z2; t4 = t4 * x3; x3 = t1 + t2; t4 = t4 - x3; - x3 = GET_X(out) + GET_Z(out); - auto y3 = GET_X(in) + GET_Z(in); + x3 = x1 + z1; + auto y3 = x2 + z2; x3 = x3 * y3; y3 = t0 + t2; @@ -196,38 +168,138 @@ void math::CurveAdd(Point& out, const Point& in) { z3 = z3 * t4; z3 = z3 + t0; + x1 = x3; + y1 = y3; + z1 = z3; +} + +void addMixed(Field& x1, + Field& y1, + Field& z1, + const Field& x2, + const Field& y2) { + static const auto b3 = Field(3 * 7); + + auto t0 = x1 * x2; + auto t1 = y1 * y2; + auto t3 = x2 + y2; + + auto t4 = x1 + y1; + t3 = t3 * t4; + t4 = t0 + t1; + + t3 = t3 - t4; + t4 = y2 * z1; + t4 = t4 + y1; + + auto y3 = x2 * z1; + y3 = y3 + x1; + auto x3 = t0 + t0; + + t0 = x3 + t0; + auto t2 = b3 * z1; + auto z3 = t1 + t2; + + t1 = t1 - t2; + y3 = b3 * y3; + x3 = t4 * y3; + + t2 = t3 * t1; + x3 = t2 - x3; + y3 = y3 * t0; + + t1 = t1 * z3; + y3 = t1 + y3; + t0 = t0 * t3; + + z3 = z3 * t4; + z3 = z3 + t0; + + x1 = x3; + y1 = y3; + z1 = z3; +} + +} // namespace + +template <> +void math::ec::dbl(Point& out) { + // https://eprint.iacr.org/2015/1060.pdf algorithm 9. + + static const Field b3(3 * 7); + + auto t0 = GET_Y(out) * GET_Y(out); + auto z3 = t0 + t0; + z3 = z3 + z3; + + z3 = z3 + z3; + auto t1 = GET_Y(out) * GET_Z(out); + auto t2 = GET_Z(out) * GET_Z(out); + + t2 = b3 * t2; + auto x3 = t2 * z3; + auto y3 = t0 + t2; + + z3 = t1 * z3; + t1 = t2 + t2; + t2 = t1 + t2; + + t0 = t0 - t2; + y3 = t0 * y3; + y3 = x3 + y3; + + t1 = GET_X(out) * GET_Y(out); + x3 = t0 * t1; + x3 = x3 + x3; + out[0] = x3; out[1] = y3; out[2] = z3; } template <> -void math::CurveNegate(Point& out) { - if (GET_Y(out) == Field::Zero()) { - CurveSetPointAtInfinity(out); +void math::ec::add(Point& out, const Point& in) { + // https://eprint.iacr.org/2015/1060.pdf algorithm 7, 8 + + if (GET_Z(in) == Field::one()) { + addMixed(GET_X(out), GET_Y(out), GET_Z(out), GET_X(in), GET_Y(in)); } else { - GET_Y(out).Negate(); + addProj(GET_X(out), + GET_Y(out), + GET_Z(out), + GET_X(in), + GET_Y(in), + GET_Z(in)); } } template <> -void math::CurveSubtract(Point& out, const Point& in) { +void math::ec::negate(Point& out) { + if (GET_Y(out) == Field::zero()) { + setPointAtInfinity(out); + } else { + GET_Y(out).negate(); + } +} + +template <> +void math::ec::subtract(Point& out, const Point& in) { Point copy(in); - CurveNegate(copy); - CurveAdd(out, copy); + ec::negate(copy); + ec::add(out, copy); } template <> -void math::CurveScalarMultiply(Point& out, const Number& scalar) { - if (!CurveIsPointAtInfinity(out)) { - const auto n = scalar.BitSize(); +void math::ec::scalarMultiply(Point& out, const Number& scalar) { + if (!isPointAtInfinity(out)) { + const auto n = scalar.bitSize(); Point res; - CurveSetPointAtInfinity(res); + setPointAtInfinity(res); // equivalent to for (int i = n - 1; i >= 0; i--) for (auto i = n; i-- > 0;) { - CurveDouble(res); - if (scalar.TestBit(i)) { - CurveAdd(res, out); + dbl(res); + if (scalar.testBit(i)) { + ec::add(res, out); } } out = res; @@ -235,17 +307,18 @@ void math::CurveScalarMultiply(Point& out, const Number& scalar) { } template <> -void math::CurveScalarMultiply(Point& out, - const FF& scalar) { - if (!CurveIsPointAtInfinity(out)) { - auto x = FFAccess::FromMonty(scalar); - const auto n = FFAccess::HigestSetBit(x); +void math::ec::scalarMultiply(Point& out, + const FF& scalar) { + if (!isPointAtInfinity(out)) { Point res; - CurveSetPointAtInfinity(res); - for (auto i = n; i-- > 0;) { - CurveDouble(res); - if (FFAccess::TestBit(x, i)) { - CurveAdd(res, out); + setPointAtInfinity(res); + const auto naf = details::toNaf(scalar); + for (auto i = naf.size; i-- > 0;) { + ec::dbl(res); + if (naf.values[i].pos()) { + ec::add(res, out); + } else if (naf.values[i].neg()) { + ec::subtract(res, out); } } out = res; @@ -267,42 +340,44 @@ void math::CurveScalarMultiply(Point& out, namespace { -Field ComputeOtherCoordinate(const Field& x) { - auto y_sqr = x * x * x + kCurveB; - auto z = math::FFAccess::ComputeSqrt(y_sqr); +Field computeOtherCoordinate(const Field& x) { + static const Field CURVE_B(7); + + auto y_sqr = x * x * x + CURVE_B; + auto z = math::details::sqrt(y_sqr); return z; } -bool IsSmaller(const Field& y, const Field& y_neg) { - return math::FFAccess::IsSmaller(y, y_neg); +bool isSmaller(const Field& y, const Field& y_neg) { + return math::details::isSmaller(y, y_neg); } } // namespace template <> -void math::CurveFromBytes(Point& out, const unsigned char* src) { +void math::ec::fromBytes(Point& out, const unsigned char* src) { const auto flags = *src; if (IS_POINT_AT_INFINITY(flags)) { // we opt to not validate the rest of the buffer here. This technically // allows an implementation to only send a single byte in case it wishes to // send the point-at-infinity. - CurveSetPointAtInfinity(out); + setPointAtInfinity(out); } else { if (IS_FULL_POINT(flags)) { - out[0] = Field::Read(src + 1); - out[1] = Field::Read(src + 1 + Field::ByteSize()); - out[2] = Field::One(); + out[0] = Field::read(src + 1); + out[1] = Field::read(src + 1 + Field::byteSize()); + out[2] = Field::one(); } else { - Field x = Field::Read(src + 1); + Field x = Field::read(src + 1); out[0] = x; - out[2] = Field::One(); + out[2] = Field::one(); - Field y = ComputeOtherCoordinate(x); - Field yn = y.Negated(); + Field y = computeOtherCoordinate(x); + Field yn = y.negated(); - auto smaller = IsSmaller(y, yn); + auto smaller = isSmaller(y, yn); auto select_smaller = SELECT_SMALLER(flags); if (smaller) { out[1] = select_smaller == 0 ? yn : y; @@ -318,9 +393,9 @@ void math::CurveFromBytes(Point& out, const unsigned char* src) { #define MARK_SELECT_SMALLER(buf) (*(buf) |= SELECT_SMALLER_FLAG) template <> -void math::CurveToBytes(unsigned char* dest, - const Point& in, - bool compress) { +void math::ec::toBytes(unsigned char* dest, + const Point& in, + bool compress) { // Make sure flag byte is zeroed. *dest = 0; @@ -329,28 +404,28 @@ void math::CurveToBytes(unsigned char* dest, MARK_FULL_POINT(dest); } - if (CurveIsPointAtInfinity(in)) { + if (isPointAtInfinity(in)) { MARK_POINT_AT_INFINITY(dest); // zero rest of the buffer to ensure we can always safely send the right // amount of bytes. std::memset(dest + 1, 0, compress ? 32 : 64); } else { - const auto ap = CurveToAffine(in); + const auto ap = toAffine(in); // if compression is used, we indicate a bit indicating which of {y, -y} is // the smaller, and the only write the x coordinate. Otherwise we write both // x and y. if (compress) { // include a flag which indicates which of {y, -y} is the smaller. const auto& y = ap[1]; - const auto yn = y.Negated(); + const auto yn = y.negated(); - if (IsSmaller(y, yn)) { + if (isSmaller(y, yn)) { MARK_SELECT_SMALLER(dest); } - ap[0].Write(dest + 1); + ap[0].write(dest + 1); } else { - ap[0].Write(dest + 1); - ap[1].Write(dest + 1 + Field::ByteSize()); + ap[0].write(dest + 1); + ap[1].write(dest + 1 + Field::byteSize()); } } } diff --git a/src/scl/math/ops_gmp_ff.cc b/src/scl/math/fields/ff_ops_gmp.cc similarity index 83% rename from src/scl/math/ops_gmp_ff.cc rename to src/scl/math/fields/ff_ops_gmp.cc index 1529e49..fe057d2 100644 --- a/src/scl/math/ops_gmp_ff.cc +++ b/src/scl/math/fields/ff_ops_gmp.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,9 +15,11 @@ * along with this program. If not, see . */ -#include "scl/math/ops_gmp_ff.h" +#include "scl/math/fields/ff_ops_gmp.h" -std::size_t scl::math::FindFirstNonZero(const std::string& s) { +using namespace scl; + +std::size_t math::ff::findFirstNonZero(const std::string& s) { int n = 0; for (const auto c : s) { if (c != '0') { diff --git a/src/scl/math/mersenne127.cc b/src/scl/math/fields/mersenne127.cc similarity index 59% rename from src/scl/math/mersenne127.cc rename to src/scl/math/fields/mersenne127.cc index dc2e05a..ab40340 100644 --- a/src/scl/math/mersenne127.cc +++ b/src/scl/math/fields/mersenne127.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,31 +20,39 @@ #include #include -#include "scl/math/ff_ops.h" -#include "scl/math/ops_small_fp.h" +#include "./small_ff.h" +#include "scl/math/fields/ff_ops.h" #include "scl/util/str.h" +using namespace scl; + using u64 = std::uint64_t; using u128 = __uint128_t; // The prime p = 2^127 - 1. static const u128 p = (((u128)0x7FFFFFFFFFFFFFFF) << 64) | 0xFFFFFFFFFFFFFFFF; -using Mersenne127 = scl::math::Mersenne127; +using Mersenne127 = scl::math::ff::Mersenne127; template <> -void scl::math::FieldConvertIn(u128& out, const int value) { +void math::ff::convertTo(u128& out, const int value) { out = value < 0 ? value + p : value; } template <> -void scl::math::FieldAdd(u128& out, const u128& op) { - ModAdd(out, op, p); +void math::ff::convertTo(u128& out, const std::string& src) { + out = util::fromHexString(src); + out = out % p; +} + +template <> +void math::ff::add(u128& out, const u128& op) { + details::modAdd(out, op, p); } template <> -void scl::math::FieldSubtract(u128& out, const u128& op) { - ModSub(out, op, p); +void math::ff::subtract(u128& out, const u128& op) { + details::modSub(out, op, p); } namespace { @@ -55,7 +63,7 @@ struct u256 { }; // https://cp-algorithms.com/algebra/montgomery_multiplication.html -u256 MultiplyFull(const u128 x, const u128 y) { +u256 multiplyFull(const u128 x, const u128 y) { u64 a = x >> 64; u64 b = x; u64 c = y >> 64; @@ -77,53 +85,44 @@ u256 MultiplyFull(const u128 x, const u128 y) { } // namespace template <> -void scl::math::FieldMultiply(u128& out, const u128& op) { - u256 z = MultiplyFull(out, op); +void math::ff::multiply(u128& out, const u128& op) { + u256 z = multiplyFull(out, op); out = z.high << 1; u128 b = z.low; out |= b >> 127; b &= p; - ModAdd(out, b, p); + details::modAdd(out, b, p); } template <> -void scl::math::FieldNegate(u128& out) { - ModNeg(out, p); +void math::ff::negate(u128& out) { + details::modNeg(out, p); } template <> -void scl::math::FieldInvert(u128& out) { - ModInv(out, out, p); +void math::ff::invert(u128& out) { + details::modInv(out, out, p); } template <> -bool scl::math::FieldEqual(const u128& in1, const u128& in2) { +bool math::ff::equal(const u128& in1, const u128& in2) { return in1 == in2; } template <> -void scl::math::FieldFromBytes(u128& dest, - const unsigned char* src) { +void math::ff::fromBytes(u128& dest, const unsigned char* src) { dest = *(const u128*)src; dest = dest % p; } template <> -void scl::math::FieldToBytes(unsigned char* dest, - const u128& src) { +void math::ff::toBytes(unsigned char* dest, const u128& src) { std::memcpy(dest, &src, sizeof(u128)); } template <> -std::string scl::math::FieldToString(const u128& in) { - return util::ToHexString(in); -} - -template <> -void scl::math::FieldFromString(u128& out, - const std::string& src) { - out = util::FromHexString(src); - out = out % p; +std::string math::ff::toString(const u128& in) { + return util::toHexString(in); } diff --git a/src/scl/math/mersenne61.cc b/src/scl/math/fields/mersenne61.cc similarity index 53% rename from src/scl/math/mersenne61.cc rename to src/scl/math/fields/mersenne61.cc index e0b56e4..9d95c7a 100644 --- a/src/scl/math/mersenne61.cc +++ b/src/scl/math/fields/mersenne61.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -18,38 +18,45 @@ #include "scl/math/fields/mersenne61.h" #include -#include -#include +#include -#include "scl/math/ff_ops.h" -#include "scl/math/ops_small_fp.h" +#include "./small_ff.h" +#include "scl/math/fields/ff_ops.h" #include "scl/util/str.h" +using namespace scl; + using u64 = std::uint64_t; using u128 = __uint128_t; // The prime p = 2^61 - 1 static const u64 p = 0x1FFFFFFFFFFFFFFF; -using Mersenne61 = scl::math::Mersenne61; +using Mersenne61 = scl::math::ff::Mersenne61; template <> -void scl::math::FieldConvertIn(u64& out, const int value) { +void math::ff::convertTo(u64& out, const int value) { out = value < 0 ? value + p : value; } template <> -void scl::math::FieldAdd(u64& out, const u64& op) { - ModAdd(out, op, p); +void math::ff::convertTo(u64& out, const std::string& src) { + out = util::fromHexString(src); + out = out % p; +} + +template <> +void math::ff::add(u64& out, const u64& op) { + details::modAdd(out, op, p); } template <> -void scl::math::FieldSubtract(u64& out, const u64& op) { - ModSub(out, op, p); +void math::ff::subtract(u64& out, const u64& op) { + details::modSub(out, op, p); } template <> -void scl::math::FieldMultiply(u64& out, const u64& op) { +void math::ff::multiply(u64& out, const u64& op) { u128 z = (u128)out * op; u64 a = z >> 61; u64 b = (u64)z; @@ -57,44 +64,37 @@ void scl::math::FieldMultiply(u64& out, const u64& op) { a |= b >> 61; b &= p; - ModAdd(a, b, p); + details::modAdd(a, b, p); out = a; } template <> -void scl::math::FieldNegate(u64& out) { - ModNeg(out, p); +void math::ff::negate(u64& out) { + details::modNeg(out, p); } template <> -void scl::math::FieldInvert(u64& out) { - ModInv(out, out, p); +void math::ff::invert(u64& out) { + details::modInv(out, out, p); } template <> -bool scl::math::FieldEqual(const u64& in1, const u64& in2) { +bool math::ff::equal(const u64& in1, const u64& in2) { return in1 == in2; } template <> -void scl::math::FieldFromBytes(u64& dest, - const unsigned char* src) { +void math::ff::fromBytes(u64& dest, const unsigned char* src) { dest = *(const u64*)src; dest = dest % p; } template <> -void scl::math::FieldToBytes(unsigned char* dest, const u64& src) { +void math::ff::toBytes(unsigned char* dest, const u64& src) { std::memcpy(dest, &src, sizeof(u64)); } template <> -std::string scl::math::FieldToString(const u64& in) { - return util::ToHexString(in); -} - -template <> -void scl::math::FieldFromString(u64& out, const std::string& src) { - out = util::FromHexString(src); - out = out % p; +std::string math::ff::toString(const u64& in) { + return util::toHexString(in); } diff --git a/src/scl/math/fields/naf.h b/src/scl/math/fields/naf.h new file mode 100644 index 0000000..b9cc1fa --- /dev/null +++ b/src/scl/math/fields/naf.h @@ -0,0 +1,107 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_MATH_FIELDS_NAF_H +#define SCL_MATH_FIELDS_NAF_H + +#include +#include + +#include "scl/math/ff.h" + +namespace scl::math::details { + +/** + * @brief Non-adjacent Form encoding of a field element. + * @tparam T a finite field type. + */ +template +struct NafEncoding { + public: + /** + * @brief A type indicating the sign in the encoding. + */ + class Value { + public: + constexpr Value() : Value(0) {} + + /** + * @brief Create a value representing +1. + */ + static constexpr Value createPos() { + return Value{1}; + } + + /** + * @brief Create a value representing -1. + */ + static constexpr Value createNeg() { + return Value{2}; + } + + /** + * @brief Create a value represeting 0. + */ + static constexpr Value createZero() { + return Value{0}; + } + + /** + * @brief Check if this value is +1. + */ + constexpr bool pos() const { + return m_v == 1; + } + + /** + * @brief Check if this value is -1. + */ + constexpr bool neg() const { + return m_v == 2; + } + + /** + * @brief Check if this value is 0. + */ + constexpr bool zero() const { + return m_v == 0; + } + + private: + constexpr Value(unsigned char v) : m_v(v) {} + unsigned char m_v; + }; + + /** + * @brief Maximum size of the encoding. + */ + constexpr static std::size_t MAX_SIZE = T::BIT_SIZE + 1; + + /** + * @brief The trits. + */ + std::array values; + + /** + * @brief The number of meaningful entries in values. + */ + std::size_t size; +}; + +} // namespace scl::math::details + +#endif // SCL_MATH_FIELDS_NAF_H diff --git a/src/scl/math/secp256k1_field.cc b/src/scl/math/fields/secp256k1_field.cc similarity index 53% rename from src/scl/math/secp256k1_field.cc rename to src/scl/math/fields/secp256k1_field.cc index 877bde8..0a6598d 100644 --- a/src/scl/math/secp256k1_field.cc +++ b/src/scl/math/fields/secp256k1_field.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,16 +15,20 @@ * along with this program. If not, see . */ +#include "scl/math/fields/secp256k1_field.h" + #include #include #include "./secp256k1_helpers.h" -#include "scl/math/curves/secp256k1.h" -#include "scl/math/ff_ops.h" +#include "scl/math/ff.h" +#include "scl/math/fields/ff_ops.h" +#include "scl/math/fields/ff_ops_gmp.h" #include "scl/math/number.h" -#include "scl/math/ops_gmp_ff.h" -using Field = scl::math::Secp256k1::Field; +using namespace scl; + +using Field = math::ff::Secp256k1Field; using Elem = Field::ValueType; #define NUM_LIMBS 4 @@ -36,7 +40,7 @@ using Elem = Field::ValueType; } \ } while (0) -static const scl::math::RedParams RD = { +static const math::ff::RedParams RD = { // Prime { 0xFFFFFFFEFFFFFC2F, // @@ -53,8 +57,8 @@ static const scl::math::RedParams RD = { }}; template <> -scl::math::Number scl::math::Order>() { - return Number::FromString( +math::Number math::order>() { + return Number::fromString( "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F"); } @@ -62,37 +66,43 @@ scl::math::Number scl::math::Order>() { #define PTR(X) (X).data() template <> -void scl::math::FieldConvertIn(Elem& out, const int value) { +void math::ff::convertTo(Elem& out, const int value) { + out = {0}; + montyInFromInt(PTR(out), value, RD); +} + +template <> +void math::ff::convertTo(Elem& out, const std::string& src) { out = {0}; - MontyInFromInt(PTR(out), value, RD); + montyFromString(PTR(out), src, RD); } template <> -void scl::math::FieldAdd(Elem& out, const Elem& op) { - MontyModAdd(PTR(out), PTR(op), RD); +void math::ff::add(Elem& out, const Elem& op) { + montyModAdd(PTR(out), PTR(op), RD); } template <> -void scl::math::FieldSubtract(Elem& out, const Elem& op) { - MontyModSub(PTR(out), PTR(op), RD); +void math::ff::subtract(Elem& out, const Elem& op) { + montyModSub(PTR(out), PTR(op), RD); } template <> -void scl::math::FieldNegate(Elem& out) { - MontyModNeg(PTR(out), RD); +void math::ff::negate(Elem& out) { + montyModNeg(PTR(out), RD); } template <> -void scl::math::FieldMultiply(Elem& out, const Elem& op) { - MontyModMul(PTR(out), PTR(op), RD); +void math::ff::multiply(Elem& out, const Elem& op) { + montyModMul(PTR(out), PTR(op), RD); } #define ONE \ { 0x1000003D1, 0, 0, 0 } template <> -void scl::math::FieldInvert(Elem& out) { - static const mp_limb_t kPrimeMinus2[NUM_LIMBS] = { +void math::ff::invert(Elem& out) { + static const mp_limb_t PRIME_MINUS_2[NUM_LIMBS] = { 0xFFFFFFFEFFFFFC2D, // 0xFFFFFFFFFFFFFFFF, // 0xFFFFFFFFFFFFFFFF, // @@ -100,45 +110,36 @@ void scl::math::FieldInvert(Elem& out) { }; Elem res = ONE; - MontyModInv(PTR(res), PTR(out), kPrimeMinus2, RD); + montyModInv(PTR(res), PTR(out), PRIME_MINUS_2, RD); out = res; } template <> -bool scl::math::FieldEqual(const Elem& in1, const Elem& in2) { - return CompareValues(PTR(in1), PTR(in2)) == 0; +bool math::ff::equal(const Elem& in1, const Elem& in2) { + return compareValues(PTR(in1), PTR(in2)) == 0; } template <> -void scl::math::FieldFromBytes(Elem& dest, const unsigned char* src) { - MontyFromBytes(PTR(dest), src, RD); +void math::ff::fromBytes(Elem& dest, const unsigned char* src) { + montyFromBytes(PTR(dest), src, RD); } template <> -void scl::math::FieldToBytes(unsigned char* dest, const Elem& src) { - MontyToBytes(dest, PTR(src), RD); +void math::ff::toBytes(unsigned char* dest, const Elem& src) { + montyToBytes(dest, PTR(src), RD); } template <> -void scl::math::FieldFromString(Elem& out, const std::string& src) { - out = {0}; - MontyFromString(PTR(out), src, RD); +std::string math::ff::toString(const Elem& in) { + return montyToString(PTR(in), RD); } -template <> -std::string scl::math::FieldToString(const Elem& in) { - return MontyToString(PTR(in), RD); -} - -bool scl::math::FFAccess::IsSmaller( - const scl::math::FF& lhs, - const scl::math::FF& rhs) { - auto c = CompareValues(PTR(lhs.m_value), PTR(rhs.m_value)); +bool math::details::isSmaller(const FF& lhs, const FF& rhs) { + auto c = ff::compareValues(PTR(lhs.value()), PTR(rhs.value())); return c <= 0; } -scl::math::FF scl::math::FFAccess::ComputeSqrt( - const scl::math::FF& x) { +math::FF math::details::sqrt(const FF& x) { // (p + 1) / 4. We assume the input is a square mod p, so x^{e} gives a square // root of x. static const mp_limb_t e[NUM_LIMBS] = { @@ -150,9 +151,9 @@ scl::math::FF scl::math::FFAccess::ComputeSqrt( FF out; Elem res = ONE; - MontyModExp(PTR(res), PTR(x.m_value), e, RD); - out.m_value = res; + montyModExp(PTR(res), PTR(x.value()), e, RD); + out.value() = res; return out; -} // LCOV_EXCL_LINE +} #undef ONE diff --git a/src/scl/math/fields/secp256k1_helpers.h b/src/scl/math/fields/secp256k1_helpers.h new file mode 100644 index 0000000..fde8948 --- /dev/null +++ b/src/scl/math/fields/secp256k1_helpers.h @@ -0,0 +1,60 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#ifndef SCL_MATH_FIELDS_SECP256K1_HELPERS_H +#define SCL_MATH_FIELDS_SECP256K1_HELPERS_H + +#include "naf.h" +#include "scl/math/curves/secp256k1.h" +#include "scl/math/ec.h" +#include "scl/math/fields/secp256k1_field.h" +#include "scl/math/fields/secp256k1_scalar.h" + +namespace scl::math::details { + +/** + * @brief Check which of two field elements is smaller. + * + * Used in serialization. + */ +bool isSmaller(const FF& lhs, + const FF& rhs); + +/** + * @brief Compute the square root of an element. + * + * Used in serialization. + */ +FF sqrt(const FF& x); + +/** + * @brief Convert a field element out of montgomery representation. + * + * Used in scalar multiplications. + */ +FF fromMonty(const FF& x); + +/** + * @brief Convert a field element into a NAF encoding. + * + * Used in scalar multiplication. + */ +NafEncoding toNaf(const FF& x); + +} // namespace scl::math::details + +#endif // SCL_MATH_FIELDS_SECP256K1_HELPERS_H diff --git a/src/scl/math/fields/secp256k1_scalar.cc b/src/scl/math/fields/secp256k1_scalar.cc new file mode 100644 index 0000000..6c570a9 --- /dev/null +++ b/src/scl/math/fields/secp256k1_scalar.cc @@ -0,0 +1,205 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "scl/math/fields/secp256k1_scalar.h" + +#include +#include + +#include + +#include "./secp256k1_helpers.h" +#include "scl/math/ff.h" +#include "scl/math/fields/ff_ops.h" +#include "scl/math/fields/ff_ops_gmp.h" + +using namespace scl; + +using Field = math::ff::Secp256k1Scalar; +using Elem = Field::ValueType; + +constexpr static std::size_t NUM_LIMBS = std::tuple_size{}; + +#define SCL_COPY(out, in, size) \ + do { \ + for (std::size_t i = 0; i < (size); ++i) { \ + *((out) + i) = *((in) + i); \ + } \ + } while (0) + +template <> +math::Number math::order>() { + return Number::fromString( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); +} + +static const math::ff::RedParams RD = { + // Prime + { + 0xBFD25E8CD0364141, // + 0xBAAEDCE6AF48A03B, // + 0xFFFFFFFFFFFFFFFE, // + 0xFFFFFFFFFFFFFFFF // + }, + // Montgomery constant + { + 0x4B0DFF665588B13F, // + 0x50A51AC834B9EC24, // + 0x897F30C127CFAB5E, // + 0xD9E8890D6494EF93 // + }}; + +#define PTR(X) (X).data() + +template <> +void math::ff::convertTo(Elem& out, const int value) { + out = {0}; + montyInFromInt(PTR(out), value, RD); +} + +template <> +void math::ff::convertTo(Elem& out, const std::string& src) { + out = {0}; + montyFromString(PTR(out), src, RD); +} + +template <> +void math::ff::add(Elem& out, const Elem& op) { + montyModAdd(PTR(out), PTR(op), RD); +} + +template <> +void math::ff::subtract(Elem& out, const Elem& op) { + montyModSub(PTR(out), PTR(op), RD); +} + +template <> +void math::ff::negate(Elem& out) { + montyModNeg(PTR(out), RD); +} + +template <> +void math::ff::multiply(Elem& out, const Elem& op) { + montyModMul(PTR(out), PTR(op), RD); +} + +#define ONE \ + { 0x402DA1732FC9BEBF, 0x4551231950B75FC4, 0x1, 0 } + +template <> +void math::ff::invert(Elem& out) { + static const mp_limb_t PRIME_MINUS_2[NUM_LIMBS] = { + 0xBFD25E8CD036413F, // + 0xBAAEDCE6AF48A03B, // + 0xFFFFFFFFFFFFFFFE, // + 0xFFFFFFFFFFFFFFFF // + }; + + Elem res = ONE; + montyModInv(PTR(res), PTR(out), PRIME_MINUS_2, RD); + out = res; +} + +template <> +bool math::ff::equal(const Elem& in1, const Elem& in2) { + return compareValues(PTR(in1), PTR(in2)) == 0; +} + +template <> +void math::ff::fromBytes(Elem& dest, const unsigned char* src) { + montyFromBytes(PTR(dest), src, RD); +} + +template <> +void math::ff::toBytes(unsigned char* dest, const Elem& src) { + montyToBytes(dest, PTR(src), RD); +} + +template <> +std::string math::ff::toString(const Elem& in) { + return montyToString(PTR(in), RD); +} + +namespace { + +bool testBit(const mp_limb_t* in, std::size_t pos) { + const auto bits_per_limb = static_cast(mp_bits_per_limb); + const auto limb = pos / bits_per_limb; + const auto limb_pos = pos % bits_per_limb; + return ((in[limb] >> limb_pos) & 1) == 1; +} + +} // namespace + +math::FF math::details::fromMonty(const FF& x) { + mp_limb_t padded[2 * NUM_LIMBS] = {0}; + SCL_COPY(padded, PTR(x.value()), NUM_LIMBS); + montyRedc(padded, RD); + + FF r; + SCL_COPY(PTR(r.value()), padded, NUM_LIMBS); + + return r; +} + +namespace { + +void add1(mp_limb_t* out) { + static const mp_limb_t one[NUM_LIMBS] = {1, 0, 0, 0}; + mpn_add_n(out, out, one, NUM_LIMBS); +} + +void sub1(mp_limb_t* out) { + static const mp_limb_t one[NUM_LIMBS] = {1, 0, 0, 0}; + mpn_sub_n(out, out, one, NUM_LIMBS); +} + +} // namespace + +// Compute a NAF encoding of a field element using the simpel algorithm provided +// here: https://en.wikipedia.org/wiki/Non-adjacent_form#Converting_to_NAF +math::details::NafEncoding math::details::toNaf(const FF& x) { + using NafEnc = math::details::NafEncoding; + + auto val = fromMonty(x).value(); + + std::array z; + std::size_t i = 0; + + while (!mpn_zero_p(PTR(val), NUM_LIMBS)) { + // check if val is odd + if (::testBit(PTR(val), 0)) { + // check if val is 1 or 3 mod 4 + if (::testBit(PTR(val), 1)) { + z[i] = NafEnc::Value::createNeg(); + add1(PTR(val)); + } else { + z[i] = NafEnc::Value::createPos(); + sub1(PTR(val)); + } + } else { + z[i] = NafEnc::Value::createZero(); + } + + i++; + mpn_rshift(PTR(val), PTR(val), NUM_LIMBS, 1); + } + + return {z, i}; +} + +#undef ONE diff --git a/include/scl/math/ops_small_fp.h b/src/scl/math/fields/small_ff.h similarity index 83% rename from include/scl/math/ops_small_fp.h rename to src/scl/math/fields/small_ff.h index 5a669cc..1e711b5 100644 --- a/include/scl/math/ops_small_fp.h +++ b/src/scl/math/fields/small_ff.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,18 +15,18 @@ * along with this program. If not, see . */ -#ifndef SCL_MATH_OPS_SMALL_FP_H -#define SCL_MATH_OPS_SMALL_FP_H +#ifndef SCL_MATH_FIELDS_SMALL_FF_H +#define SCL_MATH_FIELDS_SMALL_FF_H #include -namespace scl::math { +namespace scl::math::details { /** * @brief Compute a modular addition on two simple types. */ template -void ModAdd(T& t, const T& v, const T& m) { +void modAdd(T& t, const T& v, const T& m) { t = t + v; if (t >= m) { t = t - m; @@ -37,7 +37,7 @@ void ModAdd(T& t, const T& v, const T& m) { * @brief Compute a modular subtraction on two simple types. */ template -void ModSub(T& t, const T& v, const T& m) { +void modSub(T& t, const T& v, const T& m) { if (v > t) { t = t + m - v; } else { @@ -49,7 +49,7 @@ void ModSub(T& t, const T& v, const T& m) { * @brief Compute the additive inverse of a simple type. */ template -void ModNeg(T& t, const T& m) { +void modNeg(T& t, const T& m) { if (t) { t = m - t; } @@ -59,7 +59,7 @@ void ModNeg(T& t, const T& m) { * @brief Compute a modular inverse of a simple type. */ template -void ModInv(T& t, const T& v, const T& m) { +void modInv(T& t, const T& v, const T& m) { #define SCL_PARALLEL_ASSIGN(v1, v2, q) \ do { \ const auto __temp = v2; \ @@ -91,6 +91,6 @@ void ModInv(T& t, const T& v, const T& m) { t = static_cast(k); } -} // namespace scl::math +} // namespace scl::math::details -#endif // SCL_MATH_OPS_SMALL_FP_H +#endif // SCL_MATH_FIELDS_SMALL_FF_H diff --git a/src/scl/math/number.cc b/src/scl/math/number.cc index 485f668..dc6d747 100644 --- a/src/scl/math/number.cc +++ b/src/scl/math/number.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -25,31 +25,33 @@ #include -scl::math::Number::Number() { +using namespace scl; + +math::Number::Number() { mpz_init(m_value); } -scl::math::Number::Number(const Number& number) : Number() { +math::Number::Number(const Number& number) : Number() { mpz_set(m_value, number.m_value); } -scl::math::Number::Number(Number&& number) noexcept : Number() { +math::Number::Number(Number&& number) noexcept : Number() { mpz_set(m_value, number.m_value); } -scl::math::Number::~Number() { +math::Number::~Number() { mpz_clear(m_value); } -scl::math::Number scl::math::Number::Random(std::size_t bits, util::PRG& prg) { +math::Number math::Number::random(std::size_t bits, util::PRG& prg) { auto len = (bits - 1) / 8 + 2; auto data = std::make_unique(len); - prg.Next(data.get(), len); + prg.next(data.get(), len); // trim trailing bits to ensure the resulting number is atmost bits large data[1] &= (1 << (bits % 8)) - 1; - scl::math::Number r; + math::Number r; mpz_import(r.m_value, len - 1, 1, 1, 0, 0, data.get() + 1); if ((data[0] & 1) != 0) { mpz_neg(r.m_value, r.m_value); @@ -57,21 +59,20 @@ scl::math::Number scl::math::Number::Random(std::size_t bits, util::PRG& prg) { return r; } -scl::math::Number scl::math::Number::RandomPrime(std::size_t bits, - util::PRG& prg) { - auto r = Random(bits, prg); +math::Number math::Number::randomPrime(std::size_t bits, util::PRG& prg) { + auto r = random(bits, prg); Number prime; mpz_nextprime(prime.m_value, r.m_value); return prime; } -scl::math::Number scl::math::Number::FromString(const std::string& str) { - scl::math::Number num; +math::Number math::Number::fromString(const std::string& str) { + math::Number num; mpz_set_str(num.m_value, str.c_str(), 16); return num; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::Read(const unsigned char* buf) { +math::Number math::Number::read(const unsigned char* buf) { std::uint32_t size_and_sign; std::memcpy(&size_and_sign, buf, sizeof(std::uint32_t)); @@ -86,51 +87,51 @@ scl::math::Number scl::math::Number::Read(const unsigned char* buf) { return r; } // LCOV_EXCL_LINE -scl::math::Number::Number(int value) : Number() { +math::Number::Number(int value) : Number() { mpz_set_si(m_value, value); } -scl::math::Number scl::math::Number::operator+(const Number& number) const { - scl::math::Number sum; +math::Number math::Number::operator+(const Number& number) const { + math::Number sum; mpz_add(sum.m_value, m_value, number.m_value); return sum; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator-(const Number& number) const { - scl::math::Number diff; +math::Number math::Number::operator-(const Number& number) const { + math::Number diff; mpz_sub(diff.m_value, m_value, number.m_value); return diff; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator-() const { - scl::math::Number neg; +math::Number math::Number::operator-() const { + math::Number neg; mpz_neg(neg.m_value, m_value); return neg; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator*(const Number& number) const { - scl::math::Number prod; +math::Number math::Number::operator*(const Number& number) const { + math::Number prod; mpz_mul(prod.m_value, m_value, number.m_value); return prod; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator/(const Number& number) const { +math::Number math::Number::operator/(const Number& number) const { if (mpz_sgn(number.m_value) == 0) { throw std::logic_error("division by 0"); } - scl::math::Number frac; + math::Number frac; mpz_div(frac.m_value, m_value, number.m_value); return frac; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator%(const Number& mod) const { - scl::math::Number res; +math::Number math::Number::operator%(const Number& mod) const { + math::Number res; mpz_mod(res.m_value, m_value, mod.m_value); return res; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator<<(int shift) const { - scl::math::Number shifted; +math::Number math::Number::operator<<(int shift) const { + math::Number shifted; if (shift < 0) { shifted = operator>>(-shift); } else { @@ -139,8 +140,8 @@ scl::math::Number scl::math::Number::operator<<(int shift) const { return shifted; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator>>(int shift) const { - scl::math::Number shifted; +math::Number math::Number::operator>>(int shift) const { + math::Number shifted; if (shift < 0) { shifted = operator<<(-shift); } else { @@ -149,47 +150,47 @@ scl::math::Number scl::math::Number::operator>>(int shift) const { return shifted; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator^(const Number& number) const { - scl::math::Number xord; +math::Number math::Number::operator^(const Number& number) const { + math::Number xord; mpz_xor(xord.m_value, m_value, number.m_value); return xord; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator|(const Number& number) const { - scl::math::Number ord; +math::Number math::Number::operator|(const Number& number) const { + math::Number ord; mpz_ior(ord.m_value, m_value, number.m_value); return ord; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator&(const Number& number) const { - scl::math::Number andd; +math::Number math::Number::operator&(const Number& number) const { + math::Number andd; mpz_and(andd.m_value, m_value, number.m_value); return andd; } // LCOV_EXCL_LINE -scl::math::Number scl::math::Number::operator~() const { - scl::math::Number com; +math::Number math::Number::operator~() const { + math::Number com; mpz_com(com.m_value, m_value); return com; } // LCOV_EXCL_LINE -int scl::math::Number::Compare(const Number& number) const { +int math::Number::compare(const Number& number) const { return mpz_cmp(m_value, number.m_value); } -std::size_t scl::math::Number::ByteSize() const { - return (BitSize() - 1) / 8 + 1; +std::size_t math::Number::byteSize() const { + return (bitSize() - 1) / 8 + 1; } -std::size_t scl::math::Number::BitSize() const { +std::size_t math::Number::bitSize() const { return mpz_sizeinbase(m_value, 2); } -bool scl::math::Number::TestBit(std::size_t index) const { +bool math::Number::testBit(std::size_t index) const { return mpz_tstbit(m_value, index); } -std::string scl::math::Number::ToString() const { +std::string math::Number::toString() const { char* cstr; cstr = mpz_get_str(nullptr, 16, m_value); std::stringstream ss; @@ -198,8 +199,8 @@ std::string scl::math::Number::ToString() const { return ss.str(); } -void scl::math::Number::Write(unsigned char* buf) const { - std::uint32_t size_and_sign = ByteSize(); +void math::Number::write(unsigned char* buf) const { + std::uint32_t size_and_sign = byteSize(); if (mpz_sgn(m_value) < 0) { size_and_sign |= (1 << 31); @@ -209,19 +210,19 @@ void scl::math::Number::Write(unsigned char* buf) const { mpz_export(buf + sizeof(std::uint32_t), NULL, 1, 1, 0, 0, m_value); } -scl::math::Number scl::math::LCM(const Number& a, const Number& b) { +math::Number math::lcm(const Number& a, const Number& b) { Number lcm; mpz_lcm(lcm.m_value, a.m_value, b.m_value); return lcm; } // LCOV_EXCL_LINE -scl::math::Number scl::math::GCD(const Number& a, const Number& b) { +math::Number math::gcd(const Number& a, const Number& b) { Number gcd; mpz_gcd(gcd.m_value, a.m_value, b.m_value); return gcd; } // LCOV_EXCL_LINE -scl::math::Number scl::math::ModInverse(const Number& val, const Number& mod) { +math::Number math::modInverse(const Number& val, const Number& mod) { if (mpz_sgn(mod.m_value) == 0) { throw std::invalid_argument("modulus cannot be 0"); } @@ -235,9 +236,9 @@ scl::math::Number scl::math::ModInverse(const Number& val, const Number& mod) { return inv; } // LCOV_EXCL_LINE -scl::math::Number scl::math::ModExp(const Number& base, - const Number& exp, - const Number& mod) { +math::Number math::modExp(const Number& base, + const Number& exp, + const Number& mod) { Number r; mpz_powm(r.m_value, base.m_value, exp.m_value, mod.m_value); return r; diff --git a/src/scl/math/secp256k1_helpers.h b/src/scl/math/secp256k1_helpers.h deleted file mode 100644 index 3053e04..0000000 --- a/src/scl/math/secp256k1_helpers.h +++ /dev/null @@ -1,68 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#ifndef SCL_MATH_SECP256K1_HELPERS_H -#define SCL_MATH_SECP256K1_HELPERS_H - -#include "scl/math/curves/secp256k1.h" -#include "scl/math/ec.h" - -namespace scl::math { - -/** - * @brief Helper class for Secp256k1::Field. - */ -template <> -struct FFAccess { - /** - * @brief Compare two field elements lexicographical. - */ - static bool IsSmaller(const FF& lhs, - const FF& rhs); - - /** - * @brief Compute the square root of an element - */ - static FF ComputeSqrt(const FF& x); -}; - -/** - * @brief Helper class for Secp256k1::Order. - */ -template <> -struct FFAccess { - /** - * @brief Convert a field element out of montgomery representation. - */ - static FF FromMonty(const FF& element); - - /** - * @brief Find the position of the highest set bit. - */ - static std::size_t HigestSetBit(const FF& element); - - /** - * @brief Check if a particular bit is set. - * - * \p pos is assumed to be at or below HighestSetBit(\p element). - */ - static bool TestBit(const FF& element, std::size_t pos); -}; - -} // namespace scl::math - -#endif // SCL_MATH_SECP256K1_HELPERS_H diff --git a/src/scl/math/secp256k1_scalar.cc b/src/scl/math/secp256k1_scalar.cc deleted file mode 100644 index 91725fd..0000000 --- a/src/scl/math/secp256k1_scalar.cc +++ /dev/null @@ -1,160 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include - -#include - -#include "./secp256k1_helpers.h" -#include "scl/math/curves/secp256k1.h" -#include "scl/math/ec.h" -#include "scl/math/ff.h" -#include "scl/math/ff_ops.h" -#include "scl/math/ops_gmp_ff.h" - -using Field = scl::math::Secp256k1::Scalar; -using Elem = Field::ValueType; - -#define NUM_LIMBS 4 - -#define SCL_COPY(out, in, size) \ - do { \ - for (std::size_t i = 0; i < (size); ++i) { \ - *((out) + i) = *((in) + i); \ - } \ - } while (0) - -template <> -scl::math::Number scl::math::Order>() { - return Number::FromString( - "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141"); -} - -static const scl::math::RedParams RD = { - // Prime - { - 0xBFD25E8CD0364141, // - 0xBAAEDCE6AF48A03B, // - 0xFFFFFFFFFFFFFFFE, // - 0xFFFFFFFFFFFFFFFF // - }, - // Montgomery constant - { - 0x4B0DFF665588B13F, // - 0x50A51AC834B9EC24, // - 0x897F30C127CFAB5E, // - 0xD9E8890D6494EF93 // - }}; - -#define PTR(X) (X).data() - -template <> -void scl::math::FieldConvertIn(Elem& out, const int value) { - out = {0}; - MontyInFromInt(PTR(out), value, RD); -} - -template <> -void scl::math::FieldAdd(Elem& out, const Elem& op) { - MontyModAdd(PTR(out), PTR(op), RD); -} - -template <> -void scl::math::FieldSubtract(Elem& out, const Elem& op) { - MontyModSub(PTR(out), PTR(op), RD); -} - -template <> -void scl::math::FieldNegate(Elem& out) { - MontyModNeg(PTR(out), RD); -} - -template <> -void scl::math::FieldMultiply(Elem& out, const Elem& op) { - MontyModMul(PTR(out), PTR(op), RD); -} - -#define ONE \ - { 0x402DA1732FC9BEBF, 0x4551231950B75FC4, 0x1, 0 } - -template <> -void scl::math::FieldInvert(Elem& out) { - static const mp_limb_t kPrimeMinus2[NUM_LIMBS] = { - 0xBFD25E8CD036413F, // - 0xBAAEDCE6AF48A03B, // - 0xFFFFFFFFFFFFFFFE, // - 0xFFFFFFFFFFFFFFFF // - }; - - Elem res = ONE; - MontyModInv(PTR(res), PTR(out), kPrimeMinus2, RD); - out = res; -} - -template <> -bool scl::math::FieldEqual(const Elem& in1, const Elem& in2) { - return CompareValues(PTR(in1), PTR(in2)) == 0; -} - -template <> -void scl::math::FieldFromBytes(Elem& dest, const unsigned char* src) { - MontyFromBytes(PTR(dest), src, RD); -} - -template <> -void scl::math::FieldToBytes(unsigned char* dest, const Elem& src) { - MontyToBytes(dest, PTR(src), RD); -} - -template <> -std::string scl::math::FieldToString(const Elem& in) { - return MontyToString(PTR(in), RD); -} - -template <> -void scl::math::FieldFromString(Elem& out, const std::string& src) { - out = {0}; - MontyFromString(PTR(out), src, RD); -} - -std::size_t scl::math::FFAccess::HigestSetBit( - const scl::math::FF& element) { - return mpn_sizeinbase(PTR(element.m_value), NUM_LIMBS, 2); -} - -bool scl::math::FFAccess::TestBit(const scl::math::FF& element, - std::size_t pos) { - const auto bits_per_limb = static_cast(mp_bits_per_limb); - const auto limb = pos / bits_per_limb; - const auto limb_pos = pos % bits_per_limb; - return ((element.m_value[limb] >> limb_pos) & 1) == 1; -} - -scl::math::FF scl::math::FFAccess::FromMonty( - const scl::math::FF& element) { - mp_limb_t padded[2 * NUM_LIMBS] = {0}; - SCL_COPY(padded, PTR(element.m_value), NUM_LIMBS); - MontyRedc(padded, RD); - - FF r; - SCL_COPY(PTR(r.m_value), padded, NUM_LIMBS); - - return r; -} - -#undef ONE diff --git a/src/scl/net/channel.cc b/src/scl/net/channel.cc deleted file mode 100644 index 1824eeb..0000000 --- a/src/scl/net/channel.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include "scl/net/channel.h" - -void scl::net::Channel::Send(const scl::net::Packet& packet) { - const auto size_size = sizeof(net::Packet::SizeType); - unsigned char size_buf[size_size]; - net::Packet::SizeType packet_size = packet.Size(); - std::memcpy(size_buf, &packet_size, size_size); - Send(size_buf, size_size); - Send(packet.Get(), packet.Size()); -} - -std::optional scl::net::Channel::Recv(bool block) { - const auto size_size = sizeof(net::Packet::SizeType); - unsigned char size_buf[size_size]; - net::Packet::SizeType packet_size; - - if (block) { - Recv(size_buf, size_size); - } else { - if (HasData()) { - Recv(size_buf, size_size); - } else { - return {}; - } - } - - std::memcpy(&packet_size, size_buf, size_size); - - net::Packet p(packet_size); - Recv(p.Get(), packet_size); - p.SetWritePtr(packet_size); - return p; -} diff --git a/src/scl/net/config.cc b/src/scl/net/config.cc index cd051a3..df3a74a 100644 --- a/src/scl/net/config.cc +++ b/src/scl/net/config.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -23,9 +23,11 @@ #include #include +using namespace scl; + namespace { -void ValidateIdAndSize(std::size_t id, std::size_t n) { +void validateIdAndSize(std::size_t id, std::size_t n) { if (n == 0) { throw std::invalid_argument("n cannot be zero"); } @@ -37,9 +39,8 @@ void ValidateIdAndSize(std::size_t id, std::size_t n) { } // namespace -scl::net::NetworkConfig scl::net::NetworkConfig::Load( - std::size_t id, - const std::string& filename) { +net::NetworkConfig net::NetworkConfig::load(std::size_t id, + const std::string& filename) { std::ifstream file(filename); if (!file.is_open()) { @@ -66,16 +67,15 @@ scl::net::NetworkConfig scl::net::NetworkConfig::Load( info.emplace_back(Party{id, hostname, port}); } - ValidateIdAndSize(id, info.size()); + validateIdAndSize(id, info.size()); return NetworkConfig(id, info); } -scl::net::NetworkConfig scl::net::NetworkConfig::Localhost( - std::size_t id, - std::size_t size, - std::size_t port_base) { - ValidateIdAndSize(id, size); +net::NetworkConfig net::NetworkConfig::localhost(std::size_t id, + std::size_t size, + std::size_t port_base) { + validateIdAndSize(id, size); std::vector info; for (std::size_t i = 0; i < size; ++i) { @@ -86,24 +86,10 @@ scl::net::NetworkConfig scl::net::NetworkConfig::Localhost( return NetworkConfig(id, info); } -std::string scl::net::NetworkConfig::ToString() const { - std::stringstream ss; - ss << "[id=" << m_id << ", "; - std::size_t i = 0; - for (; i < m_parties.size() - 1; i++) { - const auto party = m_parties[i]; - ss << "{" << party.id << ", " << party.hostname << ", " << party.port - << "}, "; - } - const auto last = m_parties[i]; - ss << "{" << last.id << ", " << last.hostname << ", " << last.port << "}]"; - return ss.str(); -} - -void scl::net::NetworkConfig::Validate() { - auto n = NetworkSize(); +void net::NetworkConfig::validate() { + auto n = networkSize(); - if (static_cast(Id()) >= n) { + if (static_cast(id()) >= n) { throw std::invalid_argument("my ID is invalid in config"); } diff --git a/src/scl/net/mem_channel.cc b/src/scl/net/mem_channel.cc deleted file mode 100644 index 5dacae7..0000000 --- a/src/scl/net/mem_channel.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include "scl/net/mem_channel.h" - -#include - -// used to silence narrowing conversion errors. x will have type std::size_t -#define DIFF_T(x) static_cast::difference_type>((x)) - -void scl::net::MemoryBackedChannel::Send(const unsigned char* src, - std::size_t n) { - m_out->PushBack(std::vector(src, src + n)); -} - -std::size_t scl::net::MemoryBackedChannel::Recv(unsigned char* dst, - std::size_t n) { - std::size_t rem = n; - - // if there's any leftovers from previous calls to recv, then we retrieve - // those first. - const auto leftovers = m_overflow.size(); - if (leftovers > 0) { - const auto to_copy = leftovers > rem ? rem : leftovers; - auto* data = m_overflow.data(); - std::memcpy(dst, data, to_copy); - rem -= to_copy; - m_overflow = - std::vector(m_overflow.begin() + DIFF_T(to_copy), - m_overflow.end()); - } - - while (rem > 0) { - auto data = m_in->Pop(); - const auto to_copy = data.size() > rem ? rem : data.size(); - std::memcpy(dst + (n - rem), data.data(), to_copy); - rem -= to_copy; - - // if we didn't copy all of data, then rem == 0 and we need to save what - // remains to overflow - if (to_copy < data.size()) { - const auto leftovers = data.size() - to_copy; - const auto old_size = m_overflow.size(); - m_overflow.reserve(old_size + leftovers); - m_overflow.insert(m_overflow.begin() + DIFF_T(old_size), - data.begin() + DIFF_T(to_copy), - data.end()); - } - } - - return n; -} diff --git a/src/scl/net/network.cc b/src/scl/net/network.cc index b08e0b0..8658c9e 100644 --- a/src/scl/net/network.cc +++ b/src/scl/net/network.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,55 +17,129 @@ #include "scl/net/network.h" +#include +#include +#include +#include +#include #include +#include "scl/coro/coroutine.h" +#include "scl/coro/future.h" #include "scl/net/channel.h" -#include "scl/net/mem_channel.h" +#include "scl/net/config.h" +#include "scl/net/loopback.h" #include "scl/net/sys_iface.h" +#include "scl/net/tcp_channel.h" #include "scl/net/tcp_utils.h" -#include "scl/net/threaded_sender.h" -scl::net::FakeNetwork scl::net::FakeNetwork::Create(unsigned id, - std::size_t n) { - std::vector> channels; - std::vector> remotes; - channels.reserve(n); - remotes.reserve(n); +using namespace scl; +using namespace std::chrono_literals; - for (std::size_t i = 0; i < n; ++i) { - if (i == id) { - channels.emplace_back(MemoryBackedChannel::CreateLoopback()); - remotes.emplace_back(nullptr); +namespace { + +template +coro::Task writePartyId(net::SocketType socket, std::uint32_t party_id) { + SYS::write(socket, &party_id, sizeof(std::uint32_t)); + co_return; +} + +template +coro::Task readPartyId(net::SocketType socket) { + std::uint32_t party_id; + while (true) { + auto read = SYS::read(socket, &party_id, sizeof(std::uint32_t)); + if (read < 0) { + const auto err = SYS::getError(); + if (err == EAGAIN || err == EWOULDBLOCK) { + co_await [sock = socket]() { + return net::details::pollSocket(sock, POLLIN); + }; + } } else { - auto chls = MemoryBackedChannel::CreatePaired(); - channels.emplace_back(chls[0]); - remotes.emplace_back(chls[1]); + break; } } + co_return party_id; +} + +struct SocketAndId { + net::SocketType socket; + std::size_t id; +}; + +template +coro::Task acceptConnection(net::SocketType server_socket) { + while (true) { + if (net::details::pollSocket(server_socket, POLLIN)) { + auto conn = net::details::acceptConnection(server_socket); + net::details::markSocketNonBlocking(conn.socket); - return FakeNetwork{id, Network{channels, id}, remotes}; + auto id = co_await readPartyId(conn.socket); + + co_return {conn.socket, id}; + } else { + co_await 100ms; + } + } } -std::vector scl::net::CreateMemoryBackedNetwork( - std::size_t n) { - std::vector>> channels(n); +template +coro::Task establishConnection(net::Party party, + std::size_t my_id) { + std::size_t attempts = 100; // max attempts. - for (std::size_t i = 0; i < n; ++i) { - channels[i] = std::vector>(n); + while (attempts > 0) { + net::SocketType socket = -1; + + socket = net::details::connectAsClient(party.hostname, (int)party.port); + // TODO: What errors to retry on? + + attempts--; + + if (socket == -1) { + co_await 100ms; + } else { + net::details::markSocketNonBlocking(socket); + co_await writePartyId(socket, my_id); + co_return {socket, party.id}; + } } - std::vector networks; - networks.reserve(n); + throw std::runtime_error("could not establish connection to party"); +} + +} // namespace + +coro::Task net::Network::create(const NetworkConfig& config) { + std::vector> channels(config.networkSize()); + + const std::size_t id = config.id(); + const std::size_t n = config.networkSize(); + channels[id] = LoopbackChannel::create(); + + std::vector> tasks; + + const auto me = config.party(id); + auto server_socket = details::createServerSocket((int)me.port, 128); + details::markSocketNonBlocking(server_socket); for (std::size_t i = 0; i < n; ++i) { - channels[i][i] = MemoryBackedChannel::CreateLoopback(); - for (std::size_t j = i + 1; j < n; ++j) { - auto chls = MemoryBackedChannel::CreatePaired(); - channels[i][j] = chls[0]; - channels[j][i] = chls[1]; + if (i < id) { + tasks.emplace_back( + establishConnection(config.party(i), me.id)); + } else if (i > id) { + tasks.emplace_back(acceptConnection(server_socket)); } - networks.emplace_back(Network{channels[i], i}); } - return networks; + std::vector sais = co_await coro::batch(std::move(tasks)); + + details::SysIFace::close(server_socket); + + for (const SocketAndId& sai : sais) { + channels[sai.id] = std::make_shared>(sai.socket); + } + + co_return Network{channels, config.id()}; } diff --git a/src/scl/simulation/channel.cc b/src/scl/simulation/channel.cc index 118eea1..50c05ca 100644 --- a/src/scl/simulation/channel.cc +++ b/src/scl/simulation/channel.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,255 +17,107 @@ #include "scl/simulation/channel.h" -#include +#include -#include "scl/simulation/channel_id.h" -#include "scl/simulation/context.h" +#include "scl/coro/runtime.h" #include "scl/simulation/event.h" -#include "scl/simulation/simulator.h" #include "scl/util/time.h" -using EventPtr = std::shared_ptr; - -EventPtr scl::sim::SimulateClose(std::shared_ptr ctx, ChannelId id) { - const auto lid = id.local; - const auto trt = ctx->Checkpoint(lid); - return std::make_shared(Event::Type::CLOSE, trt, id); -} - -#define SCL_LOCAL_COMP_BEGIN const auto scl__lcb = scl::util::Time::Now() -#define SCL_LOCAL_COMP_END scl::util::Time::Now() - scl__lcb - -EventPtr scl::sim::SimulateSend(std::shared_ptr ctx, - ChannelId id, - const unsigned char* src, - std::size_t n) { - SCL_LOCAL_COMP_BEGIN; - - ctx->Buffer(id)->Write(src, n); - - const auto local_comp_time = SCL_LOCAL_COMP_END; - const auto exec_time = ctx->Checkpoint(id.local) - local_comp_time; - - auto event = - std::make_shared(Event::Type::SEND, exec_time, id, n); - ctx->AddCandidateToRun(id.remote); - ctx->AddWrite(id, n, exec_time); - return event; -} +using namespace scl; namespace { -scl::util::Time::Duration AdjustRecvTime(std::shared_ptr ctx, - scl::sim::ChannelId id, - scl::util::Time::Duration t, - std::size_t n) { - auto rem = n; - - while (rem > 0 && ctx->HasWrite(id)) { - auto& w = ctx->NextWrite(id); - scl::util::Time::Duration recv_time; - if (w.amount > rem) { - const auto delay = - scl::sim::ComputeRecvTime(ctx->ChannelConfiguration(id), rem); - recv_time = w.time + delay; - w.amount -= rem; - rem = 0; - } else { - const auto delay = - scl::sim::ComputeRecvTime(ctx->ChannelConfiguration(id), w.amount); - recv_time = w.time + delay; - rem -= w.amount; - ctx->DeleteWrite(id); - } - - t = std::max(t, recv_time); - } - - return t; +std::size_t totalPacketSize(const net::Packet& packet) { + return packet.size() + sizeof(net::Packet::SizeType); } } // namespace -EventPtr scl::sim::SimulateRecv(std::shared_ptr ctx, - ChannelId id, - unsigned char* dst, - std::size_t n) { - SCL_LOCAL_COMP_BEGIN; - - if (ctx->Buffer(id)->Size() < n) { - ctx->AddCandidateToRun(id.remote); - throw SimulationFailure(); - } - - ctx->Buffer(id)->Read(dst, n); - - const auto local_comp_time = SCL_LOCAL_COMP_END; - const auto exec_time = ctx->Checkpoint(id.local) - local_comp_time; - const auto adjusted_time = AdjustRecvTime(ctx, id.Flip(), exec_time, n); - - return std::make_shared(Event::Type::RECV, - exec_time, - adjusted_time - exec_time, - id, - n); +void sim::details::SimulatedChannel::close() { + util::Time::Duration elapsed = m_context.elapsedTime(); + m_context.recordEvent(Event::closeChannel(elapsed, m_cid)); + m_context.startClock(); } -std::pair scl::sim::SimulateHasData( - std::shared_ptr ctx, - ChannelId id) { - // The other party hasn't had a chance to run yet, so it's not possible to - // determine if there's data available for us. - if (ctx->Trace(id.remote).empty()) { - ctx->AddCandidateToRun(id.remote); - throw SimulationFailure("other party hasnt started yet"); - } - - // We determine if there is data available by inspecting the list of WriteOps - // created by the remote party. Since each WriteOp has a timestamp, we can use - // that to determine if the data would have arrived at us yet. - // - // The rules for what to return, and when to fail the simulation goes as - // follows: - // - // - WriteOp op exists such that op.amount > 0. This op corresponds to the - // data that we would receive the next time we call Recv on this channel. - // - // If it is the case that - // - // op.time + time_to_send_1_byte <= our_current_time, - // - // then we can return has_data == true. Otherwise, we can return false. - // Note that, even if the remote party is behind is in time, we know that - // it is not possible for it to send data that we would receive earlier - // than the data connected to op. - // - // - No WriteOp exists. In this case, we either return has_data == false, or - // we fail the simulation. We can return has_data == false if - // - // remote_current_time - time_to_send_1_byte >= our_current_time - // - // as we know that no Send that the remote party makes, would have arrived - // to us before now. On the other hand, if the above does not hold, then we - // cannot say for sure that the remote party might not send data that we - // would be able to receive now, and so we have to fail the simulation. - - // Time it takes for 1 byte to go from the remote party to us. - const auto offset = ComputeRecvTime(ctx->ChannelConfiguration(id.Flip()), 1); +coro::Task sim::details::SimulatedChannel::send(net::Packet&& packet) { + util::Time::Duration elapsed = m_context.elapsedTime(); + const std::size_t nbytes = totalPacketSize(packet); + m_context.send(m_cid.remote, elapsed); - // Go through each write op of the other party, and find the earliest one. - const auto me_latest = ctx->Checkpoint(id.local); - bool has_data = false; - bool has_result = false; - if (ctx->HasWrite(id.Flip())) { - if (ctx->NextWrite(id.Flip()).time + offset <= me_latest) { - has_data = true; - } else { - has_data = false; - has_result = true; - } - } - - // Handle the case where no WriteOp existed at all. Here we will fail the - // simulation if the remote party is too far behind us in time. - if (!has_data && !has_result) { - const auto other_latest = ctx->LatestTimestamp(id.remote) - offset; - if (!ctx->HasTerminated(id.remote) && other_latest <= me_latest) { - ctx->AddCandidateToRun(id.remote); - throw SimulationFailure("no data, and we're ahead"); - } - } + m_transport->send(m_cid, std::move(packet)); - const auto event = std::make_shared(me_latest, id, has_data); - return {has_data, event}; + m_context.recordEvent(Event::sendData(elapsed, m_cid, nbytes)); + m_context.startClock(); + co_return; } -void scl::sim::Channel::Send(const scl::net::Packet& packet) { - const auto packet_size = packet.Size(); - const auto size_size = sizeof(net::Packet::SizeType); +coro::Task sim::details::SimulatedChannel::send( + const net::Packet& packet) { + util::Time::Duration elapsed = m_context.elapsedTime(); + const std::size_t nbytes = totalPacketSize(packet); + m_context.send(m_cid.remote, elapsed); - // A packet is a size + content, which are sent separately. - scl::net::Channel::Send(packet); + m_transport->send(m_cid, packet); - // Sending the size and conte each generate a "SEND" event. These are removed - // here, and replaced by a single "PACKET_SEND" event that is set to have - // happened at the same time as the first event, and with an amount equal to - // the sum of the two events. - const auto data_event = m_ctx->PopLastEvent(m_id.local); - const auto size_event = m_ctx->PopLastEvent(m_id.local); - const auto event = - std::make_shared(Event::Type::PACKET_SEND, - size_event->Timestamp(), - m_id, - size_size + packet_size); - m_ctx->AddEvent(m_id.local, event); + m_context.recordEvent(Event::sendData(elapsed, m_cid, nbytes)); + m_context.startClock(); + co_return; } -namespace { +coro::Task sim::details::SimulatedChannel::recv() { + util::Time::Duration elapsed = m_context.elapsedTime(); -std::size_t GetDataAmount(scl::sim::Event* event) { - return reinterpret_cast(event)->DataAmount(); -} + m_context.recvStart(m_cid.remote); -} // namespace + // block until there is data available on the transport. + co_await [tp = m_transport, cid = m_cid]() { return tp->hasData(cid); }; -std::optional scl::sim::Channel::Recv(bool block) { - // A packet is received a little differently, depending on whether it blocks - // or not. If the recv is blocking, then we receive a size + content. If the - // receive is non-blocking, then we first check if there's data before - // receiving the size + content. + auto packet = m_transport->recv(m_cid); - auto p = net::Channel::Recv(block); + m_context.recvDone(m_cid.remote); - if (block) { - // Receive was blocking, so we need to remove the two last events, - // corresponding to the receiving the size of the packet, and the packet's - // content. The information in these two events is then turned into a - // PACKET_RECV event. - const auto data_event = m_ctx->PopLastEvent(m_id.local); - const auto size_event = m_ctx->PopLastEvent(m_id.local); + elapsed = m_context.recv(m_cid.remote, totalPacketSize(packet), elapsed); - m_ctx->AddEvent( - m_id.local, - std::make_shared( - size_event->Timestamp() - size_event->Offset(), - size_event->Offset() + data_event->Offset(), - m_id, - GetDataAmount(data_event.get()) + GetDataAmount(size_event.get()), - true)); - } else { - // If the receive was non-blocking, then we either have one event (in case - // there was no data to receive), or three (in case there was data to - // receive). - // - // The extra event here, compared to the blocking case, is an event arising - // from a call to HasData. - if (p.has_value()) { - const auto data_event = m_ctx->PopLastEvent(m_id.local); - const auto size_event = m_ctx->PopLastEvent(m_id.local); - const auto hd_event = m_ctx->PopLastEvent(m_id.local); + const std::size_t nbytes = totalPacketSize(packet); + m_context.recordEvent(Event::recvData(elapsed, m_cid, nbytes)); + m_context.startClock(); + co_return packet; +} - const auto event = std::make_shared( - hd_event->Timestamp(), - size_event->Offset() + data_event->Offset(), - m_id, - GetDataAmount(data_event.get()) + GetDataAmount(size_event.get()), - false); +coro::Task sim::details::SimulatedChannel::hasData() { + util::Time::Duration now = m_context.elapsedTime(); + m_context.recordEvent(Event::hasData(now, m_cid)); - m_ctx->AddEvent(m_id.local, event); + auto has_data = m_transport->hasData(m_cid); - } else { - const auto hd_event = m_ctx->PopLastEvent(m_id.local); + if (!has_data) { + const auto other = m_cid.remote; - m_ctx->AddEvent( - m_id.local, - std::make_shared(hd_event->Timestamp(), - util::Time::Duration::zero(), - m_id, - 0, - false)); - } + // have to consider three cases here: + // + // 1) If the remote party is ahead of us, then any data it sends will + // first + // arrive at some point in the future. + // + // 2) If the remote party is dead, then _no_ data will arrive to us. + // + // 3) If the remote party is trying to receive data from us, then it will + // not have data for us until we send something, which cannot be + // earlier than "now". In particular, we won't receive data from remote + // until at some point after whatever "now" is. + co_await [now, ctx = m_context, other]() { + const auto remote_ahead = now < ctx.currentTimeOf(other); + const auto remote_dead = ctx.dead(other); + const auto remote_waiting_for_us = ctx.receiving(other); + + return remote_ahead || remote_dead || remote_waiting_for_us; + }; + + m_context.startClock(); + // query the transport again. + co_return m_transport->hasData(m_cid); } - return p; -} // LCOV_EXCL_LINE + m_context.startClock(); + co_return true; +} diff --git a/src/scl/simulation/config.cc b/src/scl/simulation/config.cc index a9e36b4..3ab9977 100644 --- a/src/scl/simulation/config.cc +++ b/src/scl/simulation/config.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -22,7 +22,7 @@ using namespace scl; -void sim::ChannelConfig::Builder::Validate() const { +void sim::ChannelConfig::Builder::validate() const { if (m_bandwidth.has_value()) { if (m_bandwidth.value() == 0) { throw std::invalid_argument("bandwidth cannot be 0"); @@ -52,14 +52,14 @@ void sim::ChannelConfig::Builder::Validate() const { } std::ostream& sim::operator<<(std::ostream& os, const ChannelConfig& config) { - if (config.Type() == sim::ChannelConfig::NetworkType::TCP) { + if (config.type() == sim::ChannelConfig::NetworkType::TCP) { os << "SimulationConfig{"; os << "Type: TCP, "; - os << "Bandwidth: " << config.Bandwidth() << " bits/s, "; + os << "Bandwidth: " << config.bandwidth() << " bits/s, "; os << "RTT: " << config.RTT() << " ms, "; os << "MSS: " << config.MSS() << " bytes, "; - os << "PackageLoss: " << 100 * config.PackageLoss() << "%, "; - os << "WindowSize: " << config.WindowSize() << " bytes}"; + os << "PackageLoss: " << 100 * config.packetLoss() << "%, "; + os << "WindowSize: " << config.windowSize() << " bytes}"; } else { os << "SimulationConfig{INSTANT}"; } @@ -67,10 +67,10 @@ std::ostream& sim::operator<<(std::ostream& os, const ChannelConfig& config) { return os; } -sim::ChannelConfig sim::ChannelConfig::Default() { - return ChannelConfig::Builder{}.Build(); +sim::ChannelConfig sim::ChannelConfig::defaultConfig() { + return ChannelConfig::Builder{}.build(); } -sim::ChannelConfig sim::ChannelConfig::Loopback() { - return ChannelConfig::Builder{}.Type(NetworkType::INSTANT).Build(); +sim::ChannelConfig sim::ChannelConfig::loopback() { + return ChannelConfig::Builder{}.type(NetworkType::INSTANT).build(); } diff --git a/src/scl/simulation/context.cc b/src/scl/simulation/context.cc index 832f7a0..8b92759 100644 --- a/src/scl/simulation/context.cc +++ b/src/scl/simulation/context.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,141 +17,215 @@ #include "scl/simulation/context.h" +#include +#include +#include + #include "scl/simulation/config.h" #include "scl/simulation/event.h" -#include "scl/simulation/mem_channel_buffer.h" -#include "scl/simulation/simulator.h" +#include "scl/simulation/hook.h" +#include "scl/util/bitmap.h" using namespace scl; +// makes things a bit nicer to read. +using GlobalCtx = sim::details::GlobalContext; -template <> -std::shared_ptr -sim::Context::Create( - std::size_t number_of_parties, - std::shared_ptr config) { - auto ctx = std::make_shared(config); - - ctx->m_nparties = number_of_parties; - ctx->m_traces.resize(number_of_parties); - - for (std::size_t i = 0; i < number_of_parties; ++i) { - ctx->m_buffers[ChannelId(i, i)] = - MemoryBackedChannelBuffer::CreateLoopback(); - for (std::size_t j = i + 1; j < number_of_parties; ++j) { - ChannelId cid(i, j); - auto cp = MemoryBackedChannelBuffer::CreatePaired(); - ctx->m_buffers[cid] = cp[0]; - ctx->m_buffers[cid.Flip()] = cp[1]; - } +namespace { + +std::vector initBitmaps(std::size_t number_of_parties) { + std::vector bms; + bms.reserve(number_of_parties); + for (std::size_t i = 0; i < number_of_parties; i++) { + bms.emplace_back(number_of_parties); + } + return bms; +} + +} // namespace + +GlobalCtx GlobalCtx::create(std::size_t number_of_parties, + std::unique_ptr network_config, + std::vector hooks) { + std::vector traces(number_of_parties); + + for (std::size_t i = 0; i < number_of_parties; i++) { + traces.reserve(1024); + } + + std::unordered_map> sends; + std::vector clocks(number_of_parties); + std::vector recv_map = initBitmaps(number_of_parties); + + return {number_of_parties, + std::move(network_config), + traces, + sends, + clocks, + recv_map, + util::Bitmap(number_of_parties), + std::move(hooks)}; +} + +util::Time::Duration GlobalCtx::LocalContext::lastEventTimestamp() const { + if (!m_gctx.traces[m_id].empty()) { + return m_gctx.traces[m_id].back()->timestamp; } + return util::Time::Duration::zero(); +} - return ctx; -} // LCOV_EXCL_LINE +util::Time::Duration GlobalCtx::LocalContext::elapsedTime() const { + const util::Time::Duration most_recent = lastEventTimestamp(); + return most_recent + (util::Time::now() - m_gctx.clocks[m_id]); +} + +void GlobalCtx::LocalContext::startClock() { + m_gctx.clocks[m_id] = util::Time::now(); +} namespace { -std::size_t Next(std::size_t id, std::size_t n) { - return (id + 1) % n; +// Computes total size in bits that nbytes of data would occupy provided some +// maximum segment size +long double sizeWithHeadersInBits(std::size_t nbytes, + std::size_t mss) noexcept { + static constexpr std::size_t TCP_IP_HEADER = 40; + const std::size_t num_packets = std::ceil((double)nbytes / (double)mss); + return 8 * (nbytes + num_packets * TCP_IP_HEADER); } -} // namespace +// Converts the RTT in a config, assumed to be in ms, to seconds. +long double rttSeconds(const sim::ChannelConfig& config) noexcept { + using namespace std::chrono_literals; + const auto d = std::chrono::milliseconds(config.RTT()); + return d / 1.0s; +} -std::optional sim::Context::NextToRun( - std::optional current) { - // party 0 is always the party to go first. - if (!current.has_value()) { - return 0; - } +// Computes the throughput of a channel assuming a package loss of 0%. +long double throughputNoLoss(const sim::ChannelConfig& config) noexcept { + // Simple throughput formula: + // https://tetcos.com/pdf/v13/Experiments/Mathematical-Modelling-of-TCP-Throughput-Performance.pdf + const long double rtt = rttSeconds(config); + const long double wndz = 8 * (long double)config.windowSize(); + const long double max_throughput = wndz / rtt; - // we end here current throw a SimulationFailure. This only happens when it - // fails to either call Recv or HasData. - if (m_state == State::ROLLBACK) { - // the last party in m_next_party_cand is assumed to be the party for - // which current tried to Recv or HasData from. - const auto next = m_next_party_cand.back(); - - // if this party has already finished, then current will never be able to - // finish, so we crash the simulation here. - if (HasTerminated(next)) { - throw SimulationFailure( - "party tried to receive data from terminated party"); - } + // actual throughput obviously cannot exceed the capacity of the link. + const long double bw = (long double)config.bandwidth(); + const long double actual_throughput = std::min(max_throughput, bw); - // if the party is the same as current, then we are performing a rollback - // because we did not send enough data to ourselves. That data is never - // going to arrive, so there's no hope of saving the simulation. - if (next == current) { - throw SimulationFailure("infinite loop detected"); - } + return actual_throughput; +} - return next; - } +// Computes the throughput of a channel assuming a package loss of > 0%. This +// uses the Mathis formula: +// https://cseweb.ucsd.edu/classes/wi01/cse222/papers/mathis-tcpmodel-ccr97.pdf +long double throughputLoss(const sim::ChannelConfig& config) noexcept { + const long double mss = config.MSS(); + const long double loss_term = std::sqrt(3.0 / (2.0 * config.packetLoss())); + const long double rtt = rttSeconds(config); - std::size_t next = Next(current.value(), m_nparties); - std::size_t terminated = 0; - while (terminated < m_nparties) { - if (!HasTerminated(next)) { - return next; - } - terminated++; - next = Next(next, m_nparties); + return loss_term * (8 * mss / rtt); +} + +// Computes the receive time of some amount of data on a TCP channel. +util::Time::Duration recvTimeTCP(const sim::ChannelConfig& config, + std::size_t n) { + const long double total_size_bits = sizeWithHeadersInBits(n, config.MSS()); + long double actual_tp = throughputNoLoss(config); + + if (config.packetLoss() > 0) { + const long double tp = throughputLoss(config); + actual_tp = std::min(tp, actual_tp); } - return {}; + const long double t = total_size_bits / actual_tp + rttSeconds(config); + const auto t_sec = std::chrono::duration(t); + return std::chrono::duration_cast(t_sec); } -util::Time::Duration sim::Context::Checkpoint(std::size_t id) { - const auto latest = LatestTimestamp(id); - const auto last_checkpoint = m_checkpoint; - UpdateCheckpoint(); - return latest + (m_checkpoint - last_checkpoint); +// Computes the delay that sending an amount of bytes would incur. +util::Time::Duration adjustSendTime(const sim::ChannelConfig& config, + util::Time::Duration send_time, + std::size_t n) { + if (config.type() == sim::ChannelConfig::NetworkType::TCP) { + return send_time + recvTimeTCP(config, n); + } + return send_time; } -void sim::Context::Prepare(std::size_t id) { - if (m_state == State::COMMIT || m_state == State::ROLLBACK) { - // Save the current head of m_traces so we can discard new events if this - // party has to rollback. - m_trace_index = m_traces[id].size(); - m_next_party_cand.clear(); +} // namespace - // Save the current m_writes map. Recv operations will change writes made by - // other parties, so this is the easiest way to make sure Rollback does the - // right thing. - m_writes_backup = m_writes; - for (std::size_t i = 0; i < m_nparties; ++i) { - auto cid = ChannelId(id, i); - m_buffers[cid]->Prepare(); +void GlobalCtx::LocalContext::recordEvent(std::shared_ptr event) { + m_gctx.traces[m_id].emplace_back(event); + + const auto event_type = event->type; + for (const auto& [trigger, hook] : m_gctx.hooks) { + if (trigger.has_value()) { + if (trigger.value() == event_type) { + hook->run(m_id, getContext()); + } + } else { + hook->run(m_id, getContext()); } - } else { - throw std::logic_error("cannot prepare ctx"); } - m_state = State::PREPARE; } -void sim::Context::Commit(std::size_t id) { - if (m_state == State::PREPARE) { - m_writes_backup.clear(); - for (std::size_t i = 0; i < m_nparties; ++i) { - ChannelId cid(id, i); - m_buffers[cid]->Commit(); - } +util::Time::Duration GlobalCtx::LocalContext::recv( + std::size_t sender, + std::size_t nbytes, + util::Time::Duration timestamp) { + // Channel ID corresponding to the channel that the remote party writes to. + const ChannelId id{.local = sender, .remote = m_id}; + const util::Time::Duration send_time = m_gctx.sends[id].front(); + m_gctx.sends[id].pop_front(); + + const ChannelConfig cconf = m_gctx.network_config->get(id); + return std::max(timestamp, adjustSendTime(cconf, send_time, nbytes)); +} - } else { - throw std::logic_error("cannot commit"); +void GlobalCtx::LocalContext::recvStart(std::size_t id) { + m_gctx.recv_map[m_id].set(id, true); +} + +void GlobalCtx::LocalContext::recvDone(std::size_t id) { + m_gctx.recv_map[m_id].set(id, false); +} + +bool GlobalCtx::LocalContext::receiving(std::size_t receiver) const { + return m_gctx.recv_map[receiver].at(m_id); +} + +bool GlobalCtx::LocalContext::dead(std::size_t id) const { + if (m_gctx.traces[id].empty()) { + return false; } - m_state = State::COMMIT; + + const auto last_event_type = m_gctx.traces[id].back()->type; + return last_event_type == EventType::STOP || + last_event_type == EventType::KILLED || + last_event_type == EventType::CANCELLED; } -void sim::Context::Rollback(std::size_t id) { - if (m_state == State::PREPARE) { - m_traces[id].resize(m_trace_index); - m_writes = m_writes_backup; - for (std::size_t i = 0; i < m_nparties; ++i) { - ChannelId cid(id, i); - m_buffers[cid]->Rollback(); - } - } else { - throw std::logic_error("cannot rollback"); +util::Time::Duration GlobalCtx::LocalContext::currentTimeOf( + std::size_t other_party) const { + if (m_gctx.traces[other_party].empty()) { + return util::Time::Duration::zero(); } - m_state = State::ROLLBACK; + return m_gctx.traces[other_party].back()->timestamp; +} + +std::ostream& sim::details::operator<<( + std::ostream& os, + const sim::details::GlobalContext& global_ctx) { + os << "GLOBAL_CTX{"; + os << " number_of_parties=" << global_ctx.number_of_parties << "\n"; + os << " network_config=\n"; + os << " traces=\n"; + os << " sends=\n"; + os << " clocks=\n"; + os << " recv_map=\n"; + os << " cancellation_map=" << global_ctx.cancellation_map << "\n"; + os << " hooks=\n"; + os << "}\n"; + + return os; } diff --git a/src/scl/simulation/event.cc b/src/scl/simulation/event.cc index 99385a8..dfc9506 100644 --- a/src/scl/simulation/event.cc +++ b/src/scl/simulation/event.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,154 +20,285 @@ #include #include +#include "scl/simulation/channel_id.h" +#include "scl/util/time.h" + using namespace scl; -namespace { +std::shared_ptr sim::Event::start() { + return std::make_shared(EventType::START, + util::Time::Duration::zero()); +} -auto EventTypeToString(sim::Event::Type type) { - if (type == sim::Event::Type::START) { - return "START"; - } +std::shared_ptr sim::Event::stop(util::Time::Duration timestamp) { + return std::make_shared(EventType::STOP, timestamp); +} - if (type == sim::Event::Type::STOP) { - return "STOP"; - } +std::shared_ptr sim::Event::killed(util::Time::Duration timestamp, + const std::string& reason) { + return std::make_shared(timestamp, reason); +} - if (type == sim::Event::Type::SEND) { - return "SEND"; - } +std::shared_ptr sim::Event::cancelled( + util::Time::Duration timestamp) { + return std::make_shared(EventType::CANCELLED, timestamp); +} - if (type == sim::Event::Type::RECV) { - return "RECV"; - } +std::shared_ptr sim::Event::closeChannel( + util::Time::Duration timestamp, + sim::ChannelId channel_id) { + return std::make_shared(EventType::CLOSE, + timestamp, + channel_id); +} - if (type == sim::Event::Type::HAS_DATA) { - return "HAS_DATA"; - } +std::shared_ptr sim::Event::sendData(util::Time::Duration timestamp, + sim::ChannelId channel_id, + std::size_t amount) { + return std::make_shared(EventType::SEND, + timestamp, + channel_id, + amount); +} - if (type == sim::Event::Type::OUTPUT) { - return "OUTPUT"; - } +std::shared_ptr sim::Event::recvData(util::Time::Duration timestamp, + sim::ChannelId channel_id, + std::size_t amount) { + return std::make_shared(EventType::RECV, + timestamp, + channel_id, + amount); +} - if (type == sim::Event::Type::SLEEP) { - return "SLEEP"; - } +std::shared_ptr sim::Event::hasData(util::Time::Duration timestamp, + sim::ChannelId channel_id) { + return std::make_shared(EventType::HAS_DATA, + timestamp, + channel_id); +} - if (type == sim::Event::Type::SEGMENT_BEGIN) { - return "SEGMENT_BEGIN"; - } +std::shared_ptr sim::Event::sleep( + util::Time::Duration timestamp, + util::Time::Duration sleep_duration) { + return std::make_shared(EventType::SLEEP, + timestamp, + sleep_duration); +} - if (type == sim::Event::Type::SEGMENT_END) { - return "SEGMENT_END"; - } +std::shared_ptr sim::Event::output(util::Time::Duration timestamp) { + return std::make_shared(EventType::OUTPUT, timestamp); +} - if (type == sim::Event::Type::CHECKPOINT) { - return "CHECKPOINT"; - } +std::shared_ptr sim::Event::protocolBegin( + util::Time::Duration timestamp, + const std::string& protocol_name) { + return std::make_shared(EventType::PROTOCOL_BEGIN, + timestamp, + protocol_name); +} - if (type == sim::Event::Type::PACKET_SEND) { - return "PACKET_SEND"; - } +std::shared_ptr sim::Event::protocolEnd( + util::Time::Duration timestamp, + const std::string& protocol_name) { + return std::make_shared(EventType::PROTOCOL_END, + timestamp, + protocol_name); +} - if (type == sim::Event::Type::PACKET_RECV) { - return "PACKET_RECV"; - } +namespace { - if (type == sim::Event::Type::KILLED) { - return "KILLED"; +std::string eventTypeToString(sim::EventType type) { + switch (type) { + case sim::EventType::START: + return "START"; + break; + case sim::EventType::STOP: + return "STOP"; + break; + case sim::EventType::SEND: + return "SEND"; + break; + case sim::EventType::RECV: + return "RECV"; + break; + case sim::EventType::HAS_DATA: + return "HAS_DATA"; + break; + case sim::EventType::OUTPUT: + return "OUTPUT"; + break; + case sim::EventType::SLEEP: + return "SLEEP"; + break; + case sim::EventType::PROTOCOL_BEGIN: + return "PROTOCOL_BEGIN"; + break; + case sim::EventType::PROTOCOL_END: + return "PROTOCOL_END"; + break; + case sim::EventType::KILLED: + return "KILLED"; + break; + case sim::EventType::CANCELLED: + return "CANCELLED"; + break; + // case sim::EventType::CLOSE: + default: + return "CLOSE"; } +} - // if (type == Measurement::Type::CLOSE) - return "CLOSE"; +void writeObj(std::ostream& stream, const std::string& string) { + stream << "\"" << string << "\""; } -void WriteClose(std::ostream& os, const sim::NetworkEvent* m) { - os << " [Local=" << m->LocalParty() << ", Remote=" << m->RemoteParty() << "]"; +void writeKey(std::ostream& stream, const std::string& name) { + writeObj(stream, name); + stream << ":"; } -void WriteSend(std::ostream& os, const sim::NetworkDataEvent* m) { - os << " [" - << "Sender=" << m->LocalParty() << ", Receiver=" << m->RemoteParty() - << ", Amount=" << m->DataAmount() << "]"; +void writeObj(std::ostream& stream, const std::size_t& val) { + stream << val; } -void WriteRecv(std::ostream& os, const sim::NetworkDataEvent* m) { - os << " [" - << "Receiver=" << m->LocalParty() << ", Sender=" << m->RemoteParty() - << ", Amount=" << m->DataAmount() << "]"; +void writeObj(std::ostream& stream, const long double& val) { + stream << val; } -void WritePacketRecv(std::ostream& os, const sim::PacketRecvEvent* m) { - os << " [" - << "Receiver=" << m->LocalParty() << ", Sender=" << m->RemoteParty() - << ", Amount=" << m->DataAmount() << ", Blocking=" << std::boolalpha - << m->Blocking() << "]"; +void writeObj(std::ostream& stream, const util::Time::Duration& d) { + auto t = std::chrono::duration(d).count(); + writeObj(stream, t); } -void WriteSegment(std::ostream& os, const sim::SegmentEvent* m) { - const auto name = m->Name(); - if (name.empty()) { - os << " [Unnamed segment]"; - } else { - os << " [Name=" << name << "]"; - } +void writeObj(std::ostream& stream, const sim::ChannelId& id) { + stream << "{"; + + writeKey(stream, "local"); + writeObj(stream, id.local); + + stream << ","; + + writeKey(stream, "remote"); + writeObj(stream, id.remote); + + stream << "}"; } -void WriteHasData(std::ostream& os, const sim::HasDataEvent* m) { - os << " [Local=" << m->LocalParty() << ", Remote=" << m->RemoteParty() - << ", DataAvailable=" << std::boolalpha << m->HadData() << "]"; +void writeEvent(std::ostream& stream, const sim::ChannelEvent* event) { + stream << "{"; + + writeKey(stream, "channel_id"); + writeObj(stream, event->channel_id); + + stream << "}"; } -void WriteCheckpoint(std::ostream& os, const sim::CheckpointEvent* m) { - os << " [" << m->Id() << "]"; +void writeEvent(std::ostream& stream, const sim::ChannelDataEvent* event) { + stream << "{"; + + writeKey(stream, "channel_id"); + writeObj(stream, event->channel_id); + + stream << ","; + + writeKey(stream, "amount"); + writeObj(stream, event->amount); + + stream << "}"; +} + +void writeEvent(std::ostream& stream, const sim::SleepEvent* event) { + stream << "{"; + + writeKey(stream, "duration"); + writeObj(stream, event->sleep_duration); + + stream << "}"; +} + +void writeEvent(std::ostream& stream, const sim::ProtocolEvent* event) { + stream << "{"; + + writeKey(stream, "name"); + writeObj(stream, event->protocol_name); + + stream << "}"; +} + +void writeEvent(std::ostream& stream, const sim::KillEvent* event) { + stream << "{"; + + writeKey(stream, "reason"); + writeObj(stream, event->reason); + + stream << "}"; } } // namespace -std::ostream& sim::operator<<(std::ostream& os, Event::Type type) { - return os << EventTypeToString(type); +std::ostream& sim::operator<<(std::ostream& stream, + const sim::EventType event_type) { + return stream << eventTypeToString(event_type); } -std::ostream& sim::operator<<(std::ostream& os, const sim::Event* m) { - using namespace std::chrono; - const auto t = m->EventType(); - os << t << " at "; - os << duration(m->Timestamp()).count(); - os << " ms"; - if (m->Offset() > util::Time::Duration::zero()) { - os << " [Offset="; - os << duration(m->Offset()).count(); - os << " ms]"; - } +std::ostream& sim::operator<<(std::ostream& stream, const sim::Event* event) { + stream << "{"; - if (t == sim::Event::Type::SEGMENT_BEGIN || - t == sim::Event::Type::SEGMENT_END) { - WriteSegment(os, dynamic_cast(m)); - } + writeKey(stream, "timestamp"); + writeObj(stream, util::timeToMillis(event->timestamp)); - if (t == sim::Event::Type::CLOSE) { - WriteClose(os, dynamic_cast(m)); - } + stream << ","; - if (t == sim::Event::Type::SEND || t == sim::Event::Type::PACKET_SEND) { - WriteSend(os, dynamic_cast(m)); - } + writeKey(stream, "type"); + writeObj(stream, eventTypeToString(event->type)); - if (t == sim::Event::Type::RECV) { - WriteRecv(os, dynamic_cast(m)); - } + stream << ","; - if (t == sim::Event::Type::PACKET_RECV) { - WritePacketRecv(os, dynamic_cast(m)); - } + writeKey(stream, "metadata"); + + switch (event->type) { + case EventType::CLOSE: + case EventType::HAS_DATA: + writeEvent(stream, dynamic_cast(event)); + break; - if (t == sim::Event::Type::HAS_DATA) { - WriteHasData(os, dynamic_cast(m)); + case EventType::SEND: + case EventType::RECV: + writeEvent(stream, dynamic_cast(event)); + break; + + case EventType::SLEEP: + writeEvent(stream, dynamic_cast(event)); + break; + + case EventType::PROTOCOL_BEGIN: + case EventType::PROTOCOL_END: + writeEvent(stream, dynamic_cast(event)); + break; + + case EventType::KILLED: + writeEvent(stream, dynamic_cast(event)); + break; + + default: + stream << "{}"; + break; } - if (t == sim::Event::Type::CHECKPOINT) { - WriteCheckpoint(os, dynamic_cast(m)); + stream << "}"; + + return stream; +} + +void sim::writeTrace(std::ostream& stream, const sim::SimulationTrace& trace) { + stream << "["; + + if (!trace.empty()) { + for (std::size_t i = 0; i < trace.size() - 1; i++) { + stream << trace[i] << ","; + } + stream << trace[trace.size() - 1]; } - return os; + stream << "]"; } diff --git a/src/scl/simulation/measurement.cc b/src/scl/simulation/measurement.cc deleted file mode 100644 index 4de04d7..0000000 --- a/src/scl/simulation/measurement.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include "scl/simulation/measurement.h" - -#include - -using namespace scl; - -namespace { - -template -T Zero() { - return 0; -} - -template <> -util::Time::Duration Zero() { - return util::Time::Duration::zero(); -} - -template -T Mean(const sim::Measurement& m) { - T sum = Zero(); - for (const auto& v : m.Samples()) { - sum += v; - } - return sum / m.Size(); -} - -long double Sqrt(long double v) { - return std::sqrt(v); -} - -long double Sqr(long double v) { - return v * v; -} - -util::Time::Duration Sqrt(const util::Time::Duration& v) { - long double u = std::sqrt(v.count()); - std::chrono::duration w(u); - return std::chrono::duration_cast(w); -} - -util::Time::Duration Sqr(const util::Time::Duration& v) { - long double u = v.count(); - std::chrono::duration w(u * u); - return std::chrono::duration_cast(w); -} - -template -T StdDev(const sim::Measurement& m) { - const auto mu = Mean(m); - auto sum = Zero(); - for (const auto& v : m.Samples()) { - sum += Sqr(v - mu); - } - return Sqrt(sum / m.Size()); -} - -} // namespace - -std::ostream& sim::operator<<(std::ostream& os, const sim::TimeMeasurement& m) { - const auto mean = std::chrono::duration(Mean(m)).count(); - const auto std_dev = - std::chrono::duration(StdDev(m)).count(); - - os << "{" - << "\"mean\": " << mean << ", " - << "\"unit\": \"ms\", " - << "\"std_dev\": " << std_dev << "}"; - - return os; -} - -std::ostream& sim::operator<<(std::ostream& os, const sim::DataMeasurement& m) { - const auto mean = Mean(m); - const auto std_dev = StdDev(m); - - os << "{" - << "\"mean\": " << mean << ", " - << "\"unit\": \"B\", " - << "\"std_dev\": " << std_dev << "}"; - - return os; -} diff --git a/src/scl/simulation/result.cc b/src/scl/simulation/result.cc deleted file mode 100644 index 7eed414..0000000 --- a/src/scl/simulation/result.cc +++ /dev/null @@ -1,483 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include "scl/simulation/result.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "scl/simulation/channel_id.h" -#include "scl/simulation/event.h" -#include "scl/simulation/measurement.h" -#include "scl/util/time.h" - -using namespace scl; - -namespace { - -struct SentRecv { - long double sent = 0; - long double recv = 0; -}; - -using SentRecvMap = std::unordered_map; - -using CheckpointMap = std::unordered_map; - -struct Segment { - // sent/recv to other parties - SentRecvMap sr; - // execution time of the segment - util::Time::Duration dur; - // checkpoints found in the segment - CheckpointMap checkpoints; -}; - -std::string GetNameFromSegmentEvent(std::shared_ptr event) { - return std::dynamic_pointer_cast(event)->Name(); -} - -using NamedSegment = std::pair; - -bool IsRecvEvent(std::shared_ptr ptr) { - return ptr->EventType() == sim::Event::Type::RECV || - ptr->EventType() == sim::Event::Type::PACKET_RECV; -} - -bool IsSendEvent(std::shared_ptr ptr) { - return ptr->EventType() == sim::Event::Type::SEND || - ptr->EventType() == sim::Event::Type::PACKET_SEND; -} - -/** - * @brief Parse a segment from a simulation trace. - * @param start a beginning iterator, pointing to a SEGMENT_BEGIN event. - * @param end an end iterator. - */ -template -NamedSegment ParseSegment(It start, const It end) { - Segment seg; - - std::shared_ptr event = *start; - - const auto name = GetNameFromSegmentEvent(event); - seg.dur = event->Timestamp(); - - start++; - - while (start < end) { - event = *start; - - auto ne = std::dynamic_pointer_cast(event); - - if (ne != nullptr) { - const auto id = ne->RemoteParty(); - if (IsRecvEvent(ne)) { - seg.sr[id].recv += ne->DataAmount(); - } - if (IsSendEvent(ne)) { - seg.sr[id].sent += ne->DataAmount(); - } - } - - if (event->EventType() == sim::Event::Type::CHECKPOINT) { - const auto* ce = dynamic_cast(event.get()); - seg.checkpoints[ce->Id()] = ce->Timestamp(); - } - - if (event->EventType() == sim::Event::Type::SEGMENT_END) { - seg.dur = event->Timestamp() - seg.dur; - return {name, seg}; - } - start++; - } - - // we never saw a SEGMENT_END event, and now there's no more events, which - // means the simulation trace was incomplete/malformed. - throw std::logic_error("incomplete segment"); -} - -/** - * @brief Adds two maps. - * - * Helper function for adding two maps: - * - *

    - *
  • If an entry exists in \p m1 and \p m0, then the values from \p m1 are - * added to those already in \p m0
  • - - *
  • If an entry exists in \p m1 but not in \p m0, then the entry from \p m0 - * is added to \p m1
  • - *
- */ -void UpdateSentRecv(SentRecvMap& m0, const SentRecvMap& m1) { - for (const auto& [k, v] : m1) { - if (m0.find(k) == m0.end()) { - m0[k] = v; - } else { - m0[k].sent += v.sent; - m0[k].recv += v.recv; - } - } -} - -using SegmentMap = std::unordered_map; - -/** - * @brief Merge segments by their name. - * - * This takes a list of name, segment pairs, and merges the information in the - * segments that have the same name. - */ -SegmentMap MergeSegments(const std::vector& segments) { - SegmentMap m; - - m[{}].dur = util::Time::Duration::zero(); - - for (const auto& named_seg : segments) { - const auto name = named_seg.first; - const auto segm = named_seg.second; - - if (m.find(name) == m.end()) { - m[name] = segm; - } else { - m[name].dur += segm.dur; - UpdateSentRecv(m[name].sr, segm.sr); - m[name].checkpoints.insert(segm.checkpoints.begin(), - segm.checkpoints.end()); - } - - m[{}].dur += segm.dur; - UpdateSentRecv(m[{}].sr, segm.sr); - } - - return m; -} // LCOV_EXCL_LINE - -template -void ValidateTraceHeadAndTail(It head, It tail) { - if ((*head)->EventType() != sim::Event::Type::START) { - throw std::logic_error("incomplete trace"); - } - const auto last = (*tail)->EventType(); - if (last != sim::Event::Type::STOP && last != sim::Event::Type::KILLED) { - throw std::logic_error("truncated trace"); - } -} - -void AppendIfMissing(std::vector& list, - const std::string& element) { - if (std::find(list.begin(), list.end(), element) == list.end()) { - list.emplace_back(element); - } -} - -} // namespace - -/** - * @brief Create a result from a list of simulation traces. - */ -sim::Result sim::Result::Create( - const std::vector& traces) { - std::vector segments; - - for (const auto& trace : traces) { - auto b = trace.begin(); - const auto e = trace.end(); - - // sanity check - ValidateTraceHeadAndTail(b, e - 1); - - // Extract each segment - std::vector named_segments; - while (b < e) { - std::shared_ptr event = *b; - - if (event->EventType() == sim::Event::Type::SEGMENT_BEGIN) { - named_segments.emplace_back(ParseSegment(b, e)); - } - - b++; - } - - // Merge segments by name - segments.emplace_back(MergeSegments(named_segments)); - } - - std::vector segment_names; - std::unordered_map segment_measurements; - std::unordered_map checkpoints; - - for (const auto& seg_map : segments) { - for (const auto& [seg_name, seg] : seg_map) { - if (seg_name.has_value()) { - // clang-tidy cannot see that we check if seg_name has a value above, so - // disable the linter here to avoid false negatives. - const auto v = seg_name.value(); // NOLINT - AppendIfMissing(segment_names, v); - } - - for (const auto& [s, c] : seg.checkpoints) { - checkpoints[s].AddSample(c); - } - - segment_measurements[seg_name].duration_m.AddSample(seg.dur); - - SentRecv total; - for (const auto& [cid, sr] : seg.sr) { - segment_measurements[seg_name].channels_m[cid].recv.AddSample(sr.recv); - segment_measurements[seg_name].channels_m[cid].sent.AddSample(sr.sent); - total.recv += sr.recv; - total.sent += sr.sent; - } - - segment_measurements[seg_name].send_recv_m.recv.AddSample(total.recv); - segment_measurements[seg_name].send_recv_m.sent.AddSample(total.sent); - } - } - - return Result(traces, segment_measurements, checkpoints, segment_names); -} - -std::vector sim::Result::Create( - const std::vector>& traces) { - const auto num_parties = traces[0].size(); - const auto num_replications = traces.size(); - - std::vector results; - results.reserve(num_parties); - - for (std::size_t i = 0; i < num_parties; ++i) { - std::vector traces_for_party; - traces_for_party.reserve(num_replications); - for (std::size_t j = 0; j < num_replications; ++j) { - traces_for_party.emplace_back(traces[j][i]); - } - - results.emplace_back(Create(traces_for_party)); - } - - return results; -} - -namespace { - -template -std::vector KeySet(const std::unordered_map& map) { - std::vector keys; - for (const auto& [key, val] : map) { - (void)val; - keys.emplace_back(key); - } - return keys; -} // LCOV_EXCL_LINE - -} // namespace - -std::vector sim::Result::Interactions( - const SegmentName& name) const { - return KeySet(m_measurements.at(name).channels_m); -} - -namespace { - -template -void WriteSegmentTrace(std::ostream& stream, It start, It end) { - while (start != end) { - stream << *(start++) << std::endl; - } -} - -} // namespace - -void sim::Result::WriteTrace(std::ostream& stream, - std::size_t replication, - const sim::Result::SegmentName& name) const { - if (replication >= m_traces.size()) { - throw std::invalid_argument("invalid replication"); - } - - if (!name.has_value()) { - WriteSegmentTrace(stream, - m_traces[replication].begin(), - m_traces[replication].end()); - } else { - const auto& segment_name = name.value(); - bool in_relevant_segment = false; - - for (const auto& e : m_traces[replication]) { - if (in_relevant_segment) { - stream << e << std::endl; - } - - const auto s = std::dynamic_pointer_cast(e); - - if (s != nullptr) { - if (!in_relevant_segment && - s->EventType() == Event::Type::SEGMENT_BEGIN && - s->Name() == segment_name) { - stream << e << std::endl; - in_relevant_segment = true; - } - - if (in_relevant_segment && s->EventType() == Event::Type::SEGMENT_END) { - in_relevant_segment = false; - } - } - } - } -} - -namespace { - -template -void WriteMap(std::ostream& stream, const std::unordered_map& map); - -void WriteObj(std::ostream& stream, const std::string& string) { - stream << "\"" << string << "\""; -} - -void WriteKey(std::ostream& stream, const std::string& name) { - WriteObj(stream, name); - stream << ":"; -} - -void WriteObj(std::ostream& stream, const std::size_t& val) { - stream << val; -} - -void WriteObj(std::ostream& stream, const long double& val) { - stream << val; -} - -void WriteObj(std::ostream& stream, const util::Time::Duration& d) { - auto t = std::chrono::duration(d).count(); - WriteObj(stream, t); -} - -void WriteObj(std::ostream& stream, const std::optional& opt) { - if (opt.has_value()) { - WriteObj(stream, opt.value()); - } else { - stream << "null"; - } -} - -template -void WriteUnit(std::ostream& stream); - -template <> -void WriteUnit(std::ostream& stream) { - WriteObj(stream, std::string{"bytes"}); -} - -template <> -void WriteUnit(std::ostream& stream) { - WriteObj(stream, std::string{"milliseconds"}); -} - -template -void WriteList(std::ostream& stream, const std::vector& items); - -template -void WriteObj(std::ostream& stream, const sim::Measurement& m) { - stream << "{"; - WriteKey(stream, "unit"); - WriteUnit(stream); - stream << ","; - WriteKey(stream, "samples"); - WriteList(stream, m.Samples()); - stream << "}"; -} - -void WriteObj(std::ostream& stream, const sim::SendRecvMeasurement& srm) { - stream << "{"; - WriteKey(stream, "sent"); - WriteObj(stream, srm.sent); - stream << ","; - WriteKey(stream, "recv"); - WriteObj(stream, srm.recv); - stream << "}"; -} - -void WriteObj(std::ostream& stream, const sim::Result::SegmentMeasurement& m) { - stream << "{"; - WriteKey(stream, "time"); - WriteObj(stream, m.duration_m); - stream << ","; - WriteKey(stream, "data"); - WriteObj(stream, m.send_recv_m); - stream << ","; - WriteKey(stream, "channels"); - WriteMap(stream, m.channels_m); - stream << "}"; -} - -template -void WriteObj(std::ostream& stream, const std::pair& pair) { - stream << "{"; - WriteKey(stream, "key"); - WriteObj(stream, pair.first); - stream << ","; - WriteKey(stream, "value"); - WriteObj(stream, pair.second); - stream << "}"; -} - -template -void WriteList(std::ostream& stream, const std::vector& items) { - stream << "["; - for (std::size_t i = 0; i < items.size(); ++i) { - WriteObj(stream, items[i]); - if (i < items.size() - 1) { - stream << ","; - } - } - stream << "]"; -} - -template -void WriteMap(std::ostream& stream, const std::unordered_map& map) { - std::vector> kvs(map.begin(), map.end()); - WriteList(stream, kvs); -} - -} // namespace - -void sim::Result::Write(std::ostream& stream) const { - stream << "{"; - - WriteKey(stream, "names"); - WriteList(stream, m_segment_names); - stream << ","; - - WriteKey(stream, "measurements"); - WriteMap(stream, m_measurements); - stream << ","; - - WriteKey(stream, "checkpoints"); - WriteMap(stream, m_checkpoints); - - stream << "}" << std::endl; -} diff --git a/src/scl/simulation/runtime.cc b/src/scl/simulation/runtime.cc new file mode 100644 index 0000000..8a6c580 --- /dev/null +++ b/src/scl/simulation/runtime.cc @@ -0,0 +1,89 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "scl/simulation/runtime.h" + +#include + +#include "scl/simulation/context.h" + +using namespace scl; + +void sim::details::SimulatorRuntime::schedule( + std::coroutine_handle<> coroutine, + std::function&& predicate) { + m_tq.emplace_back(coroutine, std::move(predicate), m_current_pid); +} + +void sim::details::SimulatorRuntime::schedule(std::coroutine_handle<> coroutine, + util::Time::Duration delay) { + auto view = m_ctx.view(m_current_pid); + const auto last = view.lastEventTimestamp(); + view.recordEvent(Event::sleep(last, delay)); + this->schedule(coroutine); +} + +void sim::details::SimulatorRuntime::deschedule( + std::coroutine_handle<> coroutine) { + m_tq.remove_if( + [&coroutine](const Coro& coro) { return coro.coroutine == coroutine; }); +} + +// void sim::details::SimulatorRuntime::removeCancelledCoros() { +// auto b = m_tq.begin(); +// const auto e = m_tq.end(); + +// while (b != e) { +// const auto& [coro, pred, pid] = *b; +// if (pid != MANAGER_PID && m_ctx.cancellation_map.at(pid)) { + +// } +// } +// } + +std::coroutine_handle<> sim::details::SimulatorRuntime::next() { + auto b = m_tq.begin(); + const auto e = m_tq.end(); + + while (b != e) { + const auto [coro, pred, pid] = *b; + + // if we're about to run the manager coro, then we do not wish + // to check if it's been cancelled. + if (pid != MANAGER_PID && m_ctx.cancellation_map.at(pid)) { + } else if (pred()) { + m_tq.erase(b); + m_current_pid = pid; + + // Event timestamps are computed as + // + // E[i].ts = E[i - 1] + (now - last_startClock) + // + // It is therefore important that startClock is called here, since + // otherwise we may end up counting time spent executing another party (or + // just time spent in the simulation runtime), when we compute the + // timestamp of event i. + if (m_current_pid != MANAGER_PID) { + m_ctx.view(m_current_pid).startClock(); + } + return coro; + } + b++; + } + + return std::noop_coroutine(); +} diff --git a/src/scl/simulation/simulate_recv_time.cc b/src/scl/simulation/simulate_recv_time.cc deleted file mode 100644 index a7c8d40..0000000 --- a/src/scl/simulation/simulate_recv_time.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include -#include - -#include "scl/simulation/config.h" -#include "scl/simulation/simulator.h" - -using namespace scl; - -namespace { - -/** - * @brief Computes the actual number of bytes needing transfer - * @param nbytes the raw number of bytes - * @param mss the maximum segment size - * @param the number of bytes when accounting for overhead. - * - * This function computes the number of packages needed to send \p nbytes worth - * of data and adds an overhead of 40 bytes per package. - */ -long double TransferSizeWithHeadersBits(std::size_t nbytes, - std::size_t mss) noexcept { - static constexpr std::size_t kHeaderSizeBytes = 40; - const std::size_t num_packets = std::ceil((double)nbytes / (double)mss); - return 8 * (nbytes + num_packets * kHeaderSizeBytes); -} - -/** - * @brief Get the RTT from a config in seconds. - */ -long double RoundTripTimeSeconds(const sim::ChannelConfig& config) noexcept { - using namespace std::chrono_literals; - const auto d = std::chrono::milliseconds(config.RTT()); - return d / 1.0s; -} - -/** - * @brief Compute the maximum TCP throughput assuming package loss of 0% - */ -long double ThroughputZeroPackageLoss( - const sim::ChannelConfig& config) noexcept { - // Simple throughput formula: - // https://tetcos.com/pdf/v13/Experiments/Mathematical-Modelling-of-TCP-Throughput-Performance.pdf - const auto rtt = RoundTripTimeSeconds(config); - const auto wndz = 8 * (long double)config.WindowSize(); - const auto max_throughput = wndz / rtt; - - // actual throughput obviously cannot exceed the capacity of the link. - const auto bw = (long double)config.Bandwidth(); - const auto actual_throughput = std::min(max_throughput, bw); - - return actual_throughput; -} - -/** - * @brief Compute TCP throughput assuming package loss using Mathis et. al. - */ -long double ThroughputNonZeroPackageLoss( - const sim::ChannelConfig& config) noexcept { - const auto mss = (long double)config.MSS(); - const auto loss_term = std::sqrt(3.0 / (2.0 * config.PackageLoss())); - const auto rtt = RoundTripTimeSeconds(config); - - return loss_term * (8 * mss / rtt); -} - -util::Time::Duration ComputeRecvTimeTcp(const sim::ChannelConfig& config, - std::size_t n) { - const auto total_size_bits = TransferSizeWithHeadersBits(n, config.MSS()); - auto actual_tp = ThroughputZeroPackageLoss(config); - - if (config.PackageLoss() > 0) { - const auto tp = ThroughputNonZeroPackageLoss(config); - actual_tp = std::min(tp, actual_tp); - } - - const auto t = total_size_bits / actual_tp + RoundTripTimeSeconds(config); - const auto t_sec = std::chrono::duration(t); - return std::chrono::duration_cast(t_sec); -} - -} // namespace - -util::Time::Duration sim::ComputeRecvTime(const ChannelConfig& config, - std::size_t n) { - if (config.Type() == sim::ChannelConfig::NetworkType::TCP) { - return ComputeRecvTimeTcp(config, n); - } - // sim::ChannelConfig::NetworkType::INSTANT - return util::Time::Duration::zero(); -} diff --git a/src/scl/simulation/simulator.cc b/src/scl/simulation/simulator.cc index f043e2e..ca36ff1 100644 --- a/src/scl/simulation/simulator.cc +++ b/src/scl/simulation/simulator.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,183 +17,245 @@ #include "scl/simulation/simulator.h" -#include -#include +#include +#include #include -#include -#include -#include -#include -#include -#include -#include - -#include "scl/net/channel.h" -#include "scl/net/network.h" -#include "scl/protocol/base.h" + +#include + +#include "scl/coro/coroutine.h" +#include "scl/net/loopback.h" +#include "scl/simulation/cancellation.h" #include "scl/simulation/channel.h" -#include "scl/simulation/channel_id.h" -#include "scl/simulation/config.h" #include "scl/simulation/context.h" -#include "scl/simulation/env.h" #include "scl/simulation/event.h" -#include "scl/simulation/mem_channel_buffer.h" -#include "scl/simulation/result.h" +#include "scl/simulation/runtime.h" #include "scl/util/time.h" using namespace scl; namespace { -auto CreateEvent(sim::Event::Type t, util::Time::Duration d) { - return std::make_shared(t, d); -} - -auto CreateSegmentEvent(util::Time::Duration t, - const std::string& n, - bool is_end) { - if (is_end) { - return std::make_shared(sim::Event::Type::SEGMENT_END, - t, - n); - } - return std::make_shared(sim::Event::Type::SEGMENT_BEGIN, - t, - n); +std::shared_ptr createChannel( + std::size_t i, + std::size_t j, + sim::details::GlobalContext& ctx, + std::shared_ptr transport) { + sim::ChannelId cid{i, j}; + return std::make_shared(cid, + ctx.view(i), + transport); } -auto CreateNetworks(std::shared_ptr ctx) { +// creates the networks used in this simulation. +std::vector createNetworks(std::size_t n, + sim::details::GlobalContext& ctx) { + auto transport = std::make_shared(); std::vector networks; - const auto n = ctx->NumberOfParties(); networks.reserve(n); - for (std::size_t i = 0; i < n; ++i) { + for (std::size_t i = 0; i < n; i++) { std::vector> channels; channels.reserve(n); - - for (std::size_t j = 0; j < n; ++j) { - sim::ChannelId cid(i, j); - channels.emplace_back(std::make_shared(cid, ctx)); + for (std::size_t j = 0; j < n; j++) { + channels.emplace_back(createChannel(i, j, ctx, transport)); } - networks.emplace_back(channels, i); } return networks; -} // LCOV_EXCL_LINE +} + +struct ClockImpl final : public proto::Clock { + ClockImpl(const sim::details::GlobalContext::LocalContext& view) + : view(view) {} + + util::Time::Duration read() const override { + return view.elapsedTime(); + } -struct RunResult { - std::unique_ptr next; - std::any output; + sim::details::GlobalContext::LocalContext view; }; -RunResult Run(std::shared_ptr ctx, - std::size_t id, - proto::Protocol* protocol, - proto::Env& env) { - RunResult result; +auto createClock(const sim::details::GlobalContext::LocalContext& view) { + return std::make_unique(view); +} - if (ctx->Trace(id).empty()) { - ctx->AddEvent( - id, - CreateEvent(sim::Event::Type::START, util::Time::Duration::zero())); - } +struct EnvAndCtx { + proto::Env env; + sim::details::GlobalContext::LocalContext view; +}; - if (protocol == nullptr) { - // handling of entries which are null. - ctx->AddEvent( - id, - CreateEvent(sim::Event::Type::STOP, util::Time::Duration::zero())); - return result; +std::vector createEnvs(sim::details::GlobalContext& global_ctx) { + const std::size_t n = global_ctx.number_of_parties; + auto networks = createNetworks(n, global_ctx); + std::vector envs; + envs.reserve(n); + for (std::size_t i = 0; i < n; i++) { + auto view = global_ctx.view(i); + envs.emplace_back( + EnvAndCtx{proto::Env{networks[i], createClock(view)}, view}); } - ctx->AddEvent( - id, - CreateSegmentEvent(ctx->LatestTimestamp(id), protocol->Name(), false)); + return envs; +} - ctx->UpdateCheckpoint(); - result.next = protocol->Run(env); - const auto exec_time = ctx->Checkpoint(id); +coro::Task runProtocol(std::size_t id, + sim::Manager* manager, + std::unique_ptr protocol, + EnvAndCtx&& env) { + // A protocol is run for as long as all of the following is true: + // - it's output result contains another protocol to run; + // - it does not produce an uncaught exception; + // - it has not been cancelled. + // + // Running the protocol generates zero or more events, which is ultimately + // what is the interesting stuff that we're interested in. The events are + // generated (roughly speaking) in the following order: + // 1. START + // 2. Repeat as long as protocol->run().next != nullptr: + // 2.1. PROTOCOL_BEGIN + // 2.2. RECV, SEND, CLOSE, HAS_DATA, SLEEP ...} + // 2.3. OUTPUT, in case the protocol generated output + // 2.4. PROTOCOL_END + // 3. STOP + // + // If any point (with two exceptions, listed below) a hook is run which + // cancels the _current_ party (i.e, this party), then a CANCELLED event is + // produced and the function returns; if an exception is thrown, a KILLED + // event is produced with the exception's message and the function returns. + + auto& view = env.view; + + try { + view.recordEvent(sim::Event::start()); + + while (protocol) { + const auto name = protocol->name(); + + view.recordEvent( + sim::Event::protocolBegin(view.lastEventTimestamp(), name)); + + // start the clock of the party. This ensures that any time spent + // book-keeping does not go towards the total running time of the party. + view.startClock(); + auto next = co_await protocol->run(env.env); + + const auto et = view.elapsedTime(); + + if (next.result.has_value()) { + manager->handleProtocolOutput(id, next.result); + view.recordEvent(sim::Event::output(et)); + } - result.output = protocol->Output(); + view.recordEvent(sim::Event::protocolEnd(et, name)); - if (result.output.has_value()) { - ctx->AddEvent(id, CreateEvent(sim::Event::Type::OUTPUT, exec_time)); - } + protocol = std::move(next.next_protocol); + } + + view.recordEvent(sim::Event::stop(view.lastEventTimestamp())); - ctx->AddEvent(id, CreateSegmentEvent(exec_time, protocol->Name(), true)); + // We could keep running, however by suspending here we can allow a + // different party to run. This is especially important if the protocol we + // are running does not contain any suspension points. + co_await []() { return true; }; - if (result.next == nullptr) { - ctx->AddEvent(id, CreateEvent(sim::Event::Type::STOP, exec_time)); + } catch (sim::details::CancellationException& /* ignored */) { + // the simulation was cancelled by this party, so we just stop here. + view.recordEvent(sim::Event::cancelled(view.lastEventTimestamp())); + } catch (std::exception& e) { + // something went wrong, so we mark the protocol as dead and stop. + view.recordEvent(sim::Event::killed(view.lastEventTimestamp(), e.what())); } - return result; + co_return; } -std::vector CreateEnvs(const std::vector& networks, - std::shared_ptr ctx) { - std::vector envs; - envs.reserve(ctx->NumberOfParties()); - for (std::size_t i = 0; i < ctx->NumberOfParties(); ++i) { - envs.emplace_back(proto::Env{networks[i], - std::make_unique(ctx, i), - std::make_unique(ctx, i)}); +// Helper class that runs the protocols we are simulating. This class behaves +// very similar to sim::Batch, but with a specialized await_suspend. +class SimBatch final { + public: + SimBatch(std::vector>&& tasks, + sim::details::GlobalContext& gctx) + : m_tasks(std::move(tasks)), m_gctx(gctx) {} + + bool await_ready() const noexcept { + // keep running until all non-cancelled coroutines have completed. + for (std::size_t i = 0; i < m_tasks.size(); i++) { + if (!m_gctx.cancellation_map.at(i) && !m_tasks[i].ready()) { + return false; + } + } + + return true; } - return envs; -} -auto RunSimulation(std::size_t replication, sim::Manager* manager) { - auto ps = manager->Protocol(); - auto ctx = sim::Context::Create( - ps.size(), - manager->NetworkConfiguration()); + std::coroutine_handle<> await_suspend(std::coroutine_handle<> coroutine) { + sim::details::SimulatorRuntime* srt = + dynamic_cast(m_runtime); + for (std::size_t i = 0; i < m_tasks.size(); i++) { + m_tasks[i].setRuntime(m_runtime); + srt->scheduleWithId(m_tasks[i].m_handle, i); + } - auto networks = CreateNetworks(ctx); - auto envs = CreateEnvs(networks, ctx); + m_runtime->schedule(coroutine, [this]() { return await_ready(); }); - auto next_id = ctx->NextToRun(); + return m_runtime->next(); + } - while (next_id.has_value()) { - auto id = next_id.value(); + void await_resume() { + for (const auto& t : m_tasks) { + t.result(); + } + } - try { - ctx->Prepare(id); + void setRuntime(coro::Runtime* runtime) noexcept { + m_runtime = runtime; + } - auto result = Run(ctx, id, ps[id].get(), envs[id]); + private: + std::vector> m_tasks; + sim::details::GlobalContext& m_gctx; - ps[id] = std::move(result.next); + coro::Runtime* m_runtime; +}; - if (result.output.has_value()) { - manager->HandleOutput(replication, id, result.output); - } +coro::Task runProtocols( + std::vector>&& protocols, + sim::details::GlobalContext& global_ctx, + sim::Manager* manager) { + std::vector> protocol_runs; - if (ps[id] != nullptr && manager->Terminate(id, ctx->GetView())) { - ps[id] = nullptr; - ctx->AddEvent( - id, - CreateEvent(sim::Event::Type::KILLED, ctx->LatestTimestamp(id))); - } + std::vector envs = createEnvs(global_ctx); + for (std::size_t i = 0; i < protocols.size(); i++) { + protocol_runs.emplace_back( + runProtocol(i, manager, std::move(protocols[i]), std::move(envs[i]))); + } + co_await SimBatch(std::move(protocol_runs), global_ctx); +} - ctx->Commit(id); +} // namespace - } catch (sim::SimulationFailure& e) { - ctx->Rollback(id); - } +void sim::simulate(std::unique_ptr manager) { + auto protocol = manager->protocol(); - next_id = ctx->NextToRun(id); - } + // do nothing in case the caller (for whatever reason) wanted to simulate an + // empty protocol. + if (!protocol.empty()) { + auto ctx = details::GlobalContext::create(protocol.size(), + manager->networkConfiguration(), + std::move(manager->m_hooks)); + auto runtime = std::make_unique(ctx); - return ctx->Trace(); -} + // const auto start = util::Time::now(); -} // namespace + runtime->run(runProtocols(std::move(protocol), ctx, manager.get())); -std::vector sim::Simulate(std::unique_ptr manager) { - std::vector> traces; - auto network_conf = manager->NetworkConfiguration(); - for (std::size_t i = 0; i < manager->Replications(); ++i) { - traces.emplace_back(RunSimulation(i, manager.get())); - } + // const auto sim_dur = util::Time::now() - start; + // std::cout << "simulation took " << util::timeToMillis(sim_dur) << "ms\n"; - return Result::Create(traces); + for (std::size_t party_id = 0; party_id < ctx.traces.size(); party_id++) { + manager->handleSimulatorOutput(party_id, ctx.traces[party_id]); + } + } } diff --git a/src/scl/simulation/transport.cc b/src/scl/simulation/transport.cc new file mode 100644 index 0000000..c4146a9 --- /dev/null +++ b/src/scl/simulation/transport.cc @@ -0,0 +1,74 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "scl/simulation/transport.h" + +using namespace scl; + +void sim::details::Transport::send(sim::ChannelId cid, net::Packet&& packet) { + m_channels[cid.flip()].push_back(std::move(packet)); +} + +void sim::details::Transport::send(sim::ChannelId cid, + const net::Packet& packet) { + std::size_t idx; + for (idx = 0; idx < m_packets.size(); idx++) { + if (m_packets[idx].packet == packet) { + m_packets[idx].count++; + m_channels[cid.flip()].push_back(idx); + return; + } + } + + // no packet found + m_packets.emplace_back(PktAndCount{packet, 1}); + m_channels[cid.flip()].push_back(idx); +} + +bool sim::details::Transport::hasData(ChannelId cid) const { + if (m_channels.contains(cid)) { + return !m_channels.at(cid).empty(); + } + return false; +} + +net::Packet sim::details::Transport::recv(ChannelId cid) { + // define the variable before assignment to silence a bogus + // maybe-uninitialized error by GCC. + PktOrIdx pkt_or_idx; + + pkt_or_idx = std::move(m_channels.at(cid).front()); + m_channels[cid].pop_front(); + + if (pkt_or_idx.index() == 0) { + // packet that was directly moved to us. + return std::get(pkt_or_idx); + } + + const std::size_t idx = std::get(pkt_or_idx); + + if (m_packets[idx].count == 0) { + throw std::runtime_error("uh oh"); + } + + m_packets[idx].count--; + return m_packets[idx].packet; +} + +void sim::details::Transport::cleanUp(GlobalContext& ctx) { + (void)ctx; +} diff --git a/src/scl/util/cmdline.cc b/src/scl/util/cmdline.cc index 123eced..e35f26c 100644 --- a/src/scl/util/cmdline.cc +++ b/src/scl/util/cmdline.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -19,23 +19,23 @@ #include -namespace {} // namespace +using namespace scl; -bool scl::util::ProgramOptions::Parser::IsArg(std::string_view name) const { - return std::any_of(mArgs.begin(), mArgs.end(), [&](auto a) { +bool util::ProgramOptions::Parser::isArg(std::string_view name) const { + return std::any_of(m_args.begin(), m_args.end(), [&](auto a) { return a.name == name; }); } -bool scl::util::ProgramOptions::Parser::IsFlag(std::string_view name) const { - return std::any_of(mFlags.begin(), mFlags.end(), [&](auto f) { +bool util::ProgramOptions::Parser::isFlag(std::string_view name) const { + return std::any_of(m_flags.begin(), m_flags.end(), [&](auto f) { return f.name == name; }); } namespace { -bool Name(std::string_view opt_name, std::string_view& name) { +bool name(std::string_view opt_name, std::string_view& name) { if (opt_name[0] != '-') { return false; } @@ -43,23 +43,45 @@ bool Name(std::string_view opt_name, std::string_view& name) { return true; } +template +bool hasDuplicates(const std::vector& opts) { + for (std::size_t i = 0; i < opts.size(); i++) { + const std::string_view n = opts[i].name; + for (std::size_t j = i + 1; j < opts.size(); j++) { + if (n == opts[j].name) { + return true; + } + } + } + + return false; +} + } // namespace -scl::util::ProgramOptions scl::util::ProgramOptions::Parser::Parse( - int argc, - char** argv) { - mProgramName = argv[0]; +using ParseRet = std::variant; + +ParseRet util::ProgramOptions::Parser::parseArguments(int argc, char** argv) { + if (hasDuplicates(m_args)) { + return "duplicate argument definition"; + } + + if (hasDuplicates(m_flags)) { + return "duplicate flag definition"; + } + + m_program_name = argv[0]; std::vector cmd_args(argv + 1, argv + argc); const auto help_needed = std::any_of(cmd_args.begin(), cmd_args.end(), [](auto e) { return e == "-help"; }); if (help_needed) { - PrintHelp(); + return ""; } std::unordered_map args; - std::for_each(mArgs.begin(), mArgs.end(), [&args](const auto arg) { + std::for_each(m_args.begin(), m_args.end(), [&args](const auto arg) { if (arg.default_value.has_value()) { args[arg.name] = arg.default_value.value(); } @@ -68,56 +90,60 @@ scl::util::ProgramOptions scl::util::ProgramOptions::Parser::Parse( std::unordered_map flags; std::size_t i = 0; while (i < cmd_args.size()) { - std::string_view name; - if (!Name(cmd_args[i++], name)) { - PrintHelp("argument must begin with '-'"); + std::string_view arg_name; + if (!name(cmd_args[i++], arg_name)) { + return "argument must begin with '-'"; } - if (IsArg(name)) { + if (isArg(arg_name)) { if (i == cmd_args.size()) { - PrintHelp("invalid argument"); + return "invalid argument"; } - args[name] = cmd_args[i++]; - } else if (IsFlag(name)) { - flags[name] = true; + args[arg_name] = cmd_args[i++]; + } else if (isFlag(arg_name)) { + flags[arg_name] = true; } else { - PrintHelp("encountered unknown argument"); + return "encountered unknown argument"; } } // check if we got everything - ForEachRequired(mArgs, [&](const auto arg) { + std::string_view error_msg; + forEachRequired(m_args, [&](const auto arg) { if (args.find(arg.name) == args.end()) { - PrintHelp("missing required argument"); + error_msg = "missing required argument"; } }); - return ProgramOptions(args, flags); + if (error_msg.empty()) { + return ProgramOptions(args, flags); + } + return error_msg; } -void scl::util::ProgramOptions::Parser::ArgListShort( +void util::ProgramOptions::Parser::argListShort( std::ostream& stream, std::string_view program_name) const { stream << "Usage: " << program_name << " "; - ForEachRequired(mArgs, [&stream](const auto arg) { + forEachRequired(m_args, [&stream](const auto arg) { stream << "-" << arg.name << " " << arg.type_hint << " "; }); stream << "[options ...]" << std::endl; } -std::string GetPadding(std::size_t lead) { +std::string getPadding(std::size_t lead) { const static std::size_t padding = 20; const static std::size_t min_padding = 5; const auto psz = lead >= padding + min_padding ? min_padding : padding - lead; return std::string(psz, ' '); } -void WriteArg(std::ostream& stream, const scl::util::ProgramArg& arg) { +void writeArg(std::ostream& stream, const util::ProgramArg& arg) { stream << " -" << arg.name << " '" << arg.type_hint << "'"; if (!arg.description.empty()) { - const auto pad_str = GetPadding(arg.name.size() + arg.type_hint.size() + 5); - stream << pad_str << arg.description << ". "; + const auto pad_str = getPadding(arg.name.size() + arg.type_hint.size() + 5); + stream << pad_str << arg.description << "."; } if (arg.default_value.has_value()) { stream << " [default=" << arg.default_value.value() << "]"; @@ -125,75 +151,64 @@ void WriteArg(std::ostream& stream, const scl::util::ProgramArg& arg) { stream << std::endl; } -void WriteFlag(std::ostream& stream, const scl::util::ProgramFlag& flag) { +void writeFlag(std::ostream& stream, const util::ProgramFlag& flag) { stream << " -" << flag.name; if (!flag.description.empty()) { - const auto pad_str = GetPadding(flag.name.size() + 2); - stream << pad_str << flag.description << ". "; + const auto pad_str = getPadding(flag.name.size() + 2); + stream << pad_str << flag.description << "."; } stream << std::endl; } -template -bool HasRequired(It begin, It end) { - return std::any_of(begin, end, [](const auto a) { return a.required; }); +template +bool hasRequired(IT begin, IT end) { + return std::any_of(begin, end, [](const auto a) { return a.is_required; }); } -template -bool HasOptional(It begin, It end) { - return std::any_of(begin, end, [](const auto a) { return !a.required; }); +template +bool hasOptional(IT begin, IT end) { + return std::any_of(begin, end, [](const auto a) { return !a.is_required; }); } -void scl::util::ProgramOptions::Parser::ArgListLong( - std::ostream& stream) const { - if (!mDescription.empty()) { - stream << std::endl << mDescription << std::endl; +void util::ProgramOptions::Parser::argListLong(std::ostream& stream) const { + if (!m_description.empty()) { + stream << std::endl << m_description << std::endl; } stream << std::endl; - const auto has_req_arg = HasRequired(mArgs.begin(), mArgs.end()); + const auto has_req_arg = hasRequired(m_args.begin(), m_args.end()); if (has_req_arg) { stream << "Required arguments" << std::endl; - ForEachRequired(mArgs, [&stream](const auto a) { WriteArg(stream, a); }); + forEachRequired(m_args, [&stream](const auto a) { writeArg(stream, a); }); stream << std::endl; } - if (HasOptional(mArgs.begin(), mArgs.end())) { - stream << "Optional Arguments" << std::endl; + if (hasOptional(m_args.begin(), m_args.end())) { + stream << "Optional arguments" << std::endl; - ForEachOptional(mArgs, [&stream](const auto a) { WriteArg(stream, a); }); + forEachOptional(m_args, [&stream](const auto a) { writeArg(stream, a); }); stream << std::endl; } - if (!mFlags.empty()) { + if (!m_flags.empty()) { stream << "Flags" << std::endl; - std::for_each(mFlags.begin(), mFlags.end(), [&stream](const auto a) { - WriteFlag(stream, a); + std::for_each(m_flags.begin(), m_flags.end(), [&stream](const auto a) { + writeFlag(stream, a); }); stream << std::endl; } } -void scl::util::ProgramOptions::Parser::PrintHelp(std::string_view error_msg) { +void util::ProgramOptions::Parser::printHelp(std::string_view error_msg) { bool error = !error_msg.empty(); if (error) { std::cerr << "ERROR: " << error_msg << std::endl; } - if (!mProgramName.empty()) { - ArgListShort(std::cout, mProgramName); + if (!m_program_name.empty()) { + argListShort(std::cout, m_program_name); } - ArgListLong(std::cout); - -#ifdef SCL_UTIL_NO_EXIT_ON_ERROR - - throw std::runtime_error(error ? "bad" : "good"); - -#else - - std::exit(error ? 1 : 0); - -#endif + argListLong(std::cout); } diff --git a/src/scl/util/measurement.cc b/src/scl/util/measurement.cc new file mode 100644 index 0000000..e423ec6 --- /dev/null +++ b/src/scl/util/measurement.cc @@ -0,0 +1,80 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include "scl/util/measurement.h" + +#include + +#include "scl/util/time.h" + +using namespace scl; + +template <> +long double util::Measurement::zero() const { + return 0; +} + +template <> +util::Time::Duration util::Measurement::zero() const { + return util::Time::Duration::zero(); +} + +template <> +long double util::Measurement::square(long double v) const { + return v * v; +} + +template <> +util::Time::Duration util::Measurement::square( + util::Time::Duration dur) const { + long double u = dur.count(); + std::chrono::duration w(u * u); + return std::chrono::duration_cast(w); +} + +template <> +long double util::Measurement::sqrt(long double v) const { + return std::sqrt(v); +} + +template <> +util::Time::Duration util::Measurement::sqrt( + util::Time::Duration dur) const { + long double u = std::sqrt(dur.count()); + std::chrono::duration w(u); + return std::chrono::duration_cast(w); +} + +std::ostream& util::operator<<(std::ostream& os, + const util::TimeMeasurement& measurement) { + os << "{" + << "\"mean\": " << util::timeToMillis(measurement.mean()) << ", " + << "\"unit\": \"ms\", " + << "\"std_dev\": " << util::timeToMillis(measurement.stddev()) << "}"; + + return os; +} + +std::ostream& util::operator<<(std::ostream& os, + const util::DataMeasurement& measurement) { + os << "{" + << "\"mean\": " << measurement.mean() << ", " + << "\"unit\": \"B\", " + << "\"std_dev\": " << measurement.stddev() << "}"; + + return os; +} diff --git a/src/scl/util/prg.cc b/src/scl/util/prg.cc index 9ce3778..9333a09 100644 --- a/src/scl/util/prg.cc +++ b/src/scl/util/prg.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,11 +17,12 @@ #include "scl/util/prg.h" -#include -#include +#include #include #include +#include +#include /** * PRG implementation based on AES-CTR with code from @@ -46,11 +47,11 @@ } while (0) #define AES_128_KEY_EXP(k, rcon) \ - Aes128KeyExpansion(k, _mm_aeskeygenassist_si128(k, rcon)) + aes128KeyExpansion(k, _mm_aeskeygenassist_si128(k, rcon)) namespace { -auto Aes128KeyExpansion(__m128i key, __m128i keygened) { +auto aes128KeyExpansion(__m128i key, __m128i keygened) { keygened = _mm_shuffle_epi32(keygened, _MM_SHUFFLE(3, 3, 3, 3)); key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); @@ -58,7 +59,7 @@ auto Aes128KeyExpansion(__m128i key, __m128i keygened) { return _mm_xor_si128(key, keygened); } -void Aes128LoadKey(const unsigned char* enc_key, __m128i* key_schedule) { +void aes128LoadKey(const unsigned char* enc_key, __m128i* key_schedule) { const auto* k = reinterpret_cast(enc_key); key_schedule[0] = _mm_loadu_si128(k); key_schedule[1] = AES_128_KEY_EXP(key_schedule[0], 0x01); @@ -73,54 +74,54 @@ void Aes128LoadKey(const unsigned char* enc_key, __m128i* key_schedule) { key_schedule[10] = AES_128_KEY_EXP(key_schedule[9], 0x36); } -void Aes128Enc(const __m128i* key_schedule, __m128i m, unsigned char* ct) { +void aes128Enc(const __m128i* key_schedule, __m128i m, unsigned char* ct) { DO_ENC_BLOCK(m, key_schedule); _mm_storeu_si128(reinterpret_cast<__m128i*>(ct), m); } -auto create_mask(long counter) { +auto createMask(long counter) { return _mm_set_epi64x(PRG_NONCE, counter); } } // namespace -scl::util::PRG scl::util::PRG::Create(const unsigned char* seed, +scl::util::PRG scl::util::PRG::create(const unsigned char* seed, std::size_t seed_len) { - std::array s = {0}; + std::array s = {0}; if (seed != nullptr) { - if (seed_len > PRG::SeedSize()) { - std::copy(seed, seed + SeedSize(), s.begin()); + if (seed_len > PRG::seedSize()) { + std::copy(seed, seed + seedSize(), s.begin()); } else { std::copy(seed, seed + seed_len, s.begin()); } } PRG prg(s); - prg.Init(); + prg.init(); return prg; } -scl::util::PRG scl::util::PRG::Create() { - return PRG::Create(nullptr, 0); +scl::util::PRG scl::util::PRG::create() { + return PRG::create(nullptr, 0); } -scl::util::PRG scl::util::PRG::Create(const std::string& seed) { - return PRG::Create((const unsigned char*)seed.c_str(), seed.length()); +scl::util::PRG scl::util::PRG::create(const std::string& seed) { + return PRG::create((const unsigned char*)seed.c_str(), seed.length()); } -void scl::util::PRG::Update() { +void scl::util::PRG::update() { m_counter += 1; } -void scl::util::PRG::Init() { - Aes128LoadKey(m_seed.data(), m_state); +void scl::util::PRG::init() { + aes128LoadKey(m_seed.data(), m_state); } -void scl::util::PRG::Reset() { - Init(); +void scl::util::PRG::reset() { + init(); m_counter = PRG_INITIAL_COUNTER; } -void scl::util::PRG::Next(unsigned char* buffer, size_t n) { +void scl::util::PRG::next(unsigned char* buffer, size_t n) { if (n == 0) { return; } @@ -131,13 +132,13 @@ void scl::util::PRG::Next(unsigned char* buffer, size_t n) { nblocks++; } - auto mask = create_mask(m_counter); + auto mask = createMask(m_counter); auto out = std::make_unique(nblocks * BLOCK_SIZE); auto* p = out.get(); for (size_t i = 0; i < nblocks; i++) { - Aes128Enc(m_state, mask, p); - Update(); - mask = create_mask(m_counter); + aes128Enc(m_state, mask, p); + update(); + mask = createMask(m_counter); p += BLOCK_SIZE; } diff --git a/src/scl/util/sha256.cc b/src/scl/util/sha256.cc index 2755306..3aef6fe 100644 --- a/src/scl/util/sha256.cc +++ b/src/scl/util/sha256.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,8 +17,8 @@ #include "scl/util/sha256.h" +#include #include -#include /** * SHA-256 implementation based on https://github.com/System-Glitch/SHA256. @@ -26,19 +26,19 @@ namespace { -auto RotR(uint32_t x, unsigned n) { +auto rotR(uint32_t x, unsigned n) { return (x >> n) | (x << (32 - n)); } -auto Sig0(uint32_t x) { - return RotR(x, 7) ^ RotR(x, 18) ^ (x >> 3); +auto sig0(uint32_t x) { + return rotR(x, 7) ^ rotR(x, 18) ^ (x >> 3); } -auto Sig1(uint32_t x) { - return RotR(x, 17) ^ RotR(x, 19) ^ (x >> 10); +auto sig1(uint32_t x) { + return rotR(x, 17) ^ rotR(x, 19) ^ (x >> 10); } -auto Split(std::array& chunk) { +auto split(std::array& chunk) { std::array split; for (std::size_t i = 0, j = 0; i < 16; ++i, j += 4) { split[i] = (chunk[j] << 24) // @@ -48,24 +48,24 @@ auto Split(std::array& chunk) { } for (std::size_t i = 16; i < 64; ++i) { - split[i] = Sig1(split[i - 2]) + Sig0(split[i - 15]); + split[i] = sig1(split[i - 2]) + sig0(split[i - 15]); split[i] += split[i - 7] + split[i - 16]; } return split; } -auto Majority(uint32_t x, uint32_t y, uint32_t z) { +auto majority(uint32_t x, uint32_t y, uint32_t z) { return (x & (y | z)) | (y & z); } -auto Choose(uint32_t x, uint32_t y, uint32_t z) { +auto choose(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (~x & z); } } // namespace -void scl::util::Sha256::Transform() { +void scl::util::Sha256::transform() { // round constants. static constexpr std::array k = { 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, @@ -80,15 +80,15 @@ void scl::util::Sha256::Transform() { 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2}; - const auto m = Split(m_chunk); + const auto m = split(m_chunk); auto s = m_state; for (std::size_t i = 0; i < 64; ++i) { - const auto maj = Majority(s[0], s[1], s[2]); - const auto chs = Choose(s[4], s[5], s[6]); + const auto maj = majority(s[0], s[1], s[2]); + const auto chs = choose(s[4], s[5], s[6]); - const auto xor_a = RotR(s[0], 2) ^ RotR(s[0], 13) ^ RotR(s[0], 22); - const auto xor_e = RotR(s[4], 6) ^ RotR(s[4], 11) ^ RotR(s[4], 25); + const auto xor_a = rotR(s[0], 2) ^ rotR(s[0], 13) ^ rotR(s[0], 22); + const auto xor_e = rotR(s[4], 6) ^ rotR(s[4], 11) ^ rotR(s[4], 25); const auto sum = m[i] + k[i] + s[7] + chs + xor_e; @@ -110,7 +110,7 @@ void scl::util::Sha256::Transform() { } } -void scl::util::Sha256::Pad() { +void scl::util::Sha256::pad() { auto i = m_chunk_pos; const auto end = m_chunk_pos < 56U ? 56U : 64U; @@ -120,7 +120,7 @@ void scl::util::Sha256::Pad() { } if (m_chunk_pos >= 56) { - Transform(); + transform(); std::fill(m_chunk.begin(), m_chunk.begin() + 56, 0); } @@ -135,10 +135,10 @@ void scl::util::Sha256::Pad() { m_chunk[57] = m_total_len >> 48; m_chunk[56] = m_total_len >> 56; - Transform(); + transform(); } -scl::util::Sha256::DigestType scl::util::Sha256::WriteDigest() { +scl::util::Sha256::DigestType scl::util::Sha256::writeDigest() { Sha256::DigestType digest; for (std::size_t i = 0; i < 4; ++i) { @@ -150,18 +150,18 @@ scl::util::Sha256::DigestType scl::util::Sha256::WriteDigest() { return digest; } -void scl::util::Sha256::Hash(const unsigned char* bytes, std::size_t nbytes) { +void scl::util::Sha256::hash(const unsigned char* bytes, std::size_t nbytes) { for (std::size_t i = 0; i < nbytes; ++i) { m_chunk[m_chunk_pos++] = bytes[i]; if (m_chunk_pos == 64) { - Transform(); + transform(); m_total_len += 512; m_chunk_pos = 0; } } } -scl::util::Sha256::DigestType scl::util::Sha256::Write() { - Pad(); - return WriteDigest(); +scl::util::Sha256::DigestType scl::util::Sha256::write() { + pad(); + return writeDigest(); } diff --git a/src/scl/util/sha3.cc b/src/scl/util/sha3.cc index 53c8fc2..7121b10 100644 --- a/src/scl/util/sha3.cc +++ b/src/scl/util/sha3.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -43,7 +43,7 @@ uint64_t RotLeft64(uint64_t x, uint64_t y) { } // namespace -void scl::util::Keccakf(uint64_t state[25]) { +void scl::util::keccakf(uint64_t state[25]) { uint64_t t; uint64_t bc[5]; diff --git a/src/scl/util/str.cc b/src/scl/util/str.cc index 7b98311..7872efc 100644 --- a/src/scl/util/str.cc +++ b/src/scl/util/str.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -20,7 +20,7 @@ #include template <> -std::string scl::util::ToHexString(const __uint128_t& v) { +std::string scl::util::toHexString(const __uint128_t& v) { std::string str; if (v == 0) { str = "0"; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..13f7c2f --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,119 @@ +# SCL --- Secure Computation Library +# Copyright (C) 2024 Anders Dalskov +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +cmake_minimum_required(VERSION 3.5) + +set(SCL_SOURCE_FILES_TEST + scl/util/test_prg.cc + scl/util/test_sha3.cc + scl/util/test_sha256.cc + scl/util/test_ecdsa.cc + scl/util/test_cmdline.cc + scl/util/test_merkle.cc + scl/util/test_bitmap.cc + scl/util/test_measurement.cc + + scl/serialization/test_serializer.cc + + scl/gf7.cc + scl/math/test_mersenne61.cc + scl/math/test_mersenne127.cc + scl/math/test_vector.cc + scl/math/test_matrix.cc + scl/math/test_la.cc + scl/math/test_ff.cc + scl/math/test_z2k.cc + scl/math/test_poly.cc + scl/math/test_array.cc + + scl/math/test_secp256k1.cc + scl/math/test_number.cc + + scl/ss/test_additive.cc + scl/ss/test_shamir.cc + scl/ss/test_feldman.cc + scl/ss/test_pedersen.cc + + scl/coro/test_task.cc + scl/coro/test_batch.cc + + scl/net/util.cc + scl/net/test_config.cc + scl/net/test_loopback.cc + scl/net/test_network.cc + scl/net/test_packet.cc + + scl/protocol/test_protocol.cc + + scl/simulation/test_event.cc + scl/simulation/test_context.cc + scl/simulation/test_config.cc + scl/simulation/test_channel.cc + scl/simulation/test_simulator.cc +) + +Include(FetchContent) + +# get Catch2 and compile it. We cannot use the system provided +# version, since that is compiled with C++14 which causes linker +# errors... +FetchContent_Declare( + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.4.0 +) +FetchContent_MakeAvailable(Catch2) + +include(CTest) +include(Catch) + +add_compile_definitions(SCL_TEST_DATA_DIR="${CMAKE_SOURCE_DIR}/test/data/") + +add_executable(scl_test ${SCL_SOURCE_FILES_TEST}) +target_link_libraries(scl_test + PRIVATE Catch2::Catch2WithMain + PRIVATE scl + PRIVATE pthread + PRIVATE gmp) +catch_discover_tests(scl_test) + +if(SCL_BUILD_TEST_WITH_COVERAGE) + + if (NOT SCL_GCOV_BIN) + set(SCL_GCOV_BIN "gcov") + endif() + + ## stuff that SCL uses, but which we are not interested in + ## generating coverage of. + set(SCL_COVERAGE_EXCLUDES "") + list(APPEND SCL_COVERAGE_EXCLUDES "'/usr/*'") + list(APPEND SCL_COVERAGE_EXCLUDES "'${CMAKE_BINARY_DIR}/*'") + list(APPEND SCL_COVERAGE_EXCLUDES "'${CMAKE_SOURCE_DIR}/test/*'") + + add_custom_target(coverage + COMMAND lcov --ignore-errors mismatch --gcov-tool ${SCL_GCOV_BIN} -d ${CMAKE_BINARY_DIR} -b ${CMAKE_BINARY_DIR} -z + COMMAND lcov --ignore-errors mismatch --gcov-tool ${SCL_GCOV_BIN} -d ${CMAKE_BINARY_DIR} -b ${CMAKE_BINARY_DIR} -c -i -o cov.base + COMMAND scl_test + COMMAND lcov --ignore-errors mismatch --gcov-tool ${SCL_GCOV_BIN} -d ${CMAKE_BINARY_DIR} -b ${CMAKE_BINARY_DIR} -c -o cov.cap + COMMAND lcov --ignore-errors mismatch --gcov-tool ${SCL_GCOV_BIN} -a cov.base -a cov.cap -o cov.total + COMMAND wc -l cov.total + COMMAND lcov --ignore-errors mismatch --gcov-tool ${SCL_GCOV_BIN} --remove cov.total ${SCL_COVERAGE_EXCLUDES} -o cov.info + COMMAND genhtml --demangle-cpp --ignore-errors mismatch -o coverage cov.info + + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + BYPRODUCTS cov.base cov.cap cov.total cov.info) + +endif() diff --git a/test/scl/coro/test_batch.cc b/test/scl/coro/test_batch.cc new file mode 100644 index 0000000..019591b --- /dev/null +++ b/test/scl/coro/test_batch.cc @@ -0,0 +1,83 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include +#include + +#include "scl/coro/batch.h" +#include "scl/coro/runtime.h" +#include "scl/coro/task.h" + +using namespace scl; + +namespace { + +coro::Task task() { + co_return 42; +} + +coro::Task batch() { + std::vector> tasks; + tasks.emplace_back(task()); + tasks.emplace_back(task()); + tasks.emplace_back(task()); + + auto rs = co_await coro::batch(std::move(tasks)); + REQUIRE(rs.size() == 3); + REQUIRE(rs[0] == 42); + REQUIRE(rs[1] == 42); + REQUIRE(rs[2] == 42); +} + +} // namespace + +TEST_CASE("Simple batch", "[coro]") { + auto rt = coro::DefaultRuntime::create(); + rt->run(batch()); +} + +namespace { + +coro::Task sleeps() { + using namespace std::chrono_literals; + co_await 100h; + co_return 42; +} + +coro::Task partialBatch() { + std::vector> tasks; + tasks.emplace_back(task()); + tasks.emplace_back(sleeps()); + tasks.emplace_back(task()); + + auto rs = co_await coro::batch(std::move(tasks), 2); + REQUIRE(rs.size() == 3); + REQUIRE(rs[0].has_value()); + REQUIRE_FALSE(rs[1].has_value()); + REQUIRE(rs[2].has_value()); + + REQUIRE(rs[0].value() == 42); + REQUIRE(rs[2].value() == 42); +} + +} // namespace + +TEST_CASE("Partial batch execution", "[coro]") { + auto rt = coro::DefaultRuntime::create(); + rt->run(partialBatch()); +} diff --git a/test/scl/coro/test_task.cc b/test/scl/coro/test_task.cc new file mode 100644 index 0000000..cb49724 --- /dev/null +++ b/test/scl/coro/test_task.cc @@ -0,0 +1,120 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include + +#include "scl/coro/runtime.h" +#include "scl/coro/task.h" + +using namespace scl; + +namespace { + +coro::Task voidTask(bool& r) { + r = true; + co_return; +} + +} // namespace + +TEST_CASE("void task", "[coro]") { + bool set = false; + auto rt = coro::DefaultRuntime::create(); + + // Tasks are cold start, so not run before being executed by a runtime. + auto void_task = voidTask(set); + REQUIRE(!set); + + rt->run(std::move(void_task)); + + REQUIRE(set); +} + +namespace { + +coro::Task intTask() { + co_return 42; +} + +} // namespace + +TEST_CASE("int task", "[coro]") { + auto rt = coro::DefaultRuntime::create(); + auto r = rt->run(intTask()); + REQUIRE(r == 42); +} + +namespace { + +coro::Task anotherIntTask() { + co_return co_await intTask() + 1; +} + +coro::Task adder() { + auto v0 = co_await intTask(); + auto v1 = co_await anotherIntTask(); + co_return v0 + v1; +} + +} // namespace + +TEST_CASE("adder task", "[coro]") { + auto rt = coro::DefaultRuntime::create(); + // runs until the coroutine returns, even if it awaits. + auto r = rt->run(adder()); + REQUIRE(r == 42 + 43); +} + +namespace { + +coro::Task throws() { + throw std::runtime_error("oops"); +} + +coro::Task voidThrows() { + co_await throws(); +} + +coro::Task nonVoidThrows() { + co_await throws(); + co_return 42; +} + +} // namespace + +TEST_CASE("task throws void", "[coro]") { + auto rt = coro::DefaultRuntime::create(); + REQUIRE_THROWS_MATCHES(rt->run(voidThrows()), + std::runtime_error, + Catch::Matchers::Message("oops")); +} + +TEST_CASE("task throws non-void", "[coro]") { + auto rt = coro::DefaultRuntime::create(); + REQUIRE_THROWS_MATCHES(rt->run(nonVoidThrows()), + std::runtime_error, + Catch::Matchers::Message("oops")); +} + +TEST_CASE("result on unfinished Task", "[coro]") { + auto t1 = nonVoidThrows(); + REQUIRE_THROWS_MATCHES( + t1.result(), + std::logic_error, + Catch::Matchers::Message("result() called on unfinished coroutine")); +} diff --git a/test/scl/gf7.cc b/test/scl/gf7.cc index 65d5083..4b8ead6 100644 --- a/test/scl/gf7.cc +++ b/test/scl/gf7.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,24 +17,25 @@ #include "./gf7.h" -#include "scl/math/ff_ops.h" +#include "scl/math/fields/ff_ops.h" -using GF7 = scl::test::GaloisField7; +using namespace scl; + +using GF7 = test::GaloisField7; template <> -void scl::math::FieldConvertIn(unsigned char& out, int v) { +void math::ff::convertTo(unsigned char& out, int v) { auto r = v % 7; out = r < 0 ? 7 + r : r; } template <> -void scl::math::FieldAdd(unsigned char& out, const unsigned char& op) { +void math::ff::add(unsigned char& out, const unsigned char& op) { out = (out + op) % 7; } template <> -void scl::math::FieldSubtract(unsigned char& out, - const unsigned char& op) { +void math::ff::subtract(unsigned char& out, const unsigned char& op) { if (out < op) { out = 7 + out - op; } else { @@ -43,18 +44,17 @@ void scl::math::FieldSubtract(unsigned char& out, } template <> -void scl::math::FieldMultiply(unsigned char& out, - const unsigned char& op) { +void math::ff::multiply(unsigned char& out, const unsigned char& op) { out = (out * op) % 7; } template <> -void scl::math::FieldNegate(unsigned char& out) { +void math::ff::negate(unsigned char& out) { out = (7 - out) % 7; } template <> -void scl::math::FieldInvert(unsigned char& out) { +void math::ff::invert(unsigned char& out) { unsigned char inv; switch (out) { case 1: @@ -80,26 +80,23 @@ void scl::math::FieldInvert(unsigned char& out) { } template <> -bool scl::math::FieldEqual(const unsigned char& in1, - const unsigned char& in2) { +bool math::ff::equal(const unsigned char& in1, const unsigned char& in2) { return in1 == in2; } template <> -void scl::math::FieldFromBytes(unsigned char& dest, - const unsigned char* src) { +void math::ff::fromBytes(unsigned char& dest, const unsigned char* src) { dest = *src; dest = dest % 7; } template <> -void scl::math::FieldToBytes(unsigned char* dest, - const unsigned char& src) { +void math::ff::toBytes(unsigned char* dest, const unsigned char& src) { *dest = src; } template <> -std::string scl::math::FieldToString(const unsigned char& in) { +std::string math::ff::toString(const unsigned char& in) { std::stringstream ss; ss << (int)in; return ss.str(); diff --git a/test/scl/gf7.h b/test/scl/gf7.h index 3c3ac1d..6bf83fc 100644 --- a/test/scl/gf7.h +++ b/test/scl/gf7.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by diff --git a/test/scl/math/fields.h b/test/scl/math/fields.h index 3d4f9ac..5f91f54 100644 --- a/test/scl/math/fields.h +++ b/test/scl/math/fields.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -16,7 +16,8 @@ */ #include "../gf7.h" -#include "scl/math/curves/secp256k1.h" +#include "scl/math/fields/secp256k1_field.h" +#include "scl/math/fields/secp256k1_scalar.h" #include "scl/math/fp.h" namespace scl::test { @@ -25,17 +26,11 @@ using Mersenne61 = math::Fp<61>; using Mersenne127 = math::Fp<127>; using GF7 = math::FF; -#ifdef SCL_ENABLE_EC_TESTS -using Secp256k1_Field = math::FF; -using Secp256k1_Order = math::FF; -#endif +using Secp256k1_Field = math::FF; +using Secp256k1_Order = math::FF; } // namespace scl::test -#ifdef SCL_ENABLE_EC_TESTS #define FIELD_DEFS \ scl::test::Mersenne61, scl::test::Mersenne127, scl::test::GF7, \ scl::test::Secp256k1_Field, scl::test::Secp256k1_Order -#else -#define FIELD_DEFS scl::test::Mersenne61, scl::test::Mersenne127, scl::test::GF7 -#endif diff --git a/test/scl/math/test_array.cc b/test/scl/math/test_array.cc new file mode 100644 index 0000000..d49c447 --- /dev/null +++ b/test/scl/math/test_array.cc @@ -0,0 +1,82 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include + +#include "scl/math/array.h" +#include "scl/math/curves/secp256k1.h" +#include "scl/math/ec.h" +#include "scl/math/ff.h" +#include "scl/math/fields/mersenne127.h" +#include "scl/serialization/serializer.h" + +using namespace scl; +using G = math::EC; +using F = G::ScalarField; + +TEST_CASE("Array default init", "[math]") { + const G inf; + math::Array p; + + REQUIRE(p == math::Array{{inf, inf, inf, inf}}); + + const auto zero = F::zero(); + math::Array q; + REQUIRE(q == math::Array{{zero, zero, zero}}); +} + +TEST_CASE("Array operations", "[math]") { + math::Array p = {{F(1), F(2), F(4)}}; + math::Array q = {{F(4), F(2), F(1)}}; + + REQUIRE(p + q == math::Array{{F(5), F(4), F(5)}}); + REQUIRE(p - q == math::Array{{F(-3), F(0), F(3)}}); + REQUIRE(p * q == math::Array{{F(4), F(4), F(4)}}); + REQUIRE(q * p == math::Array{{F(4), F(4), F(4)}}); +} + +TEST_CASE("Array operations mixed", "[math]") { + const auto gen = G::generator(); + math::Array g = {{gen, gen, gen}}; + math::Array f = {{F(44), F(55), F(66)}}; + + REQUIRE(g * f == math::Array{{gen * F(44), gen * F(55), gen * F(66)}}); + REQUIRE(f * g == math::Array{{gen * F(44), gen * F(55), gen * F(66)}}); +} + +TEST_CASE("Array to string", "[math]") { + math::Array p; + REQUIRE(p.toString() == "P{EC{POINT_AT_INFINITY}, EC{POINT_AT_INFINITY}}"); +} + +TEST_CASE("Array serialization", "[math]") { + auto prg = util::PRG::create("prod seri"); + auto prod = math::Array::random(prg); + + using S = seri::Serializer>; + + unsigned char buf[S::sizeOf(prod)]; + S::write(prod, buf); + + math::Array p; + + REQUIRE(p != prod); + + S::read(p, buf); + + REQUIRE(p == prod); +} diff --git a/test/scl/math/test_ff.cc b/test/scl/math/test_ff.cc index 2c9e24a..6db4a03 100644 --- a/test/scl/math/test_ff.cc +++ b/test/scl/math/test_ff.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,9 @@ * along with this program. If not, see . */ -#include +#include +#include +#include #include #include "fields.h" @@ -26,15 +28,15 @@ using namespace scl; namespace { template -T RandomNonZero(util::PRG& prg) { - auto a = T::Random(prg); +T randomNonZero(util::PRG& prg) { + auto a = T::random(prg); for (std::size_t i = 0; i < 10; ++i) { - if (a == T::Zero()) { - a = T::Random(prg); + if (a == T::zero()) { + a = T::random(prg); } break; } - if (a == T::Zero()) { + if (a == T::zero()) { throw std::logic_error("could not generate a non-zero random value"); } return a; @@ -43,8 +45,8 @@ T RandomNonZero(util::PRG& prg) { // Specialization for the very small field since it's apparently possible to hit // zero 10 times in a row... template <> -test::GF7 RandomNonZero(util::PRG& prg) { - auto a = test::GF7::Random(prg); +test::GF7 randomNonZero(util::PRG& prg) { + auto a = test::GF7::random(prg); if (a != test::GF7(6)) { return a + test::GF7(1); } @@ -62,45 +64,54 @@ test::GF7 RandomNonZero(util::PRG& prg) { TEMPLATE_TEST_CASE("FF Random", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto prg = util::PRG::Create(); - auto zero = FF::Zero(); + auto prg = util::PRG::create(); + auto zero = FF::zero(); - auto nz = RandomNonZero(prg); + auto nz = randomNonZero(prg); REQUIRE(nz != zero); } TEMPLATE_TEST_CASE("FF Addition", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto zero = FF::Zero(); - auto prg = util::PRG::Create("FF addition"); + auto zero = FF::zero(); + auto prg = util::PRG::create("FF addition"); REPEAT { - auto x = RandomNonZero(prg); - auto y = RandomNonZero(prg); - auto c = x + y; - REQUIRE(c != x); - REQUIRE(c != y); - REQUIRE(c == y + x); - x += y; - REQUIRE(c == x); + auto a = randomNonZero(prg); + auto b = randomNonZero(prg); + auto c = a + b; + REQUIRE(c != a); + REQUIRE(c != b); + REQUIRE(c == b + a); + a += b; + REQUIRE(c == a); REQUIRE(c + zero == c); + + // post-increment should return old value. + auto old_a = a++; + REQUIRE(a == old_a + FF::one()); + + // pre-increment should return new value. + auto old_b = b; + auto new_b = ++b; + REQUIRE(old_b == new_b - FF::one()); } } TEMPLATE_TEST_CASE("FF Negation", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto zero = FF::Zero(); + auto zero = FF::zero(); REQUIRE(zero == -zero); - auto prg = util::PRG::Create("FF negation"); + auto prg = util::PRG::create("FF negation"); REPEAT { - auto a = RandomNonZero(prg); - auto a_negated = a.Negated(); + auto a = randomNonZero(prg); + auto a_negated = a.negated(); REQUIRE(a != a_negated); REQUIRE(a + a_negated == zero); REQUIRE(a_negated == -a); - a.Negate(); + a.negate(); REQUIRE(a == a_negated); REQUIRE(a - zero == a); } @@ -109,31 +120,38 @@ TEMPLATE_TEST_CASE("FF Negation", "[math][ff]", FIELD_DEFS) { TEMPLATE_TEST_CASE("FF Subtraction", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto zero = FF::Zero(); - auto prg = util::PRG::Create("FF subtraction"); + auto zero = FF::zero(); + auto prg = util::PRG::create("FF subtraction"); REPEAT { - auto a = RandomNonZero(prg); - auto b = RandomNonZero(prg); + auto a = randomNonZero(prg); + auto b = randomNonZero(prg); REQUIRE(a - b == -(b - a)); REQUIRE(a - b == -b + a); REQUIRE(a - a == zero); auto c = a - b; a -= b; REQUIRE(c == a); + + auto old_a = a--; + REQUIRE(a == old_a - FF::one()); + + auto old_b = b; + auto new_b = --b; + REQUIRE(old_b == new_b + FF::one()); } } TEMPLATE_TEST_CASE("FF Multiplication", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto zero = FF::Zero(); - auto prg = util::PRG::Create("FF multiplication"); + auto zero = FF::zero(); + auto prg = util::PRG::create("FF multiplication"); REPEAT { - auto a = RandomNonZero(prg); - auto b = RandomNonZero(prg); + auto a = randomNonZero(prg); + auto b = randomNonZero(prg); REQUIRE(a * b != zero); REQUIRE(a * b == b * a); - auto c = RandomNonZero(prg); + auto c = randomNonZero(prg); REQUIRE(c * (a + b) == c * a + c * b); auto d = a * b; a *= b; @@ -146,18 +164,18 @@ TEMPLATE_TEST_CASE("FF Multiplication", "[math][ff]", FIELD_DEFS) { TEMPLATE_TEST_CASE("FF Inversion", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto zero = FF::Zero(); + auto zero = FF::zero(); REQUIRE_THROWS_MATCHES( - zero.Inverse(), + zero.inverse(), std::logic_error, Catch::Matchers::Message("0 not invertible modulo prime")); - auto prg = util::PRG::Create("FF inversion"); + auto prg = util::PRG::create("FF inversion"); REPEAT { - auto a = RandomNonZero(prg); - auto a_inverse = a.Inverse(); - REQUIRE(a * a_inverse == FF::One()); - a.Invert(); + auto a = randomNonZero(prg); + auto a_inverse = a.inverse(); + REQUIRE(a * a_inverse == FF::one()); + a.invert(); REQUIRE(a == a_inverse); } } @@ -165,13 +183,13 @@ TEMPLATE_TEST_CASE("FF Inversion", "[math][ff]", FIELD_DEFS) { TEMPLATE_TEST_CASE("FF Division", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto zero = FF::Zero(); - auto prg = util::PRG::Create("FF division"); + auto zero = FF::zero(); + auto prg = util::PRG::create("FF division"); REPEAT { - auto a = RandomNonZero(prg); - auto b = RandomNonZero(prg); - REQUIRE(a / a == FF::One()); - REQUIRE(a / b == (b / a).Inverse()); + auto a = randomNonZero(prg); + auto b = randomNonZero(prg); + REQUIRE(a / a == FF::one()); + REQUIRE(a / b == (b / a).inverse()); auto c = a / b; a /= b; REQUIRE(c == a); @@ -182,13 +200,13 @@ TEMPLATE_TEST_CASE("FF Division", "[math][ff]", FIELD_DEFS) { TEMPLATE_TEST_CASE("FF serialization", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto prg = util::PRG::Create("FF serialization"); + auto prg = util::PRG::create("FF serialization"); REPEAT { - auto a = RandomNonZero(prg); - unsigned char buf[FF::ByteSize()] = {0}; - a.Write(buf); + auto a = randomNonZero(prg); + unsigned char buf[FF::byteSize()] = {0}; + a.write(buf); - auto b = FF::Read(buf); + auto b = FF::read(buf); REQUIRE(a == b); } } @@ -196,14 +214,14 @@ TEMPLATE_TEST_CASE("FF serialization", "[math][ff]", FIELD_DEFS) { TEMPLATE_TEST_CASE("FF Exp", "[math][ff]", FIELD_DEFS) { using FF = TestType; - auto prg = util::PRG::Create("FF exp"); + auto prg = util::PRG::create("FF exp"); - auto a = RandomNonZero(prg); + auto a = randomNonZero(prg); - REQUIRE(a == Exp(a, 1)); - REQUIRE(a * a == Exp(a, 2)); + REQUIRE(a == exp(a, 1)); + REQUIRE(a * a == exp(a, 2)); - REQUIRE(a * a * a * a * a * a == Exp(a, 6)); + REQUIRE(a * a * a * a * a * a == exp(a, 6)); - REQUIRE(FF::One() == Exp(a, 0)); + REQUIRE(FF::one() == exp(a, 0)); } diff --git a/test/scl/math/test_la.cc b/test/scl/math/test_la.cc index f5b8b8a..da3463e 100644 --- a/test/scl/math/test_la.cc +++ b/test/scl/math/test_la.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,49 +15,50 @@ * along with this program. If not, see . */ -#include +#include +#include #include "../gf7.h" #include "scl/math/fp.h" -#include "scl/math/la.h" -#include "scl/math/mat.h" +#include "scl/math/matrix.h" using namespace scl; using FF = math::FF; -using Mat = math::Mat; -using Vec = math::Vec; +using Matrix = math::Matrix; +using Vector = math::Vector; -const auto zero = FF::Zero(); -const auto one = FF::One(); +const auto zero = FF::zero(); +const auto one = FF::one(); TEST_CASE("LinAlg GetPivot", "[math][la]") { // [1 0 1] // [0 1 0] // [0 0 0] // clang-format off - Mat A = Mat::FromVector(3, 3, - {one, zero, one, - zero, one, zero, - zero, zero, zero}); + Matrix A = Matrix::fromVector(3, 3, + {one, zero, one, + zero, one, zero, + zero, zero, zero}); // clang-format on - REQUIRE(math::GetPivotInColumn(A, 2) == -1); - REQUIRE(math::GetPivotInColumn(A, 1) == 1); - REQUIRE(math::GetPivotInColumn(A, 0) == 0); + REQUIRE(math::getPivotInColumn(A, 2) == -1); + REQUIRE(math::getPivotInColumn(A, 1) == 1); + REQUIRE(math::getPivotInColumn(A, 0) == 0); A(2, 2) = one; - REQUIRE(math::GetPivotInColumn(A, 2) == 2); + REQUIRE(math::getPivotInColumn(A, 2) == 2); - Mat B(2, 2); - REQUIRE(math::GetPivotInColumn(B, 0) == -1); + Matrix B(2, 2); + REQUIRE(math::getPivotInColumn(B, 0) == -1); } TEST_CASE("LinAlg FindFirstNonZeroRow", "[math][la]") { - Mat A = Mat::FromVector(3, - 3, - {one, zero, one, zero, one, zero, zero, zero, zero}); - REQUIRE(math::FindFirstNonZeroRow(A) == 1); + Matrix A = + Matrix::fromVector(3, + 3, + {one, zero, one, zero, one, zero, zero, zero, zero}); + REQUIRE(math::findFirstNonZeroRow(A) == 1); A(2, 1) = one; - REQUIRE(math::FindFirstNonZeroRow(A) == 2); + REQUIRE(math::findFirstNonZeroRow(A) == 2); } TEST_CASE("LinAlg ExtractSolution", "[math][la]") { @@ -65,88 +66,65 @@ TEST_CASE("LinAlg ExtractSolution", "[math][la]") { // [0 1 0 5] // [0 0 1 2] // clang-format off - Mat A = Mat::FromVector(3, 4, - {one, zero, zero, FF(3), - zero, one, zero, FF(5), - zero, zero, one, FF(2)} + Matrix A = Matrix::fromVector(3, 4, + {one, zero, zero, FF(3), + zero, one, zero, FF(5), + zero, zero, one, FF(2)} ); // clang-format-on - auto x = math::ExtractSolution(A); - REQUIRE(x.Equals(Vec{FF(3), FF(5), FF(2)})); + auto x = math::extractSolution(A); + REQUIRE(x.equals(Vector{FF(3), FF(5), FF(2)})); // [1 3 1 2] // [0 0 1 4] // [0 0 0 0] // clang-format off - Mat B = Mat::FromVector(3, 4, - {FF(1), FF(3), FF(1), FF(2), - FF(0), FF(0), FF(1), FF(4), - FF(0), FF(0), FF(0), FF(0)}); + Matrix B = Matrix::fromVector(3, 4, + {FF(1), FF(3), FF(1), FF(2), + FF(0), FF(0), FF(1), FF(4), + FF(0), FF(0), FF(0), FF(0)}); // clang-format on - auto y = math::ExtractSolution(B); - REQUIRE(y.Equals(Vec{FF(4), FF(4), FF(0)})); + auto y = math::extractSolution(B); + REQUIRE(y.equals(Vector{FF(4), FF(4), FF(0)})); // [0 0 0 0] // [2 0 0 0] // [0 0 0 0] - Mat C(3, 4); + Matrix C(3, 4); C(1, 0) = FF(2); - auto z = math::ExtractSolution(C); - REQUIRE(z.Equals(Vec{zero, one, zero})); + auto z = math::extractSolution(C); + REQUIRE(z.equals(Vector{zero, one, zero})); } TEST_CASE("LinAlg Solve random", "[math][la]") { auto n = 10; - auto prg = util::PRG::Create(); + auto prg = util::PRG::create(); - Mat A = Mat::Random(n, n, prg); - Vec b = Vec::Random(n, prg); - Vec x(n); - math::SolveLinearSystem(x, A, b); + Matrix A = Matrix::random(n, n, prg); + Vector b = Vector::random(n, prg); + Vector x(n); + math::solveLinearSystem(x, A, b); - REQUIRE(A.Multiply(x.ToColumnMatrix()).Equals(b.ToColumnMatrix())); + REQUIRE(A.multiply(x.toColumnMatrix()).equals(b.toColumnMatrix())); } TEST_CASE("LinAlg malformed systems", "[math][la]") { - Vec x; - Mat A(2, 2); - Vec b(3); + Vector x; + Matrix A(2, 2); + Vector b(3); REQUIRE_THROWS_MATCHES( - math::SolveLinearSystem(x, A, b), + math::solveLinearSystem(x, A, b), std::invalid_argument, Catch::Matchers::Message("malformed system of equations")); } TEST_CASE("LinAlg HasSolution", "[math][la]") { - Mat A(2, 3); + Matrix A(2, 3); // Has an all zero row, so no unique solution is possible - REQUIRE_FALSE(math::HasSolution(A, true)); + REQUIRE_FALSE(math::hasSolution(A, true)); // An all zero row implies a free variable, so many solutions exist - REQUIRE(math::HasSolution(A, false)); + REQUIRE(math::hasSolution(A, false)); A(0, 2) = FF(1); - REQUIRE_FALSE(math::HasSolution(A, false)); -} - -TEST_CASE("LinAlg compute inverse", "[math][la]") { - // TODO: This could/should be placed as a helper in the mat class, I think. - - std::size_t n = 10; - auto prg = util::PRG::Create(); - Mat A = Mat::Random(n, n, prg); - Mat I = Mat::Identity(n); - - auto aug = math::CreateAugmentedMatrix(A, I); - REQUIRE_FALSE(aug.IsIdentity()); - math::RowReduceInPlace(aug); - - Mat Ainv(n, n); - for (std::size_t i = 0; i < n; ++i) { - for (std::size_t j = 0; j < n; ++j) { - Ainv(i, j) = aug(i, n + j); - } - } - - REQUIRE_FALSE(A.IsIdentity()); - REQUIRE(A.Multiply(Ainv).IsIdentity()); + REQUIRE_FALSE(math::hasSolution(A, false)); } diff --git a/test/scl/math/test_mat.cc b/test/scl/math/test_matrix.cc similarity index 50% rename from test/scl/math/test_mat.cc rename to test/scl/math/test_matrix.cc index 89ad1fd..10cf649 100644 --- a/test/scl/math/test_mat.cc +++ b/test/scl/math/test_matrix.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,25 +15,25 @@ * along with this program. If not, see . */ -#include -#include +#include +#include #include #include "scl/math/fp.h" -#include "scl/math/mat.h" +#include "scl/math/matrix.h" using namespace scl; using FF = math::Fp<61>; -using Mat = math::Mat; -using Vec = math::Vec; +using Matrix = math::Matrix; +using Vector = math::Vector; namespace { -void Populate(Mat& m, const std::vector& values) { - for (std::size_t i = 0; i < m.Rows(); i++) { - for (std::size_t j = 0; j < m.Cols(); j++) { - m(i, j) = FF(values[i * m.Cols() + j]); +void populate(Matrix& m, const std::vector& values) { + for (std::size_t i = 0; i < m.rows(); i++) { + for (std::size_t j = 0; j < m.cols(); j++) { + m(i, j) = FF(values[i * m.cols() + j]); } } } @@ -41,37 +41,37 @@ void Populate(Mat& m, const std::vector& values) { } // namespace TEST_CASE("Matrix construction", "[math][matrix]") { - Mat m0(2, 2); - Populate(m0, {1, 2, 5, 6}); - Mat m1(2, 2); - Populate(m1, {4, 3, 2, 1}); - - REQUIRE(!m0.Equals(m1)); - REQUIRE(m0.Rows() == 2); - REQUIRE(m0.Cols() == 2); + Matrix m0(2, 2); + populate(m0, {1, 2, 5, 6}); + Matrix m1(2, 2); + populate(m1, {4, 3, 2, 1}); + + REQUIRE(!m0.equals(m1)); + REQUIRE(m0.rows() == 2); + REQUIRE(m0.cols() == 2); REQUIRE(m0(0, 0) == FF(1)); REQUIRE(m0(0, 1) == FF(2)); REQUIRE(m0(1, 0) == FF(5)); REQUIRE(m0(1, 1) == FF(6)); - Mat a(5, 5); - Mat b(5); // square matrix + Matrix a(5, 5); + Matrix b(5); // square matrix // matrices are 0 initialized, so the above matrices are equal - REQUIRE(a.Equals(b)); + REQUIRE(a.equals(b)); - REQUIRE_THROWS_MATCHES(Mat(0, 1), + REQUIRE_THROWS_MATCHES(Matrix(0, 1), std::invalid_argument, Catch::Matchers::Message("n or m cannot be 0")); - REQUIRE_THROWS_MATCHES(Mat(1, 0), + REQUIRE_THROWS_MATCHES(Matrix(1, 0), std::invalid_argument, Catch::Matchers::Message("n or m cannot be 0")); } TEST_CASE("Matrix construction random", "[math][matrix]") { - auto prg = util::PRG::Create(); - Mat mr = Mat::Random(4, 5, prg); - REQUIRE(mr.Rows() == 4); - REQUIRE(mr.Cols() == 5); + auto prg = util::PRG::create(); + Matrix mr = Matrix::random(4, 5, prg); + REQUIRE(mr.rows() == 4); + REQUIRE(mr.cols() == 5); bool not_zero = true; for (std::size_t i = 0; i < 4; i++) { for (std::size_t j = 0; j < 5; j++) { @@ -82,9 +82,9 @@ TEST_CASE("Matrix construction random", "[math][matrix]") { } TEST_CASE("Matrix construction from Vec", "[math][matrix]") { - Mat m = Mat::FromVector(2, 2, {FF(1), FF(2), FF(3), FF(4)}); - REQUIRE(m.Rows() == 2); - REQUIRE(m.Cols() == 2); + Matrix m = Matrix::fromVector(2, 2, {FF(1), FF(2), FF(3), FF(4)}); + REQUIRE(m.rows() == 2); + REQUIRE(m.cols() == 2); std::size_t k = 1; for (std::size_t i = 0; i < 2; ++i) { for (std::size_t j = 0; j < 2; ++j) { @@ -92,14 +92,14 @@ TEST_CASE("Matrix construction from Vec", "[math][matrix]") { } } - REQUIRE_THROWS_MATCHES(Mat::FromVector(2, 2, {FF(1)}), + REQUIRE_THROWS_MATCHES(Matrix::fromVector(2, 2, {FF(1)}), std::invalid_argument, Catch::Matchers::Message("invalid dimensions")); } TEST_CASE("Matrix mutation", "[math][matrix]") { - Mat m0(2, 2); - Populate(m0, {1, 2, 5, 6}); + Matrix m0(2, 2); + populate(m0, {1, 2, 5, 6}); auto m = m0; m(0, 1) = FF(100); @@ -108,132 +108,132 @@ TEST_CASE("Matrix mutation", "[math][matrix]") { } TEST_CASE("Matrix ToString", "[math][matrix]") { - Mat m(3, 2); - Populate(m, {1, 2, 44444, 5, 6, 7}); + Matrix m(3, 2); + populate(m, {1, 2, 44444, 5, 6, 7}); std::string expected = "\n" "[ 1 2 ]\n" "[ ad9c 5 ]\n" "[ 6 7 ]"; - REQUIRE(m.ToString() == expected); + REQUIRE(m.toString() == expected); std::stringstream ss; ss << m; REQUIRE(ss.str() == expected); - Mat m1; - REQUIRE(m1.ToString() == "[ EMPTY MATRIX ]"); + Matrix m1; + REQUIRE(m1.toString() == "[ EMPTY MATRIX ]"); } TEST_CASE("Matrix Addition", "[math][matrix]") { - Mat m0(2, 2); - Populate(m0, {1, 2, 5, 6}); - Mat m1(2, 2); - Populate(m1, {4, 3, 2, 1}); - - auto m2 = m0.Add(m1); - REQUIRE(m2.Rows() == 2); - REQUIRE(m2.Cols() == 2); + Matrix m0(2, 2); + populate(m0, {1, 2, 5, 6}); + Matrix m1(2, 2); + populate(m1, {4, 3, 2, 1}); + + auto m2 = m0.add(m1); + REQUIRE(m2.rows() == 2); + REQUIRE(m2.cols() == 2); REQUIRE(m2(0, 0) == FF(5)); REQUIRE(m2(0, 1) == FF(5)); REQUIRE(m2(1, 0) == FF(7)); REQUIRE(m2(1, 1) == FF(7)); - m2.AddInPlace(m0); - REQUIRE(m2.Equals(m0.Add(m0).Add(m1))); + m2.addInPlace(m0); + REQUIRE(m2.equals(m0.add(m0).add(m1))); } TEST_CASE("Matrix Subtraction", "[math][matrix]") { - Mat m0(2, 2); - Populate(m0, {1, 2, 5, 6}); - Mat m1(2, 2); - Populate(m1, {4, 3, 2, 1}); + Matrix m0(2, 2); + populate(m0, {1, 2, 5, 6}); + Matrix m1(2, 2); + populate(m1, {4, 3, 2, 1}); - auto m2 = m0.Subtract(m1); + auto m2 = m0.subtract(m1); REQUIRE(m2(0, 0) == FF(1) - FF(4)); REQUIRE(m2(0, 1) == FF(2) - FF(3)); REQUIRE(m2(1, 0) == FF(3)); REQUIRE(m2(1, 1) == FF(5)); - m2.SubtractInPlace(m0); - REQUIRE(m2.Equals(m0.Subtract(m0).Subtract(m1))); + m2.subtractInPlace(m0); + REQUIRE(m2.equals(m0.subtract(m0).subtract(m1))); } TEST_CASE("Matrix MultiplyEntryWise", "[math][matrix]") { - Mat m0(2, 2); - Populate(m0, {1, 2, 5, 6}); - Mat m1(2, 2); - Populate(m1, {4, 3, 2, 1}); + Matrix m0(2, 2); + populate(m0, {1, 2, 5, 6}); + Matrix m1(2, 2); + populate(m1, {4, 3, 2, 1}); - auto m2 = m0.MultiplyEntryWise(m1); + auto m2 = m0.multiplyEntryWise(m1); REQUIRE(m2(0, 0) == FF(4)); REQUIRE(m2(0, 1) == FF(6)); REQUIRE(m2(1, 0) == FF(10)); REQUIRE(m2(1, 1) == FF(6)); - m2.MultiplyEntryWiseInPlace(m0); - REQUIRE(m2.Equals(m0.MultiplyEntryWise(m0).MultiplyEntryWise(m1))); + m2.multiplyEntryWiseInPlace(m0); + REQUIRE(m2.equals(m0.multiplyEntryWise(m0).multiplyEntryWise(m1))); } TEST_CASE("Matrix Multiply", "[math][matrix]") { - Mat m0(2, 2); - Populate(m0, {1, 2, 5, 6}); - Mat m1(2, 2); - Populate(m1, {4, 3, 2, 1}); - - auto m2 = m0.Multiply(m1); - REQUIRE(m2.Rows() == 2); - REQUIRE(m2.Cols() == 2); + Matrix m0(2, 2); + populate(m0, {1, 2, 5, 6}); + Matrix m1(2, 2); + populate(m1, {4, 3, 2, 1}); + + auto m2 = m0.multiply(m1); + REQUIRE(m2.rows() == 2); + REQUIRE(m2.cols() == 2); REQUIRE(m2(0, 0) == FF(8)); REQUIRE(m2(0, 1) == FF(5)); REQUIRE(m2(1, 0) == FF(32)); REQUIRE(m2(1, 1) == FF(21)); - Mat m3(2, 10); - Populate(m3, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, + Matrix m3(2, 10); + populate(m3, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}); - auto m5 = m0.Multiply(m3); - REQUIRE(m5.Rows() == 2); - REQUIRE(m5.Cols() == 10); - Mat m4(2, 10); - Populate(m4, {23, 26, 29, 32, 35, 38, 41, 44, 47, 40, + auto m5 = m0.multiply(m3); + REQUIRE(m5.rows() == 2); + REQUIRE(m5.cols() == 10); + Matrix m4(2, 10); + populate(m4, {23, 26, 29, 32, 35, 38, 41, 44, 47, 40, 71, 82, 93, 104, 115, 126, 137, 148, 159, 120}); - REQUIRE(m5.Equals(m4)); + REQUIRE(m5.equals(m4)); REQUIRE_THROWS_MATCHES( - m3.Multiply(m0), + m3.multiply(m0), std::invalid_argument, - Catch::Matchers::Message("matmul: this->Cols() != that->Rows()")); + Catch::Matchers::Message("matmul: this->cols() != that->rows()")); } TEST_CASE("Matrix vector multiply", "[math][matrix]") { - Mat m0(2, 3); - Populate(m0, {1, 2, 3, 4, 5, 6}); - Vec v0 = {FF(1), FF(2), FF(3)}; + Matrix m0(2, 3); + populate(m0, {1, 2, 3, 4, 5, 6}); + Vector v0 = {FF(1), FF(2), FF(3)}; - Vec v1 = m0.Multiply(v0); - REQUIRE(v1.Size() == 2); + Vector v1 = m0.multiply(v0); + REQUIRE(v1.size() == 2); REQUIRE(v1[0] == FF(1 * 1 + 2 * 2 + 3 * 3)); REQUIRE(v1[1] == FF(4 * 1 + 5 * 2 + 6 * 3)); - Vec v2 = {FF(6), FF(7)}; + Vector v2 = {FF(6), FF(7)}; REQUIRE_THROWS_MATCHES( - m0.Multiply(v2), + m0.multiply(v2), std::invalid_argument, - Catch::Matchers::Message("matmul: this->Cols() != vec.Size()")); + Catch::Matchers::Message("matmul: this->cols() != vec.size()")); } TEST_CASE("Matrix ScalarMultiply", "[math][matrix]") { - Mat m0(2, 2); - Populate(m0, {1, 2, 5, 6}); - Mat m1(2, 2); - Populate(m1, {4, 3, 2, 1}); + Matrix m0(2, 2); + populate(m0, {1, 2, 5, 6}); + Matrix m1(2, 2); + populate(m1, {4, 3, 2, 1}); - auto m2 = m0.ScalarMultiply(FF(2)); + auto m2 = m0.scalarMultiply(FF(2)); REQUIRE(m2(0, 0) == FF(2)); REQUIRE(m2(0, 1) == FF(4)); REQUIRE(m2(1, 0) == FF(10)); REQUIRE(m2(1, 1) == FF(12)); - m2.ScalarMultiplyInPlace(FF(2)); + m2.scalarMultiplyInPlace(FF(2)); REQUIRE(m2(0, 0) == FF(4)); REQUIRE(m2(0, 1) == FF(8)); REQUIRE(m2(1, 0) == FF(20)); @@ -241,34 +241,34 @@ TEST_CASE("Matrix ScalarMultiply", "[math][matrix]") { } TEST_CASE("Matrix Transpose", "[math][matrix]") { - Mat m3(2, 3); - Populate(m3, {1, 2, 3, 11, 12, 13}); - auto m4 = m3.Transpose(); - REQUIRE(m4.Rows() == m3.Cols()); - REQUIRE(m4.Cols() == m3.Rows()); - REQUIRE(m4(0, 0) == FF(1)); - REQUIRE(m4(0, 1) == FF(11)); - REQUIRE(m4(1, 0) == FF(2)); - REQUIRE(m4(1, 1) == FF(12)); - REQUIRE(m4(2, 0) == FF(3)); - REQUIRE(m4(2, 1) == FF(13)); + Matrix m3(2, 3); + populate(m3, {1, 2, 3, 11, 12, 13}); + auto m4 = m3.transpose(); + REQUIRE(m4.rows() == m3.cols()); + REQUIRE(m4.cols() == m3.rows()); + REQUIRE(m4(0, 0) == m3(0, 0)); + REQUIRE(m4(0, 1) == m3(1, 0)); + REQUIRE(m4(1, 0) == m3(0, 1)); + REQUIRE(m4(1, 1) == m3(1, 1)); + REQUIRE(m4(2, 0) == m3(0, 2)); + REQUIRE(m4(2, 1) == m3(1, 2)); } TEST_CASE("Matrix check compatability", "[math][matrix]") { - Mat m0(2, 2); - Mat m1(3, 2); - REQUIRE_THROWS_MATCHES(m1.Add(m0), + Matrix m0(2, 2); + Matrix m1(3, 2); + REQUIRE_THROWS_MATCHES(m1.add(m0), std::invalid_argument, Catch::Matchers::Message("incompatible matrices")); } TEST_CASE("Matrix resize", "[math][matrix]") { - auto prg = util::PRG::Create(); - Mat m = Mat::Random(2, 4, prg); + auto prg = util::PRG::create(); + Matrix m = Matrix::random(2, 4, prg); auto copy = m; - copy.Resize(1, 8); - REQUIRE(copy.Rows() == 1); - REQUIRE(copy.Cols() == 8); + copy.resize(1, 8); + REQUIRE(copy.rows() == 1); + REQUIRE(copy.cols() == 8); std::size_t c = 0; for (std::size_t i = 0; i < 2; ++i) { for (std::size_t j = 0; j < 4; ++j) { @@ -276,37 +276,37 @@ TEST_CASE("Matrix resize", "[math][matrix]") { } } - REQUIRE_THROWS_MATCHES(m.Resize(42, 4), + REQUIRE_THROWS_MATCHES(m.resize(42, 4), std::invalid_argument, Catch::Matchers::Message("cannot resize matrix")); } TEST_CASE("Matrix equality", "[math][matrix]") { - auto prg = util::PRG::Create(); - auto m0 = Mat::Random(3, 4, prg); - auto m1 = Mat::Random(3, 4, prg); + auto prg = util::PRG::create(); + auto m0 = Matrix::random(3, 4, prg); + auto m1 = Matrix::random(3, 4, prg); - REQUIRE_FALSE(m0.Equals(m1)); - prg.Reset(); - m1 = Mat::Random(3, 4, prg); - REQUIRE(m0.Equals(m1)); + REQUIRE_FALSE(m0.equals(m1)); + prg.reset(); + m1 = Matrix::random(3, 4, prg); + REQUIRE(m0.equals(m1)); - auto m2 = Mat::Random(2, 2, prg); - REQUIRE_FALSE(m2.Equals(m1)); + auto m2 = Matrix::random(2, 2, prg); + REQUIRE_FALSE(m2.equals(m1)); } TEST_CASE("Matrix isSquare", "[math][matrix]") { - auto prg = util::PRG::Create(); - Mat sq = Mat::Random(2, 2, prg); - REQUIRE(sq.IsSquare()); - Mat nsq = Mat::Random(4, 2, prg); - REQUIRE(!nsq.IsSquare()); + auto prg = util::PRG::create(); + Matrix sq = Matrix::random(2, 2, prg); + REQUIRE(sq.isSquare()); + Matrix nsq = Matrix::random(4, 2, prg); + REQUIRE(!nsq.isSquare()); } TEST_CASE("Matrix identity", "[math][matrix]") { - Mat A = Mat::Identity(10); - REQUIRE(A.Rows() == 10); - REQUIRE(A.Cols() == 10); + Matrix A = Matrix::identity(10); + REQUIRE(A.rows() == 10); + REQUIRE(A.cols() == 10); bool good = true; for (std::size_t i = 0; i < 10; ++i) { for (std::size_t j = 0; j < 10; ++j) { @@ -318,10 +318,54 @@ TEST_CASE("Matrix identity", "[math][matrix]") { } } REQUIRE(good); + + Matrix B(2, 1); + REQUIRE_FALSE(B.isIdentity()); +} + +TEST_CASE("Matrix inversion", "[math][matrix]") { + auto prg = util::PRG::create("mat_inv"); + const auto m = Matrix::random(10, 10, prg); + const auto i = m.invert(); + REQUIRE(m.multiply(i).isIdentity()); +} + +TEST_CASE("Matrix inversion bad", "[math][matrix]") { + auto prg = util::PRG::create("mat_inv2"); + auto m = Matrix::random(5, 6, prg); + REQUIRE_THROWS_MATCHES( + m.invert(), + std::invalid_argument, + Catch::Matchers::Message("cannot invert non-square matrix")); +} + +TEST_CASE("Matrix poly eval and interp", "[math][matrix]") { + auto prg = util::PRG::create("mat_interp"); + + const std::size_t n = 10; + const std::size_t t = 3; + + const math::Vector coeff = math::Vector::random(t, prg); + + const math::Matrix vand = math::Matrix::vandermonde(n, t); + + const auto evals = vand.multiply(coeff); + + REQUIRE(evals.size() == n); + + math::Vector points = {evals[0], evals[4], evals[5]}; + + const math::Matrix vandi = + math::Matrix::vandermonde(t, + t, + math::Vector{FF(1), FF(5), FF(6)}); + + const auto coeff_ = vandi.invert().multiply(points); + REQUIRE(coeff_ == coeff); } TEST_CASE("Matrix vandermonde", "[math][matrix]") { - auto m0 = Mat::Vandermonde(3, 3); + auto m0 = Matrix::vandermonde(3, 3); REQUIRE(m0(0, 0) == FF(1)); REQUIRE(m0(0, 1) == FF(1)); REQUIRE(m0(0, 2) == FF(1)); @@ -333,7 +377,7 @@ TEST_CASE("Matrix vandermonde", "[math][matrix]") { REQUIRE(m0(2, 2) == FF(9)); std::vector xs{FF(3), FF(5), FF(8)}; - auto m1 = Mat::Vandermonde(3, 3, xs); + auto m1 = Matrix::vandermonde(3, 3, xs); REQUIRE(m1(0, 0) == FF(1)); REQUIRE(m1(0, 1) == FF(3)); REQUIRE(m1(0, 2) == FF(9)); @@ -345,7 +389,7 @@ TEST_CASE("Matrix vandermonde", "[math][matrix]") { REQUIRE(m1(2, 2) == FF(64)); xs.emplace_back(FF(55)); - REQUIRE_THROWS_MATCHES(Mat::Vandermonde(3, 3, xs), + REQUIRE_THROWS_MATCHES(Matrix::vandermonde(3, 3, xs), std::invalid_argument, Catch::Matchers::Message("|xs| != number of rows")); } @@ -353,10 +397,10 @@ TEST_CASE("Matrix vandermonde", "[math][matrix]") { TEST_CASE("Matrix HIM", "[math][matrix]") { // TODO: Not a very good test. - auto him = Mat::HyperInvertible(4, 5); + auto him = Matrix::hyperInvertible(4, 5); for (std::size_t i = 0; i < 4; ++i) { for (std::size_t j = 0; j < 5; ++j) { - REQUIRE(him(i, j) != FF::Zero()); + REQUIRE(him(i, j) != FF::zero()); } } } diff --git a/test/scl/math/test_mersenne127.cc b/test/scl/math/test_mersenne127.cc index 7f7abcd..d9f548c 100644 --- a/test/scl/math/test_mersenne127.cc +++ b/test/scl/math/test_mersenne127.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,7 @@ * along with this program. If not, see . */ -#include +#include #include #include "scl/math/fp.h" @@ -26,23 +26,23 @@ using Field = math::Fp<127>; using u128 = __uint128_t; TEST_CASE("Mersenne127 defs", "[math][ff]") { - REQUIRE(Field::BitSize() == 127); - REQUIRE(Field::ByteSize() == 16); - REQUIRE(std::string(Field::Name()) == "Mersenne127"); + REQUIRE(Field::bitSize() == 127); + REQUIRE(Field::byteSize() == 16); + REQUIRE(std::string(Field::name()) == "Mersenne127"); } TEST_CASE("Mersenne127 to string", "[math][ff]") { - REQUIRE(Field::Zero().ToString() == "0"); - REQUIRE(Field::One().ToString() == "1"); + REQUIRE(Field::zero().toString() == "0"); + REQUIRE(Field::one().toString() == "1"); Field x(0x7b); - REQUIRE(x.ToString() == "7b"); + REQUIRE(x.toString() == "7b"); - REQUIRE(Field::FromString("80000000000000000000000000000000") == - Field::One()); + REQUIRE(Field::fromString("80000000000000000000000000000000") == + Field::one()); - Field big = Field::FromString("58797a14d0653d22a05c11c60e1aacf4"); - REQUIRE(big.ToString() == "58797a14d0653d22a05c11c60e1aacf4"); + Field big = Field::fromString("58797a14d0653d22a05c11c60e1aacf4"); + REQUIRE(big.toString() == "58797a14d0653d22a05c11c60e1aacf4"); std::stringstream ss; ss << x; @@ -50,19 +50,19 @@ TEST_CASE("Mersenne127 to string", "[math][ff]") { } TEST_CASE("Mersenne127 from string", "[math][ff]") { - auto y = Field::FromString("7b"); + auto y = Field::fromString("7b"); REQUIRE(y == Field(0x7b)); } TEST_CASE("Mersenne127 read/write", "[math][ff]") { - Field big = Field::FromString("58797a14d0653d22a05c11c60e1aacf4"); - unsigned char buffer[Field::ByteSize()]; - big.Write(buffer); - auto y = Field::Read(buffer); + Field big = Field::fromString("58797a14d0653d22a05c11c60e1aacf4"); + unsigned char buffer[Field::byteSize()]; + big.write(buffer); + auto y = Field::read(buffer); REQUIRE(big == y); Field x(0x7b); - x.Write(buffer); - auto z = Field::Read(buffer); + x.write(buffer); + auto z = Field::read(buffer); REQUIRE(z == x); } diff --git a/test/scl/math/test_mersenne61.cc b/test/scl/math/test_mersenne61.cc index d0b35a1..cd84a5e 100644 --- a/test/scl/math/test_mersenne61.cc +++ b/test/scl/math/test_mersenne61.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,8 @@ * along with this program. If not, see . */ -#include +#include +#include #include #include "scl/math/fp.h" @@ -25,21 +26,21 @@ using namespace scl; using Field = math::Fp<61>; TEST_CASE("Mersenne61 defs", "[math][ff]") { - REQUIRE(std::string(Field::Name()) == "Mersenne61"); - REQUIRE(Field::BitSize() == 61); - REQUIRE(Field::ByteSize() == 8); + REQUIRE(std::string(Field::name()) == "Mersenne61"); + REQUIRE(Field::bitSize() == 61); + REQUIRE(Field::byteSize() == 8); } TEST_CASE("Mersenne61 to string", "[math][ff]") { - Field zero = Field::Zero(); - Field one = Field::One(); + Field zero = Field::zero(); + Field one = Field::one(); Field x(0x7b); Field big(0x41621e); - REQUIRE(zero.ToString() == "0"); - REQUIRE(one.ToString() == "1"); - REQUIRE(x.ToString() == "7b"); - REQUIRE(big.ToString() == "41621e"); + REQUIRE(zero.toString() == "0"); + REQUIRE(one.toString() == "1"); + REQUIRE(x.toString() == "7b"); + REQUIRE(big.toString() == "41621e"); std::stringstream ss; ss << x; REQUIRE(ss.str() == "7b"); @@ -49,16 +50,16 @@ TEST_CASE("Mersenne61 from string", "[math][ff]") { Field x(0x7b); Field big(0x41621e); - REQUIRE_THROWS_MATCHES(Field::FromString("012"), + REQUIRE_THROWS_MATCHES(Field::fromString("012"), std::invalid_argument, Catch::Matchers::Message("odd-length hex string")); REQUIRE_THROWS_MATCHES( - Field::FromString("1g"), + Field::fromString("1g"), std::invalid_argument, Catch::Matchers::Message("encountered invalid hex character")); - auto y = Field::FromString("7b"); + auto y = Field::fromString("7b"); REQUIRE(x == y); - auto z = Field::FromString("41621E"); + auto z = Field::fromString("41621E"); REQUIRE(z == big); } @@ -66,11 +67,11 @@ TEST_CASE("Mersenne61 read/write", "[math][ff]") { Field x(0x7b); Field big(0x41621e); - unsigned char buffer[Field::ByteSize()]; - x.Write(buffer); - auto y = Field::Read(buffer); + unsigned char buffer[Field::byteSize()]; + x.write(buffer); + auto y = Field::read(buffer); REQUIRE(x == y); - big.Write(buffer); - auto z = Field::Read(buffer); + big.write(buffer); + auto z = Field::read(buffer); REQUIRE(z == big); } diff --git a/test/scl/math/test_number.cc b/test/scl/math/test_number.cc index 8468d3c..57ec17b 100644 --- a/test/scl/math/test_number.cc +++ b/test/scl/math/test_number.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,8 @@ * along with this program. If not, see . */ -#include +#include +#include #include #include @@ -33,43 +34,43 @@ using Number = math::Number; #define REPEAT for (std::size_t i = 0; i < SCl_NUMBER_TEST_REPETITIONS; ++i) TEST_CASE("Number create", "[math]") { - auto prg = util::PRG::Create(); + auto prg = util::PRG::create(); Number n0(27); - REQUIRE(n0.ToString() == "Number{1b}"); + REQUIRE(n0.toString() == "Number{1b}"); Number n1(-42); - REQUIRE(n1.ToString() == "Number{-2a}"); + REQUIRE(n1.toString() == "Number{-2a}"); Number zero; - REQUIRE(zero.ToString() == "Number{0}"); + REQUIRE(zero.toString() == "Number{0}"); Number zero_alt(0); - REQUIRE(zero_alt.ToString() == "Number{0}"); - Number r0 = Number::Random(127, prg); - REQUIRE(r0.ToString() == "Number{-27a8004ea0c9708441893d2808ca9457}"); + REQUIRE(zero_alt.toString() == "Number{0}"); + Number r0 = Number::random(127, prg); + REQUIRE(r0.toString() == "Number{-27a8004ea0c9708441893d2808ca9457}"); // the above is only 126 bits, but it's close enough - REQUIRE(r0.BitSize() == 126); - Number r1 = Number::Random(65, prg); - REQUIRE(r1.ToString() == "Number{10584d2a1c30fa50d}"); - REQUIRE(r1.BitSize() == 65); + REQUIRE(r0.bitSize() == 126); + Number r1 = Number::random(65, prg); + REQUIRE(r1.toString() == "Number{10584d2a1c30fa50d}"); + REQUIRE(r1.bitSize() == 65); - Number p = Number::RandomPrime(10, prg); - REQUIRE(p.ToString() == "Number{133}"); // 307 + Number p = Number::randomPrime(10, prg); + REQUIRE(p.toString() == "Number{133}"); // 307 std::stringstream ss; ss << r1; - REQUIRE(ss.str() == r1.ToString()); + REQUIRE(ss.str() == r1.toString()); Number r2(std::move(n0)); REQUIRE(r2 == Number(27)); } TEST_CASE("Number from string", "[math]") { - auto x = Number::FromString("7b"); + auto x = Number::fromString("7b"); REQUIRE(x == Number(0x7b)); } TEST_CASE("Number assignment", "[math]") { - auto prg = util::PRG::Create("Number assignment"); - auto x = Number::Random(100, prg); - auto y = Number::Random(100, prg); + auto prg = util::PRG::create("Number assignment"); + auto x = Number::random(100, prg); + auto y = Number::random(100, prg); REQUIRE(x != y); Number t; @@ -101,10 +102,10 @@ TEST_CASE("Number addition", "[math]") { REQUIRE(a + b == Number(55 + 32)); Number zero(0); - auto prg = util::PRG::Create("Number addition"); + auto prg = util::PRG::create("Number addition"); REPEAT { - auto x = Number::Random(100, prg); - auto y = Number::Random(100, prg); + auto x = Number::random(100, prg); + auto y = Number::random(100, prg); auto z = x + y; REQUIRE(z != x); @@ -124,10 +125,10 @@ TEST_CASE("Number subtraction", "[math]") { REQUIRE(a - b == Number(123 - 555)); Number zero; - auto prg = util::PRG::Create("Number subtraction"); + auto prg = util::PRG::create("Number subtraction"); REPEAT { - auto x = Number::Random(100, prg); - auto y = Number::Random(100, prg); + auto x = Number::random(100, prg); + auto y = Number::random(100, prg); auto z = x - y; REQUIRE(z != x); @@ -154,17 +155,17 @@ TEST_CASE("Number multiplication", "[math]") { Number one(1); Number zero; - auto prg = util::PRG::Create("Number multiplication"); + auto prg = util::PRG::create("Number multiplication"); REPEAT { - auto x = Number::Random(100, prg); - auto y = Number::Random(100, prg); + auto x = Number::random(100, prg); + auto y = Number::random(100, prg); auto z = x * y; REQUIRE(z != x); REQUIRE(z != y); REQUIRE(z == y * x); - auto w = Number::Random(100, prg); + auto w = Number::random(100, prg); REQUIRE(w * (x + y) == w * x + w * y); x *= y; @@ -184,10 +185,10 @@ TEST_CASE("Number division", "[math]") { REQUIRE(a / b == Number(123 / 43)); Number one(1); - auto prg = util::PRG::Create("Number division"); + auto prg = util::PRG::create("Number division"); REPEAT { - auto x = Number::Random(100, prg); - auto y = Number::Random(85, prg); + auto x = Number::random(100, prg); + auto y = Number::random(85, prg); auto z = x / y; REQUIRE(z != x); @@ -234,10 +235,10 @@ TEST_CASE("Number xor", "[math]") { Number b(5545); REQUIRE((a ^ b) == Number(2231 ^ 5545)); - auto prg = util::PRG::Create("Number xor"); + auto prg = util::PRG::create("Number xor"); REPEAT { - auto x = Number::Random(100, prg); - auto y = Number::Random(100, prg); + auto x = Number::random(100, prg); + auto y = Number::random(100, prg); auto z = x ^ y; REQUIRE(z != x); @@ -255,10 +256,10 @@ TEST_CASE("Number or", "[math]") { Number b(5545); REQUIRE((a | b) == Number(2231 | 5545)); - auto prg = util::PRG::Create("Number or"); + auto prg = util::PRG::create("Number or"); REPEAT { - auto x = Number::Random(100, prg); - auto y = Number::Random(100, prg); + auto x = Number::random(100, prg); + auto y = Number::random(100, prg); auto z = x | y; REQUIRE(z != x); @@ -276,10 +277,10 @@ TEST_CASE("Number and", "[math]") { Number b(5545); REQUIRE((a & b) == Number(2231 & 5545)); - auto prg = util::PRG::Create("Number and"); + auto prg = util::PRG::create("Number and"); REPEAT { - auto x = Number::Random(100, prg); - auto y = Number::Random(100, prg); + auto x = Number::random(100, prg); + auto y = Number::random(100, prg); auto z = x & y; REQUIRE(z != x); @@ -301,23 +302,23 @@ TEST_CASE("Number complement", "[math]") { TEST_CASE("Number test bit", "[math]") { Number a(49); // out-of-range returns false - REQUIRE_FALSE(a.TestBit(100)); - - REQUIRE(a.TestBit(0)); - REQUIRE_FALSE(a.TestBit(1)); - REQUIRE_FALSE(a.TestBit(2)); - REQUIRE_FALSE(a.TestBit(3)); - REQUIRE(a.TestBit(4)); - REQUIRE(a.TestBit(5)); + REQUIRE_FALSE(a.testBit(100)); + + REQUIRE(a.testBit(0)); + REQUIRE_FALSE(a.testBit(1)); + REQUIRE_FALSE(a.testBit(2)); + REQUIRE_FALSE(a.testBit(3)); + REQUIRE(a.testBit(4)); + REQUIRE(a.testBit(5)); } TEST_CASE("Number mod inverse invalid", "[math]") { Number a(10); - REQUIRE_THROWS_MATCHES(math::ModInverse(a, Number(0)), + REQUIRE_THROWS_MATCHES(math::modInverse(a, Number(0)), std::invalid_argument, Catch::Matchers::Message("modulus cannot be 0")); - REQUIRE_THROWS_MATCHES(math::ModInverse(a, Number(2)), + REQUIRE_THROWS_MATCHES(math::modInverse(a, Number(2)), std::logic_error, Catch::Matchers::Message("number not invertible")); } @@ -325,44 +326,44 @@ TEST_CASE("Number mod inverse invalid", "[math]") { TEST_CASE("Number read/write", "[math]") { Number a(1234); - REQUIRE(a.BitSize() == 11); - REQUIRE(a.ByteSize() == 2); + REQUIRE(a.bitSize() == 11); + REQUIRE(a.byteSize() == 2); auto buf = - std::make_unique(a.ByteSize() + sizeof(std::uint32_t)); + std::make_unique(a.byteSize() + sizeof(std::uint32_t)); - a.Write(buf.get()); - REQUIRE(a == Number::Read(buf.get())); + a.write(buf.get()); + REQUIRE(a == Number::read(buf.get())); - auto prg = util::PRG::Create("rw"); + auto prg = util::PRG::create("rw"); REPEAT { - const auto x = Number::Random(100, prg); + const auto x = Number::random(100, prg); auto bufx = - std::make_unique(x.ByteSize() + sizeof(std::uint32_t)); - x.Write(bufx.get()); - REQUIRE(x == Number::Read(bufx.get())); + std::make_unique(x.byteSize() + sizeof(std::uint32_t)); + x.write(bufx.get()); + REQUIRE(x == Number::read(bufx.get())); } } TEST_CASE("Number RSA example", "[math]") { - auto prg = util::PRG::Create("rsa"); - const auto p = Number::RandomPrime(512, prg); - const auto q = Number::RandomPrime(512, prg); + auto prg = util::PRG::create("rsa"); + const auto p = Number::randomPrime(512, prg); + const auto q = Number::randomPrime(512, prg); REQUIRE(p != q); const auto n = p * q; - const auto lm = math::LCM(p - Number(1), q - Number(1)); + const auto lm = math::lcm(p - Number(1), q - Number(1)); const auto e = Number(0x10001); - REQUIRE(math::GCD(e, lm) == Number(1)); + REQUIRE(math::gcd(e, lm) == Number(1)); - const auto d = math::ModInverse(e, lm); + const auto d = math::modInverse(e, lm); REQUIRE((d * e) % lm == Number(1)); Number msg(1234); - const auto ctxt = math::ModExp(msg, e, n); + const auto ctxt = math::modExp(msg, e, n); REQUIRE(ctxt != msg); - const auto ptxt = math::ModExp(ctxt, d, n); + const auto ptxt = math::modExp(ctxt, d, n); REQUIRE(ptxt == msg); } diff --git a/test/scl/math/test_poly.cc b/test/scl/math/test_poly.cc index e038965..4409790 100644 --- a/test/scl/math/test_poly.cc +++ b/test/scl/math/test_poly.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,9 @@ * along with this program. If not, see . */ -#include +#include +#include +#include #include #include "../math/fields.h" @@ -29,53 +31,53 @@ TEMPLATE_TEST_CASE("Polynomial construct", "[ss][math]", FIELD_DEFS) { using FF = TestType; math::Polynomial p; - REQUIRE(p.Degree() == 0); + REQUIRE(p.degree() == 0); REQUIRE(p[0] == FF()); - REQUIRE(p.IsZero()); + REQUIRE(p.isZero()); math::Polynomial q(FF(123)); - REQUIRE(q.Degree() == 0); - REQUIRE(q.ConstantTerm() == FF(123)); + REQUIRE(q.degree() == 0); + REQUIRE(q.constantTerm() == FF(123)); REQUIRE(q[0] == FF(123)); - math::Vec coeff = {FF(1), FF(2), FF(6)}; - auto x = math::Polynomial::Create(coeff); - REQUIRE(x.Degree() == 2); + math::Vector coeff = {FF(1), FF(2), FF(6)}; + auto x = math::Polynomial::create(coeff); + REQUIRE(x.degree() == 2); REQUIRE(x[0] == FF(1)); REQUIRE(x[1] == FF(2)); REQUIRE(x[2] == FF(6)); - REQUIRE(x.Coefficients() == coeff); + REQUIRE(x.coefficients() == coeff); - math::Vec with_zeros = {FF(1), FF(0), FF(3), FF(0)}; - auto y = math::Polynomial::Create(with_zeros); - REQUIRE(y.Degree() == 2); + math::Vector with_zeros = {FF(1), FF(0), FF(3), FF(0)}; + auto y = math::Polynomial::create(with_zeros); + REQUIRE(y.degree() == 2); REQUIRE(y[0] == FF(1)); REQUIRE(y[1] == FF(0)); REQUIRE(y[2] == FF(3)); - math::Vec empty; - auto z = math::Polynomial::Create(empty); - REQUIRE(z.Degree() == 0); + math::Vector empty; + auto z = math::Polynomial::create(empty); + REQUIRE(z.degree() == 0); REQUIRE(z[0] == FF(0)); } TEMPLATE_TEST_CASE("Polynomial evaluate", "[math][ss]", FIELD_DEFS) { using FF = TestType; - math::Vec coeff = {FF(4), FF(5), FF(1)}; - auto p = math::Polynomial::Create(coeff); - auto x5 = p.Evaluate(FF(5)); + math::Vector coeff = {FF(4), FF(5), FF(1)}; + auto p = math::Polynomial::create(coeff); + auto x5 = p.evaluate(FF(5)); REQUIRE(x5 == FF(54)); } TEMPLATE_TEST_CASE("Polynomial to string", "[math][ss]", FIELD_DEFS) { using FF = TestType; - math::Vec coeff = {FF(4), FF(5), FF(1)}; - auto p = math::Polynomial::Create(coeff); + math::Vector coeff = {FF(4), FF(5), FF(1)}; + auto p = math::Polynomial::create(coeff); - REQUIRE(p.ToString() == "f(x) = 4 + 5x + 1x^2"); - REQUIRE(p.ToString("g", "y") == "g(y) = 4 + 5y + 1y^2"); + REQUIRE(p.toString() == "f(x) = 4 + 5x + 1x^2"); + REQUIRE(p.toString("g", "y") == "g(y) = 4 + 5y + 1y^2"); std::stringstream ss; ss << p; @@ -85,44 +87,44 @@ TEMPLATE_TEST_CASE("Polynomial to string", "[math][ss]", FIELD_DEFS) { TEMPLATE_TEST_CASE("Polynomial addition", "[math][ss]", FIELD_DEFS) { using FF = TestType; - math::Vec c0 = {FF(1), FF(2), FF(3)}; - math::Vec c1 = {FF(5), FF(3), FF(3), FF(1)}; - auto p = math::Polynomial::Create(c0); - auto q = math::Polynomial::Create(c1); - auto e = p.Add(q); - REQUIRE(e.Degree() == q.Degree()); + math::Vector c0 = {FF(1), FF(2), FF(3)}; + math::Vector c1 = {FF(5), FF(3), FF(3), FF(1)}; + auto p = math::Polynomial::create(c0); + auto q = math::Polynomial::create(c1); + auto e = p.add(q); + REQUIRE(e.degree() == q.degree()); REQUIRE(e[0] == FF(6)); REQUIRE(e[1] == FF(5)); REQUIRE(e[2] == FF(6)); REQUIRE(e[3] == FF(1)); - auto d = q.Add(p); + auto d = q.add(p); REQUIRE(d[0] == e[0]); REQUIRE(d[1] == e[1]); REQUIRE(d[2] == e[2]); REQUIRE(d[3] == e[3]); - math::Vec cn = {-FF(1), -FF(2), -FF(3)}; - auto t = math::Polynomial::Create(cn); - auto w = t.Add(p); - REQUIRE(w.Degree() == 0); + math::Vector cn = {-FF(1), -FF(2), -FF(3)}; + auto t = math::Polynomial::create(cn); + auto w = t.add(p); + REQUIRE(w.degree() == 0); } TEMPLATE_TEST_CASE("Polynomial subtraction", "[math][ss]", FIELD_DEFS) { using FF = TestType; - math::Vec c0 = {FF(1), FF(2), FF(3)}; - math::Vec c1 = {FF(5), FF(3), FF(3), FF(1)}; - auto p = math::Polynomial::Create(c0); - auto q = math::Polynomial::Create(c1); - auto e = p.Subtract(q); - REQUIRE(e.Degree() == q.Degree()); + math::Vector c0 = {FF(1), FF(2), FF(3)}; + math::Vector c1 = {FF(5), FF(3), FF(3), FF(1)}; + auto p = math::Polynomial::create(c0); + auto q = math::Polynomial::create(c1); + auto e = p.subtract(q); + REQUIRE(e.degree() == q.degree()); REQUIRE(e[0] == -FF(4)); REQUIRE(e[1] == -FF(1)); REQUIRE(e[2] == FF(0)); REQUIRE(e[3] == -FF(1)); - auto d = q.Subtract(p); + auto d = q.subtract(p); REQUIRE(-d[0] == e[0]); REQUIRE(-d[1] == e[1]); REQUIRE(-d[2] == e[2]); @@ -134,12 +136,12 @@ TEMPLATE_TEST_CASE("Polynomial multiplication", "[math][ss]", FIELD_DEFS) { // (1 + 2x + 3x^2) * (5 + 3x + 3x^2 + x^3) // = 5 + 13x + 24x^2 + 16x^3 + 11x^4 + 3x^5 - math::Vec c0 = {FF(1), FF(2), FF(3)}; - math::Vec c1 = {FF(5), FF(3), FF(3), FF(1)}; - auto p = math::Polynomial::Create(c0); - auto q = math::Polynomial::Create(c1); - auto e = p.Multiply(q); - REQUIRE(e.Degree() == 5); + math::Vector c0 = {FF(1), FF(2), FF(3)}; + math::Vector c1 = {FF(5), FF(3), FF(3), FF(1)}; + auto p = math::Polynomial::create(c0); + auto q = math::Polynomial::create(c1); + auto e = p.multiply(q); + REQUIRE(e.degree() == 5); REQUIRE(e[0] == FF(5)); REQUIRE(e[1] == FF(13)); REQUIRE(e[2] == FF(24)); @@ -151,30 +153,30 @@ TEMPLATE_TEST_CASE("Polynomial multiplication", "[math][ss]", FIELD_DEFS) { TEMPLATE_TEST_CASE("Polynomial division", "[math][ss]", FIELD_DEFS) { using FF = TestType; - math::Vec c0 = {FF(1), FF(2), FF(3)}; - math::Vec c1 = {FF(5), FF(3), FF(3), FF(1)}; - auto p = math::Polynomial::Create(c0); - auto q = math::Polynomial::Create(c1); - auto e = q.Divide(p); - auto x = p.Multiply(e[0]).Add(e[1]); + math::Vector c0 = {FF(1), FF(2), FF(3)}; + math::Vector c1 = {FF(5), FF(3), FF(3), FF(1)}; + auto p = math::Polynomial::create(c0); + auto q = math::Polynomial::create(c1); + auto e = q.divide(p); + auto x = p.multiply(e[0]).add(e[1]); - REQUIRE(x.Degree() == q.Degree()); - for (std::size_t i = 0; i < x.Degree(); ++i) { + REQUIRE(x.degree() == q.degree()); + for (std::size_t i = 0; i < x.degree(); ++i) { REQUIRE(x[i] == q[i]); } math::Polynomial z; - REQUIRE_THROWS_MATCHES(p.Divide(z), + REQUIRE_THROWS_MATCHES(p.divide(z), std::invalid_argument, Catch::Matchers::Message("division by 0")); - auto prg = util::PRG::Create(); - auto c0_ = math::Vec::Random(10, prg); - auto c1_ = math::Vec::Random(9, prg); - auto a = math::Polynomial::Create(c0_); - auto b = math::Polynomial::Create(c1_); - auto qr = a.Divide(b); + auto prg = util::PRG::create(); + auto c0_ = math::Vector::random(10, prg); + auto c1_ = math::Vector::random(9, prg); + auto a = math::Polynomial::create(c0_); + auto b = math::Polynomial::create(c1_); + auto qr = a.divide(b); - auto v = b.Multiply(qr[0]).Add(qr[1]); + auto v = b.multiply(qr[0]).add(qr[1]); REQUIRE(v == a); } diff --git a/test/scl/math/test_secp256k1.cc b/test/scl/math/test_secp256k1.cc index 6e2f2c1..029dec3 100644 --- a/test/scl/math/test_secp256k1.cc +++ b/test/scl/math/test_secp256k1.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,248 +15,267 @@ * along with this program. If not, see . */ -#include +#include +#include #include #include #include "scl/math/curves/secp256k1.h" -#include "scl/math/ec_ops.h" -#include "scl/math/ff.h" -#include "scl/math/fp.h" +#include "scl/math/ec.h" #include "scl/math/number.h" +#include "scl/util/digest.h" +#include "scl/util/hash.h" #include "scl/util/prg.h" using namespace scl; -using Curve = math::EC; +using Curve = math::EC; using Scalar = Curve::ScalarField; using Field = Curve::Field; namespace { -Curve RandomPoint(util::PRG& prg) { - auto r = math::Number::Random(100, prg); - return Curve::Generator() * r; +Curve randomPoint(util::PRG& prg) { + auto r = math::Number::random(100, prg); + return Curve::generator() * r; } } // namespace TEST_CASE("Secp256k1 defs", "[math][ff]") { - REQUIRE(std::string(Field::Name()) == "secp256k1_field"); - REQUIRE(Field::ByteSize() == 32); - REQUIRE(Field::BitSize() == 256); - - REQUIRE(std::string(Scalar::Name()) == "secp256k1_order"); - REQUIRE(Scalar::ByteSize() == 32); - REQUIRE(Scalar::BitSize() == 256); - - REQUIRE(std::string(Curve::Name()) == "secp256k1"); - REQUIRE(Curve::ByteSize() == 33); - REQUIRE(Curve::ByteSize(false) == 65); - REQUIRE(Curve::BitSize() == 264); - REQUIRE(Curve::BitSize(false) == 520); + REQUIRE(std::string(Field::name()) == "secp256k1_field"); + REQUIRE(Field::byteSize() == 32); + REQUIRE(Field::bitSize() == 256); + + REQUIRE(std::string(Scalar::name()) == "secp256k1_order"); + REQUIRE(Scalar::byteSize() == 32); + REQUIRE(Scalar::bitSize() == 256); + + REQUIRE(std::string(Curve::name()) == "secp256k1"); + REQUIRE(Curve::byteSize(true) == 33); + REQUIRE(Curve::byteSize(false) == 65); + REQUIRE(Curve::bitSize(true) == 264); + REQUIRE(Curve::bitSize(false) == 520); } TEST_CASE("Secp256k1 field to string", "[math][ff]") { - REQUIRE(Field(0).ToString() == "0"); + REQUIRE(Field(0).toString() == "0"); - auto prg = util::PRG::Create("Secp256k1 field"); - auto x = Field::Random(prg); - REQUIRE(x.ToString() == + auto prg = util::PRG::create("Secp256k1 field"); + auto x = Field::random(prg); + REQUIRE(x.toString() == "62883be8479ee8f4a3367086d0044440bc7505bc2a2b099e3f71f131eedd42d7"); } TEST_CASE("Secp256k1 field from string", "[math][ff]") { - Field y = Field::FromString("cafe"); + Field y = Field::fromString("cafe"); REQUIRE(y == Field(51966)); - REQUIRE(y.ToString() == "cafe"); + REQUIRE(y.toString() == "cafe"); - Field x = Field::FromString( + Field x = Field::fromString( "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2a"); - REQUIRE(x.ToString() == + REQUIRE(x.toString() == "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2a"); std::stringstream ss; ss << x; - REQUIRE(ss.str() == x.ToString()); + REQUIRE(ss.str() == x.toString()); - Field twoPast = Field::FromString( + Field twoPast = Field::fromString( "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc31"); REQUIRE(twoPast == Field(2)); - REQUIRE(Field::FromString("016") == Field::FromString("16")); + REQUIRE(Field::fromString("016") == Field::fromString("16")); REQUIRE_THROWS_MATCHES( - Field::FromString("ffffffffffffffffffffffffffffffffffffffffffffffffffff" + Field::fromString("ffffffffffffffffffffffffffffffffffffffffffffffffffff" "fffffffffffff"), std::invalid_argument, Catch::Matchers::Message("hex string too large to parse")); } TEST_CASE("Secp256k1 from affine", "[math][ec]") { - auto x = Field::FromString( + auto x = Field::fromString( "e47b4a1c2e13cf0e97c9adf5a645ce388e04317b7830401aabb42e188c9883fa"); - auto y = Field::FromString( + auto y = Field::fromString( "2aafa6e870684327ec92006e6c601a8b6e0fb9ff06ae120cb330a2eee86009ff"); - auto g = Curve::FromAffine(x, y); - REQUIRE(!g.PointAtInfinity()); + auto g = Curve::fromAffine(x, y); + REQUIRE(!g.isPointAtInfinity()); - auto as_affine = g.ToAffine(); + auto as_affine = g.toAffine(); REQUIRE(as_affine[0] == x); REQUIRE(as_affine[1] == y); REQUIRE_THROWS_MATCHES( - Curve::FromAffine(Field(0), Field(0)), + Curve::fromAffine(Field(0), Field(0)), std::invalid_argument, Catch::Matchers::Message("provided (x, y) not on curve")); } TEST_CASE("Secp256k1 point-at-infinity", "[math][ec]") { Curve p; - REQUIRE(p.PointAtInfinity()); + REQUIRE(p.isPointAtInfinity()); } TEST_CASE("Secp256k1 generator", "[math][ec]") { - auto g = Curve::Generator(); + auto g = Curve::generator(); REQUIRE( - g.ToString() == + g.toString() == "EC{79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, " "483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8}"); std::stringstream ss; ss << g; - REQUIRE(ss.str() == g.ToString()); + REQUIRE(ss.str() == g.toString()); - auto ord = math::Order(); + auto ord = math::order(); - REQUIRE(!g.PointAtInfinity()); + REQUIRE(!g.isPointAtInfinity()); auto poi = g * ord; - REQUIRE(poi.PointAtInfinity()); + REQUIRE(poi.isPointAtInfinity()); auto not_poi = g * (ord - math::Number(1)); REQUIRE(poi != not_poi); - REQUIRE(!not_poi.PointAtInfinity()); + REQUIRE(!not_poi.isPointAtInfinity()); - REQUIRE(poi.ToString() == "EC{POINT_AT_INFINITY}"); + REQUIRE(poi.toString() == "EC{POINT_AT_INFINITY}"); } TEST_CASE("Secp256k1 addition", "[math][ec]") { - auto prg = util::PRG::Create("Secp256k1 addition"); + auto prg = util::PRG::create("Secp256k1 addition"); - auto a = RandomPoint(prg); - auto b = RandomPoint(prg); + auto a = randomPoint(prg); + auto b = randomPoint(prg); REQUIRE(a != b); auto c = a + b; REQUIRE(a != c); REQUIRE(b != c); auto d = a + a; - REQUIRE(d == a.Double()); + REQUIRE(d == a.doublePoint()); a += b; REQUIRE(a == c); REQUIRE(a != b); - REQUIRE((c - a).PointAtInfinity()); + REQUIRE((c - a).isPointAtInfinity()); + + auto x = randomPoint(prg); + auto y = randomPoint(prg); + auto z = x + y; + x.normalize(); + REQUIRE(x + y == y + x); + REQUIRE(z == x + y); } TEST_CASE("Secp256k1 negation", "[math][ec]") { - auto prg = util::PRG::Create("Secp256k1 negation"); + auto prg = util::PRG::create("Secp256k1 negation"); - auto a = RandomPoint(prg); + auto a = randomPoint(prg); auto b = -a; - REQUIRE((a + b).PointAtInfinity()); + REQUIRE((a + b).isPointAtInfinity()); } TEST_CASE("Secp256k1 scalar multiplication", "[math][ec]") { - auto prg = util::PRG::Create("Secp256k1 scalar-mul"); + auto prg = util::PRG::create("Secp256k1 scalar-mul"); - auto a = RandomPoint(prg); + auto a = randomPoint(prg); auto p_minus_1 = Scalar(-1); auto c = a * p_minus_1; - REQUIRE(!c.PointAtInfinity()); - REQUIRE((c + a).PointAtInfinity()); + REQUIRE(!c.isPointAtInfinity()); + REQUIRE((c + a).isPointAtInfinity()); - auto x = Scalar::Random(prg); - auto y = Scalar::Random(prg); + auto x = Scalar::random(prg); + auto y = Scalar::random(prg); REQUIRE((x + y) * a == x * a + y * a); - auto G = Curve::Generator(); + auto G = Curve::generator(); - auto v = Scalar::FromString("03"); - auto u = Scalar::FromString("02"); - auto w = Scalar::FromString("06"); + auto v = Scalar::fromString("03"); + auto u = Scalar::fromString("02"); + auto w = Scalar::fromString("06"); auto P = G * w; auto Q = (G * v) * u; REQUIRE(P == Q); - auto n = math::Number::FromString("06"); + auto n = math::Number::fromString("06"); REQUIRE(n * G == w * G); REQUIRE(n * G == G * n); } TEST_CASE("Secp256k1 negation special case", "[math][ec]") { - using CurveT = math::Secp256k1; - CurveT::ValueType point = {Field(1), Field(0), Field(1)}; - REQUIRE(!math::CurveIsPointAtInfinity(point)); - math::CurveNegate(point); - REQUIRE(math::CurveIsPointAtInfinity(point)); + Curve P; + P.negate(); + REQUIRE(P.isPointAtInfinity()); } TEST_CASE("Secp256k1 serialization", "[math][ec]") { - auto prg = util::PRG::Create(); + auto prg = util::PRG::create(); - REQUIRE(Curve::ByteSize(false) == 32 + 32 + 1); - REQUIRE(Curve::ByteSize(true) == 32 + 1); + REQUIRE(Curve::byteSize(false) == 32 + 32 + 1); + REQUIRE(Curve::byteSize(true) == 32 + 1); - auto a = RandomPoint(prg); - auto buffer = std::make_unique(Curve::ByteSize(false)); - a.Write(buffer.get(), false); + auto a = randomPoint(prg); + auto buffer = std::make_unique(Curve::byteSize(false)); + a.write(buffer.get(), false); REQUIRE(buffer[0] == 0x04); - auto c = Curve::Read(buffer.get()); + auto c = Curve::read(buffer.get()); REQUIRE(a == c); - a.Write(buffer.get(), true); - auto d = Curve::Read(buffer.get()); + a.write(buffer.get(), true); + auto d = Curve::read(buffer.get()); REQUIRE(buffer[0] == 0x01); REQUIRE(a == d); Curve poi; - poi.Write(buffer.get(), false); + poi.write(buffer.get(), false); REQUIRE(buffer[0] == 0x06); - auto e = Curve::Read(buffer.get()); - REQUIRE(e.PointAtInfinity()); + auto e = Curve::read(buffer.get()); + REQUIRE(e.isPointAtInfinity()); - poi.Write(buffer.get(), true); + poi.write(buffer.get(), true); REQUIRE(buffer[0] == 0x02); - auto f = Curve::Read(buffer.get()); - REQUIRE(f.PointAtInfinity()); + auto f = Curve::read(buffer.get()); + REQUIRE(f.isPointAtInfinity()); - auto g = Curve::FromAffine( - Field::FromString("e47b4a1c2e13cf0e97c9adf5a645ce388e04317b7830401aabb4" + auto g = Curve::fromAffine( + Field::fromString("e47b4a1c2e13cf0e97c9adf5a645ce388e04317b7830401aabb4" "2e188c9883fa"), // - Field::FromString("2aafa6e870684327ec92006e6c601a8b6e0fb9ff06ae120cb330" + Field::fromString("2aafa6e870684327ec92006e6c601a8b6e0fb9ff06ae120cb330" "a2eee86009ff") // ); - g.Write(buffer.get()); + g.write(buffer.get(), true); REQUIRE(buffer[0] == 0x01); - auto h = Curve::Read(buffer.get()); + auto h = Curve::read(buffer.get()); REQUIRE(h == g); - auto i = Curve::FromAffine( - Field::FromString("b2d352841ef12627042948c3b3d4ed822fc99a4643d446f8ab9b" + auto i = Curve::fromAffine( + Field::fromString("b2d352841ef12627042948c3b3d4ed822fc99a4643d446f8ab9b" "de5aa5f63d36"), // - Field::FromString("f1a6bd63f76bb38cf80a1d88da5167c1b102288dbab9b04c210b" + Field::fromString("f1a6bd63f76bb38cf80a1d88da5167c1b102288dbab9b04c210b" "d9863d83d0e3") // ); - i.Write(buffer.get()); - auto j = Curve::Read(buffer.get()); + i.write(buffer.get(), true); + auto j = Curve::read(buffer.get()); REQUIRE(i == j); } TEST_CASE("Secp256k1 order", "[math]") { - auto ord = math::Order(); + auto ord = math::order(); REQUIRE( ord == - math::Number::FromString( + math::Number::fromString( "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F")); } + +TEST_CASE("Secp256k1 hashing", "[math]") { + util::Hash<256> hgen; + + const auto digest_gen = hgen.update(Curve::generator()).finalize(); + REQUIRE(util::digestToString(digest_gen) == + "3f0db2047deb5c2c92e336aecdd4ba1d745fcfd0e77a5f8592dda348a3ff5707"); + + util::Hash<256> hpoi; + const auto digest_poi = hpoi.update(Curve{}).finalize(); + REQUIRE(util::digestToString(digest_poi) == + "4fdfe4c2be45edb360dab48435c14be84e087c162cbb421d8c91a3c99e31a82f"); +} diff --git a/test/scl/math/test_vec.cc b/test/scl/math/test_vector.cc similarity index 53% rename from test/scl/math/test_vec.cc rename to test/scl/math/test_vector.cc index b515df2..65465ee 100644 --- a/test/scl/math/test_vec.cc +++ b/test/scl/math/test_vector.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -16,7 +16,8 @@ */ #include -#include +#include +#include #include #include #include @@ -24,17 +25,16 @@ #include "scl/math/curves/secp256k1.h" #include "scl/math/ec.h" #include "scl/math/fp.h" -#include "scl/math/mat.h" -#include "scl/math/vec.h" -#include "scl/util/traits.h" +#include "scl/math/matrix.h" +#include "scl/math/vector.h" using namespace scl; using FF = math::Fp<61>; -using Vec = math::Vec; +using Vector = math::Vector; -auto v0 = Vec{FF(1), FF(2), FF(3)}; -auto v1 = Vec{FF(2), FF(123), FF(5)}; +auto v0 = Vector{FF(1), FF(2), FF(3)}; +auto v1 = Vector{FF(2), FF(123), FF(5)}; TEST_CASE("Vector access", "[math][la]") { REQUIRE(v0[0] == FF(1)); @@ -49,84 +49,84 @@ TEST_CASE("Vector mutate", "[math][la]") { } TEST_CASE("Vector size", "[math][la]") { - REQUIRE(v0.Size() == 3); - REQUIRE(v1.Size() == 3); - auto v3 = Vec(100); - REQUIRE(v3.Size() == 100); + REQUIRE(v0.size() == 3); + REQUIRE(v1.size() == 3); + auto v3 = Vector(100); + REQUIRE(v3.size() == 100); } TEST_CASE("Vector addition", "[math][la]") { - auto v2 = v0.Add(v1); + auto v2 = v0.add(v1); REQUIRE(v2[0] == FF(3)); REQUIRE(v2[1] == FF(125)); REQUIRE(v2[2] == FF(8)); - v2.AddInPlace(v0); - REQUIRE(v2.Equals(v0.Add(v1).Add(v0))); + v2.addInPlace(v0); + REQUIRE(v2.equals(v0.add(v1).add(v0))); } TEST_CASE("Vector subtract", "[math][la]") { - auto v2 = v0.Subtract(v1); - REQUIRE(v2.Equals(Vec{FF(1) - FF(2), FF(2) - FF(123), FF(3) - FF(5)})); - v2.SubtractInPlace(v1); - REQUIRE(v2.Equals(v0.Subtract(v1).Subtract(v1))); + auto v2 = v0.subtract(v1); + REQUIRE(v2.equals(Vector{FF(1) - FF(2), FF(2) - FF(123), FF(3) - FF(5)})); + v2.subtractInPlace(v1); + REQUIRE(v2.equals(v0.subtract(v1).subtract(v1))); } TEST_CASE("Vector multiply entry-wise", "[math][la]") { - auto v2 = v0.MultiplyEntryWise(v1); - REQUIRE(v2.Equals(Vec{FF(2), FF(246), FF(15)})); - v2.MultiplyEntryWiseInPlace(v1); - REQUIRE(v2.Equals(v0.MultiplyEntryWise(v1).MultiplyEntryWise(v1))); + auto v2 = v0.multiplyEntryWise(v1); + REQUIRE(v2.equals(Vector{FF(2), FF(246), FF(15)})); + v2.multiplyEntryWiseInPlace(v1); + REQUIRE(v2.equals(v0.multiplyEntryWise(v1).multiplyEntryWise(v1))); } TEST_CASE("Vector dot", "[math][la]") { - auto dp = v0.Dot(v1); + auto dp = v0.dot(v1); REQUIRE(dp == FF(263)); } TEST_CASE("Vector scalar multiplication", "[math][la]") { - auto v2 = v1.ScalarMultiply(FF(2)); - REQUIRE(v2.Equals(Vec{FF(4), FF(246), FF(10)})); - v2.ScalarMultiplyInPlace(FF(2)); - REQUIRE(v2.Equals(Vec{FF(8), FF(492), FF(20)})); + auto v2 = v1.scalarMultiply(FF(2)); + REQUIRE(v2.equals(Vector{FF(4), FF(246), FF(10)})); + v2.scalarMultiplyInPlace(FF(2)); + REQUIRE(v2.equals(Vector{FF(8), FF(492), FF(20)})); } TEST_CASE("Vector to matrix", "[math][la]") { - auto m0 = v0.ToRowMatrix(); - REQUIRE(m0.Rows() == 1); - REQUIRE(m0.Cols() == 3); - auto m1 = v1.ToColumnMatrix(); - REQUIRE(m1.Rows() == 3); - REQUIRE(m1.Cols() == 1); + auto m0 = v0.toRowMatrix(); + REQUIRE(m0.rows() == 1); + REQUIRE(m0.cols() == 3); + auto m1 = v1.toColumnMatrix(); + REQUIRE(m1.rows() == 3); + REQUIRE(m1.cols() == 1); } TEST_CASE("Vector to string", "[math][la]") { - REQUIRE(v0.ToString() == "[1, 2, 3]"); - REQUIRE(v1.ToString() == "[2, 7b, 5]"); + REQUIRE(v0.toString() == "[1, 2, 3]"); + REQUIRE(v1.toString() == "[2, 7b, 5]"); std::stringstream ss; ss << v0; REQUIRE(ss.str() == "[1, 2, 3]"); - Vec v; - REQUIRE(v.ToString() == "[ EMPTY VECTOR ]"); + Vector v; + REQUIRE(v.toString() == "[ EMPTY VECTOR ]"); } TEST_CASE("Vector incompatible", "[math][la]") { - auto v2 = Vec{FF(2), FF(3)}; - REQUIRE(!v2.Equals(v1)); - REQUIRE_THROWS_MATCHES(v2.Add(v1), + auto v2 = Vector{FF(2), FF(3)}; + REQUIRE(!v2.equals(v1)); + REQUIRE_THROWS_MATCHES(v2.add(v1), std::invalid_argument, Catch::Matchers::Message("Vec sizes mismatch")); } TEST_CASE("Vector to std::vector", "[math][la]") { - auto stl0 = v0.ToStlVector(); + auto stl0 = v0.toStlVector(); REQUIRE(stl0 == std::vector{FF(1), FF(2), FF(3)}); } TEST_CASE("Vector random", "[math][la]") { - auto prg = util::PRG::Create("Vector random"); - auto r = Vec::Random(3, prg); + auto prg = util::PRG::create("Vector random"); + auto r = Vector::random(3, prg); auto zero = FF(); - REQUIRE(r.Size() == 3); + REQUIRE(r.size() == 3); REQUIRE(r[0] != zero); REQUIRE(r[0] != v0[0]); REQUIRE(r[1] != zero); @@ -136,20 +136,20 @@ TEST_CASE("Vector random", "[math][la]") { } TEST_CASE("Vector range", "[math][la]") { - auto v = Vec::Range(1, 4); + auto v = Vector::range(1, 4); REQUIRE(v[0] == FF(1)); REQUIRE(v[1] == FF(2)); REQUIRE(v[2] == FF(3)); - REQUIRE(Vec::Range(1, 1) == Vec{}); + REQUIRE(Vector::range(1, 1).empty()); - REQUIRE_THROWS_MATCHES(Vec::Range(2, 1), + REQUIRE_THROWS_MATCHES(Vector::range(2, 1), std::invalid_argument, Catch::Matchers::Message("invalid range")); } TEST_CASE("Vector iterator", "[math][la]") { - auto v2 = Vec{FF(1), FF(2), FF(3)}; + auto v2 = Vector{FF(1), FF(2), FF(3)}; std::size_t i = 0; for (auto& v : v0) { REQUIRE(v == v2[i++]); @@ -158,40 +158,46 @@ TEST_CASE("Vector iterator", "[math][la]") { auto count = std::count(v2.begin(), v2.end(), FF(2)); REQUIRE(count == 1); - auto v3 = Vec(v2.begin(), v2.end()); - REQUIRE(v3.Equals(v2)); + auto v3 = Vector(v2.begin(), v2.end()); + REQUIRE(v3.equals(v2)); } TEST_CASE("Vector sub vector", "[math][la]") { - auto v = Vec{FF(1), FF(2), FF(3), FF(4)}; - REQUIRE(v.SubVector(1, 2) == Vec{FF(2)}); - REQUIRE(v.SubVector(1, 3) == Vec{FF(2), FF(3)}); - REQUIRE(v.SubVector(1, 1) == Vec{}); - REQUIRE(v.SubVector(2) == Vec{FF(1), FF(2)}); + auto v = Vector{FF(1), FF(2), FF(3), FF(4)}; + REQUIRE(v.subVector(1, 2) == Vector{FF(2)}); + REQUIRE(v.subVector(1, 3) == Vector{FF(2), FF(3)}); + REQUIRE(v.subVector(1, 1).empty()); + REQUIRE(v.subVector(2) == Vector{FF(1), FF(2)}); - REQUIRE_THROWS_MATCHES(v.SubVector(2, 1), + REQUIRE_THROWS_MATCHES(v.subVector(2, 1), std::logic_error, Catch::Matchers::Message("invalid range")); } TEST_CASE("Vector scalar EC", "[math]") { - using Curve = math::EC; + using Curve = math::EC; - auto v = math::Vec{Curve::Generator(), - Curve::Generator(), - Curve::Generator()}; + auto v = math::Vector{Curve::generator(), + Curve::generator(), + Curve::generator()}; const auto s = Curve::ScalarField(123); - auto w = v.ScalarMultiply(s); + auto w = v.scalarMultiply(s); - REQUIRE(w[0] == Curve::Generator() * s); - REQUIRE(w[1] == Curve::Generator() * s); - REQUIRE(w[2] == Curve::Generator() * s); + REQUIRE(w[0] == Curve::generator() * s); + REQUIRE(w[1] == Curve::generator() * s); + REQUIRE(w[2] == Curve::generator() * s); const auto z = math::Number(123); - auto u = w.ScalarMultiply(math::Number(123)); + auto u = w.scalarMultiply(math::Number(123)); REQUIRE(u[0] == w[0] * z); REQUIRE(u[1] == w[1] * z); REQUIRE(u[2] == w[2] * z); } + +TEST_CASE("Vector byte size", "[math]") { + auto v = Vector{FF(1), FF(2)}; + + REQUIRE(v.byteSize() == 2 * FF::byteSize()); +} diff --git a/test/scl/math/test_z2k.cc b/test/scl/math/test_z2k.cc index fbf985a..ee1ecab 100644 --- a/test/scl/math/test_z2k.cc +++ b/test/scl/math/test_z2k.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,9 @@ * along with this program. If not, see . */ -#include +#include +#include +#include #include #include "scl/math/z2k.h" @@ -32,83 +34,83 @@ using Big = math::Z2k<123>; TEMPLATE_TEST_CASE("Z2k name", "[math][ring]", RING_DEFS) { using Ring = TestType; - REQUIRE(Ring::Name() == std::string("Z2k")); + REQUIRE(Ring::name() == std::string("Z2k")); } TEST_CASE("Z2k big size", "[math][ring]") { - REQUIRE(Big::BitSize() == 123); - REQUIRE(Big::ByteSize() == 16); + REQUIRE(Big::bitSize() == 123); + REQUIRE(Big::byteSize() == 16); } TEST_CASE("Z2k small size", "[math][ring]") { - REQUIRE(Small::BitSize() == 62); - REQUIRE(Small::ByteSize() == 8); + REQUIRE(Small::bitSize() == 62); + REQUIRE(Small::byteSize() == 8); } TEMPLATE_TEST_CASE("Z2k addition", "[math][ring]", RING_DEFS) { using Ring = TestType; - auto prg = util::PRG::Create("Z2k addition"); + auto prg = util::PRG::create("Z2k addition"); - auto a = Ring::Random(prg); - auto b = Ring::Random(prg); + auto a = Ring::random(prg); + auto b = Ring::random(prg); auto c = a + b; REQUIRE(c != a); REQUIRE(c != b); REQUIRE(c == b + a); a += b; REQUIRE(c == a); - REQUIRE(a + Ring::Zero() == a); + REQUIRE(a + Ring::zero() == a); } TEMPLATE_TEST_CASE("Z2k negation", "[math][ring]", RING_DEFS) { using Ring = TestType; - auto prg = util::PRG::Create("Z2k negation"); + auto prg = util::PRG::create("Z2k negation"); - auto a = Ring::Random(prg); - auto a_negated = a.Negated(); + auto a = Ring::random(prg); + auto a_negated = a.negated(); REQUIRE(a != a_negated); - REQUIRE(a + a_negated == Ring::Zero()); + REQUIRE(a + a_negated == Ring::zero()); REQUIRE(a_negated == -a); - a.Negate(); + a.negate(); REQUIRE(a == a_negated); } TEMPLATE_TEST_CASE("Z2k subtraction", "[math][ring]", RING_DEFS) { using Ring = TestType; - auto prg = util::PRG::Create("Z2k subtraction"); + auto prg = util::PRG::create("Z2k subtraction"); - auto a = Ring::Random(prg); - auto b = Ring::Random(prg); + auto a = Ring::random(prg); + auto b = Ring::random(prg); auto c = a - b; REQUIRE(c == -(b - a)); a -= b; REQUIRE(c == a); - REQUIRE(c - c == Ring::Zero()); - REQUIRE(c - Ring::Zero() == c); + REQUIRE(c - c == Ring::zero()); + REQUIRE(c - Ring::zero() == c); } TEMPLATE_TEST_CASE("Z2k multiplication", "[math][ring]", RING_DEFS) { using Ring = TestType; - auto prg = util::PRG::Create("Z2k multiplication"); + auto prg = util::PRG::create("Z2k multiplication"); - auto a = Ring::Random(prg); - auto b = Ring::Random(prg); + auto a = Ring::random(prg); + auto b = Ring::random(prg); REQUIRE(a * b == b * a); - auto c = Ring::Random(prg); + auto c = Ring::random(prg); REQUIRE(c * (a + b) == c * a + c * b); auto d = a * b; a *= b; REQUIRE(a == d); - REQUIRE(a * Ring::One() == a); + REQUIRE(a * Ring::one() == a); } namespace { -template -T RandomInvertible(util::PRG& prg) { - T z; - while (z.Lsb() == 0) { - z = T::Random(prg); +template +RING randomInvertible(util::PRG& prg) { + RING z; + while (z.lsb() == 0) { + z = RING::random(prg); } return z; } @@ -116,26 +118,26 @@ T RandomInvertible(util::PRG& prg) { TEMPLATE_TEST_CASE("Z2k inverses", "[math][ring]", RING_DEFS) { using Ring = TestType; - auto prg = util::PRG::Create("Z2k inverses"); + auto prg = util::PRG::create("Z2k inverses"); - auto a = RandomInvertible(prg); - auto a_inverse = a.Inverse(); + auto a = randomInvertible(prg); + auto a_inverse = a.inverse(); REQUIRE(a * a_inverse == TestType(1)); REQUIRE_THROWS_MATCHES( - Ring::Zero().Inverse(), + Ring::zero().inverse(), std::logic_error, Catch::Matchers::Message("value not invertible modulo 2^K")); } TEMPLATE_TEST_CASE("Z2k division", "[math][ring]", RING_DEFS) { using Ring = TestType; - auto prg = util::PRG::Create("Z2k division"); + auto prg = util::PRG::create("Z2k division"); - auto a = RandomInvertible(prg); - auto b = RandomInvertible(prg); + auto a = randomInvertible(prg); + auto b = randomInvertible(prg); REQUIRE(a / a == TestType(1)); - REQUIRE(a / b == (b / a).Inverse()); + REQUIRE(a / b == (b / a).inverse()); auto c = a / b; a /= b; REQUIRE(c == a); @@ -143,12 +145,12 @@ TEMPLATE_TEST_CASE("Z2k division", "[math][ring]", RING_DEFS) { TEMPLATE_TEST_CASE("Z2k serialization", "[math][ring]", RING_DEFS) { using Ring = TestType; - auto prg = util::PRG::Create("Z2k serialization"); + auto prg = util::PRG::create("Z2k serialization"); - auto a = Ring::Random(prg); - unsigned char buffer[TestType::ByteSize()]; - a.Write(buffer); - auto b = Ring::Read(buffer); + auto a = Ring::random(prg); + unsigned char buffer[TestType::byteSize()]; + a.write(buffer); + auto b = Ring::read(buffer); REQUIRE(a == b); } @@ -156,7 +158,7 @@ TEMPLATE_TEST_CASE("Z2k to string", "[math][ring]", RING_DEFS) { using Ring = TestType; Ring x(0x7b); - REQUIRE(x.ToString() == "7b"); + REQUIRE(x.toString() == "7b"); std::stringstream ss; ss << x; REQUIRE(ss.str() == "7b"); @@ -170,14 +172,14 @@ TEST_CASE("Z2k truncation", "[math]") { REQUIRE(a == b); - unsigned char buffer_a[Z2k::ByteSize() + 2] = {0}; - unsigned char buffer_b[Z2k::ByteSize() + 2] = {0}; + unsigned char buffer_a[Z2k::byteSize() + 2] = {0}; + unsigned char buffer_b[Z2k::byteSize() + 2] = {0}; buffer_a[4] = 0xff; buffer_a[5] = 0xff; buffer_b[4] = 0xff; buffer_b[5] = 0xff; - a.Write(buffer_a); - b.Write(buffer_b); + a.write(buffer_a); + b.write(buffer_b); REQUIRE(buffer_a[0] == buffer_b[0]); REQUIRE(buffer_a[1] == buffer_b[1]); @@ -189,8 +191,8 @@ TEST_CASE("Z2k truncation", "[math]") { REQUIRE(buffer_b[4] == 0xff); REQUIRE(buffer_b[5] == 0xff); - REQUIRE(a.ToString() == "abcdef11"); - REQUIRE(b.ToString() == "abcdef11"); + REQUIRE(a.toString() == "abcdef11"); + REQUIRE(b.toString() == "abcdef11"); std::stringstream ss_a; std::stringstream ss_b; @@ -198,7 +200,7 @@ TEST_CASE("Z2k truncation", "[math]") { ss_b << b; REQUIRE(ss_a.str() == ss_b.str()); - Z2k c = Z2k::FromString("34abcdef11"); + Z2k c = Z2k::fromString("34abcdef11"); REQUIRE(c == a); REQUIRE(c == b); } diff --git a/test/scl/net/test_channel.cc b/test/scl/net/test_channel.cc deleted file mode 100644 index e7c7ea0..0000000 --- a/test/scl/net/test_channel.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include - -#include "scl/math/fp.h" -#include "scl/net/channel.h" -#include "scl/net/mem_channel.h" - -using namespace scl; - -struct DummyChannel final : public net::Channel { - void Close() override {} - - void Send(const unsigned char* src, std::size_t n) override { - (void)src; - (void)n; - } - - std::size_t Recv(unsigned char* dst, std::size_t n) override { - (void)dst; - recv++; - return n; - } - - bool HasData() override { - has_data++; - return there_is_data; - } - - std::size_t recv = 0; - std::size_t has_data = 0; - bool there_is_data = false; -}; - -TEST_CASE("Channel non-block recv", "[net]") { - std::unique_ptr chl = std::make_unique(); - DummyChannel* dc = static_cast(chl.get()); - - auto p1 = chl->Recv(false); - REQUIRE(!p1.has_value()); - REQUIRE(dc->has_data == 1); - REQUIRE(dc->recv == 0); - - dc->there_is_data = true; - auto p2 = chl->Recv(false); - REQUIRE(p2.has_value()); - REQUIRE(dc->has_data == 2); - REQUIRE(dc->recv == 2); -} diff --git a/test/scl/net/test_config.cc b/test/scl/net/test_config.cc index 41aea52..cdc5ec9 100644 --- a/test/scl/net/test_config.cc +++ b/test/scl/net/test_config.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,8 @@ * along with this program. If not, see . */ -#include +#include +#include #include #include @@ -25,11 +26,11 @@ using namespace scl; TEST_CASE("Config read from file", "[net]") { const auto* filename = SCL_TEST_DATA_DIR "3_parties.txt"; - auto cfg = net::NetworkConfig::Load(0, filename); + auto cfg = net::NetworkConfig::load(0, filename); - REQUIRE(cfg.NetworkSize() == 3); - REQUIRE(cfg.Id() == 0); - auto parties = cfg.Parties(); + REQUIRE(cfg.networkSize() == 3); + REQUIRE(cfg.id() == 0); + auto parties = cfg.parties(); REQUIRE(parties[0].hostname == "1.2.3.4"); REQUIRE(parties[0].port == 8000); REQUIRE(parties[1].hostname == "2.3.4.5"); @@ -38,44 +39,38 @@ TEST_CASE("Config read from file", "[net]") { REQUIRE(parties[2].port == 3000); std::string invalid_empty = SCL_TEST_DATA_DIR "invalid_no_entries.txt"; - REQUIRE_THROWS_MATCHES(net::NetworkConfig::Load(0, invalid_empty), + REQUIRE_THROWS_MATCHES(net::NetworkConfig::load(0, invalid_empty), std::invalid_argument, Catch::Matchers::Message("n cannot be zero")); std::string valid = SCL_TEST_DATA_DIR "3_parties.txt"; - REQUIRE_THROWS_MATCHES(net::NetworkConfig::Load(4, valid), + REQUIRE_THROWS_MATCHES(net::NetworkConfig::load(4, valid), std::invalid_argument, Catch::Matchers::Message("invalid id")); std::string invalid_entry = SCL_TEST_DATA_DIR "invalid_entry.txt"; REQUIRE_THROWS_MATCHES( - net::NetworkConfig::Load(0, invalid_entry), + net::NetworkConfig::load(0, invalid_entry), std::invalid_argument, Catch::Matchers::Message("invalid entry in config file")); std::string invalid_non_existing_file; - REQUIRE_THROWS_MATCHES(net::NetworkConfig::Load(0, invalid_non_existing_file), + REQUIRE_THROWS_MATCHES(net::NetworkConfig::load(0, invalid_non_existing_file), std::invalid_argument, Catch::Matchers::Message("could not open file")); } TEST_CASE("Config configure all parties local", "[net]") { - auto cfg = net::NetworkConfig::Localhost(0, 5); - REQUIRE(cfg.Id() == 0); - REQUIRE(cfg.NetworkSize() == 5); + auto cfg = net::NetworkConfig::localhost(0, 5); + REQUIRE(cfg.id() == 0); + REQUIRE(cfg.networkSize() == 5); std::size_t i = 0; - for (const auto& ci : cfg.Parties()) { + for (const auto& ci : cfg.parties()) { REQUIRE(ci.port == DEFAULT_PORT_OFFSET + i++); REQUIRE(ci.hostname == "127.0.0.1"); } } -TEST_CASE("Config to string", "[net]") { - net::NetworkConfig cfg(1, {{0, "1.2.3.4", 123}, {1, "4.4.4.4", 444}}); - std::string expected = "[id=1, {0, 1.2.3.4, 123}, {1, 4.4.4.4, 444}]"; - REQUIRE(cfg.ToString() == expected); -} - TEST_CASE("Config validation", "[net]") { REQUIRE_THROWS_MATCHES( net::NetworkConfig(2, {{0, "1.2.3.4", 123}, {1, "4.4.4.4", 444}}), diff --git a/test/scl/net/test_loopback.cc b/test/scl/net/test_loopback.cc new file mode 100644 index 0000000..4ded928 --- /dev/null +++ b/test/scl/net/test_loopback.cc @@ -0,0 +1,72 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include +#include +#include +#include + +#include "scl/coro/batch.h" +#include "scl/coro/coroutine.h" +#include "scl/math/fp.h" +#include "scl/math/vector.h" +#include "scl/net/loopback.h" +#include "scl/util/prg.h" +#include "util.h" + +using namespace scl; + +TEST_CASE("Loopback to self close", "[net]") { + auto channel = net::LoopbackChannel::create(); + + net::Packet p; + p << 1 << 2 << 3; + + auto rt = coro::DefaultRuntime::create(); + + rt->run(channel->send(p)); + auto received = rt->run(channel->recv()); + + REQUIRE(received.read() == 1); + REQUIRE(received.read() == 2); + REQUIRE(received.read() == 3); +} + +TEST_CASE("Loopback send/recv", "[net]") { + auto channels = net::LoopbackChannel::createPaired(); + auto chl0 = channels[0]; + auto chl1 = channels[1]; + + net::Packet p; + p << 1 << 2 << 3; + + auto rt = coro::DefaultRuntime::create(); + + rt->run(chl0->send(p)); + auto received = rt->run(chl1->recv()); + + REQUIRE(received.read() == 1); + REQUIRE(received.read() == 2); + REQUIRE(received.read() == 3); + + rt->run(chl1->send(std::move(p))); + auto received1 = rt->run(chl0->recv()); + REQUIRE(received1.read() == 1); + REQUIRE(received1.read() == 2); + REQUIRE(received1.read() == 3); +} diff --git a/test/scl/net/test_mem_channel.cc b/test/scl/net/test_mem_channel.cc deleted file mode 100644 index aef7490..0000000 --- a/test/scl/net/test_mem_channel.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include -#include -#include - -#include "scl/math/fp.h" -#include "scl/math/vec.h" -#include "scl/net/mem_channel.h" -#include "scl/util/prg.h" -#include "util.h" - -using namespace scl; - -TEST_CASE("MemoryBackedChannel close", "[net]") { - auto channel = net::MemoryBackedChannel::CreateLoopback(); - channel->Close(); -} - -TEST_CASE("MemoryBackedChannel send/recv", "[net]") { - auto channels = net::MemoryBackedChannel::CreatePaired(); - auto chl0 = channels[0]; - auto chl1 = channels[1]; - - auto prg = util::PRG::Create(); - unsigned char data_in[200] = {0}; - prg.Next(data_in, 200); - - unsigned char data_out[200] = {0}; - REQUIRE(!chl1->HasData()); - chl0->Send(data_in, 200); - REQUIRE(!chl0->HasData()); - REQUIRE(chl1->HasData()); - chl1->Recv(data_out, 200); - REQUIRE(test::BufferEquals(data_in, data_out, 200)); -} - -TEST_CASE("MemoryBackedChannel send chunked", "[net]") { - auto channels = net::MemoryBackedChannel::CreatePaired(); - auto chl0 = channels[0]; - auto chl1 = channels[1]; - - auto prg = util::PRG::Create(); - unsigned char data_in[200] = {0}; - unsigned char data_out[200] = {0}; - - prg.Next(data_in, 200); - - chl0->Send(data_in, 50); - chl0->Send(data_in + 50, 50); - chl0->Send(data_in + 100, 100); - chl1->Recv(data_out, 200); - - REQUIRE(test::BufferEquals(data_in, data_out, 200)); -} - -TEST_CASE("MemoryBackedChannel recv chunked", "[net]") { - auto channels = net::MemoryBackedChannel::CreatePaired(); - auto chl0 = channels[0]; - auto chl1 = channels[1]; - - auto prg = util::PRG::Create(); - unsigned char data_in[200] = {0}; - unsigned char data_out[200] = {0}; - - prg.Next(data_in, 200); - - chl0->Send(data_in, 100); - chl0->Send(data_in + 100, 100); - chl1->Recv(data_out, 100); - chl1->Recv(data_out + 100, 100); - - REQUIRE(test::BufferEquals(data_in, data_out, 200)); -} - -TEST_CASE("MemoryBackedChannel send to self", "[net]") { - auto c = net::MemoryBackedChannel::CreateLoopback(); - unsigned char data_in[200] = {0}; - - c->Send(data_in, 20); - c->Send(data_in + 20, 100); - c->Send(data_in + 120, 80); - - unsigned char data_out[200] = {0}; - c->Recv(data_out, 10); - c->Recv(data_out + 10, 100); - c->Recv(data_out + 110, 90); - - REQUIRE(test::BufferEquals(data_in, data_out, 200)); -} diff --git a/test/scl/net/test_network.cc b/test/scl/net/test_network.cc index b6b4fb2..dea33d4 100644 --- a/test/scl/net/test_network.cc +++ b/test/scl/net/test_network.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,192 +15,67 @@ * along with this program. If not, see . */ -#include +#include #include #include +#include "scl/coro/batch.h" +#include "scl/coro/coroutine.h" #include "scl/net/config.h" #include "scl/net/network.h" #include "scl/net/tcp_channel.h" using namespace scl; -namespace { - -net::Packet GetPacketWithData() { - net::Packet p; - p << (int)123; - p << (float)4.5; - return p; -} - -void CheckPacket(std::optional& p) { - if (p.has_value()) { - REQUIRE(p.value().Read() == 123); - REQUIRE(p.value().Read() == 4.5); - } else { - FAIL("packet did not have data"); - } +TEST_CASE("Network one party", "[net]") { + auto rt = coro::DefaultRuntime::create(); + auto network = + rt->run(net::Network::create(net::NetworkConfig::localhost(0, 1))); + REQUIRE(network.size() == 1); } -} // namespace - -TEST_CASE("Network fake", "[net]") { - auto fake = net::FakeNetwork::Create(0, 3); - - auto network = fake.my_network; - auto remotes = fake.incoming; - - REQUIRE(network.Size() == 3); - REQUIRE(remotes[0] == nullptr); +namespace { - auto p = GetPacketWithData(); +coro::Task> connect3() { + std::vector> networks; + auto conf0 = net::NetworkConfig::localhost(0, 3); + networks.emplace_back(net::Network::create(conf0)); - remotes[1]->Send(p); + auto conf1 = net::NetworkConfig::localhost(1, 3); + networks.emplace_back(net::Network::create(conf1)); - auto rp1 = network.Party(1)->Recv(); - CheckPacket(rp1); + auto conf2 = net::NetworkConfig::localhost(2, 3); + networks.emplace_back(net::Network::create(conf2)); - p.ResetReadPtr(); - network.Party(0)->Send(p); - auto rp0 = network.Party(0)->Recv(); - CheckPacket(rp0); + co_return co_await coro::batch(std::move(networks)); } -TEST_CASE("Network fully connected", "[net]") { - auto networks = net::CreateMemoryBackedNetwork(3); - REQUIRE(networks.size() == 3); - - auto network0 = networks[0]; - auto network1 = networks[1]; - - auto p = GetPacketWithData(); - - // p0 -> p1 - network0.Party(1)->Send(p); - - // p1 <- p0 - auto p10 = network1.Party(0)->Recv(); - CheckPacket(p10); - - auto network2 = networks[2]; - // p2 -> p0 - network2.Party(0)->Send(p); - - // p0 <- p2 - auto p02 = network0.Party(2)->Recv(); - CheckPacket(p02); +coro::Task send(net::Channel* channel, int v) { + net::Packet p; + p << v; + co_await channel->send(p); } -TEST_CASE("Network TCP", "[net]") { - net::Network network0; - net::Network network1; - net::Network network2; - - std::thread t0([&]() { - network0 = net::Network::Create>( - net::NetworkConfig::Localhost(0, 3)); - }); - std::thread t1([&]() { - network1 = net::Network::Create>( - net::NetworkConfig::Localhost(1, 3)); - }); - std::thread t2([&]() { - network2 = net::Network::Create>( - net::NetworkConfig::Localhost(2, 3)); - }); - - t0.join(); - t1.join(); - t2.join(); - - for (std::size_t i = 0; i < 3; ++i) { - // Alive doesn't exist on InMemoryChannel - if (i != 0) { - REQUIRE(((net::TcpChannel<>*)network0.Party(i))->Alive()); - } - if (i != 1) { - REQUIRE(((net::TcpChannel<>*)network1.Party(i))->Alive()); - } - if (i != 2) { - REQUIRE(((net::TcpChannel<>*)network2.Party(i))->Alive()); - } - } - - auto p = GetPacketWithData(); - - network0.Party(2)->Send(p); - - auto p20 = network2.Party(0)->Recv(); - CheckPacket(p20); +coro::Task recv(net::Channel* channel) { + net::Packet p = co_await channel->recv(); + co_return p.read(); } -struct ChannelMock final : net::Channel { - void Close() override { - close_called++; - } - - void Send(const unsigned char* src, std::size_t n) override { - (void)src; - (void)n; - send_called++; - } - - std::size_t Recv(unsigned char* dst, std::size_t n) override { - (void)dst; - (void)n; - return 0; - } - - bool HasData() override { - return false; - } - - std::size_t close_called = 0; - std::size_t send_called = 0; -}; - -TEST_CASE("Network party getters") { - const auto chl0 = std::make_shared(); - const auto chl1 = std::make_shared(); - const auto chl2 = std::make_shared(); - - net::Network nw({chl0, chl1, chl2}, 1); - - auto p = GetPacketWithData(); - REQUIRE(chl2->send_called == 0); - nw.Next()->Send(p); - REQUIRE(chl2->send_called == 2); - - REQUIRE(chl0->send_called == 0); - nw.Previous()->Send(p); - REQUIRE(chl0->send_called == 2); - - REQUIRE_THROWS_MATCHES(nw.Other(), - std::logic_error, - Catch::Matchers::Message( - "other party ambiguous for more than 2 parties")); - - net::Network two_parties({chl0, chl1}, 1); +} // namespace - two_parties.Other()->Send(p); - REQUIRE(chl0->send_called == 4); -} +TEST_CASE("Network TCP", "[net]") { + auto rt = coro::DefaultRuntime::create(); -TEST_CASE("Network close") { - const auto chl0 = std::make_shared(); - const auto chl1 = std::make_shared(); - const auto chl2 = std::make_shared(); + auto networks = rt->run(connect3()); - net::Network nw({chl0, chl1, chl2}, 1); + REQUIRE(networks.size() == 3); - REQUIRE(chl0->close_called == 0); - REQUIRE(chl1->close_called == 0); - REQUIRE(chl2->close_called == 0); + rt->run(send(networks[0].party(1), 123)); + rt->run(send(networks[2].party(0), 456)); - nw.Close(); + auto v = rt->run(recv(networks[1].party(0))); + REQUIRE(v == 123); - REQUIRE(chl0->close_called == 1); - REQUIRE(chl1->close_called == 1); - REQUIRE(chl2->close_called == 1); + auto w = rt->run(recv(networks[0].party(2))); + REQUIRE(w == 456); } diff --git a/test/scl/net/test_packet.cc b/test/scl/net/test_packet.cc index e9f1304..2f79186 100644 --- a/test/scl/net/test_packet.cc +++ b/test/scl/net/test_packet.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,27 +15,28 @@ * along with this program. If not, see . */ -#include +#include #include -#include "scl/math/curves/secp256k1.h" +#include "scl/math/fields/secp256k1_field.h" #include "scl/math/fp.h" -#include "scl/math/mat.h" +#include "scl/math/matrix.h" #include "scl/math/number.h" #include "scl/net/packet.h" +#include "scl/serialization/serializer.h" using namespace scl; using SmallObj = math::Fp<61>; -using LargeObj = math::FF; +using LargeObj = math::FF; TEST_CASE("Packet read/write different types", "[net]") { net::Packet p; p << LargeObj(1234) << SmallObj(33) << LargeObj(5); - REQUIRE(p.Read() == LargeObj(1234)); - REQUIRE(p.Read() == SmallObj(33)); - REQUIRE(p.Read() == LargeObj(5)); + REQUIRE(p.read() == LargeObj(1234)); + REQUIRE(p.read() == SmallObj(33)); + REQUIRE(p.read() == LargeObj(5)); } TEST_CASE("Packet read/write many", "[net]") { @@ -45,11 +46,11 @@ TEST_CASE("Packet read/write many", "[net]") { p << SmallObj((int)i); } - REQUIRE(p.Size() == SmallObj::ByteSize() * 10000); + REQUIRE(p.size() == SmallObj::byteSize() * 10000); bool all_equal = true; for (std::size_t i = 0; i < 10000; ++i) { - all_equal &= p.Read() == SmallObj((int)i); + all_equal &= p.read() == SmallObj((int)i); } REQUIRE(all_equal); @@ -58,21 +59,25 @@ TEST_CASE("Packet read/write many", "[net]") { TEST_CASE("Packet read/write matrix", "[net]") { net::Packet p; - auto prg = util::PRG::Create("packet mat"); - const auto m = math::Mat::Random(10, 3, prg); + auto prg = util::PRG::create("packet mat"); + const auto m = math::Matrix::random(10, 3, prg); p << m; - REQUIRE(p.Read>().Equals(m)); + auto mm = p.read>(); + + REQUIRE(mm.rows() == m.rows()); + REQUIRE(mm.cols() == m.cols()); + REQUIRE(mm.equals(m)); } TEST_CASE("Packet read/write vec", "[net]") { net::Packet p; - auto prg = util::PRG::Create("packet vec"); - const auto v = math::Vec::Random(10, prg); + auto prg = util::PRG::create("packet vec"); + const auto v = math::Vector::random(10, prg); p << v; - REQUIRE(p.Read>() == v); + REQUIRE(p.read>() == v); } TEST_CASE("Packet read/write pointers", "[net]") { @@ -80,14 +85,76 @@ TEST_CASE("Packet read/write pointers", "[net]") { p << 1 << 2 << 3 << 4; - REQUIRE(p.Read() == 1); - REQUIRE(p.Read() == 2); - p.ResetReadPtr(); - REQUIRE(p.Read() == 1); - REQUIRE(p.Read() == 2); + REQUIRE(p.read() == 1); + REQUIRE(p.read() == 2); + p.resetReadPtr(); + REQUIRE(p.read() == 1); + REQUIRE(p.read() == 2); - p.ResetWritePtr(); + p.resetWritePtr(); p << 5 << 6; - REQUIRE(p.Read() == 5); - REQUIRE(p.Read() == 6); + REQUIRE(p.read() == 5); + REQUIRE(p.read() == 6); +} + +TEST_CASE("Packet Write", "[net]") { + net::Packet p; + + const auto w = p.write((int)123); + REQUIRE(w == seri::Serializer::sizeOf(0)); +} + +TEST_CASE("Packet concat", "[net]") { + net::Packet p0; + net::Packet p1; + + p0 << 1 << 2 << LargeObj(44); + p1 << 3 << SmallObj(55) << 4; + + const auto p0_sz = p0.size(); + const auto p1_sz = p1.size(); + + p0 << p1; + + REQUIRE(p0.read() == 1); + REQUIRE(p0.read() == 2); + REQUIRE(p0.read() == LargeObj(44)); + REQUIRE(p0.read() == 3); + REQUIRE(p0.read() == SmallObj(55)); + REQUIRE(p0.read() == 4); + REQUIRE(p0_sz + p1_sz == p0.size()); +} + +TEST_CASE("Packet remaining", "[net]") { + net::Packet p; + + p << 1 << 2 << 3; + + REQUIRE(p.remaining() == p.size()); + p.read(); + + REQUIRE(p.remaining() == p.size() - sizeof(int)); + p.read(); + p.read(); + REQUIRE(p.remaining() == 0); +} + +TEST_CASE("Packet eq", "[net]") { + net::Packet p0; + net::Packet p1; + + REQUIRE(p0 == p1); + + p0 << 2; + REQUIRE_FALSE(p0 == p1); + + p1 << 2; + REQUIRE(p0 == p1); + + p1 << 3; + REQUIRE_FALSE(p0 == p1); + + p1.setWritePtr(sizeof(int)); + + REQUIRE(p1 == p0); } diff --git a/test/scl/net/test_shared_deque.cc b/test/scl/net/test_shared_deque.cc deleted file mode 100644 index 39f7fdd..0000000 --- a/test/scl/net/test_shared_deque.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include - -#include "scl/net/shared_deque.h" - -using namespace scl; - -TEST_CASE("SharedDeque", "[misc]") { - net::SharedDeque dq; - - dq.PushBack(4); - dq.PushBack(5); - dq.PushBack(2); - - REQUIRE(dq.Peek() == 4); - REQUIRE(dq.Peek() == 4); - - auto four = dq.Pop(); - REQUIRE(four == 4); - - REQUIRE(dq.Peek() == 5); - dq.PopFront(); - auto two = dq.Pop(); - REQUIRE(two == 2); - - REQUIRE(dq.Size() == 0); -} - -TEST_CASE("SharedDeque pop", "[misc]") { - using namespace std::chrono_literals; - - SECTION("Pop") { - net::SharedDeque dq; - int v = 0; - - std::thread t([&]() { v = dq.Pop(); }); - - std::this_thread::sleep_for(20ms); - - REQUIRE(dq.Size() == 0); - dq.PushBack(42); - - t.join(); - REQUIRE(dq.Size() == 0); - REQUIRE(v == 42); - } - - SECTION("PopFront") { - net::SharedDeque dq; - - std::thread t([&]() { dq.PopFront(); }); - - std::this_thread::sleep_for(20ms); - - REQUIRE(dq.Size() == 0); - dq.PushBack(42); - - t.join(); - REQUIRE(dq.Size() == 0); - } -} diff --git a/test/scl/net/test_tcp.cc b/test/scl/net/test_tcp.cc deleted file mode 100644 index 5005b1e..0000000 --- a/test/scl/net/test_tcp.cc +++ /dev/null @@ -1,220 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include -#include - -#include "scl/net/sys_iface.h" -#include "scl/net/tcp_utils.h" -#include "util.h" - -using namespace scl; - -TEST_CASE("SysIFace GetError", "[net]") { - // This is needed to make the test stable when computing coverage. Not sure - // why. - errno = 0; - REQUIRE(net::SysIFace::GetError() == 0); -} - -#define DEFAULT_SETSOCKOPT \ - static auto SetSockOpt(int s, int l, int o, const void* ov, socklen_t ol) { \ - return net::SysIFace::SetSockOpt(s, l, o, ov, ol); \ - } - -#define DEFAULT_HOST_TO_NET \ - static auto HostToNet(short h) { \ - return net::SysIFace::HostToNet(h); \ - } - -#define DEFAULT_BIND \ - static auto Bind(int s, const struct sockaddr* a, socklen_t al) { \ - return net::SysIFace::Bind(s, a, al); \ - } - -#define DEFAULT_LISTEN \ - static auto Listen(int s, int b) { \ - return net::SysIFace::Listen(s, b); \ - } - -#define DEFAULT_ADDR_TO_BIN \ - static auto AddrToBin(int a, const char* s, void* d) { \ - return net::SysIFace::AddrToBin(a, s, d); \ - } - -#define DEFAULT_CONNECT \ - static auto Connect(int s, const struct sockaddr* a, socklen_t al) { \ - return net::SysIFace::Connect(s, a, al); \ - } - -struct SysIFace_SocketFails { - DEFAULT_BIND; - DEFAULT_HOST_TO_NET; - DEFAULT_LISTEN; - DEFAULT_SETSOCKOPT; - DEFAULT_ADDR_TO_BIN; - DEFAULT_CONNECT; - - static auto GetError() { - return EACCES; - } - - static auto Socket(int domain, int type, int protocol) { - (void)domain; - (void)type; - (void)protocol; - return -1; - } -}; - -TEST_CASE("CreateServerSocket fails on socket", "[net]") { - REQUIRE_THROWS_MATCHES( - net::CreateServerSocket(1, 1), - std::system_error, - Catch::Matchers::Message( - "could not acquire server socket: Permission denied")); -} - -#define DEFAULT_SOCKET \ - static auto Socket(int d, int t, int p) { \ - return net::SysIFace::Socket(d, t, p); \ - } - -struct SysIFace_SetSockOptFails { - DEFAULT_BIND; - DEFAULT_HOST_TO_NET; - DEFAULT_LISTEN; - DEFAULT_SOCKET; - - static auto GetError() { - return EBADF; - } - - static auto SetSockOpt(int sockfd, - int level, - int optname, - const void* optval, - socklen_t optlen) { - (void)sockfd; - (void)level; - (void)optname; - (void)optval; - (void)optlen; - return -1; - } -}; - -TEST_CASE("CreateServerSocket fails on setsockopt", "[net]") { - REQUIRE_THROWS_MATCHES( - net::CreateServerSocket(1, 1), - std::system_error, - Catch::Matchers::Message( - "could not set socket options: Bad file descriptor")); -} - -struct SysIFace_BindFails { - DEFAULT_HOST_TO_NET; - DEFAULT_LISTEN; - DEFAULT_SETSOCKOPT; - DEFAULT_SOCKET; - - static auto GetError() { - return EACCES; - } - - static auto Bind(int sockfd, const struct sockaddr* addr, socklen_t addrlen) { - (void)sockfd; - (void)addr; - (void)addrlen; - return -1; - } -}; - -TEST_CASE("CreateServerSocket fails on bind", "[net]") { - REQUIRE_THROWS_MATCHES( - net::CreateServerSocket(1, 1), - std::system_error, - Catch::Matchers::Message("could not bind socket: Permission denied")); -} - -struct SysIFace_ListenFails { - DEFAULT_HOST_TO_NET; - DEFAULT_SOCKET; - DEFAULT_SETSOCKOPT; - DEFAULT_BIND; - - static auto GetError() { - return EADDRINUSE; - } - - static auto Listen(int sockfd, int backlog) { - (void)sockfd; - (void)backlog; - return -1; - } -}; - -TEST_CASE("CreateServerSocket fails on listen", "[net]") { - const auto port = test::GetPort(); - REQUIRE_THROWS_MATCHES( - net::CreateServerSocket(port, 1), - std::system_error, - Catch::Matchers::Message( - "could not listen on socket: Address already in use")); -} - -struct SysIFace_AcceptFails { - static auto NetToAddr(struct in_addr inp) { - return net::SysIFace::NetToAddr(inp); - } - - static auto GetError() { - return EAGAIN; - } - - static auto Accept(int sockfd, - struct sockaddr* addr, - const socklen_t* addrlen) { - (void)sockfd; - (void)addr; - (void)addrlen; - return -1; - } -}; - -TEST_CASE("AcceptConnection fails on accept", "[net]") { - REQUIRE_THROWS_MATCHES( - net::AcceptConnection(0), - std::system_error, - Catch::Matchers::Message( - "could not accept connection: Resource temporarily unavailable")); -} - -TEST_CASE("ConnectAsClient fails on socket", "[net]") { - REQUIRE_THROWS_MATCHES( - net::ConnectAsClient("127.0.0.1", 1111), - std::system_error, - Catch::Matchers::Message("could not acquire socket: Permission denied")); -} - -TEST_CASE("ConnectAsClient invalid address") { - REQUIRE_THROWS_MATCHES( - net::ConnectAsClient("not a valid hostname", 1111), - std::runtime_error, - Catch::Matchers::Message("invalid hostname")); -} diff --git a/test/scl/net/test_tcp_channel.cc b/test/scl/net/test_tcp_channel.cc deleted file mode 100644 index cece5e7..0000000 --- a/test/scl/net/test_tcp_channel.cc +++ /dev/null @@ -1,244 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include -#include - -#include "scl/net/sys_iface.h" -#include "scl/net/tcp_channel.h" -#include "scl/net/tcp_utils.h" -#include "scl/util/prg.h" -#include "util.h" - -using namespace scl; - -TEST_CASE("TcpChannel connect and then close", "[net]") { - auto port = test::GetPort(); - - std::shared_ptr> client; - std::shared_ptr> server; - - std::thread clt([&]() { - int socket = net::ConnectAsClient("0.0.0.0", port); - client = std::make_shared>(socket); - }); - - std::thread srv([&]() { - int ssock = net::CreateServerSocket(port, 1); - auto ac = net::AcceptConnection(ssock); - server = std::make_shared>(ac.socket); - net::SysIFace::Close(ssock); - }); - - clt.join(); - srv.join(); - - REQUIRE(client->Alive()); - REQUIRE(server->Alive()); - - client->Close(); - server->Close(); - - REQUIRE(!server->Alive()); - REQUIRE(!client->Alive()); -} - -TEST_CASE("TcpChannel send/recv", "[net]") { - auto port = scl::test::GetPort(); - - std::shared_ptr> client; - std::shared_ptr> server; - - std::thread clt([&]() { - int socket = net::ConnectAsClient("0.0.0.0", port); - client = std::make_shared>(socket); - }); - - std::thread srv([&]() { - int ssock = net::CreateServerSocket(port, 1); - auto ac = net::AcceptConnection(ssock); - server = std::make_shared>(ac.socket); - net::SysIFace::Close(ssock); - }); - - clt.join(); - srv.join(); - - auto prg = util::PRG::Create(); - unsigned char send[200] = {0}; - unsigned char recv[200] = {0}; - prg.Next(send, 200); - - REQUIRE(!server->HasData()); - - client->Send(send, 100); - client->Send(send + 100, 100); - - REQUIRE(server->HasData()); - server->Recv(recv, 20); - server->Recv(recv + 20, 180); - - REQUIRE(test::BufferEquals(send, recv, 200)); -} - -TEST_CASE("TcpChannel recv from closed socket", "[net]") { - auto port = test::GetPort(); - - std::shared_ptr> client; - std::shared_ptr> server; - - std::thread clt([&]() { - int socket = net::ConnectAsClient("0.0.0.0", port); - client = std::make_shared>(socket); - }); - - std::thread srv([&]() { - int ssock = net::CreateServerSocket(port, 1); - auto ac = net::AcceptConnection(ssock); - server = std::make_shared>(ac.socket); - net::SysIFace::Close(ssock); - }); - - clt.join(); - srv.join(); - - client->Close(); - unsigned char buf[3] = {0}; - auto r = server->Recv(buf, 3); - REQUIRE(r == 0); -} - -#define DEFAULT_READ \ - static auto Read(int fd, void* buf, size_t count) { \ - return net::SysIFace::Read(fd, buf, count); \ - } - -#define DEFAULT_CLOSE \ - static auto Close(int fd) { \ - return net::SysIFace::Close(fd); \ - } - -#define DEFAULT_POLL \ - static auto Poll(struct pollfd* fds, nfds_t nfds, int timeout) { \ - return net::SysIFace::Poll(fds, nfds, timeout); \ - } - -#define DEFAULT_WRITE \ - static auto Write(int fd, const void* buf, size_t count) { \ - return net::SysIFace::Write(fd, buf, count); \ - } - -struct SysIFace_WriteFails { - DEFAULT_READ; - DEFAULT_CLOSE; - DEFAULT_POLL; - - static auto GetError() { - return EAGAIN; - } - - static int Write(int fd, const void* buf, size_t count) { - (void)fd; - (void)buf; - (void)count; - return -1; - } -}; - -TEST_CASE("TcpChannel Send fails", "[net]") { - net::TcpChannel c(1); - REQUIRE_THROWS_MATCHES(c.Send(nullptr, 1), - std::system_error, - Catch::Matchers::Message( - "write failed: Resource temporarily unavailable")); -} - -struct SysIFace_ReadFails { - DEFAULT_CLOSE; - DEFAULT_POLL; - DEFAULT_WRITE; - - static auto GetError() { - return EAGAIN; - } - - static int Read(int fd, void* buf, size_t count) { - (void)fd; - (void)buf; - (void)count; - return -1; - } -}; - -TEST_CASE("TcpChannel Recv fails", "[net]") { - net::TcpChannel c(1); - REQUIRE_THROWS_MATCHES(c.Recv(nullptr, 1), - std::system_error, - Catch::Matchers::Message( - "read failed: Resource temporarily unavailable")); -} - -struct SysIFace_CloseFails { - DEFAULT_READ; - DEFAULT_WRITE; - DEFAULT_POLL; - - static auto GetError() { - return EIO; - } - - static int Close(int fd) { - (void)fd; - return -1; - } -}; - -TEST_CASE("TcpChannel Close fails", "[net]") { - net::TcpChannel c(1); - REQUIRE(c.Alive()); - REQUIRE_THROWS_MATCHES( - c.Close(), - std::system_error, - Catch::Matchers::Message("close failed: Input/output error")); - REQUIRE_FALSE(c.Alive()); - c.Close(); -} - -struct SysIFace_PollFails { - DEFAULT_READ; - DEFAULT_WRITE; - DEFAULT_CLOSE; - - static auto GetError() { - return EFAULT; - } - - static int Poll(struct pollfd* fds, nfds_t nfds, int timeout) { - (void)fds; - (void)nfds; - (void)timeout; - return -1; - } -}; - -TEST_CASE("TcpChannel HasData fails", "[net]") { - net::TcpChannel c(1); - REQUIRE_THROWS_MATCHES(c.HasData(), - std::system_error, - Catch::Matchers::Message("poll failed: Bad address")); -} diff --git a/test/scl/net/test_threaded_sender.cc b/test/scl/net/test_threaded_sender.cc deleted file mode 100644 index e5d0cf6..0000000 --- a/test/scl/net/test_threaded_sender.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include -#include -#include - -#include "scl/net/sys_iface.h" -#include "scl/net/tcp_utils.h" -#include "scl/net/threaded_sender.h" -#include "scl/util/prg.h" -#include "util.h" - -using namespace scl; - -TEST_CASE("ThreadedSender send/recv", "[network]") { - auto port = test::GetPort(); - - std::shared_ptr client; - std::shared_ptr server; - - std::thread clt([&]() { - int socket = net::ConnectAsClient("0.0.0.0", port); - client = std::make_shared(socket); - }); - - std::thread srv([&]() { - int ssock = net::CreateServerSocket(port, 1); - auto ac = net::AcceptConnection(ssock); - server = std::make_shared(ac.socket); - net::SysIFace::Close(ssock); - }); - - clt.join(); - srv.join(); - - auto prg = util::PRG::Create(); - unsigned char send[200] = {0}; - unsigned char recv[200] = {0}; - prg.Next(send, 200); - - REQUIRE(!server->HasData()); - - client->Send(send, 100); - client->Send(send + 100, 100); - - // because the sender returns immediately, there might not be data - // available, so we will try a couple of times before failing. - { - using namespace std::chrono_literals; - auto c = 0; - while (c < 10 && !server->HasData()) { - std::this_thread::sleep_for(100ms); - c++; - } - } - REQUIRE(server->HasData()); - - server->Recv(recv, 20); - server->Recv(recv + 20, 180); - - client->Close(); - server->Close(); - - REQUIRE(test::BufferEquals(send, recv, 200)); -} diff --git a/test/scl/net/util.cc b/test/scl/net/util.cc index 23577a3..71f3285 100644 --- a/test/scl/net/util.cc +++ b/test/scl/net/util.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -17,15 +17,15 @@ #include "util.h" +using namespace scl; + int test_port = SCL_DEFAULT_TEST_PORT; -int scl::test::GetPort() { +int test::getPort() { return test_port++; } -bool scl::test::BufferEquals(const unsigned char* a, - const unsigned char* b, - int n) { +bool test::bufferEquals(const unsigned char* a, const unsigned char* b, int n) { while (n-- > 0 && *a++ == *b++) { ; } diff --git a/test/scl/net/util.h b/test/scl/net/util.h index 829efba..754c53e 100644 --- a/test/scl/net/util.h +++ b/test/scl/net/util.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -31,7 +31,7 @@ namespace scl::test { * @brief Get a fresh port for use in tests that require ports. * @note Not thread safe. */ -int GetPort(); +int getPort(); /** * @brief Test if two buffers are equal. @@ -40,7 +40,7 @@ int GetPort(); * @param n the number of bytes to check * @param true if \p a and \p b coincide on the first \p n bytes. */ -bool BufferEquals(const unsigned char* a, const unsigned char* b, int n); +bool bufferEquals(const unsigned char* a, const unsigned char* b, int n); } // namespace scl::test diff --git a/test/scl/protocol/beaver.h b/test/scl/protocol/beaver.h index 0c9971b..bcd2245 100644 --- a/test/scl/protocol/beaver.h +++ b/test/scl/protocol/beaver.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -21,98 +21,52 @@ #include #include "./triple.h" +#include "scl/coro/task.h" #include "scl/protocol/base.h" #include "scl/protocol/env.h" +#include "scl/protocol/result.h" namespace scl::test { -template -struct BeaverMul { - class Init; - class Finalize; - - static std::unique_ptr Create(FF x, - FF y, - Triple triple) { - return std::make_unique(x, y, triple); - } -}; - -template -class BeaverMul::Init final : public proto::Protocol { +template +class BeaverMul final : public proto::Protocol { public: - Init(FF x, FF y, Triple triple) : m_x(x), m_y(y), m_triple(triple){}; + BeaverMul(SHARE x, SHARE y, Triple triple) + : m_x(x), m_y(y), m_triple(triple) {} - std::unique_ptr Run(proto::Env& env) override { - net::Packet p; + coro::Task run(proto::Env& env) const override { + net::Packet packet; - for (std::size_t i = 0; i < 200; i += 2) { - auto e = m_x + m_triple.a; - auto d = m_y + m_triple.b; - p << e << d; - } + packet << m_x - m_triple.a; // [e] = [x] - [a] + packet << m_y - m_triple.b; // [d] = [y] - [b] - env.network.Party(0)->Send(p); - env.network.Party(1)->Send(p); + co_await env.network.party(0)->send(packet); + co_await env.network.party(1)->send(packet); - return std::make_unique(m_triple); - } + net::Packet packet0 = co_await env.network.party(0)->recv(); + net::Packet packet1 = co_await env.network.party(1)->recv(); - std::string Name() const override { - return "init"; - }; + const auto e0 = packet0.read(); + const auto d0 = packet0.read(); + const auto e1 = packet1.read(); + const auto d1 = packet1.read(); - private: - FF m_x; - FF m_y; - Triple m_triple; -}; + const auto e = e0 + e1; + const auto d = d0 + d1; -template -class BeaverMul::Finalize final : public proto::Protocol { - public: - Finalize(Triple triple) : m_triple(triple){}; - - std::unique_ptr Run(proto::Env& env) override { - auto p0 = env.network.Party(0)->Recv().value(); - auto p1 = env.network.Party(1)->Recv().value(); - - math::Vec output(100); - - std::size_t output_idx = 0; - for (std::size_t i = 0; i < 200; i += 2) { - const auto e0 = p0.Read(); - const auto d0 = p0.Read(); - const auto e1 = p1.Read(); - const auto d1 = p1.Read(); - auto e = e0 + e1; - auto d = d0 + d1; - - // Constant addition - if (env.network.MyId() == 0) { - output[output_idx++] = - e * d - e * m_triple.b - d * m_triple.a + m_triple.c; - } else { - output[output_idx++] = -e * m_triple.b - d * m_triple.a + m_triple.c; - } + // [z] = ed + e[b] + d[a] + [c]. Only party 0 adds constants. + auto z = e * m_triple.b + d * m_triple.a + m_triple.c; + if (env.network.myId() == 0) { + z += e * d; } - m_output = output; - - return nullptr; - }; - - std::string Name() const override { - return "finalize"; - }; - - std::any Output() const override { - return m_output; - }; + co_return proto::ProtocolResult::done(z); + } private: - Triple m_triple; - std::any m_output; + SHARE m_x; + SHARE m_y; + Triple m_triple; }; } // namespace scl::test diff --git a/test/scl/protocol/test_protocol.cc b/test/scl/protocol/test_protocol.cc index f951faf..beeba5a 100644 --- a/test/scl/protocol/test_protocol.cc +++ b/test/scl/protocol/test_protocol.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,138 +15,64 @@ * along with this program. If not, see . */ -#include -#include -#include +#include +#include #include "./beaver.h" #include "./triple.h" +#include "scl/coro/runtime.h" #include "scl/math/fp.h" +#include "scl/net/loopback.h" #include "scl/net/network.h" #include "scl/protocol/base.h" #include "scl/protocol/env.h" +#include "scl/protocol/eval.h" #include "scl/ss/additive.h" using namespace scl; using FF = math::Fp<61>; -auto prg = util::PRG::Create(); -auto xs = ss::AdditiveShare(FF(42), 2, prg); -auto ys = ss::AdditiveShare(FF(11), 2, prg); -auto ts = test::RandomTriple(prg); +auto prg = util::PRG::create(); +auto x = FF(42); +auto y = FF(11); +auto xs = ss::additiveShare(x, 2, prg); +auto ys = ss::additiveShare(y, 2, prg); +auto ts = test::randomTriple2(prg); namespace { -auto CreateEnv(net::Network& network) { - return proto::Env{network, - std::make_unique(), - std::make_unique()}; -} - -} // namespace - -TEST_CASE("Dynamic protocol beaver step-by-step", "[protocol]") { - auto networks = net::CreateMemoryBackedNetwork(2); - - std::unique_ptr p0 = - std::make_unique::Init>(xs[0], ys[0], ts[0]); - std::unique_ptr p1 = - std::make_unique::Init>(xs[1], ys[1], ts[1]); - - auto env0 = CreateEnv(networks[0]); - auto env1 = CreateEnv(networks[1]); - - p0 = p0->Run(env0); - p1 = p1->Run(env1); - - REQUIRE(p0->Run(env0) == nullptr); - REQUIRE(p1->Run(env1) == nullptr); - - auto z0 = std::any_cast>(p0->Output()); - auto z1 = std::any_cast>(p1->Output()); - - REQUIRE(z0.Size() == 100); - REQUIRE(z1.Size() == 100); - - REQUIRE(z0[0] + z1[0] == FF(42) * FF(11)); -} - -TEST_CASE("Dynamic protocol eval beaver", "[protocol]") { - auto networks = net::CreateMemoryBackedNetwork(2); - - auto p0 = test::BeaverMul::Create(xs[0], ys[0], ts[0]); - auto p1 = test::BeaverMul::Create(xs[1], ys[1], ts[1]); - - math::Vec z0; - math::Vec z1; - - std::thread t0([&]() { - proto::Evaluate(std::move(p0), networks[0], [&](const std::any& v) { - z0 = std::any_cast>(v); - }); - }); - std::thread t1([&]() { - proto::Evaluate(std::move(p1), networks[1], [&](const std::any& v) { - z1 = std::any_cast>(v); - }); - }); +std::array createEnvs() { + auto p0p0 = net::LoopbackChannel::create(); + auto p1p1 = net::LoopbackChannel::create(); + auto p0p1 = net::LoopbackChannel::createPaired(); - t0.join(); - t1.join(); - - REQUIRE(z0.Size() == 100); - REQUIRE(z1.Size() == 100); - - for (std::size_t i = 0; i < 100; ++i) { - REQUIRE(z0[i] + z1[i] == FF(42) * FF(11)); - } -} - -TEST_CASE("Dynamic protocol eval null protocol", "[protocol]") { - auto networks = net::CreateMemoryBackedNetwork(1); - proto::Evaluate(nullptr, networks[0]); -} - -TEST_CASE("Protocol env real-time clock", "[protocol]") { - proto::RealTimeClock clock; - - using namespace std::chrono_literals; - std::this_thread::sleep_for(100ms); - - auto d = clock.Read(); - - REQUIRE(d <= 110ms); - REQUIRE(d >= 100ms); + return {proto::createDefaultEnv(net::Network({p0p0, p0p1[0]}, 0)), + proto::createDefaultEnv(net::Network({p1p1, p0p1[1]}, 1))}; } -TEST_CASE("Protocol env real-time clock checkpoint", "[protocol]") { - // https://truong.io/posts/capturing_stdout_for_c++_unit_testing.html - - proto::RealTimeClock clock; +coro::Task runBeaverMulTwoParties() { + auto envs = createEnvs(); - std::stringstream buf; - std::streambuf* coutbuf = std::cout.rdbuf(buf.rdbuf()); + auto beaver0 = std::make_unique>(xs[0], ys[0], ts[0]); + auto beaver1 = std::make_unique>(xs[1], ys[1], ts[1]); - clock.Checkpoint("asd"); + std::vector> protocol_evaluations; + protocol_evaluations.emplace_back( + proto::evaluate(std::move(beaver0), envs[0])); + protocol_evaluations.emplace_back( + proto::evaluate(std::move(beaver1), envs[1])); - auto output = buf.str(); + std::vector shares = + co_await coro::batch(std::move(protocol_evaluations)); - std::cout.rdbuf(coutbuf); - - REQUIRE_THAT(output, Catch::Matchers::StartsWith("asd @")); + co_return shares[0] + shares[1]; } -TEST_CASE("Protocol env Stl thread context", "[protocol]") { - proto::StlThreadContext ctx; - - using namespace std::chrono_literals; - - auto t0 = util::Time::Now(); - - ctx.Sleep(100); - - auto t1 = util::Time::Now(); +} // namespace - REQUIRE(t1 - t0 >= 100ms); +TEST_CASE("Beaver multiplication protocol", "[proto]") { + auto rt = coro::DefaultRuntime::create(); + auto z = rt->run(runBeaverMulTwoParties()); + REQUIRE(z == x * y); } diff --git a/test/scl/protocol/triple.h b/test/scl/protocol/triple.h index e56b88c..637991e 100644 --- a/test/scl/protocol/triple.h +++ b/test/scl/protocol/triple.h @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -25,30 +25,30 @@ namespace scl::test { -template +template struct Triple { - Triple(FF a, FF b, FF c) : a(a), b(b), c(c){}; + Triple(FIELD a, FIELD b, FIELD c) : a(a), b(b), c(c){}; - FF a; - FF b; - FF c; + FIELD a; + FIELD b; + FIELD c; }; -template -std::vector> RandomTriple(util::PRG& prg) { - auto a = FF::Random(prg); - auto b = FF::Random(prg); +template +std::vector> randomTriple2(util::PRG& prg) { + auto a = FIELD::random(prg); + auto b = FIELD::random(prg); auto c = a * b; - auto as = ss::AdditiveShare(a, 2, prg); - auto bs = ss::AdditiveShare(b, 2, prg); - auto cs = ss::AdditiveShare(c, 2, prg); + auto as = ss::additiveShare(a, 2, prg); + auto bs = ss::additiveShare(b, 2, prg); + auto cs = ss::additiveShare(c, 2, prg); return {{as[0], bs[0], cs[0]}, {as[1], bs[1], cs[1]}}; } -template -std::ostream& operator<<(std::ostream& os, const Triple& triple) { +template +std::ostream& operator<<(std::ostream& os, const Triple& triple) { return os << triple.a << " " << triple.b << " " << triple.c; } diff --git a/test/scl/serialization/test_serializer.cc b/test/scl/serialization/test_serializer.cc index 4e82965..0ca3d91 100644 --- a/test/scl/serialization/test_serializer.cc +++ b/test/scl/serialization/test_serializer.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,12 +15,11 @@ * along with this program. If not, see . */ -#include +#include #include "scl/math/fp.h" #include "scl/math/number.h" #include "scl/serialization/serializer.h" -#include "scl/serialization/serializers.h" using namespace scl; @@ -29,21 +28,21 @@ TEST_CASE("Serialization simple types", "[misc]") { const auto int_size = sizeof(int); unsigned char buf[4 * int_size]; - REQUIRE(int_size == Sint::SizeOf(10)); + REQUIRE(int_size == Sint::sizeOf(10)); - Sint::Write(1, buf); - Sint::Write(3, buf + int_size); - Sint::Write(5, buf + 2 * int_size); - Sint::Write(7, buf + 3 * int_size); + Sint::write(1, buf); + Sint::write(3, buf + int_size); + Sint::write(5, buf + 2 * int_size); + Sint::write(7, buf + 3 * int_size); int v; - Sint::Read(v, buf); + Sint::read(v, buf); REQUIRE(v == 1); - Sint::Read(v, buf + int_size); + Sint::read(v, buf + int_size); REQUIRE(v == 3); - Sint::Read(v, buf + 2 * int_size); + Sint::read(v, buf + 2 * int_size); REQUIRE(v == 5); - Sint::Read(v, buf + 3 * int_size); + Sint::read(v, buf + 3 * int_size); REQUIRE(v == 7); } @@ -59,29 +58,31 @@ TEST_CASE("Serialization simple types struct", "[misc]") { SomeStruct s{1, true, 2.5}; unsigned char buf[sizeof(SomeStruct)]; - REQUIRE(Sss::SizeOf(s) == sizeof(SomeStruct)); + REQUIRE(Sss::sizeOf(s) == sizeof(SomeStruct)); - Sss::Write(s, buf); + Sss::write(s, buf); SomeStruct sr; - Sss::Read(sr, buf); + Sss::read(sr, buf); REQUIRE(s.vi == sr.vi); REQUIRE(s.vb == sr.vb); REQUIRE(s.vd == sr.vd); } +constexpr std::size_t VEC_OVERHEAD = sizeof(seri::StlVecSizeType); + TEST_CASE("Serialization vector", "[misc]") { using Sv = seri::Serializer>; std::vector v = {1, 2, 3, 4}; - REQUIRE(Sv::SizeOf(v) == 4 * sizeof(int) + sizeof(std::size_t)); - unsigned char buf[4 * sizeof(int) + sizeof(std::size_t)]; + REQUIRE(Sv::sizeOf(v) == 4 * sizeof(int) + VEC_OVERHEAD); + unsigned char buf[4 * sizeof(int) + VEC_OVERHEAD]; - Sv::Write(v, buf); + Sv::write(v, buf); std::vector w; - Sv::Read(w, buf); + Sv::read(w, buf); REQUIRE(w == v); } @@ -90,14 +91,14 @@ TEST_CASE("Serialization vector vector", "[misc]") { using Sv = seri::Serializer>>; std::vector> v = {{1, 2, 3}, {2, 3}, {5, 6, 7}}; - const auto expected_size = 8 * sizeof(int) + 4 * sizeof(std::size_t); - REQUIRE(Sv::SizeOf(v) == expected_size); + const auto expected_size = 8 * sizeof(int) + 4 * VEC_OVERHEAD; + REQUIRE(Sv::sizeOf(v) == expected_size); unsigned char buf[expected_size]; - Sv::Write(v, buf); + Sv::write(v, buf); std::vector> w; - Sv::Read(w, buf); + Sv::read(w, buf); REQUIRE(v == w); } @@ -107,15 +108,15 @@ TEST_CASE("Serialization Vec", "[misc]") { using Sv = seri::Serializer>; std::vector v = {Fp(1), Fp(2), Fp(3)}; - const auto expected_size = sizeof(std::size_t) + Fp::ByteSize() * 3; - REQUIRE(Sv::SizeOf(v) == expected_size); + const auto expected_size = VEC_OVERHEAD + Fp::byteSize() * 3; + REQUIRE(Sv::sizeOf(v) == expected_size); unsigned char buf[expected_size]; - Sv::Write(v, buf); + Sv::write(v, buf); std::vector w; - Sv::Read(w, buf); + Sv::read(w, buf); REQUIRE(v == w); } @@ -124,11 +125,11 @@ TEST_CASE("Serialization number", "[misc]") { using Sn = seri::Serializer; math::Number a(1234); - auto buf = std::make_unique(Sn::SizeOf(a)); + auto buf = std::make_unique(Sn::sizeOf(a)); - Sn::Write(a, buf.get()); + Sn::write(a, buf.get()); math::Number b; - Sn::Read(b, buf.get()); + Sn::read(b, buf.get()); REQUIRE(a == b); } @@ -140,11 +141,11 @@ TEST_CASE("Serialization number vector", "[misc]") { math::Number(123), math::Number(-10)}; - auto buf = std::make_unique(Sn::SizeOf(nums)); - Sn::Write(nums, buf.get()); + auto buf = std::make_unique(Sn::sizeOf(nums)); + Sn::write(nums, buf.get()); std::vector r; - Sn::Read(r, buf.get()); + Sn::read(r, buf.get()); REQUIRE(nums == r); } diff --git a/test/scl/simulation/test_channel.cc b/test/scl/simulation/test_channel.cc index 4079323..5f22dc4 100644 --- a/test/scl/simulation/test_channel.cc +++ b/test/scl/simulation/test_channel.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,186 +15,74 @@ * along with this program. If not, see . */ -#include +#include #include +#include "scl/coro/runtime.h" +#include "scl/net/loopback.h" +#include "scl/net/packet.h" #include "scl/simulation/channel.h" #include "scl/simulation/config.h" #include "scl/simulation/context.h" #include "scl/simulation/event.h" -#include "scl/simulation/mem_channel_buffer.h" -#include "scl/simulation/simulator.h" +#include "scl/simulation/runtime.h" using namespace scl; +using namespace std::chrono_literals; namespace { -struct InstantNetworkConfig final : sim::NetworkConfig { - sim::ChannelConfig Get(sim::ChannelId channel_id) override { - (void)channel_id; - return sim::ChannelConfig::Loopback(); - } -}; - -auto StartEvent(util::Time::Duration ts) { - return std::make_shared(sim::Event::Type::START, ts); +std::array createChannels( + sim::details::GlobalContext& gctx) { + auto transport = std::make_shared(); + sim::details::SimulatedChannel channel01({0, 1}, gctx.view(0), transport); + sim::details::SimulatedChannel channel10({1, 0}, gctx.view(1), transport); + return {channel01, channel10}; } -auto StopEvent(util::Time::Duration ts) { - return std::make_shared(sim::Event::Type::STOP, ts); +std::size_t getChannelDataEventAmount(std::shared_ptr event_ptr) { + return std::dynamic_pointer_cast(event_ptr)->amount; } } // namespace -TEST_CASE("Channel recv packet blocking", "[sim]") { - auto cfg = std::make_shared(); - auto ctx = sim::Context::Create(2, cfg); - auto chl0 = sim::Channel({0, 1}, ctx); - auto chl1 = sim::Channel({1, 0}, ctx); +TEST_CASE("SimulatedChannel send/recv", "[sim]") { + auto gctx = sim::details::GlobalContext::create( + 2, + std::make_unique(), + {}); + auto channels = createChannels(gctx); - net::Packet p; - p << 123; - ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); - chl0.Send(p); - const auto t0 = ctx->Trace(0); - REQUIRE(t0.size() == 2); - REQUIRE(t0[0]->EventType() == sim::Event::Type::START); - REQUIRE(t0[1]->EventType() == sim::Event::Type::PACKET_SEND); - - ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); - chl1.Recv(); - - const auto t1 = ctx->Trace(1); - REQUIRE(t1.size() == 2); - REQUIRE(t1[0]->EventType() == sim::Event::Type::START); - REQUIRE(t1[1]->EventType() == sim::Event::Type::PACKET_RECV); -} + gctx.view(0).recordEvent(sim::Event::start()); + gctx.view(1).recordEvent(sim::Event::start()); -TEST_CASE("Channel recv packet non-blocking", "[sim]") { - auto cfg = std::make_shared(); - auto ctx = sim::Context::Create(2, cfg); - auto chl0 = sim::Channel({0, 1}, ctx); - auto chl1 = sim::Channel({1, 0}, ctx); + auto rt = sim::details::SimulatorRuntime(gctx); net::Packet p; - p << 123; - ctx->AddEvent(0, StartEvent(util::Time::Duration(1000))); - chl0.Send(p); - - ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); - auto pkt = chl1.Recv(false); - - REQUIRE_FALSE(pkt.has_value()); - auto t0 = ctx->Trace(1); - REQUIRE(t0.size() == 2); - REQUIRE(t0[0]->EventType() == sim::Event::Type::START); - REQUIRE(t0[1]->EventType() == sim::Event::Type::PACKET_RECV); - - ctx->AddEvent(1, StartEvent(ctx->LatestTimestamp(0))); - auto pkt0 = chl1.Recv(false); - - REQUIRE(pkt0.has_value()); - t0 = ctx->Trace(1); - REQUIRE(t0.size() == 4); - REQUIRE(t0[2]->EventType() == sim::Event::Type::START); - REQUIRE(t0[3]->EventType() == sim::Event::Type::PACKET_RECV); -} - -TEST_CASE("Channel recv chunked", "[sim]") { - auto cfg = std::make_shared(); - auto ctx = sim::Context::Create(2, cfg); - auto chl0 = sim::Channel({0, 1}, ctx); - auto chl1 = sim::Channel({1, 0}, ctx); - - unsigned char data[] = {1, 2, 3, 4}; - ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); - chl0.Send(data, 4); - - ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); - unsigned char recv[4] = {0}; - - REQUIRE(ctx->HasWrite({0, 1})); - REQUIRE(ctx->NextWrite({0, 1}).amount == 4); - chl1.Recv(recv, 2); - - REQUIRE(ctx->NextWrite({0, 1}).amount == 2); - chl1.Recv(recv + 2, 2); - REQUIRE_FALSE(ctx->HasWrite({0, 1})); - - REQUIRE(data[0] == recv[0]); - REQUIRE(data[1] == recv[1]); - REQUIRE(data[2] == recv[2]); - REQUIRE(data[3] == recv[3]); -} - -TEST_CASE("Channel HasData no data, but not far ahead", "[sim]") { - auto cfg = std::make_shared(); - auto ctx = sim::Context::Create(2, cfg); - - sim::Channel p0({0, 1}, ctx); - sim::Channel p1({1, 0}, ctx); + p << 1 << 2 << 3; - // P1 at time 100000, P0 at time 0. So we can say for sure that P1 does not - // have data for P0. + const std::size_t expected_size = + sizeof(net::Packet::SizeType) + 3 * sizeof(int); - ctx->AddEvent(1, StartEvent(util::Time::Duration(100000))); - ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); - - ctx->UpdateCheckpoint(); - auto hd = p0.HasData(); - REQUIRE_FALSE(hd); -} - -TEST_CASE("Channel HasData no data, other party terminated", "[sim]") { - auto cfg = std::make_shared(); - auto ctx = sim::Context::Create(2, cfg); - - sim::Channel p0({0, 1}, ctx); - sim::Channel p1({1, 0}, ctx); - - ctx->AddEvent(1, StopEvent(util::Time::Duration::zero())); - ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); - - ctx->UpdateCheckpoint(); - auto hd = p0.HasData(); - REQUIRE_FALSE(hd); -} - -TEST_CASE("Channel HasData no data, fails", "[sim]") { - auto cfg = std::make_shared(); - auto ctx = sim::Context::Create(3, cfg); - - sim::Channel p0({0, 1}, ctx); - sim::Channel p1({1, 0}, ctx); - - // P1 at time 100000, P0 at time 0. So we can say for sure that P1 does not - // have data for P0. - - ctx->AddEvent(1, StartEvent(util::Time::Duration::zero())); - ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); - - ctx->UpdateCheckpoint(); - REQUIRE_THROWS_MATCHES(p0.HasData(), - sim::SimulationFailure, - Catch::Matchers::Message("no data, and we're ahead")); - auto next = ctx->NextToRun(0); - REQUIRE(next.value_or(-1) == 1); -} + gctx.view(0).startClock(); + rt.run(channels[0].send(std::move(p))); + REQUIRE(gctx.traces[0].size() == 2); + REQUIRE(gctx.traces[0].back()->type == sim::EventType::SEND); + REQUIRE(getChannelDataEventAmount(gctx.traces[0].back()) == expected_size); -TEST_CASE("Channel HasData other party not started", "[sim]") { - auto cfg = std::make_shared(); - auto ctx = sim::Context::Create(3, cfg); + REQUIRE(gctx.sends[{0, 1}].size() == 1); + auto send_ts = gctx.sends[{0, 1}].front(); + REQUIRE(gctx.traces[0].back()->timestamp == send_ts); - sim::Channel p0({0, 1}, ctx); - sim::Channel p1({1, 0}, ctx); + gctx.view(1).startClock(); + auto pr = rt.run(channels[1].recv()); + REQUIRE(pr.read() == 1); + REQUIRE(pr.read() == 2); + REQUIRE(pr.read() == 3); - ctx->AddEvent(0, StartEvent(util::Time::Duration::zero())); + REQUIRE(gctx.traces[1].size() == 2); + REQUIRE(gctx.traces[1].back()->type == sim::EventType::RECV); + REQUIRE(getChannelDataEventAmount(gctx.traces[1].back()) == expected_size); - ctx->UpdateCheckpoint(); - REQUIRE_THROWS_MATCHES( - p0.HasData(), - sim::SimulationFailure, - Catch::Matchers::Message("other party hasnt started yet")); - auto next = ctx->NextToRun(0); - REQUIRE(next.value_or(-1) == 1); + REQUIRE(gctx.sends[{0, 1}].empty()); } diff --git a/test/scl/simulation/test_config.cc b/test/scl/simulation/test_config.cc index daac6ce..18c0611 100644 --- a/test/scl/simulation/test_config.cc +++ b/test/scl/simulation/test_config.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,12 +15,13 @@ * along with this program. If not, see . */ -#include +#include +#include #include #include #include "scl/simulation/config.h" -#include "scl/simulation/simulator.h" +#include "scl/util/time.h" using namespace scl; using namespace std::chrono_literals; @@ -28,7 +29,7 @@ using namespace std::chrono_literals; namespace { template -void ApproxDuration(util::Time::Duration d, T v, T b) { +void approxDuration(util::Time::Duration d, T v, T b) { if (v > d) { REQUIRE(v - d <= b); } else { @@ -36,87 +37,47 @@ void ApproxDuration(util::Time::Duration d, T v, T b) { } } -std::size_t KB(std::size_t bytes) { - return 1000 * bytes; -} - -std::size_t MB(std::size_t bytes) { - return 1000 * KB(bytes); -} - } // namespace -TEST_CASE("ComputeRecvTime default config", "[sim]") { - // https://wintelguy.com/wanperf.pl - // parameters: - // Link bandwidth (Mbit/s): 1 - // RTT (millisecond): 100 - // Packet loss (%): 0 - // MTU (Byte): 1500 - // L1/L2 frame overhead (Byte): 0 <-- not accounted for in scl - // TCP/IP (v4) header overhead (Byte): 40 - // TCP window (RWND) size (Byte): 65536 - // File size (MByte): 1 - - const auto cfg = sim::ChannelConfig::Default(); - const auto tenMB = MB(10); - const auto t = sim::ComputeRecvTime(cfg, tenMB); - ApproxDuration(t, 82s, 1s); -} - -TEST_CASE("ComputeRecvTime lossy", "[sim]") { - const auto cfg = sim::ChannelConfig::Builder().PackageLoss(0.001).Build(); - const auto tenMB = MB(10); - const auto t = sim::ComputeRecvTime(cfg, tenMB); - ApproxDuration(t, 82s, 1s); -} - -TEST_CASE("ComputeRecvTime lo", "[sim]") { - const auto cfg = sim::ChannelConfig::Loopback(); - const auto amount = MB(10000); - const auto t = sim::ComputeRecvTime(cfg, amount); - REQUIRE(t.count() == 0); -} - TEST_CASE("SimulationConfig default", "[sim]") { - auto cfg = sim::ChannelConfig::Default(); + auto cfg = sim::ChannelConfig::defaultConfig(); - REQUIRE(cfg.Bandwidth() == sim::ChannelConfig::DEFAULT_BANDWIDTH); + REQUIRE(cfg.bandwidth() == sim::ChannelConfig::DEFAULT_BANDWIDTH); REQUIRE(cfg.RTT() == sim::ChannelConfig::DEFAULT_RTT); REQUIRE(cfg.MSS() == sim::ChannelConfig::DEFAULT_MSS); - REQUIRE(cfg.PackageLoss() == sim::ChannelConfig::DEFAULT_PACKAGE_LOSS); - REQUIRE(cfg.WindowSize() == sim::ChannelConfig::DEFAULT_WINDOW_SIZE); + REQUIRE(cfg.packetLoss() == sim::ChannelConfig::DEFAULT_PACKAGE_LOSS); + REQUIRE(cfg.windowSize() == sim::ChannelConfig::DEFAULT_WINDOW_SIZE); } TEST_CASE("SimulationConfig setters", "[sim]") { - auto cfg_it = sim::ChannelConfig::Builder{}.MSS(5000).Build(); + auto cfg_it = sim::ChannelConfig::Builder{}.MSS(5000).build(); REQUIRE(cfg_it.MSS() == 5000); - REQUIRE(cfg_it.Bandwidth() == sim::ChannelConfig::DEFAULT_BANDWIDTH); + REQUIRE(cfg_it.bandwidth() == sim::ChannelConfig::DEFAULT_BANDWIDTH); // Assume rest of properties are also defaulted correctly. } TEST_CASE("SimulationConfig validation", "[sim]") { - REQUIRE_THROWS_MATCHES(sim::ChannelConfig::Builder{}.Bandwidth(0).Build(), + REQUIRE_THROWS_MATCHES(sim::ChannelConfig::Builder{}.bandwidth(0).build(), std::invalid_argument, Catch::Matchers::Message("bandwidth cannot be 0")); - REQUIRE_THROWS_MATCHES(sim::ChannelConfig::Builder{}.MSS(0).Build(), + REQUIRE_THROWS_MATCHES(sim::ChannelConfig::Builder{}.MSS(0).build(), std::invalid_argument, Catch::Matchers::Message("MSS cannot be 0")); REQUIRE_THROWS_MATCHES( - sim::ChannelConfig::Builder{}.PackageLoss(-0.1).Build(), + sim::ChannelConfig::Builder{}.packetLoss(-0.1).build(), std::invalid_argument, Catch::Matchers::Message("package loss percentage cannot be negative")); REQUIRE_THROWS_MATCHES( - sim::ChannelConfig::Builder{}.PackageLoss(1).Build(), + sim::ChannelConfig::Builder{}.packetLoss(1).build(), std::invalid_argument, Catch::Matchers::Message("package loss percentage cannot exceed 100%")); REQUIRE_THROWS_MATCHES( - sim::ChannelConfig::Builder{}.WindowSize(0).Build(), + sim::ChannelConfig::Builder{}.windowSize(0).build(), std::invalid_argument, Catch::Matchers::Message("TCP window size cannot be 0")); } @@ -124,12 +85,12 @@ TEST_CASE("SimulationConfig validation", "[sim]") { TEST_CASE("SimulationConfig to string", "[sim]") { std::stringstream ss; auto cfg = sim::ChannelConfig::Builder{} - .Bandwidth(2) + .bandwidth(2) .MSS(10) .RTT(50) - .PackageLoss(0.01) - .WindowSize(500) - .Build(); + .packetLoss(0.01) + .windowSize(500) + .build(); ss << cfg; // clang-format off REQUIRE(ss.str() == "SimulationConfig{" @@ -143,8 +104,8 @@ TEST_CASE("SimulationConfig to string", "[sim]") { } TEST_CASE("SimulationConfig local", "[sim]") { - auto cfg = sim::ChannelConfig::Loopback(); - REQUIRE(cfg.Type() == sim::ChannelConfig::NetworkType::INSTANT); + auto cfg = sim::ChannelConfig::loopback(); + REQUIRE(cfg.type() == sim::ChannelConfig::NetworkType::INSTANT); std::stringstream ss; ss << cfg; REQUIRE(ss.str() == "SimulationConfig{INSTANT}"); diff --git a/test/scl/simulation/test_context.cc b/test/scl/simulation/test_context.cc index 01d9a89..3af55f7 100644 --- a/test/scl/simulation/test_context.cc +++ b/test/scl/simulation/test_context.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,306 +15,67 @@ * along with this program. If not, see . */ -#include +#include #include #include -#include "scl/simulation/buffer.h" #include "scl/simulation/channel_id.h" #include "scl/simulation/config.h" #include "scl/simulation/context.h" #include "scl/simulation/event.h" -#include "scl/simulation/mem_channel_buffer.h" -#include "scl/simulation/simulator.h" using namespace scl; +using namespace std::chrono_literals; -namespace { - -auto SomeEvent() { - return std::make_shared(sim::Event::Type::START, - util::Time::Duration::zero()); -} - -auto DefaultNetworkConfig() { - return std::make_shared(); -} - -} // namespace - -TEST_CASE("Simulation context add events", "[sim]") { - auto ctx = sim::Context::Create( - 5, - DefaultNetworkConfig()); - - ctx->AddEvent(2, SomeEvent()); - ctx->AddEvent(2, SomeEvent()); - ctx->AddEvent(1, SomeEvent()); - - const auto t = ctx->Trace(); - REQUIRE(t.size() == 5); - REQUIRE(t[2].size() == 2); - REQUIRE(t[1].size() == 1); - REQUIRE(t[0].empty()); -} - -TEST_CASE("Simulation context total run time", "[sim]") { - auto ctx = sim::Context::Create( +TEST_CASE("Context", "[sim]") { + auto gctx = sim::details::GlobalContext::create( 5, - DefaultNetworkConfig()); - - ctx->AddEvent(0, SomeEvent()); - auto t0 = ctx->Checkpoint(0); - auto t1 = ctx->Checkpoint(0); - - // t0 is the difference between 0 and now (so very large). t1 is the - // difference between t0 and now (so kinda small). - REQUIRE(t0 > t1); -} - -namespace { - -struct DummyChannelBuffer final : public sim::ChannelBuffer { - void Read(unsigned char* data, std::size_t n) override { - (void)data; - (void)n; - throw std::logic_error("not supported"); - } - - void Write(const unsigned char* data, std::size_t n) override { - (void)data; - (void)n; - throw std::logic_error("not supported"); - } - - std::size_t Size() override { - throw std::logic_error("not supported"); - } - - void Prepare() override { - prepare_called++; - } - - void Commit() override { - commit_called++; - } - - void Rollback() override { - rollback_called++; - } - - std::size_t prepare_called = 0; - std::size_t commit_called = 0; - std::size_t rollback_called = 0; -}; - -} // namespace - -namespace scl { - -template <> -std::shared_ptr sim::Context::Create( - std::size_t n, - std::shared_ptr config) { - auto ctx = std::make_shared(config); - - ctx->m_nparties = n; - ctx->m_traces.resize(n); - - for (std::size_t i = 0; i < n; ++i) { - for (std::size_t j = 0; j < n; ++j) { - sim::ChannelId cid(i, j); - ctx->m_buffers[cid] = std::make_shared(); - } - } - return ctx; -} - -} // namespace scl - -#define AS_DUMMY(ctx, i, j) \ - std::dynamic_pointer_cast( \ - (ctx)->Buffer(sim::ChannelId((i), (j)))) - -TEST_CASE("Simulation context prepare-commit-rollback", "[sim]") { - auto ctx = - sim::Context::Create(5, DefaultNetworkConfig()); - - ctx->Prepare(0); - - REQUIRE(AS_DUMMY(ctx, 0, 0)->prepare_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 1)->prepare_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 2)->prepare_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 3)->prepare_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 4)->prepare_called == 1); - REQUIRE(AS_DUMMY(ctx, 1, 0)->prepare_called == 0); - - ctx->AddEvent(0, SomeEvent()); - ctx->AddEvent(0, SomeEvent()); - ctx->AddEvent(0, SomeEvent()); - - REQUIRE(ctx->Trace()[0].size() == 3); + std::make_unique(), + {}); - ctx->Commit(0); - REQUIRE(AS_DUMMY(ctx, 0, 0)->commit_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 1)->commit_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 2)->commit_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 3)->commit_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 4)->commit_called == 1); - REQUIRE(AS_DUMMY(ctx, 1, 0)->commit_called == 0); + REQUIRE(gctx.traces.size() == 5); - REQUIRE(ctx->Trace()[0].size() == 3); + auto view0 = gctx.view(0); + view0.recordEvent(sim::Event::start()); + view0.startClock(); + REQUIRE(view0.elapsedTime() > 0ms); - ctx->Prepare(2); - - ctx->AddEvent(2, SomeEvent()); - ctx->AddEvent(2, SomeEvent()); - - REQUIRE(ctx->Trace()[2].size() == 2); - - ctx->Rollback(2); - REQUIRE(AS_DUMMY(ctx, 2, 0)->rollback_called == 1); - REQUIRE(AS_DUMMY(ctx, 2, 1)->rollback_called == 1); - REQUIRE(AS_DUMMY(ctx, 2, 2)->rollback_called == 1); - REQUIRE(AS_DUMMY(ctx, 2, 3)->rollback_called == 1); - REQUIRE(AS_DUMMY(ctx, 2, 4)->rollback_called == 1); - REQUIRE(AS_DUMMY(ctx, 0, 1)->rollback_called == 0); - - REQUIRE(ctx->Trace()[2].empty()); - REQUIRE(ctx->Trace()[0].size() == 3); -} - -TEST_CASE("Simulation context invalid prepare-commit-rollback", "[sim]") { - auto ctx = - sim::Context::Create(5, DefaultNetworkConfig()); - - REQUIRE_THROWS_MATCHES(ctx->Commit(0), - std::logic_error, - Catch::Matchers::Message("cannot commit")); - REQUIRE_THROWS_MATCHES(ctx->Rollback(0), - std::logic_error, - Catch::Matchers::Message("cannot rollback")); - - ctx->Prepare(0); - - REQUIRE_THROWS_MATCHES(ctx->Prepare(0), - std::logic_error, - Catch::Matchers::Message("cannot prepare ctx")); + view0.recordEvent(sim::Event::closeChannel(100ms, {0, 0})); + auto view1 = gctx.view(1); + REQUIRE(view1.currentTimeOf(0) == 100ms); } -namespace { - -auto StopEvent() { - return std::make_shared(sim::Event::Type::STOP, - util::Time::Duration::zero()); -} - -auto StartEvent() { - return std::make_shared(sim::Event::Type::START, - util::Time::Duration::zero()); -} - -} // namespace - -TEST_CASE("Simulation context NextToRun simple", "[sim]") { - auto ctx = - sim::Context::Create(3, DefaultNetworkConfig()); - - // First party to run is always party 0 - auto next = ctx->NextToRun(); - REQUIRE(next.has_value()); - REQUIRE(next.value() == 0); // NOLINT - - ctx->AddEvent(0, StartEvent()); - ctx->AddEvent(1, StartEvent()); - ctx->AddEvent(2, StartEvent()); - - // next to run is going to be party 1 - next = ctx->NextToRun(next); - REQUIRE(next.has_value()); - REQUIRE(next.value() == 1); // NOLINT - - // next would be party 2, but it has finished running, so party 0 is next. - ctx->AddEvent(2, StopEvent()); - next = ctx->NextToRun(next); - REQUIRE(next.has_value()); - REQUIRE(next.value() == 0); // NOLINT - - ctx->AddEvent(0, StopEvent()); - next = ctx->NextToRun(next); - REQUIRE(next.has_value()); - REQUIRE(next.value() == 1); // NOLINT - - ctx->AddEvent(1, StopEvent()); - next = ctx->NextToRun(next); - REQUIRE_FALSE(next.has_value()); -} - -TEST_CASE("Simulation context NextToRun fails", "[sim]") { - auto ctx = - sim::Context::Create(3, DefaultNetworkConfig()); - - // 0 running - auto next = ctx->NextToRun(); - ctx->Prepare(0); - - ctx->AddEvent(0, StartEvent()); - ctx->AddCandidateToRun(2); - - ctx->Rollback(0); - - // 2 running - next = ctx->NextToRun(next); - if (!next.has_value()) { - FAIL("no output"); - } else { - REQUIRE(next.value() == 2); - } - - ctx->Prepare(2); - - // party 2 tries to receive from itself, but fails -- so it's gonna get stuck - // in an infinite loop - ctx->AddCandidateToRun(2); - ctx->Rollback(2); - - REQUIRE_THROWS_MATCHES(ctx->NextToRun(next), - scl::sim::SimulationFailure, - Catch::Matchers::Message("infinite loop detected")); - - // Party 1 has stopped running, but party 2 expects data -- so party 2 will - // never be able to finish running because the protocol is malformed. - ctx->AddEvent(1, StopEvent()); - - ctx->Prepare(2); - ctx->AddCandidateToRun(1); +TEST_CASE("Context send", "[sim]") { + auto gctx = sim::details::GlobalContext::create( + 5, + std::make_unique(), + {}); - ctx->Rollback(2); + auto view0 = gctx.view(0); + view0.send(1, 100ms); + REQUIRE(gctx.sends[{0, 1}].front() == 100ms); - REQUIRE_THROWS_MATCHES( - ctx->NextToRun(next), - scl::sim::SimulationFailure, - Catch::Matchers::Message( - "party tried to receive data from terminated party")); + view0.send(1, 150ms); + REQUIRE(gctx.sends[{0, 1}].front() == 100ms); + gctx.sends[{0, 1}].pop_front(); + REQUIRE(gctx.sends[{0, 1}].front() == 150ms); } -TEST_CASE("Simulation context rollback write ops", "[sim]") { - auto ctx = - sim::Context::Create(3, DefaultNetworkConfig()); - - const auto ts = util::Time::Duration::zero(); +TEST_CASE("Context recv", "[sim]") { + auto gctx = sim::details::GlobalContext::create( + 5, + std::make_unique(), + {}); - // party 0 sends to party 1 - ctx->Prepare(0); - ctx->AddWrite({0, 1}, 10, ts); - ctx->Commit(0); + auto view0 = gctx.view(0); + auto view1 = gctx.view(1); - // party 1 receives data from party 0, but then performs a rollback. - ctx->Prepare(1); - REQUIRE(ctx->NextWrite({0, 1}).amount == 10); - ctx->NextWrite({0, 1}).amount = 0; - ctx->Rollback(1); + view0.send(1, 100ms); + auto dur = view1.recv(0, 10, 100ms); + REQUIRE(dur > 100ms); - // the change to the write op above should be undone by the rollback. - REQUIRE(ctx->NextWrite({0, 1}).amount == 10); + view0.send(1, 0ms); + dur = view1.recv(0, 10, 1s); + // data will already have arrived when recv is called. + REQUIRE(dur == 1s); } diff --git a/test/scl/simulation/test_env.cc b/test/scl/simulation/test_env.cc index 3019350..841b031 100644 --- a/test/scl/simulation/test_env.cc +++ b/test/scl/simulation/test_env.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,7 @@ * along with this program. If not, see . */ -#include +#include #include #include "scl/simulation/config.h" @@ -103,9 +103,9 @@ TEST_CASE("Simulation env thread", "[sim]") { sim::Clock clock(ctx, 0); ctx->AddEvent(0, SomeEvent(util::Time::Duration(1000ms))); - thread.Sleep(1000000); + thread.Sleep(util::Time::Duration(400h)); auto t0 = clock.Read(); - REQUIRE(t0 > 1000ms + 1000000ms); - REQUIRE(t0 < 1050ms + 1000000ms); + REQUIRE(t0 > 1000ms + 400h); + REQUIRE(t0 < 1050ms + 400h); } diff --git a/test/scl/simulation/test_event.cc b/test/scl/simulation/test_event.cc index e35dd26..50c9aff 100644 --- a/test/scl/simulation/test_event.cc +++ b/test/scl/simulation/test_event.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,132 +15,160 @@ * along with this program. If not, see . */ -#include +#include #include #include "scl/simulation/event.h" using namespace scl; +using namespace std::chrono_literals; namespace { -std::string ToString(sim::Event* e) { - std::stringstream ss; - ss << e; - return ss.str(); +auto str(std::shared_ptr e) { + std::stringstream s; + s << e; + return s.str(); } } // namespace -TEST_CASE("Simulation Event", "[sim]") { +TEST_CASE("Simulation events", "[sim]") { SECTION("START") { - sim::Event e(sim::Event::Type::START, util::Time::Duration::zero()); - REQUIRE(ToString(&e) == "START at 0 ms"); + auto e = sim::Event::start(); + REQUIRE(str(e) == + "{" + "\"timestamp\":0," + "\"type\":\"START\"," + "\"metadata\":{}" + "}"); } SECTION("STOP") { - sim::Event e(sim::Event::Type::STOP, util::Time::Duration::zero()); - REQUIRE(ToString(&e) == "STOP at 0 ms"); + auto e = sim::Event::stop(123ms); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"STOP\"," + "\"metadata\":{}" + "}"); } - const scl::sim::ChannelId cid{2, 5}; + SECTION("CANCELLED") { + auto e = sim::Event::cancelled(123ms); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"CANCELLED\"," + "\"metadata\":{}" + "}"); + } - SECTION("SEND") { - sim::NetworkDataEvent e(sim::Event::Type::SEND, - util::Time::Duration::zero(), - cid, - 100); - REQUIRE(ToString(&e) == "SEND at 0 ms [Sender=2, Receiver=5, Amount=100]"); + SECTION("KILLED") { + auto e = sim::Event::killed(123ms, "foo"); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"KILLED\"," + "\"metadata\":{" + "\"reason\":\"foo\"" + "}" + "}"); } - SECTION("PACKET_SEND") { - sim::NetworkDataEvent e(sim::Event::Type::PACKET_SEND, - util::Time::Duration::zero(), - cid, - 100); - REQUIRE(ToString(&e) == - "PACKET_SEND at 0 ms [Sender=2, Receiver=5, Amount=100]"); + SECTION("CLOSE") { + auto e = sim::Event::closeChannel(123ms, {1, 2}); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"CLOSE\"," + "\"metadata\":{" + "\"channel_id\":{\"local\":1,\"remote\":2}" + "}" + "}"); } - SECTION("PACKET_RECV") { - sim::PacketRecvEvent e(util::Time::Duration::zero(), - util::Time::Duration::zero(), - cid, - 100, - false); - REQUIRE(ToString(&e) == - "PACKET_RECV at 0 ms [Receiver=2, Sender=5, Amount=100, " - "Blocking=false]"); + SECTION("SEND") { + auto e = sim::Event::sendData(123ms, {1, 2}, 10); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"SEND\"," + "\"metadata\":{" + "\"channel_id\":{\"local\":1,\"remote\":2}," + "\"amount\":10" + "}" + "}"); } SECTION("RECV") { - sim::NetworkDataEvent e(sim::Event::Type::RECV, - util::Time::Duration::zero(), - cid, - 100); - REQUIRE(ToString(&e) == "RECV at 0 ms [Receiver=2, Sender=5, Amount=100]"); + auto e = sim::Event::recvData(123ms, {1, 2}, 10); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"RECV\"," + "\"metadata\":{" + "\"channel_id\":{\"local\":1,\"remote\":2}," + "\"amount\":10" + "}" + "}"); } SECTION("HAS_DATA") { - sim::HasDataEvent et(util::Time::Duration::zero(), cid, true); - REQUIRE(ToString(&et) == - "HAS_DATA at 0 ms [Local=2, Remote=5, DataAvailable=true]"); - sim::HasDataEvent ef(util::Time::Duration::zero(), cid, false); - REQUIRE(ToString(&ef) == - "HAS_DATA at 0 ms [Local=2, Remote=5, DataAvailable=false]"); - } - - SECTION("OUTPUT") { - sim::Event e(sim::Event::Type::OUTPUT, util::Time::Duration::zero()); - REQUIRE(ToString(&e) == "OUTPUT at 0 ms"); + auto e = sim::Event::hasData(123ms, {1, 2}); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"HAS_DATA\"," + "\"metadata\":{" + "\"channel_id\":{\"local\":1,\"remote\":2}" + "}" + "}"); } SECTION("SLEEP") { - sim::Event e(sim::Event::Type::SLEEP, util::Time::Duration::zero()); - REQUIRE(ToString(&e) == "SLEEP at 0 ms"); - } - - SECTION("SEGMENT_BEGIN") { - sim::SegmentEvent e(sim::Event::Type::SEGMENT_BEGIN, - util::Time::Duration::zero(), - "foo"); - REQUIRE(ToString(&e) == "SEGMENT_BEGIN at 0 ms [Name=foo]"); - - sim::SegmentEvent unnamed(sim::Event::Type::SEGMENT_BEGIN, - util::Time::Duration::zero(), - ""); - REQUIRE(ToString(&unnamed) == "SEGMENT_BEGIN at 0 ms [Unnamed segment]"); + auto e = sim::Event::sleep(123ms, 100ns); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"SLEEP\"," + "\"metadata\":{" + "\"duration\":0.0001" + "}" + "}"); } - SECTION("SEGMENT_END") { - sim::SegmentEvent e(sim::Event::Type::SEGMENT_END, - util::Time::Duration::zero(), - "foo"); - REQUIRE(ToString(&e) == "SEGMENT_END at 0 ms [Name=foo]"); - } - - SECTION("CLOSE") { - sim::NetworkEvent e(sim::Event::Type::CLOSE, - util::Time::Duration::zero(), - cid); - REQUIRE(ToString(&e) == "CLOSE at 0 ms [Local=2, Remote=5]"); - } - - SECTION("KILLED") { - sim::Event e(sim::Event::Type::KILLED, util::Time::Duration::zero()); - REQUIRE(ToString(&e) == "KILLED at 0 ms"); + SECTION("OUTPUT") { + auto e = sim::Event::output(123ms); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"OUTPUT\"," + "\"metadata\":{}" + "}"); } - SECTION("CHECKPOINT") { - sim::CheckpointEvent e(util::Time::Duration::zero(), "asd"); - REQUIRE(ToString(&e) == "CHECKPOINT at 0 ms [asd]"); + SECTION("PROTOCOL_BEGIN") { + auto e = sim::Event::protocolBegin(123ms, "foo"); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"PROTOCOL_BEGIN\"," + "\"metadata\":{" + "\"name\":\"foo\"" + "}" + "}"); } - SECTION("With offset") { - using namespace std::chrono_literals; - sim::Event e(sim::Event::Type::START, - util::Time::Duration::zero(), - util::Time::Duration(123ms)); - REQUIRE(ToString(&e) == "START at 123 ms [Offset=123 ms]"); + SECTION("PROTOCOL_END") { + auto e = sim::Event::protocolEnd(123ms, "foo"); + REQUIRE(str(e) == + "{" + "\"timestamp\":123," + "\"type\":\"PROTOCOL_END\"," + "\"metadata\":{" + "\"name\":\"foo\"" + "}" + "}"); } } diff --git a/test/scl/simulation/test_manager.cc b/test/scl/simulation/test_manager.cc index b87dc98..f95bc46 100644 --- a/test/scl/simulation/test_manager.cc +++ b/test/scl/simulation/test_manager.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,7 @@ * along with this program. If not, see . */ -#include +#include #include #include "scl/simulation/config.h" diff --git a/test/scl/simulation/test_measurement.cc b/test/scl/simulation/test_measurement.cc deleted file mode 100644 index 43f05a8..0000000 --- a/test/scl/simulation/test_measurement.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include -#include - -#include "scl/simulation/measurement.h" - -using namespace scl; - -TEST_CASE("Measurement to string", "[sim]") { - sim::DataMeasurement dm; - dm.AddSample(123.45); - - std::stringstream ss; - ss << dm; - REQUIRE(ss.str() == "{\"mean\": 123.45, \"unit\": \"B\", \"std_dev\": 0}"); - - sim::TimeMeasurement tm; - tm.AddSample(util::Time::Duration::zero()); - - ss.str(""); - ss << tm; - REQUIRE(ss.str() == "{\"mean\": 0, \"unit\": \"ms\", \"std_dev\": 0}"); -} - -TEST_CASE("Measurement data", "[sim]") { - sim::DataMeasurement dm; - dm.AddSample(2); - dm.AddSample(4); - dm.AddSample(4); - dm.AddSample(4); - dm.AddSample(5); - dm.AddSample(5); - dm.AddSample(7); - dm.AddSample(9); - - REQUIRE(dm.Size() == 8); - REQUIRE(dm.Samples() == std::vector({2, 4, 4, 4, 5, 5, 7, 9})); -} - -TEST_CASE("Measurement time", "[sim]") { - using namespace std::chrono_literals; - - sim::TimeMeasurement tm; - tm.AddSample(2ms); - tm.AddSample(4ms); - tm.AddSample(4ms); - tm.AddSample(4ms); - tm.AddSample(5ms); - tm.AddSample(5ms); - tm.AddSample(7ms); - tm.AddSample(9ms); - - REQUIRE(tm.Size() == 8); - REQUIRE(tm.Samples() == std::vector( - {2ms, 4ms, 4ms, 4ms, 5ms, 5ms, 7ms, 9ms})); -} - -TEST_CASE("Measurement samples", "[sim]") { - sim::DataMeasurement dm; - REQUIRE(dm.Samples().empty()); - - dm.AddSample(42); - REQUIRE(dm.Size() == 1); - REQUIRE(dm.Samples() == std::vector{42}); - - dm.AddSample(22); - REQUIRE(dm.Size() == 2); - REQUIRE(dm.Samples() == std::vector{42, 22}); -} diff --git a/test/scl/simulation/test_mem_channel_buffer.cc b/test/scl/simulation/test_mem_channel_buffer.cc deleted file mode 100644 index e0ab631..0000000 --- a/test/scl/simulation/test_mem_channel_buffer.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include - -#include "scl/simulation/mem_channel_buffer.h" -#include "scl/util/time.h" - -using namespace scl; - -TEST_CASE("Simulation MemoryBackedChannelBuffer", "[sim]") { - auto p = sim::MemoryBackedChannelBuffer::CreatePaired(); - - auto chl0 = p[0]; - auto chl1 = p[1]; - - REQUIRE(chl0->Size() == 0); - REQUIRE(chl1->Size() == 0); - - std::vector data = {1, 2, 3, 4}; - - chl0->Write(data.data(), data.size()); - - REQUIRE(chl0->Size() == 0); - REQUIRE(chl1->Size() == 4); - - std::vector d(2); - chl1->Read(d.data(), 2); - REQUIRE(d == std::vector{1, 2}); - REQUIRE(chl1->Size() == 2); - - chl1->Read(d.data(), 2); - REQUIRE(d == std::vector{3, 4}); - REQUIRE(chl1->Size() == 0); -} - -TEST_CASE("Simulation MemoryBackedChannelBuffer rollback", "[sim]") { - // In the first two sections below, the prepare/rollback channel only ever - // reads or writes. Never both. - - std::vector data = {1, 2, 3, 4}; - - SECTION("read/writes are correctly rolled back") { - auto p = sim::MemoryBackedChannelBuffer::CreatePaired(); - auto local = p[0]; - auto remote = p[1]; - - local->Prepare(); - - local->Write(data.data(), data.size()); - - REQUIRE(remote->Size() == 4); - - local->Rollback(); - REQUIRE(remote->Size() == 0); - - remote->Write(data.data(), data.size()); - - local->Prepare(); - REQUIRE(local->Size() == 4); - std::vector d(2); - local->Read(d.data(), 2); - REQUIRE(local->Size() == 2); - local->Rollback(); - - REQUIRE(local->Size() == 4); - } - - SECTION("rollback only rolls back since last prepare") { - auto p = sim::MemoryBackedChannelBuffer::CreatePaired(); - auto local = p[0]; - auto remote = p[1]; - - local->Prepare(); - - local->Write(data.data(), data.size()); - local->Commit(); - - local->Prepare(); - local->Write(data.data(), data.size()); - - REQUIRE(remote->Size() == 8); - local->Rollback(); - - REQUIRE(remote->Size() == 4); - - remote->Write(data.data(), data.size()); - - local->Prepare(); - - REQUIRE(local->Size() == 4); - std::vector d(2); - local->Read(d.data(), 2); - REQUIRE(local->Size() == 2); - - local->Rollback(); - REQUIRE(local->Size() == 4); - } - - SECTION("rollback for loopback channel") { - auto lo = sim::MemoryBackedChannelBuffer::CreateLoopback(); - - lo->Prepare(); - lo->Write(data.data(), data.size()); - REQUIRE(lo->Size() == 4); - lo->Commit(); - - lo->Prepare(); - - lo->Write(data.data(), data.size()); - REQUIRE(lo->Size() == 8); - - std::vector d3(3); - lo->Read(d3.data(), 3); - REQUIRE(d3 == std::vector{1, 2, 3}); - REQUIRE(lo->Size() == 5); - - lo->Rollback(); - REQUIRE(lo->Size() == 4); - } -} diff --git a/test/scl/simulation/test_result.cc b/test/scl/simulation/test_result.cc deleted file mode 100644 index 8ee16ae..0000000 --- a/test/scl/simulation/test_result.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -#include -#include -#include -#include - -#include "scl/simulation/event.h" -#include "scl/simulation/result.h" - -using namespace scl; - -namespace { - -std::shared_ptr Stop() { - return std::make_shared(sim::Event::Type::STOP, - util::Time::Duration::zero()); -} - -std::shared_ptr Start() { - return std::make_shared(sim::Event::Type::START, - util::Time::Duration::zero()); -} - -std::shared_ptr BeginSegment(const std::string& name = "foo") { - return std::make_shared(sim::Event::Type::SEGMENT_BEGIN, - util::Time::Duration::zero(), - name); -} - -using TraceT = std::vector>; - -#define CREATE_TRACE(...) \ - { \ - { \ - { __VA_ARGS__ } \ - } \ - } - -} // namespace - -TEST_CASE("Simulation result invalid-traces", "[sim]") { - TraceT trace_no_start = CREATE_TRACE(Stop()); - REQUIRE_THROWS_MATCHES(sim::Result::Create(trace_no_start), - std::logic_error, - Catch::Matchers::Message("incomplete trace")); - - TraceT trace_no_stop = CREATE_TRACE(Start()); - REQUIRE_THROWS_MATCHES(sim::Result::Create(trace_no_stop), - std::logic_error, - Catch::Matchers::Message("truncated trace")); - - TraceT trace_invalid_segment = CREATE_TRACE(Start(), BeginSegment(), Stop()); - - REQUIRE_THROWS_MATCHES(sim::Result::Create(trace_invalid_segment), - std::logic_error, - Catch::Matchers::Message("incomplete segment")); -} - -namespace { - -std::shared_ptr EndSegment(const std::string& name = "foo") { - return std::make_shared(sim::Event::Type::SEGMENT_END, - util::Time::Duration::zero(), - name); -} - -std::shared_ptr Send(std::size_t to, - std::size_t from, - std::size_t amount) { - return std::make_shared(sim::Event::Type::SEND, - util::Time::Duration::zero(), - sim::ChannelId{to, from}, - amount); -} - -std::shared_ptr Recv(std::size_t to, - std::size_t from, - std::size_t amount) { - return std::make_shared(sim::Event::Type::RECV, - util::Time::Duration::zero(), - sim::ChannelId{to, from}, - amount); -} - -} // namespace - -TEST_CASE("Simulation result sent recv", "[sim]") { - TraceT trace = CREATE_TRACE(Start(), - BeginSegment(), - Send(0, 1, 123), - Recv(0, 2, 444), - EndSegment(), - BeginSegment("bar"), - Send(0, 3, 42), - Send(0, 1, 22), - EndSegment("bar"), - Stop()); - - auto r = sim::Result::Create(trace); - REQUIRE(r[0].TransferAmounts(2).sent.Samples()[0] == 0); - REQUIRE(r[0].TransferAmounts(2).recv.Samples()[0] == 444); - - REQUIRE(r[0].TransferAmounts(1).sent.Samples()[0] == 123 + 22); - REQUIRE(r[0].TransferAmounts(1, "bar").sent.Samples()[0] == 22); - - std::vector expected = {1, 2, 3}; - REQUIRE_THAT(r[0].Interactions(), Catch::Matchers::UnorderedEquals(expected)); - std::vector expected_bar = {1, 3}; - REQUIRE_THAT(r[0].Interactions("bar"), - Catch::Matchers::UnorderedEquals(expected_bar)); -} - -namespace { - -std::shared_ptr Checkpoint(const std::string& message) { - return std::make_shared(util::Time::Duration::zero(), - message); -} - -} // namespace - -TEST_CASE("Simulation result with checkpoint", "[sim]") { - TraceT trace = CREATE_TRACE(Start(), - BeginSegment(), - Checkpoint("x"), - EndSegment(), - BeginSegment(), - Checkpoint("x"), - Checkpoint("y"), - EndSegment(), - Stop()); - - auto r = sim::Result::Create(trace); - REQUIRE(r[0].Checkpoint("x").Size() == 1); - REQUIRE(r[0].Checkpoint("y").Size() == 1); -} - -TEST_CASE("Simulation result write", "[sim]") { - // TODO: This doesn't really test anything besides that Write is - // stable(-ish). Ideally, the test should check that the result is consistent - // with the content of a file on disk, but that likely requires that Write is - // deterministic, which is not the case because writes a ton of unordered - // maps. - - TraceT trace = CREATE_TRACE(Start(), - BeginSegment(), - Send(0, 1, 123), - Recv(0, 2, 444), - Checkpoint("x"), - EndSegment(), - BeginSegment("bar"), - Send(0, 3, 42), - Send(0, 1, 22), - EndSegment("bar"), - Stop()); - - auto r = sim::Result::Create(trace); - - std::stringstream ss0; - std::stringstream ss1; - r[0].Write(ss0); - r[0].Write(ss1); - REQUIRE(ss0.str() == ss1.str()); -} - -TEST_CASE("Simulation result write trace invalid replication", "[sim]") { - TraceT trace = CREATE_TRACE(Start(), Stop()); - auto r = sim::Result::Create(trace); - - REQUIRE_THROWS_MATCHES(r[0].WriteTrace(std::cout, 42), - std::invalid_argument, - Catch::Matchers::Message("invalid replication")); -} diff --git a/test/scl/simulation/test_simulator.cc b/test/scl/simulation/test_simulator.cc index 13f2607..2b9dde9 100644 --- a/test/scl/simulation/test_simulator.cc +++ b/test/scl/simulation/test_simulator.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -16,9 +16,8 @@ */ #include -#include +#include #include -#include #include #include #include @@ -26,265 +25,155 @@ #include #include -#include "../protocol/beaver.h" -#include "scl/math/fp.h" +#include "scl/coro/coroutine.h" #include "scl/protocol/base.h" #include "scl/protocol/env.h" -#include "scl/simulation/config.h" +#include "scl/protocol/result.h" +#include "scl/simulation/context.h" +#include "scl/simulation/event.h" #include "scl/simulation/manager.h" -#include "scl/simulation/result.h" #include "scl/simulation/simulator.h" -#include "scl/ss/additive.h" -#include "scl/util/prg.h" using namespace scl; using namespace std::chrono_literals; -using FF = math::Fp<61>; -using Parties = std::vector>; - -namespace { +TEST_CASE("Simulator no protocols", "[sim]") { + struct NoProtocolManager final : public sim::Manager { + std::vector> protocol() override { + return {}; + } -template -std::unique_ptr CreateParty(Ts&&... init_args) { - return std::make_unique(std::forward(init_args)...); -} + void handleSimulatorOutput( + std::size_t /* ignored */, + const sim::SimulationTrace& /* ignored */) override { + FAIL(); + } + }; -auto RecvTimeDefaultConf(std::size_t n) { - static const auto dft = sim::ChannelConfig::Default(); - return sim::ComputeRecvTime(dft, n); + sim::simulate(std::make_unique()); } -} // namespace +// Protocol where each party sends a little bit of data to the other. +struct SendRecv final : public proto::Protocol { + coro::Task run(proto::Env& env) const { + net::Packet p; + p << 1 << 2 << 3; -struct SimpleSendRecvProtocol { - struct Sender final : public proto::Protocol { - std::unique_ptr Run(proto::Env& env) override { - net::Packet p; - p << (std::size_t)123; - p << (int)-100; - env.network.Other()->Send(p); - env.network.Close(); - return nullptr; - } - }; + co_await env.network.other()->send(p); + auto r = co_await env.network.other()->recv(); - struct Receiver final : public proto::Protocol { - std::unique_ptr Run(proto::Env& env) override { - auto p = env.network.Other()->Recv(true); + env.network.close(); - if (!p.has_value()) { - throw std::runtime_error("expected data"); - } - auto& v = p.value(); + co_return proto::ProtocolResult::done(); + } +}; - is_correct = v.Read() == 123; - is_correct &= v.Read() == -100; +TEST_CASE("Simulate SendRecv protocol", "[sim]") { + class SendRecvManager final : public sim::Manager { + public: + SendRecvManager(std::size_t n) : n_m(n) {} - env.network.Close(); - return nullptr; + std::vector> protocol() override { + std::vector> p; + for (std::size_t i = 0; i < n_m; i++) { + p.emplace_back(std::make_unique()); + } + return p; }; - std::any Output() const override { - return is_correct; + void handleSimulatorOutput(std::size_t pid, + const sim::SimulationTrace& trace) override { + if (pid == 0 || pid == 1) { + validateTrace(trace); + } else { + FAIL(); + } } - bool is_correct = false; + private: + std::size_t n_m; + + void validateTrace(const sim::SimulationTrace& trace) { + // start, begin, send, recv, close, close, end, stop + REQUIRE(trace.size() == 8); + REQUIRE(trace[0]->type == sim::EventType::START); + REQUIRE(trace[1]->type == sim::EventType::PROTOCOL_BEGIN); + REQUIRE(trace[2]->type == sim::EventType::SEND); + REQUIRE(trace[3]->type == sim::EventType::RECV); + REQUIRE(trace[4]->type == sim::EventType::CLOSE); + REQUIRE(trace[5]->type == sim::EventType::CLOSE); + REQUIRE(trace[6]->type == sim::EventType::PROTOCOL_END); + REQUIRE(trace[7]->type == sim::EventType::STOP); + } }; -}; - -namespace { - -void VerifySendRecvProtocolResult(const std::vector& result) { - REQUIRE(result.size() == 2); - const auto& r0 = result[0]; - REQUIRE(r0.SegmentNames().size() == 1); - REQUIRE(r0.SegmentNames()[0] == proto::Protocol::DEFAULT_NAME); - - const auto et0 = r0.ExecutionTime(); - REQUIRE(et0.Size() == 1); - REQUIRE(et0.Samples()[0] < 1ms); - - const auto et1 = result[1].ExecutionTime(); - REQUIRE(et1.Size() == 1); - const auto bytes_recv = - sizeof(int) + sizeof(std::size_t) + sizeof(net::Packet::SizeType); - REQUIRE(et1.Samples()[0] < RecvTimeDefaultConf(bytes_recv) + 1ms); -} - -} // namespace - -TEST_CASE("Simulate SimpleSendRecvProtocol", "[sim]") { - Parties p; - p.emplace_back(CreateParty()); - p.emplace_back(CreateParty()); - - const auto result = sim::Simulate(std::move(p)); - VerifySendRecvProtocolResult(result); + sim::simulate(std::make_unique(2)); } -TEST_CASE("Simulate SimpleSendRecvProtocol reverse", "[sim]") { - Parties p; - p.emplace_back(CreateParty()); - p.emplace_back(CreateParty()); - - const auto result = sim::Simulate(std::move(p)); - VerifySendRecvProtocolResult({result[1], result[0]}); -} - -namespace { - -void VerifyType(std::shared_ptr event, sim::Event::Type type) { - REQUIRE(event->EventType() == type); -} - -void VerifyTypeString(std::stringstream& ss, - const std::string& event_type_str) { - std::string line; - std::getline(ss, line); - REQUIRE_THAT(line, Catch::Matchers::StartsWith(event_type_str)); -} - -} // namespace - -TEST_CASE("Simulate SimpleSendRecvProtocol trace", "[sim]") { - Parties p; - p.emplace_back(CreateParty()); - p.emplace_back(CreateParty()); - - const auto result = sim::Simulate(std::move(p)); - - SECTION("Sender") { - const auto& sender_trace = result[0].Trace(0); - - VerifyType(sender_trace[0], sim::Event::Type::START); - VerifyType(sender_trace[1], sim::Event::Type::SEGMENT_BEGIN); - VerifyType(sender_trace[2], sim::Event::Type::PACKET_SEND); - VerifyType(sender_trace[3], sim::Event::Type::CLOSE); // to self - VerifyType(sender_trace[4], sim::Event::Type::CLOSE); // to other - VerifyType(sender_trace[5], sim::Event::Type::SEGMENT_END); - VerifyType(sender_trace[6], sim::Event::Type::STOP); - - std::stringstream ss; - result[0].WriteTrace(ss, 0); - VerifyTypeString(ss, "START"); - VerifyTypeString(ss, "SEGMENT_BEGIN"); - VerifyTypeString(ss, "PACKET_SEND"); - VerifyTypeString(ss, "CLOSE"); - VerifyTypeString(ss, "CLOSE"); - VerifyTypeString(ss, "SEGMENT_END"); - VerifyTypeString(ss, "STOP"); +struct Sleepy final : public proto::Protocol { + coro::Task run(proto::Env& /* ignored */) const { + co_await 100s; + co_return proto::ProtocolResult::done(); } +}; - SECTION("Receiver") { - const auto& receiver_trace = result[1].Trace(0); - - VerifyType(receiver_trace[0], sim::Event::Type::START); - VerifyType(receiver_trace[1], sim::Event::Type::SEGMENT_BEGIN); - VerifyType(receiver_trace[2], sim::Event::Type::PACKET_RECV); - VerifyType(receiver_trace[3], sim::Event::Type::CLOSE); // to self - VerifyType(receiver_trace[4], sim::Event::Type::CLOSE); // to other - VerifyType(receiver_trace[5], sim::Event::Type::OUTPUT); - VerifyType(receiver_trace[6], sim::Event::Type::SEGMENT_END); - VerifyType(receiver_trace[7], sim::Event::Type::STOP); - - std::stringstream ss; - result[1].WriteTrace(ss, 0); - - VerifyTypeString(ss, "START"); - VerifyTypeString(ss, "SEGMENT_BEGIN"); - VerifyTypeString(ss, "PACKET_RECV"); - VerifyTypeString(ss, "CLOSE"); - VerifyTypeString(ss, "CLOSE"); - VerifyTypeString(ss, "OUTPUT"); - VerifyTypeString(ss, "SEGMENT_END"); - VerifyTypeString(ss, "STOP"); - } -} - -TEST_CASE("Simulate null protocol", "[sim]") { - Parties p; - p.emplace_back(nullptr); +TEST_CASE("Simulate Sleepy protocol", "[sim]") { + struct SleepyManager final : public sim::Manager { + public: + std::vector> protocol() override { + std::vector> p; + p.emplace_back(std::make_unique()); + return p; + } - const auto result = sim::Simulate(std::move(p)); + void handleSimulatorOutput(std::size_t /* ignored */, + const sim::SimulationTrace& trace) override { + REQUIRE(trace.size() == 5); + REQUIRE(trace[0]->type == sim::EventType::START); + REQUIRE(trace[1]->type == sim::EventType::PROTOCOL_BEGIN); + REQUIRE(trace[2]->type == sim::EventType::SLEEP); + REQUIRE(trace[3]->type == sim::EventType::PROTOCOL_END); + REQUIRE(trace[4]->type == sim::EventType::STOP); + } + }; - REQUIRE(result.size() == 1); - const auto trace = result[0].Trace(0); - REQUIRE(trace[0]->EventType() == sim::Event::Type::START); - REQUIRE(trace[1]->EventType() == sim::Event::Type::STOP); + sim::simulate(std::make_unique()); } -struct PingPongProtocol { - struct Ping final : public proto::Protocol { - std::unique_ptr Run(proto::Env& env) override { - unsigned char data[] = {'a', 'b', 'c'}; - env.network.Other()->Send(data, 3); - env.thread_ctx->Sleep(1000); - return std::make_unique(); +TEST_CASE("Simulate protocol cancellation", "[sim]") { + struct CancelManager final : public sim::Manager { + std::vector> protocol() override { + std::vector> p; + p.emplace_back(std::make_unique()); + p.emplace_back(std::make_unique()); + return p; } - std::string Name() const override { - return "Ping"; + + void handleSimulatorOutput(std::size_t /* ignored */, + const sim::SimulationTrace& trace) override { + // there's two cases here: The party that gets to run first will cancel + // the simulation at the first PROTOCOL_BEGIN event. The other party will + // not get to run at all. We should thus see a trace with 3 events (START, + // PROTOCOL_BEGIN, CANCELLED) and one with 0 events. + + if (trace.size() == 3) { + REQUIRE(trace[0]->type == sim::EventType::START); + REQUIRE(trace[1]->type == sim::EventType::PROTOCOL_BEGIN); + REQUIRE(trace[2]->type == sim::EventType::CANCELLED); + } else if (!trace.empty()) { + FAIL("should not happen"); + } } }; - struct Pong final : public proto::Protocol { - std::unique_ptr Run(proto::Env& env) override { - unsigned char data[3] = {0}; - env.network.Other()->Recv(data, 3); - bool good = data[0] == 'a' && data[1] == 'b' && data[2] == 'c'; - env.clock->Checkpoint(good ? "yay" : "boo"); - return std::make_unique(); - } - std::string Name() const override { - return "Pong"; + struct CancelHook : public sim::Hook { + void run(std::size_t /* ignored */, const sim::SimulationContext& ctx) { + ctx.cancelSimulation(); } }; -}; - -struct PingPongManager final : public sim::Manager { - PingPongManager(std::size_t replications) : sim::Manager(replications) {} - std::vector> Protocol() override { - Parties p; - p.emplace_back(std::make_unique()); - p.emplace_back(std::make_unique()); - return p; - } - - bool Terminate(std::size_t party_id, - const sim::Context::View& view) override { - const auto latest_time = view.Trace(party_id).back()->Timestamp(); - return latest_time > 10s; - } -}; + auto man = std::make_unique(); + man->addHook(sim::EventType::PROTOCOL_BEGIN); -TEST_CASE("Simulate PingPongProtocol", "[sim]") { - auto m = std::make_unique(1); - const auto result = sim::Simulate(std::move(m)); - - const auto last_event_p0 = result[0].Trace(0).back(); - VerifyType(last_event_p0, sim::Event::Type::KILLED); - REQUIRE(last_event_p0->Timestamp() >= 10000ms); - - const auto last_event_p1 = result[1].Trace(0).back(); - VerifyType(last_event_p1, sim::Event::Type::KILLED); - REQUIRE(last_event_p1->Timestamp() >= 10000ms); -} - -TEST_CASE("Simulate PingPongProtocol trace", "[sim]") { - auto m = std::make_unique(1); - const auto result = sim::Simulate(std::move(m)); - - std::stringstream ss; - std::string line; - result[0].WriteTrace(ss, 0, "Ping"); - - // ping/pong runs for 10 iterations - for (std::size_t i = 0; i < 10; ++i) { - VerifyTypeString(ss, "SEGMENT_BEGIN"); - VerifyTypeString(ss, "SEND"); - VerifyTypeString(ss, "SLEEP"); - VerifyTypeString(ss, "SEGMENT_END"); - } + sim::simulate(std::move(man)); } diff --git a/test/scl/ss/test_additive.cc b/test/scl/ss/test_additive.cc index 12e4ce2..0a5eb88 100644 --- a/test/scl/ss/test_additive.cc +++ b/test/scl/ss/test_additive.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,7 @@ * along with this program. If not, see . */ -#include +#include #include "scl/math/fp.h" #include "scl/ss/additive.h" @@ -25,17 +25,17 @@ using namespace scl; TEST_CASE("AdditiveSS", "[ss]") { using FF = math::Fp<61>; - auto prg = util::PRG::Create(); + auto prg = util::PRG::create(); auto secret = FF(12345); - auto shares = ss::AdditiveShare(secret, 10, prg); - REQUIRE(shares.Size() == 10); - REQUIRE(shares.Sum() == secret); + auto shares = ss::additiveShare(secret, 10, prg); + REQUIRE(shares.size() == 10); + REQUIRE(shares.sum() == secret); auto x = FF(55555); - auto shr_x = ss::AdditiveShare(x, 10, prg); - auto sum = shares.Add(shr_x); + auto shr_x = ss::additiveShare(x, 10, prg); + auto share_sum = shares.add(shr_x); - REQUIRE(sum.Sum() == secret + x); + REQUIRE(share_sum.sum() == secret + x); } diff --git a/test/scl/ss/test_feldman.cc b/test/scl/ss/test_feldman.cc index d8cab26..4c1aeb8 100644 --- a/test/scl/ss/test_feldman.cc +++ b/test/scl/ss/test_feldman.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,7 @@ * along with this program. If not, see . */ -#include +#include #include #include "scl/math/curves/secp256k1.h" @@ -26,18 +26,39 @@ using namespace scl; -TEST_CASE("Feldman", "[ss]") { - using EC = math::EC; - using FF = EC::ScalarField; +using EC = math::EC; +using FF = EC::ScalarField; - auto prg = util::PRG::Create(); +TEST_CASE("Feldman", "[ss]") { + auto prg = util::PRG::create("feldman"); std::size_t t = 4; auto secret = FF(123); - auto sb = ss::FeldmanShare(secret, 4, 24, prg); - REQUIRE(sb.shares.Size() == 24); - REQUIRE(sb.commitments.Size() == t + 1); - REQUIRE(ss::FeldmanVerify({0, secret}, sb.commitments)); - REQUIRE(ss::FeldmanVerify({23, sb.shares[22]}, sb.commitments)); - REQUIRE(ss::ShamirRecoverP(sb.shares.SubVector(5)) == secret); + auto sb = ss::feldmanSecretShare(secret, 4, 24, prg); + REQUIRE(sb.commitments[0] == secret * EC::generator()); + REQUIRE(sb.shares.size() == 24); + REQUIRE(sb.commitments.size() == t + 1); + REQUIRE(ss::feldmanVerify({secret, sb.commitments}, 0)); + REQUIRE(ss::feldmanVerify(secret, sb.commitments, 0)); + REQUIRE(ss::feldmanVerify(sb.getShare(22), 23)); + REQUIRE(ss::shamirRecoverP(sb.shares.subVector(5)) == secret); +} + +TEST_CASE("Feldman hom", "[ss]") { + auto prg = util::PRG::create("feldman hom"); + std::size_t t = 4; + + auto s0 = FF(123); + auto s1 = FF(44); + + auto ss0 = ss::feldmanSecretShare(s0, t, 10, prg); + auto ss1 = ss::feldmanSecretShare(s1, t, 10, prg); + + auto ss2 = ss0.shares.add(ss1.shares); + auto com2 = ss0.commitments.add(ss1.commitments); + + // Check that new commitment works for the sum of the secrets. + REQUIRE(ss::feldmanVerify({s0 + s1, com2}, 0)); + // Check that new commitment works for an individual share. + REQUIRE(ss::feldmanVerify({ss2[5], com2}, 6)); } diff --git a/test/scl/ss/test_pedersen.cc b/test/scl/ss/test_pedersen.cc new file mode 100644 index 0000000..430f787 --- /dev/null +++ b/test/scl/ss/test_pedersen.cc @@ -0,0 +1,136 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include + +#include "scl/math/curves/secp256k1.h" +#include "scl/math/ec.h" +#include "scl/math/vector.h" +#include "scl/ss/pedersen.h" +#include "scl/ss/shamir.h" +#include "scl/util/prg.h" + +using namespace scl; + +using EC = math::EC; +using FF = EC::ScalarField; + +const EC h = EC::generator() * FF(42); + +TEST_CASE("Pedersen", "[ss]") { + auto prg = util::PRG::create("Pedersen"); + std::size_t t = 4; + + auto rand = FF(42); + auto secret = FF(123); + auto sb = ss::pedersenSecretShare(secret, t, 24, prg, h, rand); + + REQUIRE(sb.shares.size() == 24); + REQUIRE(sb.commitments.size() == t + 1); + REQUIRE(sb.commitments[0] == secret * EC::generator() + rand * h); + + auto sh = ss::shamirRecoverP(sb.shares.subVector(t + 1)); + REQUIRE(sh[0] == secret); + REQUIRE(sh[1] == rand); + REQUIRE(ss::pedersenVerify({sh, sb.commitments}, 0, h)); + // test overload + REQUIRE(ss::pedersenVerify(sh, sb.commitments, 0, h)); +} + +TEST_CASE("Pedersen hom", "[ss]") { + auto prg = util::PRG::create("Pedersen hom"); + std::size_t t = 4; + + auto s0 = FF(123); + auto s1 = FF(44); + + auto ss0 = ss::pedersenSecretShare(s0, t, 10, prg, h); + auto ss1 = ss::pedersenSecretShare(s1, t, 10, prg, h); + + auto ss2 = ss0.shares.add(ss1.shares); + auto com2 = ss0.commitments.add(ss1.commitments); + + REQUIRE(ss::pedersenVerify({ss2[4], com2}, 5, h)); + + auto secret = ss::shamirRecoverP(ss2.subVector(t + 1)); + // the recovered value is a pair {secret, randomness}. + REQUIRE(secret[0] == s0 + s1); + REQUIRE(ss::pedersenVerify({secret, com2}, 0, h)); +} + +namespace { + +std::vector>> getShares(std::size_t n, + std::size_t t) { + auto prg = util::PRG::create("Pedersen apply"); + + std::vector>> shares(n); + + for (std::size_t i = 0; i < n; i++) { + auto secret = FF::random(prg); + auto shrs = ss::pedersenSecretShare(secret, t, n, prg, h); + for (std::size_t j = 0; j < n; j++) { + shares[j].emplace_back(shrs.getShare(j)); + } + } + + return shares; +} + +} // namespace + +TEST_CASE("Pedersen apply id", "[ss]") { + const std::size_t t = 2; + const std::size_t n = 5; + + auto shares_in = getShares(n, t); + std::vector>> shares_out; + const auto id = math::Matrix::identity(n); + + for (std::size_t i = 0; i < n; i++) { + const auto sin = shares_in[i]; + const auto sout = ss::apply(sin.begin(), sin.end(), id); + for (std::size_t j = 0; j < n; j++) { + REQUIRE(shares_in[i][j].share == sout[j].share); + REQUIRE(shares_in[i][j].commitments == sout[j].commitments); + } + } +} + +TEST_CASE("Pedersen apply", "[ss]") { + const std::size_t t = 2; + const std::size_t n = 5; + + auto shares_in = getShares(n, t); + + std::vector>> shares_out; + const auto van = math::Matrix::vandermonde(n - t, n); + + for (std::size_t i = 0; i < n; i++) { + const auto sin = shares_in[i]; + shares_out.emplace_back(ss::apply(sin, van)); + REQUIRE(shares_out[i].size() == n - t); + } + + // verify result + for (std::size_t i = 0; i < n - t; i++) { + for (std::size_t j = 0; j < n; j++) { + const auto sij = shares_out[j][i]; + REQUIRE(ss::pedersenVerify(sij, j + 1, h)); + } + } +} diff --git a/test/scl/ss/test_shamir.cc b/test/scl/ss/test_shamir.cc index 11b09de..413a9ce 100644 --- a/test/scl/ss/test_shamir.cc +++ b/test/scl/ss/test_shamir.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,14 +15,15 @@ * along with this program. If not, see . */ -#include +#include +#include #include #include "../gf7.h" #include "scl/math/fp.h" #include "scl/math/lagrange.h" #include "scl/math/poly.h" -#include "scl/math/vec.h" +#include "scl/math/vector.h" #include "scl/ss/shamir.h" #include "scl/util/prg.h" @@ -31,130 +32,129 @@ using namespace scl; using FF = math::Fp<61>; TEST_CASE("Shamir share passive", "[ss]") { - auto prg = util::PRG::Create("shamir passive"); - const auto shares = ss::ShamirShare(FF(123), 3, 4, prg); + auto prg = util::PRG::create("shamir passive"); + const auto shares = ss::shamirSecretShare(FF(123), 3, 4, prg); - REQUIRE(shares.Size() == 4); - REQUIRE(ss::ShamirRecoverP(shares) == FF(123)); + REQUIRE(shares.size() == 4); + REQUIRE(ss::shamirRecoverP(shares) == FF(123)); } TEST_CASE("Shamir reconstruct", "[ss]") { - auto prg = util::PRG::Create("shamir recons"); - const auto shares = ss::ShamirShare(FF(123), 5, 100, prg); + auto prg = util::PRG::create("shamir recons"); + const auto shares = ss::shamirSecretShare(FF(123), 5, 100, prg); - REQUIRE(shares.Size() == 100); + REQUIRE(shares.size() == 100); const auto lb_0 = - math::ComputeLagrangeBasis({FF(4), FF(5), FF(6), FF(7), FF(8), FF(9)}, + math::computeLagrangeBasis({FF(4), FF(5), FF(6), FF(7), FF(8), FF(9)}, 0); - const auto r_0 = math::UncheckedInnerProd(shares.begin() + 3, - shares.begin() + 9, - lb_0.begin()); - const auto r_0_alt = shares.SubVector(3, 9).Dot(lb_0); + const auto r_0 = + math::innerProd(shares.begin() + 3, shares.begin() + 9, lb_0.begin()); + const auto r_0_alt = shares.subVector(3, 9).dot(lb_0); REQUIRE(r_0 == FF(123)); REQUIRE(r_0_alt == r_0); const auto lb_27 = - math::ComputeLagrangeBasis({FF(4), FF(5), FF(6), FF(7), FF(8), FF(9)}, + math::computeLagrangeBasis({FF(4), FF(5), FF(6), FF(7), FF(8), FF(9)}, 27); - const auto r_27 = math::UncheckedInnerProd(shares.begin() + 3, - shares.begin() + 9, - lb_27.begin()); + const auto r_27 = math::innerProd(shares.begin() + 3, + shares.begin() + 9, + lb_27.begin()); REQUIRE(r_27 == shares[26]); } TEST_CASE("Shamir reconstruct detect", "[ss]") { - auto prg = util::PRG::Create("shamir detect"); - auto shares = ss::ShamirShare(FF(123), 4, 9, prg); + auto prg = util::PRG::create("shamir detect"); + auto shares = ss::shamirSecretShare(FF(123), 4, 9, prg); - REQUIRE(ss::ShamirRecoverD(shares) == FF(123)); + REQUIRE(ss::shamirRecoverD(shares, 4) == FF(123)); shares[2] = FF(4); REQUIRE_THROWS_MATCHES( - ss::ShamirRecoverD(shares), + ss::shamirRecoverD(shares, 4), std::logic_error, Catch::Matchers::Message("error detected during recovery")); } namespace { -math::Vec ShareWithDifferentAlphas(util::PRG& prg, - std::size_t t, - std::size_t n) { - auto c = math::Vec::Random(t + 1, prg); +math::Vector shareWithDifferentAlphas(util::PRG& prg, + std::size_t t, + std::size_t n) { + auto c = math::Vector::random(t + 1, prg); c[0] = FF(123); - const auto p = math::Polynomial::Create(c); + const auto p = math::Polynomial::create(c); std::vector shares; shares.reserve(n); for (std::size_t i = 0; i < n; ++i) { - shares.emplace_back(p.Evaluate(FF{(int)i + 42})); + shares.emplace_back(p.evaluate(FF{(int)i + 42})); } - return math::Vec(shares); + return math::Vector(shares); } } // namespace TEST_CASE("Shamir reconstruct different x and alphas", "[ss]") { - auto prg = util::PRG::Create("shamir detect2"); + auto prg = util::PRG::create("shamir detect2"); - const auto shares = ShareWithDifferentAlphas(prg, 3, 7); - const auto alphas = math::Vec::Range(42, 50); + const auto shares = shareWithDifferentAlphas(prg, 3, 7); + const auto alphas = math::Vector::range(42, 50); - REQUIRE(ss::ShamirRecoverD(shares, alphas, FF(0)) == FF(123)); + REQUIRE(ss::shamirRecoverD(shares, alphas, 3, 3, FF(0)) == FF(123)); - REQUIRE(ss::ShamirRecoverD(shares, alphas, alphas[0]) == shares[0]); + REQUIRE(ss::shamirRecoverD(shares, alphas, 3, 3, alphas[0]) == shares[0]); } TEST_CASE("Shamir reconstruct correct", "[sim]") { - auto prg = util::PRG::Create("shamir correct"); - auto shares = ss::ShamirShare(FF(123), 2, 7, prg); + auto prg = util::PRG::create("shamir correct"); + auto shares = ss::shamirSecretShare(FF(123), 2, 7, prg); - REQUIRE(ss::ShamirRecoverC(shares).f.Evaluate(FF{0}) == FF(123)); + REQUIRE(ss::shamirRecoverC(shares).f.evaluate(FF{0}) == FF(123)); shares[0] = FF(22); shares[1] = FF(23); - REQUIRE(ss::ShamirRecoverC(shares).f.Evaluate(FF{0}) == FF(123)); + REQUIRE(ss::shamirRecoverC(shares).f.evaluate(FF{0}) == FF(123)); shares[2] = FF(24); - REQUIRE_THROWS_MATCHES(ss::ShamirRecoverC(shares), + REQUIRE_THROWS_MATCHES(ss::shamirRecoverC(shares), std::logic_error, Catch::Matchers::Message("could not correct shares")); } TEST_CASE("Shamir reconstruct correct different alphas", "[ss]") { - auto prg = util::PRG::Create("shamir correct2"); + auto prg = util::PRG::create("shamir correct2"); - auto shares = ShareWithDifferentAlphas(prg, 2, 7); - const auto alphas = math::Vec::Range(42, 50); + auto shares = shareWithDifferentAlphas(prg, 2, 7); + const auto alphas = math::Vector::range(42, 50); - REQUIRE(ss::ShamirRecoverC(shares, alphas).f.ConstantTerm() == FF(123)); + REQUIRE(ss::shamirRecoverC(shares, alphas).f.constantTerm() == FF(123)); shares[4] = FF(5555); - const auto r = ss::ShamirRecoverC(shares, alphas); - REQUIRE(r.f.ConstantTerm() == FF(123)); - REQUIRE(r.err.Evaluate(alphas[4]) == FF(0)); + const auto r = ss::shamirRecoverC(shares, alphas); + REQUIRE(r.f.constantTerm() == FF(123)); + REQUIRE(r.err.evaluate(alphas[4]) == FF(0)); } -TEST_CASE("BerlekampWelch", "[ss][math]") { +TEST_CASE("BerlekampWelch wiki reference test", "[ss][math]") { // https://en.wikipedia.org/wiki/Berlekamp%E2%80%93Welch_algorithm#Example using FF = math::FF; - math::Vec bs = {FF(1), FF(5), FF(3), FF(6), FF(3), FF(2), FF(2)}; - math::Vec corrected = {FF(1), FF(6), FF(3), FF(6), FF(1), FF(2), FF(2)}; + math::Vector bs = {FF(1), FF(5), FF(3), FF(6), FF(3), FF(2), FF(2)}; + math::Vector corrected = {FF(1), FF(6), FF(3), FF(6), FF(1), FF(2), FF(2)}; - auto s = ss::ShamirRecoverC(bs); + auto s = ss::shamirRecoverC(bs); // errors - REQUIRE(s.err.Evaluate(FF(2)) == FF::Zero()); - REQUIRE(s.err.Evaluate(FF(5)) == FF::Zero()); + REQUIRE(s.err.evaluate(FF(2)) == FF::zero()); + REQUIRE(s.err.evaluate(FF(5)) == FF::zero()); - for (std::size_t i = 0; i < bs.Size(); ++i) { - REQUIRE(s.f.Evaluate(FF(i + 1)) == corrected[i]); + for (std::size_t i = 0; i < bs.size(); ++i) { + REQUIRE(s.f.evaluate(FF(i + 1)) == corrected[i]); } } diff --git a/test/scl/util/test_bitmap.cc b/test/scl/util/test_bitmap.cc new file mode 100644 index 0000000..b09e03f --- /dev/null +++ b/test/scl/util/test_bitmap.cc @@ -0,0 +1,168 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include + +#include "scl/serialization/serializer.h" +#include "scl/util/bitmap.h" + +using namespace scl; + +TEST_CASE("Bitmap construct", "[util]") { + util::Bitmap bm(10); + REQUIRE(bm.numberOfBlocks() == 2); + REQUIRE(bm.count() == 0); + + util::Bitmap bm0; + REQUIRE(bm0.numberOfBlocks() == 1); + REQUIRE(bm0.count() == 0); +} + +TEST_CASE("Bitmap get/set", "[util]") { + util::Bitmap bm(10); + + bm.set(0, true); + bm.set(7, true); + bm.set(8, true); + + REQUIRE(bm.at(0)); + REQUIRE(bm.at(7)); + REQUIRE(bm.at(8)); + REQUIRE_FALSE(bm.at(9)); + + REQUIRE(bm.count() == 3); + + bm.set(7, false); + REQUIRE_FALSE(bm.at(7)); + + REQUIRE(bm.count() == 2); +} + +TEST_CASE("Bitmap XOR", "[util]") { + util::Bitmap bm0(10); + util::Bitmap bm1(10); + + bm0.set(0, true); + bm1.set(0, true); + // result will have 0 at position 0 + + bm0.set(4, true); + bm1.set(4, false); + // result will have 1 at position 4 + + auto bm = bm0 ^ bm1; + + REQUIRE(bm.at(0) == false); + REQUIRE(bm.at(4) == true); + REQUIRE(bm.at(5) == false); +} + +TEST_CASE("Bitmap AND", "[util]") { + util::Bitmap bm0(10); + util::Bitmap bm1(10); + + bm0.set(0, true); + bm1.set(0, true); + // result will have 1 at position 0 + + bm0.set(4, true); + bm1.set(4, false); + // result will have 0 at position 4 + + auto bm = bm0 & bm1; + + REQUIRE(bm.at(0) == true); + REQUIRE(bm.at(4) == false); + REQUIRE(bm.at(5) == false); +} + +TEST_CASE("Bitmap OR", "[util]") { + util::Bitmap bm0(10); + util::Bitmap bm1(10); + + bm0.set(0, true); + bm1.set(0, true); + // result will have 1 at position 0 + + bm0.set(4, true); + bm1.set(4, false); + // result will have 1 at position 4 + + auto bm = bm0 | bm1; + + REQUIRE(bm.at(0) == true); + REQUIRE(bm.at(4) == true); + REQUIRE(bm.at(5) == false); +} + +TEST_CASE("Bitmap NEG", "[util]") { + util::Bitmap bm0(10); + + bm0.set(0, true); + bm0.set(4, true); + + auto bm = ~bm0; + + REQUIRE(bm.at(0) == false); + REQUIRE(bm.at(4) == false); + REQUIRE(bm.at(5) == true); +} + +TEST_CASE("Bitmap equal", "[util]") { + util::Bitmap bm0(10); + util::Bitmap bm1(10); + + REQUIRE(bm0 == bm1); + + bm0.set(3, true); + + REQUIRE(bm0 != bm1); +} + +TEST_CASE("Bitmap print", "[util]") { + util::Bitmap bm(10); + + bm.set(2, true); + bm.set(9, true); + + std::stringstream ss; + ss << bm; + REQUIRE(ss.str() == "0000010000000010"); +} + +TEST_CASE("Bitmap serialization", "[util]") { + util::Bitmap bm(10); + + bm.set(3, true); + bm.set(2, true); + bm.set(5, true); + + REQUIRE(bm.numberOfBlocks() == 2); + + constexpr std::size_t overhead = sizeof(seri::StlVecSizeType); + unsigned char buf[2 + overhead]; + + REQUIRE(seri::Serializer::sizeOf(bm) == 2 + overhead); + + REQUIRE(seri::Serializer::write(bm, buf) == 2 + overhead); + + util::Bitmap b; + REQUIRE(seri::Serializer::read(b, buf) == 2 + overhead); + + REQUIRE(b == bm); +} diff --git a/test/scl/util/test_cmdline.cc b/test/scl/util/test_cmdline.cc index f39dc5e..bee8e41 100644 --- a/test/scl/util/test_cmdline.cc +++ b/test/scl/util/test_cmdline.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,10 @@ * along with this program. If not, see . */ -#include +#include +#include +#include +#include #include #include "scl/util/cmdline.h" @@ -29,45 +32,72 @@ using namespace scl; std::streambuf* scl_cout = std::cout.rdbuf(scl_cout_buf.rdbuf()) #define CAPTURE_END(output_cout, output_cerr) \ - auto(output_cout) = scl_cout_buf.str(); \ - auto(output_cerr) = scl_cerr_buf.str(); \ + auto output_cout = scl_cout_buf.str(); \ + auto output_cerr = scl_cerr_buf.str(); \ std::cout.rdbuf(scl_cout); \ std::cerr.rdbuf(scl_cerr) #define WITH_EXIT_0(expr) \ REQUIRE_THROWS_MATCHES((expr), \ std::runtime_error, \ - Catch::Matchers::Message("good")) + Catch::Matchers::Message("no error")) #define WITH_EXIT_1(expr) \ REQUIRE_THROWS_MATCHES((expr), \ std::runtime_error, \ - Catch::Matchers::Message("bad")) + Catch::Matchers::Message("error")) TEST_CASE("Cmdline print help", "[util]") { const char* argv[] = {"program", "-help"}; auto p = util::ProgramOptions::Parser("Program description.") - .Add(util::ProgramArg::Optional("x", "y", "default")) - .Add(util::ProgramArg::Required("a", "b", "arg description")) - .Add(util::ProgramFlag("w", "flag description")); + .add(util::ProgramArg::optional("x", "y", "default")) + .add(util::ProgramArg::required("a", "b", "arg description")) + .add(util::ProgramFlag("w", "flag description")); CAPTURE_START; - WITH_EXIT_0(p.Parse(2, (char**)argv)); + WITH_EXIT_0(p.parse(2, (char**)argv, false)); CAPTURE_END(outc, oute); REQUIRE(oute.empty()); - - REQUIRE_THAT(outc, Catch::Matchers::StartsWith("Usage: program")); - REQUIRE_THAT(outc, Catch::Matchers::Contains("Program description.")); - REQUIRE_THAT(outc, Catch::Matchers::Contains("-x 'y'")); - REQUIRE_THAT(outc, Catch::Matchers::Contains("-a 'b'")); - REQUIRE_THAT(outc, Catch::Matchers::Contains("-w")); - REQUIRE_THAT(outc, Catch::Matchers::Contains("arg description.")); - REQUIRE_THAT(outc, Catch::Matchers::Contains("flag description.")); - REQUIRE_THAT(outc, Catch::Matchers::Contains("[default=default]")); + std::string line; + auto outcs = std::istringstream(outc); + + // Usage: program -x y -a b [options ...] + // + // Program description. + // + // Required arguments + // -a 'b' arg description. + // + // Optional Arguments + // -x 'y' [default=default] + // + // Flags + // -w flag description. + + std::getline(outcs, line); + REQUIRE(line == "Usage: program -a b [options ...]"); + std::getline(outcs, line); + std::getline(outcs, line); + REQUIRE(line == "Program description."); + std::getline(outcs, line); + std::getline(outcs, line); + REQUIRE(line == "Required arguments"); + std::getline(outcs, line); + REQUIRE(line == " -a 'b' arg description."); + std::getline(outcs, line); + std::getline(outcs, line); + REQUIRE(line == "Optional arguments"); + std::getline(outcs, line); + REQUIRE(line == " -x 'y' [default=default]"); + std::getline(outcs, line); + std::getline(outcs, line); + REQUIRE(line == "Flags"); + std::getline(outcs, line); + REQUIRE(line == " -w flag description."); } TEST_CASE("Cmdline parse with error", "[util]") { @@ -76,7 +106,7 @@ TEST_CASE("Cmdline parse with error", "[util]") { auto p = util::ProgramOptions::Parser{}; CAPTURE_START; - WITH_EXIT_1(p.Parse(2, (char**)argv)); + WITH_EXIT_1(p.parse(2, (char**)argv, false)); CAPTURE_END(outc, oute); REQUIRE_THAT(outc, Catch::Matchers::StartsWith("Usage: program")); @@ -86,10 +116,10 @@ TEST_CASE("Cmdline parse with error", "[util]") { TEST_CASE("Cmdline parse missing required", "[util]") { const char* argv[] = {"program"}; auto p = - util::ProgramOptions::Parser{}.Add(util::ProgramArg::Required("x", "y")); + util::ProgramOptions::Parser{}.add(util::ProgramArg::required("x", "y")); CAPTURE_START; - WITH_EXIT_1(p.Parse(1, (char**)argv)); + WITH_EXIT_1(p.parse(1, (char**)argv, false)); CAPTURE_END(outc, oute); REQUIRE(oute == "ERROR: missing required argument\n"); @@ -98,10 +128,10 @@ TEST_CASE("Cmdline parse missing required", "[util]") { TEST_CASE("Cmdline parse invalid argument", "[util]") { const char* argv[] = {"program", "-x"}; auto p = - util::ProgramOptions::Parser{}.Add(util::ProgramArg::Required("x", "y")); + util::ProgramOptions::Parser{}.add(util::ProgramArg::required("x", "y")); CAPTURE_START; - WITH_EXIT_1(p.Parse(2, (char**)argv)); + WITH_EXIT_1(p.parse(2, (char**)argv, false)); CAPTURE_END(outc, oute); REQUIRE(oute == "ERROR: invalid argument\n"); @@ -110,78 +140,83 @@ TEST_CASE("Cmdline parse invalid argument", "[util]") { TEST_CASE("Cmdline parse invalid argument name", "[util]") { const char* argv[] = {"program", "x"}; auto p = - util::ProgramOptions::Parser{}.Add(util::ProgramArg::Required("x", "y")); + util::ProgramOptions::Parser{}.add(util::ProgramArg::required("x", "y")); CAPTURE_START; - WITH_EXIT_1(p.Parse(2, (char**)argv)); + WITH_EXIT_1(p.parse(2, (char**)argv, false)); CAPTURE_END(outc, oute); REQUIRE(oute == "ERROR: argument must begin with '-'\n"); } TEST_CASE("Cmdline duplicate arg definition", "[util]") { - auto p = util::ProgramOptions::Parser{}.Add( - util::ProgramArg::Required("x", "int")); + auto p = util::ProgramOptions::Parser{} + .add(util::ProgramArg::required("x", "int")) + .add(util::ProgramArg::required("x", "int")); + const char* argv[] = {"program", "-x", "1 "}; CAPTURE_START; - WITH_EXIT_1(p.Add(util::ProgramArg::Required("x", "int"))); + WITH_EXIT_1(p.parse(3, (char**)argv, false)); CAPTURE_END(outc, oute); REQUIRE(oute == "ERROR: duplicate argument definition\n"); } TEST_CASE("Cmdline duplicate flag definition", "[util]") { - auto p = util::ProgramOptions::Parser{}.Add(util::ProgramFlag("x")); + auto p = util::ProgramOptions::Parser{} + .add(util::ProgramFlag("x")) + .add(util::ProgramFlag("x")); + const char* argv[] = {"program", "-x"}; CAPTURE_START; - WITH_EXIT_1(p.Add(util::ProgramFlag("x"))); + WITH_EXIT_1(p.parse(2, (char**)argv, false)); CAPTURE_END(outc, oute); - REQUIRE(oute == "ERROR: duplicate argument definition\n"); + REQUIRE(oute == "ERROR: duplicate flag definition\n"); } TEST_CASE("Cmdline parse duplicate arg", "[misc]") { const char* argv[] = {"program", "-x", "1", "-x", "2"}; auto p = util::ProgramOptions::Parser{} - .Add(util::ProgramArg::Required("x", "int")) - .Parse(5, (char**)argv); - REQUIRE(p.Get("x") == "2"); + .add(util::ProgramArg::required("x", "int")) + .parse(5, (char**)argv, false); + REQUIRE(p.get("x") == "2"); } TEST_CASE("Cmdline arg", "[util]") { const char* argv[] = {"program", "-x", "100", "-w", "600", "-b", "true"}; auto p = util::ProgramOptions::Parser{} - .Add(util::ProgramArg::Required("x", "int")) - .Add(util::ProgramArg::Required("w", "ulong")) - .Add(util::ProgramArg::Required("b", "bool")) - .Add(util::ProgramArg::Optional("y", "long", "100")) - .Parse(7, (char**)argv); - - REQUIRE(p.Has("x")); - auto v = p.Get("x"); + .add(util::ProgramArg::required("x", "int")) + .add(util::ProgramArg::required("w", "ulong")) + .add(util::ProgramArg::required("b", "bool")) + .add(util::ProgramArg::optional("y", "long", "100")) + .parse(7, (char**)argv, false); + + REQUIRE(p.has("x")); + auto v = p.get("x"); REQUIRE(v == "100"); - auto w = p.Get("x"); + auto w = p.get("x"); REQUIRE(w == 100); - REQUIRE(p.Has("w")); - auto ww = p.Get("w"); + REQUIRE(p.has("w")); + auto ww = p.get("w"); REQUIRE(ww == 600); - REQUIRE(p.Has("b")); - REQUIRE(p.Get("b")); + REQUIRE(p.has("b")); + REQUIRE(p.get("b")); - REQUIRE(p.Has("y")); - REQUIRE(p.Get("y") == 100); + REQUIRE(p.has("y")); + REQUIRE(p.get("y") == 100); } TEST_CASE("Cmdline flag", "[util]") { const char* argv[] = {"program", "-f"}; auto p = util::ProgramOptions::Parser{} - .Add(util::ProgramFlag("f")) - .Add(util::ProgramFlag("h")) - .Parse(2, (char**)argv); + .add(util::ProgramFlag("f")) + .add(util::ProgramFlag("h")) + .parse(2, (char**)argv, false); - REQUIRE(p.FlagSet("f")); - REQUIRE_FALSE(p.FlagSet("h")); - REQUIRE_FALSE(p.FlagSet("g")); + REQUIRE(p.flagSet("f")); + REQUIRE_FALSE(p.flagSet("h")); + REQUIRE_FALSE(p.flagSet("g")); } diff --git a/test/scl/util/test_ecdsa.cc b/test/scl/util/test_ecdsa.cc index 5b5769c..62f9f55 100644 --- a/test/scl/util/test_ecdsa.cc +++ b/test/scl/util/test_ecdsa.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -16,7 +16,7 @@ */ #include -#include +#include #include "scl/math/curves/secp256k1.h" #include "scl/util/hash.h" @@ -25,24 +25,24 @@ using namespace scl; TEST_CASE("ECDSA derive", "[util]") { - auto prg = util::PRG::Create("ecdsa derive"); - const auto sk = util::ECDSA::SecretKey::Random(prg); - const auto pk = util::ECDSA::Derive(sk); - REQUIRE(pk == sk * math::EC::Generator()); + auto prg = util::PRG::create("ecdsa derive"); + const auto sk = util::ECDSA::SecretKey::random(prg); + const auto pk = util::ECDSA::derive(sk); + REQUIRE(pk == sk * math::EC::generator()); } TEST_CASE("ECDSA sign", "[util]") { - auto prg = util::PRG::Create("ecdsa sign"); - const auto m = util::Hash<256>{}.Update("message").Finalize(); - const auto sk = util::ECDSA::SecretKey::Random(prg); + auto prg = util::PRG::create("ecdsa sign"); + const auto m = util::Hash<256>{}.update("message").finalize(); + const auto sk = util::ECDSA::SecretKey::random(prg); const auto sig = util::ECDSA::Sign(sk, m, prg); - const auto pk = util::ECDSA::Derive(sk); - REQUIRE(util::ECDSA::Verify(pk, sig, m)); + const auto pk = util::ECDSA::derive(sk); + REQUIRE(util::ECDSA::verify(pk, sig, m)); const std::array m_small = {1, 2, 3}; const auto sig_small = util::ECDSA::Sign(sk, m_small, prg); - REQUIRE(util::ECDSA::Verify(pk, sig_small, m_small)); + REQUIRE(util::ECDSA::verify(pk, sig_small, m_small)); - REQUIRE_FALSE(util::ECDSA::Verify(pk, sig_small, m)); + REQUIRE_FALSE(util::ECDSA::verify(pk, sig_small, m)); } diff --git a/test/scl/util/test_measurement.cc b/test/scl/util/test_measurement.cc new file mode 100644 index 0000000..d31313a --- /dev/null +++ b/test/scl/util/test_measurement.cc @@ -0,0 +1,120 @@ +/* SCL --- Secure Computation Library + * Copyright (C) 2024 Anders Dalskov + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +#include +#include +#include +#include + +#include "scl/util/measurement.h" + +using namespace scl; +using namespace std::chrono_literals; + +TEST_CASE("Measurement to string", "[util]") { + util::DataMeasurement dm; + dm.addSample(123.45); + + std::stringstream ss; + ss << dm; + REQUIRE(ss.str() == "{\"mean\": 123.45, \"unit\": \"B\", \"std_dev\": 0}"); + + util::TimeMeasurement tm; + tm.addSample(util::Time::Duration::zero()); + + ss.str(""); + ss << tm; + REQUIRE(ss.str() == "{\"mean\": 0, \"unit\": \"ms\", \"std_dev\": 0}"); +} + +TEST_CASE("Measurement mean and stddev", "[util]") { + util::DataMeasurement dm; + dm.addSample(123.42); + dm.addSample(555.21); + REQUIRE_THAT(dm.mean(), Catch::Matchers::WithinRel(339.315, 0.001)); + REQUIRE_THAT(dm.stddev(), Catch::Matchers::WithinRel(305.322, 0.001)); + + util::TimeMeasurement tm; + tm.addSample(123ms); + tm.addSample(444ms); + REQUIRE(tm.mean() == 283.5ms); + REQUIRE(tm.stddev() == 226.981276ms); +} + +TEST_CASE("Measurement data", "[util]") { + util::DataMeasurement dm; + dm.addSample(2); + dm.addSample(4); + dm.addSample(4); + dm.addSample(4); + dm.addSample(5); + dm.addSample(5); + dm.addSample(7); + dm.addSample(9); + + REQUIRE(dm.size() == 8); + REQUIRE(dm.samples() == std::vector({2, 4, 4, 4, 5, 5, 7, 9})); +} + +TEST_CASE("Measurement time", "[util]") { + util::TimeMeasurement tm; + tm.addSample(2ms); + tm.addSample(4ms); + tm.addSample(4ms); + tm.addSample(4ms); + tm.addSample(5ms); + tm.addSample(5ms); + tm.addSample(7ms); + tm.addSample(9ms); + + REQUIRE(tm.size() == 8); + REQUIRE(tm.samples() == std::vector( + {2ms, 4ms, 4ms, 4ms, 5ms, 5ms, 7ms, 9ms})); +} + +TEST_CASE("Measurement samples", "[util]") { + util::DataMeasurement dm; + REQUIRE(dm.samples().empty()); + + dm.addSample(42); + REQUIRE(dm.size() == 1); + REQUIRE(dm.samples() == std::vector{42}); + + dm.addSample(22); + REQUIRE(dm.size() == 2); + REQUIRE(dm.samples() == std::vector{42, 22}); +} + +TEST_CASE("Measurement median", "[util]") { + util::DataMeasurement dm; + util::TimeMeasurement tm; + + REQUIRE(dm.median() == 0); + REQUIRE(tm.median() == 0s); + + dm.addSample(123); + tm.addSample(123s); + + REQUIRE(dm.median() == 123); + REQUIRE(tm.median() == 123s); + + dm.addSample(442); + tm.addSample(442s); + + REQUIRE(dm.median() == 282.5); + REQUIRE(tm.median() == 282.5s); +} diff --git a/test/scl/util/test_merkle.cc b/test/scl/util/test_merkle.cc index 9fe745a..190e1b4 100644 --- a/test/scl/util/test_merkle.cc +++ b/test/scl/util/test_merkle.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,8 +15,10 @@ * along with this program. If not, see . */ -#include +#include +#include "scl/serialization/serializer.h" +#include "scl/util/bitmap.h" #include "scl/util/hash.h" #include "scl/util/merkle.h" @@ -26,63 +28,96 @@ using Mrkl = util::MerkleTree, std::string_view>; namespace { -util::Hash<256>::DigestType Hash(std::string_view thing) { - return util::Hash<256>{}.Update(thing).Finalize(); +util::Hash<256>::DigestType hash(std::string_view data) { + return util::Hash<256>{}.update(data).finalize(); } -util::Hash<256>::DigestType Hash(util::Hash<256>::DigestType a, - util::Hash<256>::DigestType b) { - return util::Hash<256>{}.Update(a).Update(b).Finalize(); +util::Hash<256>::DigestType hash(util::Hash<256>::DigestType left, + util::Hash<256>::DigestType right) { + return util::Hash<256>{}.update(left).update(right).finalize(); } } // namespace TEST_CASE("Merkle hash", "[misc]") { - auto h_abcd = Hash(Hash(Hash("a"), Hash("b")), Hash(Hash("c"), Hash("d"))); - auto m_abcd = Mrkl::Hash({"a", "b", "c", "d"}); + // clang-format off + auto h_abcd = hash( + hash(hash("a"), hash("b")), + hash(hash("c"), hash("d"))); + // clang-format on + + auto m_abcd = Mrkl::hash({"a", "b", "c", "d"}); REQUIRE(h_abcd == m_abcd); - auto h_xyvu = Hash(Hash(Hash("x"), Hash("y")), Hash(Hash("v"), Hash("u"))); - auto h_abcdxyvu = Hash(h_abcd, h_xyvu); + // clang-format off + auto h_xyvu = hash( + hash(hash("x"), hash("y")), + hash(hash("v"), hash("u"))); + // clang-format on + + auto h_abcdxyvu = hash(h_abcd, h_xyvu); - auto m_abcdxyvu = Mrkl::Hash({"a", "b", "c", "d", "x", "y", "v", "u"}); + auto m_abcdxyvu = Mrkl::hash({"a", "b", "c", "d", "x", "y", "v", "u"}); REQUIRE(h_abcdxyvu == m_abcdxyvu); } TEST_CASE("Merkle hash odd size input", "[misc]") { util::Hash<256>::DigestType z_digest; z_digest.fill(0); - auto h_abc = Hash(Hash(Hash("a"), Hash("b")), Hash(Hash("c"), Hash("c"))); - auto m_abc = Mrkl::Hash({"a", "b", "c"}); + // clang-format off + auto h_abc = hash( + hash(hash("a"), hash("b")), + hash(hash("c"), hash("c"))); + // clang-format on + + auto m_abc = Mrkl::hash({"a", "b", "c"}); REQUIRE(h_abc == m_abc); } TEST_CASE("Merkle prove", "[misc]") { std::vector data = {"a", "b", "c", "d", "e"}; - auto root = Mrkl::Hash(data); + auto root = Mrkl::hash(data); - auto h_ab = Hash(Hash("a"), Hash("b")); - auto h_cd = Hash(Hash("c"), Hash("d")); - auto h_ee = Hash(Hash("e"), Hash("e")); - auto h_abcd = Hash(h_ab, h_cd); - auto h_eeee = Hash(h_ee, h_ee); + auto h_ab = hash(hash("a"), hash("b")); + auto h_cd = hash(hash("c"), hash("d")); + auto h_ee = hash(hash("e"), hash("e")); + auto h_abcd = hash(h_ab, h_cd); + auto h_eeee = hash(h_ee, h_ee); - REQUIRE(root == Hash(h_abcd, h_eeee)); + REQUIRE(root == hash(h_abcd, h_eeee)); - auto proof = Mrkl::Prove(data, 3); + auto proof = Mrkl::prove(data, 3); // path = [H_c, H_ab, H_eeee] // direction = [left, left, right] (true, true, false) - REQUIRE(proof.direction.size() == 3); REQUIRE(proof.path.size() == 3); - REQUIRE(proof.direction == std::vector{true, true, false}); + REQUIRE(proof.direction == + util::Bitmap::fromStdVecBool(std::vector{true, true, false})); - REQUIRE(proof.path[0] == Hash("c")); + REQUIRE(proof.path[0] == hash("c")); REQUIRE(proof.path[1] == h_ab); REQUIRE(proof.path[2] == h_eeee); - REQUIRE(Mrkl::Verify("d", root, proof)); + REQUIRE(Mrkl::verify("d", root, proof)); + + using Sr = seri::Serializer; + + // two vectors. One with three digests, and one with 3 bits that fit into one + // byte. + REQUIRE(Sr::sizeOf(proof) == 2 * sizeof(seri::StlVecSizeType) + 3L * 32 + 1); + + unsigned char buf[2 * sizeof(seri::StlVecSizeType) + 3L * 32 + 1]; + Sr::write(proof, buf); + + Mrkl::Proof p; + Sr::read(p, buf); + + REQUIRE(p.direction == proof.direction); + REQUIRE(p.path.size() == proof.path.size()); + for (std::size_t i = 0; i < proof.path.size(); ++i) { + REQUIRE(p.path[i] == proof.path[i]); + } } diff --git a/test/scl/util/test_prg.cc b/test/scl/util/test_prg.cc index 233dd4d..9d0b923 100644 --- a/test/scl/util/test_prg.cc +++ b/test/scl/util/test_prg.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -16,7 +16,9 @@ */ #include -#include +#include +#include +#include #include #include #include @@ -27,7 +29,7 @@ using namespace scl; namespace { -bool LooksUniform(const unsigned char* p, std::size_t len) { +bool looksUniform(const unsigned char* p, std::size_t len) { std::vector buckets(256, 0); for (std::size_t i = 0; i < len; ++i) { buckets[p[i]]++; @@ -45,42 +47,42 @@ bool LooksUniform(const unsigned char* p, std::size_t len) { } // namespace TEST_CASE("PRG construction", "[misc]") { - REQUIRE(util::PRG::SeedSize() == 16); + REQUIRE(util::PRG::seedSize() == 16); - auto zprg = util::PRG::Create(); + auto zprg = util::PRG::create(); // Number arrived at somewhat by trial-and-error const auto n = 300000; unsigned char buffer[n] = {0}; - REQUIRE_FALSE(LooksUniform(buffer, n)); + REQUIRE_FALSE(looksUniform(buffer, n)); - zprg.Next(buffer, n); + zprg.next(buffer, n); - REQUIRE(LooksUniform(buffer, n)); + REQUIRE(looksUniform(buffer, n)); } TEST_CASE("PRG predictable", "[misc]") { unsigned char seed[] = "1234567890abcde"; - auto prg0 = util::PRG::Create(seed, 15); - auto prg1 = util::PRG::Create(seed, 15); + auto prg0 = util::PRG::create(seed, 15); + auto prg1 = util::PRG::create(seed, 15); REQUIRE(prg0.Seed() == prg1.Seed()); - auto bytes0 = prg0.Next(100); - auto bytes1 = prg1.Next(100); + auto bytes0 = prg0.next(100); + auto bytes1 = prg1.next(100); REQUIRE(bytes0 == bytes1); - prg0.Reset(); - auto bytes00 = prg0.Next(100); + prg0.reset(); + auto bytes00 = prg0.next(100); REQUIRE(bytes00 == bytes0); } TEST_CASE("PRG generate random bytes", "[misc]") { - auto prg = util::PRG::Create(); + auto prg = util::PRG::create(); std::vector buffer(100, 0); - prg.Next(buffer, 50); + prg.next(buffer, 50); bool zero = true; for (std::size_t i = 50; i < 100; ++i) { @@ -97,15 +99,15 @@ TEST_CASE("PRG generate random bytes", "[misc]") { std::vector buf = {'c', 'a', 't'}; - prg.Next(buf, 0); + prg.next(buf, 0); REQUIRE(buf == std::vector{'c', 'a', 't'}); } TEST_CASE("PRG invalid calls", "[misc]") { - auto prg = util::PRG::Create(); + auto prg = util::PRG::create(); std::vector buf(10); - REQUIRE_THROWS_MATCHES(prg.Next(buf, 11), + REQUIRE_THROWS_MATCHES(prg.next(buf, 11), std::invalid_argument, Catch::Matchers::Message("n exceeds buffer.size()")); } @@ -113,11 +115,11 @@ TEST_CASE("PRG invalid calls", "[misc]") { TEST_CASE("PRG truncate seed on create", "[misc]") { // Seeds are truncated if they exceed PRG::SeedSize() length. - auto prg0 = util::PRG::Create("0123456789abcdef_bar"); - auto prg1 = util::PRG::Create("0123456789abcdef_foo"); + auto prg0 = util::PRG::create("0123456789abcdef_bar"); + auto prg1 = util::PRG::create("0123456789abcdef_foo"); - auto bytes0 = prg0.Next(100); - auto bytes1 = prg1.Next(100); + auto bytes0 = prg0.next(100); + auto bytes1 = prg1.next(100); REQUIRE(bytes0 == bytes1); } diff --git a/test/scl/util/test_sha256.cc b/test/scl/util/test_sha256.cc index 63d2081..9991e25 100644 --- a/test/scl/util/test_sha256.cc +++ b/test/scl/util/test_sha256.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,7 @@ * along with this program. If not, see . */ -#include +#include #include #include "scl/math/curves/secp256k1.h" @@ -32,7 +32,7 @@ TEST_CASE("Sha256 empty hash", "[misc]") { 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55}; util::Sha256 hash; - auto digest = hash.Finalize(); + auto digest = hash.finalize(); REQUIRE(digest.size() == 32); REQUIRE(digest == SHA256_empty); } @@ -44,15 +44,15 @@ TEST_CASE("Sha256 abc hash", "[misc]") { 0x7a, 0x9c, 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad}; util::Sha256 hash; - hash.Update({'a', 'b', 'c'}); - auto digest = hash.Finalize(); + hash.update({'a', 'b', 'c'}); + auto digest = hash.finalize(); REQUIRE(digest.size() == 32); REQUIRE(digest == SHA256_abc); util::Sha256 hash_; - hash_.Update({'a', 'b'}); - hash_.Update({'c'}); - auto digest_ = hash_.Finalize(); + hash_.update({'a', 'b'}); + hash_.update({'c'}); + auto digest_ = hash_.finalize(); REQUIRE(digest_.size() == 32); REQUIRE(digest_ == SHA256_abc); } @@ -65,27 +65,25 @@ TEST_CASE("Sha256 hash almost complete chunk", "[misc]") { const unsigned char data[57] = {0}; util::Sha256 hash; - hash.Update(data, 57); - REQUIRE(hash.Finalize() == digest); + hash.update(data, 57); + REQUIRE(hash.finalize() == digest); } -#ifdef SCL_ENABLE_EC_TESTS - TEST_CASE("Sha256 bouncycastle reference", "[misc]") { // Reference test showing that serialization + hashing is the same as // bouncycastle in Java. - using Curve = math::EC; - auto pk = Curve::Generator() * math::Number::FromString("a"); + using Curve = math::EC; + auto pk = Curve::generator() * math::Number::fromString("a"); - const auto n = Curve::ByteSize(false); + const auto n = Curve::byteSize(false); unsigned char buf[n] = {0}; - pk.Write(buf, false); + pk.write(buf, false); util::Sha256 hash; - hash.Update(buf, n); + hash.update(buf, n); - auto d = hash.Finalize(); + auto d = hash.finalize(); std::array target = { 0xde, 0xc1, 0x6a, 0xc2, 0x78, 0x99, 0xeb, 0xdf, 0x76, 0x0e, 0xaf, @@ -94,5 +92,3 @@ TEST_CASE("Sha256 bouncycastle reference", "[misc]") { REQUIRE(d == target); } - -#endif diff --git a/test/scl/util/test_sha3.cc b/test/scl/util/test_sha3.cc index 81858ca..92262b2 100644 --- a/test/scl/util/test_sha3.cc +++ b/test/scl/util/test_sha3.cc @@ -1,5 +1,5 @@ /* SCL --- Secure Computation Library - * Copyright (C) 2023 Anders Dalskov + * Copyright (C) 2024 Anders Dalskov * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -15,7 +15,7 @@ * along with this program. If not, see . */ -#include +#include #include "scl/math/fp.h" #include "scl/util/digest.h" @@ -30,7 +30,7 @@ TEST_CASE("Sha3 empty hash", "[misc]") { 0x49, 0xfa, 0x82, 0xd8, 0x0a, 0x4b, 0x80, 0xf8, 0x43, 0x4a}; util::Hash<256> hash; - auto digest = hash.Finalize(); + auto digest = hash.finalize(); REQUIRE(digest == SHA3_256_empty); } @@ -42,7 +42,7 @@ TEST_CASE("Sha3 abc hash", "[misc]") { util::Hash<256> hash; unsigned char abc[] = "abc"; - auto digest = hash.Update(abc, 3).Finalize(); + auto digest = hash.update(abc, 3).finalize(); REQUIRE(digest == SHA3_256_abc); } @@ -59,14 +59,14 @@ TEST_CASE("Sha3-256 reference", "[misc]") { } util::Hash<256> hash0; - auto digest = hash0.Update(buf, 200).Finalize(); + auto digest = hash0.update(buf, 200).finalize(); REQUIRE(digest == SHA3_256_0xa3_200_times); util::Hash<256> hash1; for (std::size_t i = 0; i < 200; ++i) { - hash1.Update(&byte, 1); + hash1.update(&byte, 1); } - REQUIRE(hash1.Finalize() == SHA3_256_0xa3_200_times); + REQUIRE(hash1.finalize() == SHA3_256_0xa3_200_times); } TEST_CASE("Sha3-384 reference", "[misc]") { @@ -83,15 +83,15 @@ TEST_CASE("Sha3-384 reference", "[misc]") { } util::Hash<384> hash0; - auto digest = hash0.Update(buf, 200).Finalize(); + auto digest = hash0.update(buf, 200).finalize(); REQUIRE(digest.size() == 48); REQUIRE(digest == SHA3_384_0xa3_200_times); util::Hash<384> hash1; for (std::size_t i = 0; i < 200; ++i) { - hash1.Update(&byte, 1); + hash1.update(&byte, 1); } - REQUIRE(hash1.Finalize() == SHA3_384_0xa3_200_times); + REQUIRE(hash1.finalize() == SHA3_384_0xa3_200_times); } TEST_CASE("Sha3-512 reference", "[misc]") { @@ -110,34 +110,34 @@ TEST_CASE("Sha3-512 reference", "[misc]") { } util::Hash<512> hash0; - auto digest = hash0.Update(buf, 200).Finalize(); + auto digest = hash0.update(buf, 200).finalize(); REQUIRE(digest.size() == 64); REQUIRE(digest == SHA3_512_0xa3_200_times); util::Hash<512> hash1; for (std::size_t i = 0; i < 200; ++i) { - hash1.Update(&byte, 1); + hash1.update(&byte, 1); } - REQUIRE(hash1.Finalize() == SHA3_512_0xa3_200_times); + REQUIRE(hash1.finalize() == SHA3_512_0xa3_200_times); } TEST_CASE("Sha3 hash vector", "[misc]") { unsigned char ref_buf[] = "hello, world"; util::Hash<256> hash_ref; - auto ref = hash_ref.Update(ref_buf, 12).Finalize(); + auto ref = hash_ref.update(ref_buf, 12).finalize(); util::Hash<256> hash1; std::vector v = {'h', 'e', 'l', 'l', 'o', ',', ' ', 'w', 'o', 'r', 'l', 'd'}; - auto from_vec = hash1.Update(v).Finalize(); + auto from_vec = hash1.update(v).finalize(); REQUIRE(ref == from_vec); } TEST_CASE("Sha3 hash array", "[misc]") { unsigned char abc[] = "abc"; std::array abc_arr = {'a', 'b', 'c'}; - auto ref = util::Hash<256>{}.Update(abc, 3).Finalize(); - auto act = util::Hash<256>{}.Update(abc_arr).Finalize(); + auto ref = util::Hash<256>{}.update(abc, 3).finalize(); + auto act = util::Hash<256>{}.update(abc_arr).finalize(); REQUIRE(ref == act); } @@ -145,7 +145,7 @@ TEST_CASE("Sha3 field elements", "[misc]") { math::Fp<61> x(123); math::Fp<61> y(555); - auto hx = util::Hash<256>{}.Update(x).Finalize(); - auto hy = util::Hash<256>{}.Update(y).Finalize(); + auto hx = util::Hash<256>{}.update(x).finalize(); + auto hy = util::Hash<256>{}.update(y).finalize(); REQUIRE(hx != hy); }