Skip to content

Commit 59399a2

Browse files
committed
Cleanup inconsistent ID type handling and vectorize Base64
1 parent 4b6f6d9 commit 59399a2

File tree

7 files changed

+141
-109
lines changed

7 files changed

+141
-109
lines changed

GraphQLService.cpp

Lines changed: 86 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <iostream>
99
#include <algorithm>
1010
#include <stdexcept>
11+
#include <array>
1112

1213
namespace facebook {
1314
namespace graphql {
@@ -48,7 +49,7 @@ const peg::ast_node& Fragment::getSelection() const
4849
return _selection;
4950
}
5051

51-
uint8_t Base64::verifyFromBase64(unsigned char ch)
52+
uint8_t Base64::verifyFromBase64(char ch)
5253
{
5354
uint8_t result = fromBase64(ch);
5455

@@ -60,70 +61,94 @@ uint8_t Base64::verifyFromBase64(unsigned char ch)
6061
return result;
6162
}
6263

63-
std::vector<unsigned char> Base64::fromBase64(const char* encoded, size_t count)
64+
std::vector<uint8_t> Base64::fromBase64(const char* encoded, size_t count)
6465
{
65-
std::vector<unsigned char> result;
66+
std::vector<uint8_t> result;
6667

6768
if (!count)
6869
{
6970
return result;
7071
}
7172

72-
result.reserve(count * 3 / 4);
73-
while (encoded[0] && encoded[1])
73+
result.reserve((count + (count % 4)) * 3 / 4);
74+
75+
// First decode all of the full unpadded segments 24 bits at a time
76+
while (count >= 4
77+
&& encoded[3] != padding)
7478
{
75-
uint16_t buffer = static_cast<uint16_t>(verifyFromBase64(*encoded++)) << 10;
79+
const uint32_t segment = (static_cast<uint32_t>(verifyFromBase64(encoded[0])) << 18)
80+
| (static_cast<uint32_t>(verifyFromBase64(encoded[1])) << 12)
81+
| (static_cast<uint32_t>(verifyFromBase64(encoded[2])) << 6)
82+
| static_cast<uint32_t>(verifyFromBase64(encoded[3]));
83+
84+
result.emplace_back(static_cast<uint8_t>((segment & 0xFF0000) >> 16));
85+
result.emplace_back(static_cast<uint8_t>((segment & 0xFF00) >> 8));
86+
result.emplace_back(static_cast<uint8_t>(segment & 0xFF));
87+
88+
encoded += 4;
89+
count -= 4;
90+
}
7691

77-
buffer |= static_cast<uint16_t>(verifyFromBase64(*encoded++)) << 4;
78-
result.push_back(static_cast<unsigned char>((buffer & 0xFF00) >> 8));
79-
buffer = (buffer & 0xFF) << 8;
92+
// Get any leftover partial segment with 2 or 3 non-padding characters
93+
if (count > 1)
94+
{
95+
const bool triplet = (count > 2 && padding != encoded[2]);
96+
const uint8_t tail = (triplet ? verifyFromBase64(encoded[2]) : 0);
97+
const uint16_t segment = (static_cast<uint16_t>(verifyFromBase64(encoded[0])) << 10)
98+
| (static_cast<uint16_t>(verifyFromBase64(encoded[1])) << 4)
99+
| (static_cast<uint16_t>(tail) >> 2);
80100

81-
if (!*encoded || '=' == *encoded)
101+
if (triplet)
82102
{
83-
if (0 != buffer
84-
|| (*encoded && (*++encoded != '=' || *++encoded)))
103+
if (tail & 0x3)
85104
{
86105
throw schema_exception({ "invalid padding at the end of a base64 encoded string" });
87106
}
88107

89-
break;
90-
}
91-
92-
buffer |= static_cast<uint16_t>(verifyFromBase64(*encoded++)) << 6;
93-
result.push_back(static_cast<unsigned char>((buffer & 0xFF00) >> 8));
94-
buffer &= 0xFF;
108+
result.emplace_back(static_cast<uint8_t>((segment & 0xFF00) >> 8));
109+
result.emplace_back(static_cast<uint8_t>(segment & 0xFF));
95110

96-
if (!*encoded || '=' == *encoded)
111+
encoded += 3;
112+
count -= 3;
113+
}
114+
else
97115
{
98-
if (0 != buffer
99-
|| (*encoded && *++encoded))
116+
if (segment & 0xFF)
100117
{
101118
throw schema_exception({ "invalid padding at the end of a base64 encoded string" });
102119
}
103120

104-
break;
121+
result.emplace_back(static_cast<uint8_t>((segment & 0xFF00) >> 8));
122+
123+
encoded += 2;
124+
count -= 2;
105125
}
126+
}
106127

107-
buffer |= static_cast<uint16_t>(verifyFromBase64(*encoded++));
108-
result.push_back(static_cast<unsigned char>(buffer & 0xFF));
128+
// Make sure anything that's left is 0 - 2 characters of padding
129+
if ((count > 0 && padding != encoded[0])
130+
|| (count > 1 && padding != encoded[1])
131+
|| count > 2)
132+
{
133+
throw schema_exception({ "invalid padding at the end of a base64 encoded string" });
109134
}
110135

111136
return result;
112137
}
113138

