Skip to content

Base64.decode #5765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/solid-cobras-talk.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`Base64`: Add a new `decode` function that parses base64 encoded strings.
166 changes: 141 additions & 25 deletions contracts/utils/Base64.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,57 @@

pragma solidity ^0.8.20;

import {SafeCast} from "./math/SafeCast.sol";

/**
* @dev Provides a set of functions to operate with Base64 strings.
*/
library Base64 {
/**
* @dev Base64 Encoding/Decoding Table
* See sections 4 and 5 of https://datatracker.ietf.org/doc/html/rfc4648
*/
string internal constant _TABLE = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
string internal constant _TABLE_URL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
using SafeCast for bool;

error InvalidBase64Digit(uint8);

/**
* @dev Converts a `bytes` to its Bytes64 `string` representation.
*/
function encode(bytes memory data) internal pure returns (string memory) {
return _encode(data, _TABLE, true);
return string(_encode(data, false));
}

/**
* @dev Converts a `bytes` to its Bytes64Url `string` representation.
* Output is not padded with `=` as specified in https://www.rfc-editor.org/rfc/rfc4648[rfc4648].
*/
function encodeURL(bytes memory data) internal pure returns (string memory) {
return _encode(data, _TABLE_URL, false);
return string(_encode(data, true));
}

/**
* @dev Internal table-agnostic conversion
* @dev Converts a Base64 `string` to the `bytes` it represents.
*
* * Supports padded an unpadded inputs.
* * Supports both encoding ({encode} and {encodeURL}) seamlessly.
* * Does NOT revert if the input is not a valid Base64 string.
*/
function _encode(bytes memory data, string memory table, bool withPadding) private pure returns (string memory) {
function decode(string memory data) internal pure returns (bytes memory) {
return _decode(bytes(data));
}

/**
* @dev Internal table-agnostic encoding
*
* If padding is enabled, uses the Base64 table, otherwise use the Base64Url table.
* See sections 4 and 5 of https://datatracker.ietf.org/doc/html/rfc4648
*/
function _encode(bytes memory data, bool urlAndFilenameSafe) private pure returns (bytes memory result) {
/**
* Inspired by Brecht Devos (Brechtpd) implementation - MIT licence
* https://github.com/Brechtpd/base64/blob/e78d9fd951e7b0977ddca77d92dc85183770daf4/base64.sol
*/
if (data.length == 0) return "";

// Padding is enabled by default, but disabled when the "urlAndFilenameSafe" alphabet is used
//
// If padding is enabled, the final length should be `bytes` data length divided by 3 rounded up and then
// multiplied by 4 so that it leaves room for padding the last chunk
// - `data.length + 2` -> Prepare for division rounding up
Expand All @@ -52,13 +67,20 @@ library Base64 {
// - ` + 2` -> Prepare for division rounding up
// - `/ 3` -> Number of 3-bytes chunks (rounded up)
// This is equivalent to: Math.ceil((4 * data.length) / 3)
uint256 resultLength = withPadding ? 4 * ((data.length + 2) / 3) : (4 * data.length + 2) / 3;

string memory result = new string(resultLength);
uint256 resultLength = urlAndFilenameSafe ? (4 * data.length + 2) / 3 : 4 * ((data.length + 2) / 3);

assembly ("memory-safe") {
// Prepare the lookup table (skip the first "length" byte)
let tablePtr := add(table, 1)
result := mload(0x40)

// Store the encoding table in the scratch space (and fmp ptr) to avoid memory allocation
//
// Base64 (ascii) A B C D E F G H I J K L M N O P Q R S T U V W X Y Z a b c d e f g h i j k l m n o p q r s t u v w x y z 0 1 2 3 4 5 6 7 8 9 + /
// Base64 (hex) 4142434445464748494a4b4c4d4e4f505152535455565758595a6162636465666768696a6b6c6d6e6f707172737475767778797a303132333435363738392b2f
// Base64Url (ascii) A B C D E F G H I J K L M N O P Q R S T U V W X Y Z a b c d e f g h i j k l m n o p q r s t u v w x y z 0 1 2 3 4 5 6 7 8 9 - _
// Base64Url (hex) 4142434445464748494a4b4c4d4e4f505152535455565758595a6162636465666768696a6b6c6d6e6f707172737475767778797a303132333435363738392d5f
// xor (hex) 00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000670
mstore(0x1f, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdef")
mstore(0x3f, xor("ghijklmnopqrstuvwxyz0123456789+/", mul(urlAndFilenameSafe, 0x670)))

// Prepare result pointer, jump over length
let resultPtr := add(result, 0x20)
Expand All @@ -83,37 +105,131 @@ library Base64 {
// Use this as an index into the lookup table, mload an entire word
// so the desired character is in the least significant byte, and
// mstore8 this least significant byte into the result and continue.

mstore8(resultPtr, mload(add(tablePtr, and(shr(18, input), 0x3F))))
mstore8(resultPtr, mload(and(shr(18, input), 0x3F)))
resultPtr := add(resultPtr, 1) // Advance

mstore8(resultPtr, mload(add(tablePtr, and(shr(12, input), 0x3F))))
mstore8(resultPtr, mload(and(shr(12, input), 0x3F)))
resultPtr := add(resultPtr, 1) // Advance

mstore8(resultPtr, mload(add(tablePtr, and(shr(6, input), 0x3F))))
mstore8(resultPtr, mload(and(shr(6, input), 0x3F)))
resultPtr := add(resultPtr, 1) // Advance

mstore8(resultPtr, mload(add(tablePtr, and(input, 0x3F))))
mstore8(resultPtr, mload(and(input, 0x3F)))
resultPtr := add(resultPtr, 1) // Advance
}

// Reset the value that was cached
mstore(afterPtr, afterCache)

if withPadding {
if iszero(urlAndFilenameSafe) {
// When data `bytes` is not exactly 3 bytes long
// it is padded with `=` characters at the end
switch mod(mload(data), 3)
case 1 {
mstore8(sub(resultPtr, 1), 0x3d)
mstore8(sub(resultPtr, 2), 0x3d)
resultPtr := add(resultPtr, 2)
}
case 2 {
mstore8(sub(resultPtr, 1), 0x3d)
resultPtr := add(resultPtr, 1)
}
}

// Store result length and update FMP to reserve allocated space
mstore(result, resultLength)
mstore(0x40, resultPtr)
}
}

/**
* @dev Internal decoding
*/
function _decode(bytes memory data) private pure returns (bytes memory result) {
bytes4 errorSelector = InvalidBase64Digit.selector;

uint256 dataLength = data.length;
if (dataLength == 0) return "";

uint256 resultLength = (dataLength / 4) * 3;
if (dataLength % 4 == 0) {
resultLength -= (data[dataLength - 1] == "=").toUint() + (data[dataLength - 2] == "=").toUint();
} else {
resultLength += (dataLength % 4) - 1;
}

return result;
assembly ("memory-safe") {
result := mload(0x40)

// Temporarily store the reverse lookup table between in memory. This spans from 0x00 to 0x50, Using:
// - all 64bytes of scratch space
// - part of the FMP (at location 0x40)
mstore(0x30, 0x2425262728292a2b2c2d2e2f30313233)
mstore(0x20, 0x0a0b0c0d0e0f10111213141516171819ffffffff3fff1a1b1c1d1e1f20212223)
mstore(0x00, 0x3eff3eff3f3435363738393a3b3c3dffffff00ffffff00010203040506070809)

// Prepare result pointer, jump over length
let dataPtr := data
let resultPtr := add(result, 0x20)
let endPtr := add(resultPtr, resultLength)

// In some cases, the last iteration will read bytes after the end of the data. We cache the value, and
// set it to "==" (fake padding) to make sure no dirty bytes are read in that section.
let afterPtr := add(add(data, 0x20), dataLength)
let afterCache := mload(afterPtr)
mstore(afterPtr, shl(240, 0x3d3d))

// loop while not everything is decoded
for {} lt(resultPtr, endPtr) {} {
dataPtr := add(dataPtr, 4)

// Read a 4 bytes chunk of data
let input := mload(dataPtr)

// Decode each byte in the chunk as a 6 bit block, and align them to form a block of 3 bytes
let a := sub(byte(28, input), 43)
// slither-disable-next-line incorrect-shift
if iszero(and(shl(a, 1), 0xffffffd0ffffffc47ff5)) {
mstore(0, errorSelector)
mstore(4, add(a, 49))
revert(0, 0x24)
}
let b := sub(byte(29, input), 43)
// slither-disable-next-line incorrect-shift
if iszero(and(shl(b, 1), 0xffffffd0ffffffc47ff5)) {
mstore(0, errorSelector)
mstore(4, add(b, 49))
revert(0, 0x24)
}
let c := sub(byte(30, input), 43)
// slither-disable-next-line incorrect-shift
if iszero(and(shl(c, 1), 0xffffffd0ffffffc47ff5)) {
mstore(0, errorSelector)
mstore(4, add(c, 49))
revert(0, 0x24)
}
let d := sub(byte(31, input), 43)
// slither-disable-next-line incorrect-shift
if iszero(and(shl(d, 1), 0xffffffd0ffffffc47ff5)) {
mstore(0, errorSelector)
mstore(4, add(d, 49))
revert(0, 0x24)
}

mstore(
resultPtr,
or(
or(shl(250, byte(0, mload(a))), shl(244, byte(0, mload(b)))),
or(shl(238, byte(0, mload(c))), shl(232, byte(0, mload(d))))
)
)

resultPtr := add(resultPtr, 3)
}

// Reset the value that was cached
mstore(afterPtr, afterCache)

// Store result length and update FMP to reserve allocated space
mstore(result, resultLength)
mstore(0x40, endPtr)
}
}
}
2 changes: 2 additions & 0 deletions test/utils/Base64.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import {Base64} from "@openzeppelin/contracts/utils/Base64.sol";
contract Base64Test is Test {
function testEncode(bytes memory input) external pure {
assertEq(Base64.encode(input), vm.toBase64(input));
assertEq(Base64.decode(Base64.encode(input)), input);
}

function testEncodeURL(bytes memory input) external pure {
assertEq(Base64.encodeURL(input), _removePadding(vm.toBase64URL(input)));
assertEq(Base64.decode(Base64.encodeURL(input)), input);
}

function _removePadding(string memory inputStr) internal pure returns (string memory) {
Expand Down
21 changes: 16 additions & 5 deletions test/utils/Base64.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async function fixture() {
return { mock };
}

describe('Strings', function () {
describe('Base64', function () {
beforeEach(async function () {
Object.assign(this, await loadFixture(fixture));
});
Expand All @@ -27,8 +27,9 @@ describe('Strings', function () {
])
it(title, async function () {
const buffer = Buffer.from(input, 'ascii');
expect(await this.mock.$encode(buffer)).to.equal(ethers.encodeBase64(buffer));
expect(await this.mock.$encode(buffer)).to.equal(expected);
await expect(this.mock.$encode(buffer)).to.eventually.equal(ethers.encodeBase64(buffer));
await expect(this.mock.$encode(buffer)).to.eventually.equal(expected);
await expect(this.mock.$decode(expected)).to.eventually.equal(ethers.hexlify(buffer));
});
});

Expand All @@ -43,11 +44,21 @@ describe('Strings', function () {
])
it(title, async function () {
const buffer = Buffer.from(input, 'ascii');
expect(await this.mock.$encodeURL(buffer)).to.equal(base64toBase64Url(ethers.encodeBase64(buffer)));
expect(await this.mock.$encodeURL(buffer)).to.equal(expected);
await expect(this.mock.$encodeURL(buffer)).to.eventually.equal(base64toBase64Url(ethers.encodeBase64(buffer)));
await expect(this.mock.$encodeURL(buffer)).to.eventually.equal(expected);
await expect(this.mock.$decode(expected)).to.eventually.equal(ethers.hexlify(buffer));
});
});

it('Decode invalid base64 string', async function () {
// ord('$') < 43
await expect(this.mock.$decode('dGVzd$==')).to.be.reverted;
// ord('~') > 122
await expect(this.mock.$decode('dGVzd~==')).to.be.reverted;
// ord('@') in range, but '@' not in the dictionary
await expect(this.mock.$decode('dGVzd@==')).to.be.reverted;
});

it('Encode reads beyond the input buffer into dirty memory', async function () {
const mock = await ethers.deployContract('Base64Dirty');
const buffer32 = ethers.id('example');
Expand Down