|
| 1 | +// SPDX-License-Identifier: MIT |
| 2 | +pragma solidity ^0.8.20; |
| 3 | + |
| 4 | +import {Math} from "../math/Math.sol"; |
| 5 | +import {Errors} from "../Errors.sol"; |
| 6 | + |
| 7 | +/** |
| 8 | + * @dev Implementation of secp256r1 verification and recovery functions. |
| 9 | + * |
| 10 | + * The secp256r1 curve (also known as P256) is a NIST standard curve with wide support in modern devices |
| 11 | + * and cryptographic standards. Some notable examples include Apple's Secure Enclave and Android's Keystore |
| 12 | + * as well as authentication protocols like FIDO2. |
| 13 | + * |
| 14 | + * Based on the original https://github.com/itsobvioustech/aa-passkeys-wallet/blob/main/src/Secp256r1.sol[implementation of itsobvioustech]. |
| 15 | + * Heavily inspired in https://github.com/maxrobot/elliptic-solidity/blob/master/contracts/Secp256r1.sol[maxrobot] and |
| 16 | + * https://github.com/tdrerup/elliptic-curve-solidity/blob/master/contracts/curves/EllipticCurve.sol[tdrerup] implementations. |
| 17 | + */ |
| 18 | +library P256 { |
| 19 | + struct JPoint { |
| 20 | + uint256 x; |
| 21 | + uint256 y; |
| 22 | + uint256 z; |
| 23 | + } |
| 24 | + |
| 25 | + /// @dev Generator (x component) |
| 26 | + uint256 internal constant GX = 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296; |
| 27 | + /// @dev Generator (y component) |
| 28 | + uint256 internal constant GY = 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5; |
| 29 | + /// @dev P (size of the field) |
| 30 | + uint256 internal constant P = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF; |
| 31 | + /// @dev N (order of G) |
| 32 | + uint256 internal constant N = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551; |
| 33 | + /// @dev A parameter of the weierstrass equation |
| 34 | + uint256 internal constant A = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC; |
| 35 | + /// @dev B parameter of the weierstrass equation |
| 36 | + uint256 internal constant B = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B; |
| 37 | + |
| 38 | + /// @dev (P + 1) / 4. Useful to compute sqrt |
| 39 | + uint256 private constant P1DIV4 = 0x3fffffffc0000000400000000000000000000000400000000000000000000000; |
| 40 | + |
| 41 | + /// @dev N/2 for excluding higher order `s` values |
| 42 | + uint256 private constant HALF_N = 0x7fffffff800000007fffffffffffffffde737d56d38bcf4279dce5617e3192a8; |
| 43 | + |
| 44 | + /** |
| 45 | + * @dev Verifies a secp256r1 signature using the RIP-7212 precompile and falls back to the Solidity implementation |
| 46 | + * if the precompile is not available. This version should work on all chains, but requires the deployment of more |
| 47 | + * bytecode. |
| 48 | + * |
| 49 | + * @param h - hashed message |
| 50 | + * @param r - signature half R |
| 51 | + * @param s - signature half S |
| 52 | + * @param qx - public key coordinate X |
| 53 | + * @param qy - public key coordinate Y |
| 54 | + * |
| 55 | + * IMPORTANT: This function disallows signatures where the `s` value is above `N/2` to prevent malleability. |
| 56 | + * To flip the `s` value, compute `s = N - s`. |
| 57 | + */ |
| 58 | + function verify(bytes32 h, bytes32 r, bytes32 s, bytes32 qx, bytes32 qy) internal view returns (bool) { |
| 59 | + (bool valid, bool supported) = _tryVerifyNative(h, r, s, qx, qy); |
| 60 | + return supported ? valid : verifySolidity(h, r, s, qx, qy); |
| 61 | + } |
| 62 | + |
| 63 | + /** |
| 64 | + * @dev Same as {verify}, but it will revert if the required precompile is not available. |
| 65 | + * |
| 66 | + * Make sure any logic (code or precompile) deployed at that address is the expected one, |
| 67 | + * otherwise the returned value may be misinterpreted as a positive boolean. |
| 68 | + */ |
| 69 | + function verifyNative(bytes32 h, bytes32 r, bytes32 s, bytes32 qx, bytes32 qy) internal view returns (bool) { |
| 70 | + (bool valid, bool supported) = _tryVerifyNative(h, r, s, qx, qy); |
| 71 | + if (supported) { |
| 72 | + return valid; |
| 73 | + } else { |
| 74 | + revert Errors.MissingPrecompile(address(0x100)); |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + /** |
| 79 | + * @dev Same as {verify}, but it will return false if the required precompile is not available. |
| 80 | + */ |
| 81 | + function _tryVerifyNative( |
| 82 | + bytes32 h, |
| 83 | + bytes32 r, |
| 84 | + bytes32 s, |
| 85 | + bytes32 qx, |
| 86 | + bytes32 qy |
| 87 | + ) private view returns (bool valid, bool supported) { |
| 88 | + if (!_isProperSignature(r, s) || !isValidPublicKey(qx, qy)) { |
| 89 | + return (false, true); // signature is invalid, and its not because the precompile is missing |
| 90 | + } |
| 91 | + |
| 92 | + (bool success, bytes memory returndata) = address(0x100).staticcall(abi.encode(h, r, s, qx, qy)); |
| 93 | + return (success && returndata.length == 0x20) ? (abi.decode(returndata, (bool)), true) : (false, false); |
| 94 | + } |
| 95 | + |
| 96 | + /** |
| 97 | + * @dev Same as {verify}, but only the Solidity implementation is used. |
| 98 | + */ |
| 99 | + function verifySolidity(bytes32 h, bytes32 r, bytes32 s, bytes32 qx, bytes32 qy) internal view returns (bool) { |
| 100 | + if (!_isProperSignature(r, s) || !isValidPublicKey(qx, qy)) { |
| 101 | + return false; |
| 102 | + } |
| 103 | + |
| 104 | + JPoint[16] memory points = _preComputeJacobianPoints(uint256(qx), uint256(qy)); |
| 105 | + uint256 w = Math.invModPrime(uint256(s), N); |
| 106 | + uint256 u1 = mulmod(uint256(h), w, N); |
| 107 | + uint256 u2 = mulmod(uint256(r), w, N); |
| 108 | + (uint256 x, ) = _jMultShamir(points, u1, u2); |
| 109 | + return ((x % N) == uint256(r)); |
| 110 | + } |
| 111 | + |
| 112 | + /** |
| 113 | + * @dev Public key recovery |
| 114 | + * |
| 115 | + * @param h - hashed message |
| 116 | + * @param v - signature recovery param |
| 117 | + * @param r - signature half R |
| 118 | + * @param s - signature half S |
| 119 | + * |
| 120 | + * IMPORTANT: This function disallows signatures where the `s` value is above `N/2` to prevent malleability. |
| 121 | + * To flip the `s` value, compute `s = N - s` and `v = 1 - v` if (`v = 0 | 1`). |
| 122 | + */ |
| 123 | + function recovery(bytes32 h, uint8 v, bytes32 r, bytes32 s) internal view returns (bytes32, bytes32) { |
| 124 | + if (!_isProperSignature(r, s) || v > 1) { |
| 125 | + return (0, 0); |
| 126 | + } |
| 127 | + |
| 128 | + uint256 rx = uint256(r); |
| 129 | + uint256 ry2 = addmod(mulmod(addmod(mulmod(rx, rx, P), A, P), rx, P), B, P); // weierstrass equation y² = x³ + a.x + b |
| 130 | + uint256 ry = Math.modExp(ry2, P1DIV4, P); // This formula for sqrt work because P ≡ 3 (mod 4) |
| 131 | + if (mulmod(ry, ry, P) != ry2) return (0, 0); // Sanity check |
| 132 | + if (ry % 2 != v % 2) ry = P - ry; |
| 133 | + |
| 134 | + JPoint[16] memory points = _preComputeJacobianPoints(rx, ry); |
| 135 | + uint256 w = Math.invModPrime(uint256(r), N); |
| 136 | + uint256 u1 = mulmod(N - (uint256(h) % N), w, N); |
| 137 | + uint256 u2 = mulmod(uint256(s), w, N); |
| 138 | + (uint256 x, uint256 y) = _jMultShamir(points, u1, u2); |
| 139 | + return (bytes32(x), bytes32(y)); |
| 140 | + } |
| 141 | + |
| 142 | + /** |
| 143 | + * @dev Checks if (x, y) are valid coordinates of a point on the curve. |
| 144 | + * In particular this function checks that x <= P and y <= P. |
| 145 | + */ |
| 146 | + function isValidPublicKey(bytes32 x, bytes32 y) internal pure returns (bool result) { |
| 147 | + assembly ("memory-safe") { |
| 148 | + let p := P |
| 149 | + let lhs := mulmod(y, y, p) // y^2 |
| 150 | + let rhs := addmod(mulmod(addmod(mulmod(x, x, p), A, p), x, p), B, p) // ((x^2 + a) * x) + b = x^3 + ax + b |
| 151 | + result := and(and(lt(x, p), lt(y, p)), eq(lhs, rhs)) // Should conform with the Weierstrass equation |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + /** |
| 156 | + * @dev Checks if (r, s) is a proper signature. |
| 157 | + * In particular, this checks that `s` is in the "lower-range", making the signature non-malleable. |
| 158 | + */ |
| 159 | + function _isProperSignature(bytes32 r, bytes32 s) private pure returns (bool) { |
| 160 | + return uint256(r) > 0 && uint256(r) < N && uint256(s) > 0 && uint256(s) <= HALF_N; |
| 161 | + } |
| 162 | + |
| 163 | + /** |
| 164 | + * @dev Reduce from jacobian to affine coordinates |
| 165 | + * @param jx - jacobian coordinate x |
| 166 | + * @param jy - jacobian coordinate y |
| 167 | + * @param jz - jacobian coordinate z |
| 168 | + * @return ax - affine coordinate x |
| 169 | + * @return ay - affine coordinate y |
| 170 | + */ |
| 171 | + function _affineFromJacobian(uint256 jx, uint256 jy, uint256 jz) private view returns (uint256 ax, uint256 ay) { |
| 172 | + if (jz == 0) return (0, 0); |
| 173 | + uint256 zinv = Math.invModPrime(jz, P); |
| 174 | + uint256 zzinv = mulmod(zinv, zinv, P); |
| 175 | + uint256 zzzinv = mulmod(zzinv, zinv, P); |
| 176 | + ax = mulmod(jx, zzinv, P); |
| 177 | + ay = mulmod(jy, zzzinv, P); |
| 178 | + } |
| 179 | + |
| 180 | + /** |
| 181 | + * @dev Point addition on the jacobian coordinates |
| 182 | + * Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#addition-add-1998-cmo-2 |
| 183 | + */ |
| 184 | + function _jAdd( |
| 185 | + JPoint memory p1, |
| 186 | + uint256 x2, |
| 187 | + uint256 y2, |
| 188 | + uint256 z2 |
| 189 | + ) private pure returns (uint256 rx, uint256 ry, uint256 rz) { |
| 190 | + assembly ("memory-safe") { |
| 191 | + let p := P |
| 192 | + let z1 := mload(add(p1, 0x40)) |
| 193 | + let s1 := mulmod(mload(add(p1, 0x20)), mulmod(mulmod(z2, z2, p), z2, p), p) // s1 = y1*z2³ |
| 194 | + let s2 := mulmod(y2, mulmod(mulmod(z1, z1, p), z1, p), p) // s2 = y2*z1³ |
| 195 | + let r := addmod(s2, sub(p, s1), p) // r = s2-s1 |
| 196 | + let u1 := mulmod(mload(p1), mulmod(z2, z2, p), p) // u1 = x1*z2² |
| 197 | + let u2 := mulmod(x2, mulmod(z1, z1, p), p) // u2 = x2*z1² |
| 198 | + let h := addmod(u2, sub(p, u1), p) // h = u2-u1 |
| 199 | + let hh := mulmod(h, h, p) // h² |
| 200 | + |
| 201 | + // x' = r²-h³-2*u1*h² |
| 202 | + rx := addmod( |
| 203 | + addmod(mulmod(r, r, p), sub(p, mulmod(h, hh, p)), p), |
| 204 | + sub(p, mulmod(2, mulmod(u1, hh, p), p)), |
| 205 | + p |
| 206 | + ) |
| 207 | + // y' = r*(u1*h²-x')-s1*h³ |
| 208 | + ry := addmod( |
| 209 | + mulmod(r, addmod(mulmod(u1, hh, p), sub(p, rx), p), p), |
| 210 | + sub(p, mulmod(s1, mulmod(h, hh, p), p)), |
| 211 | + p |
| 212 | + ) |
| 213 | + // z' = h*z1*z2 |
| 214 | + rz := mulmod(h, mulmod(z1, z2, p), p) |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + /** |
| 219 | + * @dev Point doubling on the jacobian coordinates |
| 220 | + * Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-dbl-1998-cmo-2 |
| 221 | + */ |
| 222 | + function _jDouble(uint256 x, uint256 y, uint256 z) private pure returns (uint256 rx, uint256 ry, uint256 rz) { |
| 223 | + assembly ("memory-safe") { |
| 224 | + let p := P |
| 225 | + let yy := mulmod(y, y, p) |
| 226 | + let zz := mulmod(z, z, p) |
| 227 | + let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y² |
| 228 | + let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴ |
| 229 | + let t := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p) // t = m²-2*s |
| 230 | + |
| 231 | + // x' = t |
| 232 | + rx := t |
| 233 | + // y' = m*(s-t)-8*y⁴ |
| 234 | + ry := addmod(mulmod(m, addmod(s, sub(p, t), p), p), sub(p, mulmod(8, mulmod(yy, yy, p), p)), p) |
| 235 | + // z' = 2*y*z |
| 236 | + rz := mulmod(2, mulmod(y, z, p), p) |
| 237 | + } |
| 238 | + } |
| 239 | + |
| 240 | + /** |
| 241 | + * @dev Compute P·u1 + Q·u2 using the precomputed points for P and Q (see {_preComputeJacobianPoints}). |
| 242 | + * |
| 243 | + * Uses Strauss Shamir trick for EC multiplication |
| 244 | + * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method |
| 245 | + * we optimise on this a bit to do with 2 bits at a time rather than a single bit |
| 246 | + * the individual points for a single pass are precomputed |
| 247 | + * overall this reduces the number of additions while keeping the same number of doublings |
| 248 | + */ |
| 249 | + function _jMultShamir(JPoint[16] memory points, uint256 u1, uint256 u2) private view returns (uint256, uint256) { |
| 250 | + uint256 x = 0; |
| 251 | + uint256 y = 0; |
| 252 | + uint256 z = 0; |
| 253 | + unchecked { |
| 254 | + for (uint256 i = 0; i < 128; ++i) { |
| 255 | + if (z > 0) { |
| 256 | + (x, y, z) = _jDouble(x, y, z); |
| 257 | + (x, y, z) = _jDouble(x, y, z); |
| 258 | + } |
| 259 | + // Read 2 bits of u1, and 2 bits of u2. Combining the two give a lookup index in the table. |
| 260 | + uint256 pos = ((u1 >> 252) & 0xc) | ((u2 >> 254) & 0x3); |
| 261 | + if (pos > 0) { |
| 262 | + if (z == 0) { |
| 263 | + (x, y, z) = (points[pos].x, points[pos].y, points[pos].z); |
| 264 | + } else { |
| 265 | + (x, y, z) = _jAdd(points[pos], x, y, z); |
| 266 | + } |
| 267 | + } |
| 268 | + u1 <<= 2; |
| 269 | + u2 <<= 2; |
| 270 | + } |
| 271 | + } |
| 272 | + return _affineFromJacobian(x, y, z); |
| 273 | + } |
| 274 | + |
| 275 | + /** |
| 276 | + * @dev Precompute a matrice of useful jacobian points associated with a given P. This can be seen as a 4x4 matrix |
| 277 | + * that contains combination of P and G (generator) up to 3 times each. See the table below: |
| 278 | + * |
| 279 | + * ┌────┬─────────────────────┐ |
| 280 | + * │ i │ 0 1 2 3 │ |
| 281 | + * ├────┼─────────────────────┤ |
| 282 | + * │ 0 │ 0 p 2p 3p │ |
| 283 | + * │ 4 │ g g+p g+2p g+3p │ |
| 284 | + * │ 8 │ 2g 2g+p 2g+2p 2g+3p │ |
| 285 | + * │ 12 │ 3g 3g+p 3g+2p 3g+3p │ |
| 286 | + * └────┴─────────────────────┘ |
| 287 | + */ |
| 288 | + function _preComputeJacobianPoints(uint256 px, uint256 py) private pure returns (JPoint[16] memory points) { |
| 289 | + points[0x00] = JPoint(0, 0, 0); // 0,0 |
| 290 | + points[0x01] = JPoint(px, py, 1); // 1,0 (p) |
| 291 | + points[0x04] = JPoint(GX, GY, 1); // 0,1 (g) |
| 292 | + points[0x02] = _jDoublePoint(points[0x01]); // 2,0 (2p) |
| 293 | + points[0x08] = _jDoublePoint(points[0x04]); // 0,2 (2g) |
| 294 | + points[0x03] = _jAddPoint(points[0x01], points[0x02]); // 3,0 (3p) |
| 295 | + points[0x05] = _jAddPoint(points[0x01], points[0x04]); // 1,1 (p+g) |
| 296 | + points[0x06] = _jAddPoint(points[0x02], points[0x04]); // 2,1 (2p+g) |
| 297 | + points[0x07] = _jAddPoint(points[0x03], points[0x04]); // 3,1 (3p+g) |
| 298 | + points[0x09] = _jAddPoint(points[0x01], points[0x08]); // 1,2 (p+2g) |
| 299 | + points[0x0a] = _jAddPoint(points[0x02], points[0x08]); // 2,2 (2p+2g) |
| 300 | + points[0x0b] = _jAddPoint(points[0x03], points[0x08]); // 3,2 (3p+2g) |
| 301 | + points[0x0c] = _jAddPoint(points[0x04], points[0x08]); // 0,3 (g+2g) |
| 302 | + points[0x0d] = _jAddPoint(points[0x01], points[0x0c]); // 1,3 (p+3g) |
| 303 | + points[0x0e] = _jAddPoint(points[0x02], points[0x0c]); // 2,3 (2p+3g) |
| 304 | + points[0x0f] = _jAddPoint(points[0x03], points[0x0C]); // 3,3 (3p+3g) |
| 305 | + } |
| 306 | + |
| 307 | + function _jAddPoint(JPoint memory p1, JPoint memory p2) private pure returns (JPoint memory) { |
| 308 | + (uint256 x, uint256 y, uint256 z) = _jAdd(p1, p2.x, p2.y, p2.z); |
| 309 | + return JPoint(x, y, z); |
| 310 | + } |
| 311 | + |
| 312 | + function _jDoublePoint(JPoint memory p) private pure returns (JPoint memory) { |
| 313 | + (uint256 x, uint256 y, uint256 z) = _jDouble(p.x, p.y, p.z); |
| 314 | + return JPoint(x, y, z); |
| 315 | + } |
| 316 | +} |
0 commit comments