114-
unsigned char Base64::verifyToBase64(uint8_t i)
139+
char Base64::verifyToBase64(uint8_t i)
115140
{
116141
unsigned char result = toBase64(i);
117142

118-
if (result == '=')
143+
if (result == padding)
119144
{
120145
throw std::logic_error("invalid 6-bit value");
121146
}
122147

123148
return result;
124149
}
125150

126-
std::string Base64::toBase64(const std::vector<unsigned char>& bytes)
151+
std::string Base64::toBase64(const std::vector<uint8_t>& bytes)
127152
{
128153
std::string result;
129154

@@ -132,38 +157,43 @@ std::string Base64::toBase64(const std::vector<unsigned char>& bytes)
132157
return result;
133158
}
134159

135-
auto itr = bytes.cbegin();
136-
const auto itrEnd = bytes.cend();
137-
const size_t count = bytes.size();
160+
size_t count = bytes.size();
161+
const uint8_t* data = bytes.data();
138162

139163
result.reserve((count + (count % 3)) * 4 / 3);
140-
while (itr != itrEnd)
141-
{
142-
uint16_t buffer = static_cast<uint8_t>(*itr++) << 8;
143-
144-
result.push_back(verifyToBase64((buffer & 0xFC00) >> 10));
145-
146-
if (itr == itrEnd)
147-
{
148-
result.push_back(verifyToBase64((buffer & 0x03F0) >> 4));
149-
result.append("==");
150-
break;
151-
}
152164

153-
buffer |= static_cast<uint8_t>(*itr++);
154-
result.push_back(verifyToBase64((buffer & 0x03F0) >> 4));
155-
buffer = buffer << 8;
165+
// First encode all of the full unpadded segments 24 bits at a time
166+
while (count >= 3)
167+
{
168+
const uint32_t segment = (static_cast<uint32_t>(data[0]) << 16)
169+
| (static_cast<uint32_t>(data[1]) << 8)
170+
| static_cast<uint32_t>(data[2]);
171+
172+
result.append({
173+
verifyToBase64((segment & 0xFC0000) >> 18),
174+
verifyToBase64((segment & 0x3F000) >> 12),
175+
verifyToBase64((segment & 0xFC0) >> 6),
176+
verifyToBase64(segment & 0x3F)
177+
});
156178

157-
if (itr == itrEnd)
158-
{
159-
result.push_back(verifyToBase64((buffer & 0x0FC0) >> 6));
160-
result.push_back('=');
161-
break;
162-
}
179+
data += 3;
180+
count -= 3;
181+
}
163182

164-
buffer |= static_cast<uint8_t>(*itr++);
165-
result.push_back(verifyToBase64((buffer & 0x0FC0) >> 6));
166-
result.push_back(verifyToBase64(buffer & 0x3F));
183+
// Get any leftover partial segment with 1 or 2 bytes
184+
if (count > 0)
185+
{
186+
const bool pair = (count > 1);
187+
const uint16_t segment = (static_cast<uint16_t>(data[0]) << 8)
188+
| (pair ? static_cast<uint16_t>(data[1]) : 0);
189+
const std::array<char, 4> remainder {
190+
verifyToBase64((segment & 0xFC00) >> 10),
191+
verifyToBase64((segment & 0x3F0) >> 4),
192+
(pair ? verifyToBase64((segment & 0xF) << 2) : padding),
193+
padding
194+
};
195+
196+
result.append(remainder.data(), remainder.size());
167197
}
168198

169199
return result;
@@ -229,7 +259,7 @@ rapidjson::Document ModifiedArgument<rapidjson::Document>::convert(const rapidjs
229259
}
230260

231261
template <>
232-
std::vector<unsigned char> ModifiedArgument<std::vector<unsigned char>>::convert(const rapidjson::Value& value)
262+
std::vector<uint8_t> ModifiedArgument<std::vector<uint8_t>>::convert(const rapidjson::Value& value)
233263
{
234264
if (!value.IsString())
235265
{
@@ -282,7 +312,7 @@ rapidjson::Document ModifiedResult<rapidjson::Document>::convert(rapidjson::Docu
282312
}
283313

284314
template <>
285-
rapidjson::Document ModifiedResult<std::vector<unsigned char>>::convert(std::vector<unsigned char>&& result, ResolverParams&&)
315+
rapidjson::Document ModifiedResult<std::vector<uint8_t>>::convert(std::vector<uint8_t>&& result, ResolverParams&&)
286316
{
287317
rapidjson::Document document(rapidjson::Type::kStringType);
288318

GraphQLService.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Base64
7676
{
7777
public:
7878
// Map a single Base64-encoded character to its 6-bit integer value.
79-
static constexpr uint8_t fromBase64(unsigned char ch) noexcept
79+
static constexpr uint8_t fromBase64(char ch) noexcept
8080
{
8181
return (ch >= 'A' && ch <= 'Z' ? ch - 'A'
8282
: (ch >= 'a' && ch <= 'z' ? ch - 'a' + 26
@@ -86,27 +86,29 @@ class Base64
8686
}
8787

8888
// Convert a Base64-encoded string to a vector of bytes.
89-
static std::vector<unsigned char> fromBase64(const char* encoded, size_t count);
89+
static std::vector<uint8_t> fromBase64(const char* encoded, size_t count);
9090

9191
// Map a single 6-bit integer value to its Base64-encoded character.
92-
static constexpr unsigned char toBase64(uint8_t i) noexcept
92+
static constexpr char toBase64(uint8_t i) noexcept
9393
{
94-
return (i < 26 ? i + static_cast<uint8_t>('A')
95-
: (i < 52 ? i - 26 + static_cast<uint8_t>('a')
96-
: (i < 62 ? i - 52 + static_cast<uint8_t>('0')
94+
return (i < 26 ? static_cast<char>(i + static_cast<uint8_t>('A'))
95+
: (i < 52 ? static_cast<char>(i - 26 + static_cast<uint8_t>('a'))
96+
: (i < 62 ? static_cast<char>(i - 52 + static_cast<uint8_t>('0'))
9797
: (i == 62 ? '+'
98-
: (i == 63 ? '/' : '=')))));
98+
: (i == 63 ? '/' : padding)))));
9999
}
100100

101101
// Convert a set of bytes to Base64.
102-
static std::string toBase64(const std::vector<unsigned char>& bytes);
102+
static std::string toBase64(const std::vector<uint8_t>& bytes);
103103

104104
private:
105+
static constexpr char padding = '=';
106+
105107
// Throw a schema_exception if the character is out of range.
106-
static uint8_t verifyFromBase64(unsigned char ch);
108+
static uint8_t verifyFromBase64(char ch);
107109

108110
// Throw a logic_error if the integer is out of range.
109-
static unsigned char verifyToBase64(uint8_t i);
111+
static char verifyToBase64(uint8_t i);
110112
};
111113

112114
// Types be wrapped non-null or list types in GraphQL. Since nullability is a more special case
@@ -250,7 +252,7 @@ using IntArgument = ModifiedArgument<int>;
250252
using FloatArgument = ModifiedArgument<double>;
251253
using StringArgument = ModifiedArgument<std::string>;
252254
using BooleanArgument = ModifiedArgument<bool>;
253-
using IdArgument = ModifiedArgument<std::vector<unsigned char>>;
255+
using IdArgument = ModifiedArgument<std::vector<uint8_t>>;
254256
using ScalarArgument = ModifiedArgument<rapidjson::Document>;
255257

256258
// Each type should handle fragments with type conditions matching its own

SchemaGenerator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ const CppTypeMap Generator::s_builtinCppTypes= {
3333
"double",
3434
"std::string",
3535
"bool",
36-
"std::vector<unsigned char>",
36+
"std::vector<uint8_t>",
3737
};
3838

3939
const std::string Generator::s_scalarCppType = R"cpp(rapidjson::Document)cpp";

Today.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,22 @@ namespace facebook {
1010
namespace graphql {
1111
namespace today {
1212

13-
Appointment::Appointment(std::vector<unsigned char>&& id, std::string&& when, std::string&& subject, bool isNow)
13+
Appointment::Appointment(std::vector<uint8_t>&& id, std::string&& when, std::string&& subject, bool isNow)
1414
: _id(std::move(id))
1515
, _when(std::move(when))
1616
, _subject(std::move(subject))
1717
, _isNow(isNow)
1818
{
1919
}
2020

21-
Task::Task(std::vector<unsigned char>&& id, std::string&& title, bool isComplete)
21+
Task::Task(std::vector<uint8_t>&& id, std::string&& title, bool isComplete)
2222
: _id(std::move(id))
2323
, _title(std::move(title))
2424
, _isComplete(isComplete)
2525
{
2626
}
2727

28-
Folder::Folder(std::vector<unsigned char>&& id, std::string&& name, int unreadCount)
28+
Folder::Folder(std::vector<uint8_t>&& id, std::string&& name, int unreadCount)
2929
: _id(std::move(id))
3030
, _name(std::move(name))
3131
, _unreadCount(unreadCount)
@@ -48,7 +48,7 @@ void Query::loadAppointments() const
4848
}
4949
}
5050

51-
std::shared_ptr<Appointment> Query::findAppointment(const std::vector<unsigned char>& id) const
51+
std::shared_ptr<Appointment> Query::findAppointment(const std::vector<uint8_t>& id) const
5252
{
5353
loadAppointments();
5454

@@ -74,7 +74,7 @@ void Query::loadTasks() const
7474
}
7575
}
7676

77-
std::shared_ptr<Task> Query::findTask(const std::vector<unsigned char>& id) const
77+
std::shared_ptr<Task> Query::findTask(const std::vector<uint8_t>& id) const
7878
{
7979
loadTasks();
8080

@@ -100,7 +100,7 @@ void Query::loadUnreadCounts() const
100100
}
101101
}
102102

103-
std::shared_ptr<Folder> Query::findUnreadCount(const std::vector<unsigned char>& id) const
103+
std::shared_ptr<Folder> Query::findUnreadCount(const std::vector<uint8_t>& id) const
104104
{
105105
loadUnreadCounts();
106106

@@ -117,7 +117,7 @@ std::shared_ptr<Folder> Query::findUnreadCount(const std::vector<unsigned char>&
117117
return nullptr;
118118
}
119119

120-
std::shared_ptr<service::Object> Query::getNode(std::vector<unsigned char>&& id) const
120+
std::shared_ptr<service::Object> Query::getNode(std::vector<uint8_t>&& id) const
121121
{
122122
auto appointment = findAppointment(id);
123123

@@ -263,38 +263,38 @@ std::shared_ptr<object::FolderConnection> Query::getUnreadCounts(std::unique_ptr
263263
return std::static_pointer_cast<object::FolderConnection>(connection);
264264
}
265265

266-
std::vector<std::shared_ptr<object::Appointment>> Query::getAppointmentsById(std::vector<std::vector<unsigned char>>&& ids) const
266+
std::vector<std::shared_ptr<object::Appointment>> Query::getAppointmentsById(std::vector<std::vector<uint8_t>>&& ids) const
267267
{
268268
std::vector<std::shared_ptr<object::Appointment>> result(ids.size());
269269

270270
std::transform(ids.cbegin(), ids.cend(), result.begin(),
271-
[this](const std::vector<unsigned char>& id)
271+
[this](const std::vector<uint8_t>& id)
272272
{
273273
return std::static_pointer_cast<object::Appointment>(findAppointment(id));
274274
});
275275

276276
return result;
277277
}
278278

279-
std::vector<std::shared_ptr<object::Task>> Query::getTasksById(std::vector<std::vector<unsigned char>>&& ids) const
279+
std::vector<std::shared_ptr<object::Task>> Query::getTasksById(std::vector<std::vector<uint8_t>>&& ids) const
280280
{
281281
std::vector<std::shared_ptr<object::Task>> result(ids.size());
282282

283283
std::transform(ids.cbegin(), ids.cend(), result.begin(),
284-
[this](const std::vector<unsigned char>& id)
284+
[this](const std::vector<uint8_t>& id)
285285
{
286286
return std::static_pointer_cast<object::Task>(findTask(id));
287287
});
288288

289289
return result;
290290
}
291291

292-
std::vector<std::shared_ptr<object::Folder>> Query::getUnreadCountsById(std::vector<std::vector<unsigned char>>&& ids) const
292+
std::vector<std::shared_ptr<object::Folder>> Query::getUnreadCountsById(std::vector<std::vector<uint8_t>>&& ids) const
293293
{
294294
std::vector<std::shared_ptr<object::Folder>> result(ids.size());
295295

296296
std::transform(ids.cbegin(), ids.cend(), result.begin(),
297-
[this](const std::vector<unsigned char>& id)
297+
[this](const std::vector<uint8_t>& id)
298298
{
299299
return std::static_pointer_cast<object::Folder>(findUnreadCount(id));
300300
});

0 commit comments

Comments
 (0)