8
8
#include < iostream>
9
9
#include < algorithm>
10
10
#include < stdexcept>
11
+ #include < array>
11
12
12
13
namespace facebook {
13
14
namespace graphql {
@@ -48,7 +49,7 @@ const peg::ast_node& Fragment::getSelection() const
48
49
return _selection;
49
50
}
50
51
51
- uint8_t Base64::verifyFromBase64 (unsigned char ch)
52
+ uint8_t Base64::verifyFromBase64 (char ch)
52
53
{
53
54
uint8_t result = fromBase64 (ch);
54
55
@@ -60,70 +61,94 @@ uint8_t Base64::verifyFromBase64(unsigned char ch)
60
61
return result;
61
62
}
62
63
63
- std::vector<unsigned char > Base64::fromBase64 (const char * encoded, size_t count)
64
+ std::vector<uint8_t > Base64::fromBase64 (const char * encoded, size_t count)
64
65
{
65
- std::vector<unsigned char > result;
66
+ std::vector<uint8_t > result;
66
67
67
68
if (!count)
68
69
{
69
70
return result;
70
71
}
71
72
72
- result.reserve (count * 3 / 4 );
73
- while (encoded[0 ] && encoded[1 ])
73
+ result.reserve ((count + (count % 4 )) * 3 / 4 );
74
+
75
+ // First decode all of the full unpadded segments 24 bits at a time
76
+ while (count >= 4
77
+ && encoded[3 ] != padding)
74
78
{
75
- uint16_t buffer = static_cast <uint16_t >(verifyFromBase64 (*encoded++)) << 10 ;
79
+ const uint32_t segment = (static_cast <uint32_t >(verifyFromBase64 (encoded[0 ])) << 18 )
80
+ | (static_cast <uint32_t >(verifyFromBase64 (encoded[1 ])) << 12 )
81
+ | (static_cast <uint32_t >(verifyFromBase64 (encoded[2 ])) << 6 )
82
+ | static_cast <uint32_t >(verifyFromBase64 (encoded[3 ]));
83
+
84
+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF0000 ) >> 16 ));
85
+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF00 ) >> 8 ));
86
+ result.emplace_back (static_cast <uint8_t >(segment & 0xFF ));
87
+
88
+ encoded += 4 ;
89
+ count -= 4 ;
90
+ }
76
91
77
- buffer |= static_cast <uint16_t >(verifyFromBase64 (*encoded++)) << 4 ;
78
- result.push_back (static_cast <unsigned char >((buffer & 0xFF00 ) >> 8 ));
79
- buffer = (buffer & 0xFF ) << 8 ;
92
+ // Get any leftover partial segment with 2 or 3 non-padding characters
93
+ if (count > 1 )
94
+ {
95
+ const bool triplet = (count > 2 && padding != encoded[2 ]);
96
+ const uint8_t tail = (triplet ? verifyFromBase64 (encoded[2 ]) : 0 );
97
+ const uint16_t segment = (static_cast <uint16_t >(verifyFromBase64 (encoded[0 ])) << 10 )
98
+ | (static_cast <uint16_t >(verifyFromBase64 (encoded[1 ])) << 4 )
99
+ | (static_cast <uint16_t >(tail) >> 2 );
80
100
81
- if (!*encoded || ' = ' == *encoded )
101
+ if (triplet )
82
102
{
83
- if (0 != buffer
84
- || (*encoded && (*++encoded != ' =' || *++encoded)))
103
+ if (tail & 0x3 )
85
104
{
86
105
throw schema_exception ({ " invalid padding at the end of a base64 encoded string" });
87
106
}
88
107
89
- break ;
90
- }
91
-
92
- buffer |= static_cast <uint16_t >(verifyFromBase64 (*encoded++)) << 6 ;
93
- result.push_back (static_cast <unsigned char >((buffer & 0xFF00 ) >> 8 ));
94
- buffer &= 0xFF ;
108
+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF00 ) >> 8 ));
109
+ result.emplace_back (static_cast <uint8_t >(segment & 0xFF ));
95
110
96
- if (!*encoded || ' =' == *encoded)
111
+ encoded += 3 ;
112
+ count -= 3 ;
113
+ }
114
+ else
97
115
{
98
- if (0 != buffer
99
- || (*encoded && *++encoded))
116
+ if (segment & 0xFF )
100
117
{
101
118
throw schema_exception ({ " invalid padding at the end of a base64 encoded string" });
102
119
}
103
120
104
- break ;
121
+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF00 ) >> 8 ));
122
+
123
+ encoded += 2 ;
124
+ count -= 2 ;
105
125
}
126
+ }
106
127
107
- buffer |= static_cast <uint16_t >(verifyFromBase64 (*encoded++));
108
- result.push_back (static_cast <unsigned char >(buffer & 0xFF ));
128
+ // Make sure anything that's left is 0 - 2 characters of padding
129
+ if ((count > 0 && padding != encoded[0 ])
130
+ || (count > 1 && padding != encoded[1 ])
131
+ || count > 2 )
132
+ {
133
+ throw schema_exception ({ " invalid padding at the end of a base64 encoded string" });
109
134
}
110
135
111
136
return result;
112
137
}
113
138
114
- unsigned char Base64::verifyToBase64 (uint8_t i)
139
+ char Base64::verifyToBase64 (uint8_t i)
115
140
{
116
141
unsigned char result = toBase64 (i);
117
142
118
- if (result == ' = ' )
143
+ if (result == padding )
119
144
{
120
145
throw std::logic_error (" invalid 6-bit value" );
121
146
}
122
147
123
148
return result;
124
149
}
125
150
126
- std::string Base64::toBase64 (const std::vector<unsigned char >& bytes)
151
+ std::string Base64::toBase64 (const std::vector<uint8_t >& bytes)
127
152
{
128
153
std::string result;
129
154
@@ -132,38 +157,43 @@ std::string Base64::toBase64(const std::vector<unsigned char>& bytes)
132
157
return result;
133
158
}
134
159
135
- auto itr = bytes.cbegin ();
136
- const auto itrEnd = bytes.cend ();
137
- const size_t count = bytes.size ();
160
+ size_t count = bytes.size ();
161
+ const uint8_t * data = bytes.data ();
138
162
139
163
result.reserve ((count + (count % 3 )) * 4 / 3 );
140
- while (itr != itrEnd)
141
- {
142
- uint16_t buffer = static_cast <uint8_t >(*itr++) << 8 ;
143
-
144
- result.push_back (verifyToBase64 ((buffer & 0xFC00 ) >> 10 ));
145
-
146
- if (itr == itrEnd)
147
- {
148
- result.push_back (verifyToBase64 ((buffer & 0x03F0 ) >> 4 ));
149
- result.append (" ==" );
150
- break ;
151
- }
152
164
153
- buffer |= static_cast <uint8_t >(*itr++);
154
- result.push_back (verifyToBase64 ((buffer & 0x03F0 ) >> 4 ));
155
- buffer = buffer << 8 ;
165
+ // First encode all of the full unpadded segments 24 bits at a time
166
+ while (count >= 3 )
167
+ {
168
+ const uint32_t segment = (static_cast <uint32_t >(data[0 ]) << 16 )
169
+ | (static_cast <uint32_t >(data[1 ]) << 8 )
170
+ | static_cast <uint32_t >(data[2 ]);
171
+
172
+ result.append ({
173
+ verifyToBase64 ((segment & 0xFC0000 ) >> 18 ),
174
+ verifyToBase64 ((segment & 0x3F000 ) >> 12 ),
175
+ verifyToBase64 ((segment & 0xFC0 ) >> 6 ),
176
+ verifyToBase64 (segment & 0x3F )
177
+ });
156
178
157
- if (itr == itrEnd)
158
- {
159
- result.push_back (verifyToBase64 ((buffer & 0x0FC0 ) >> 6 ));
160
- result.push_back (' =' );
161
- break ;
162
- }
179
+ data += 3 ;
180
+ count -= 3 ;
181
+ }
163
182
164
- buffer |= static_cast <uint8_t >(*itr++);
165
- result.push_back (verifyToBase64 ((buffer & 0x0FC0 ) >> 6 ));
166
- result.push_back (verifyToBase64 (buffer & 0x3F ));
183
+ // Get any leftover partial segment with 1 or 2 bytes
184
+ if (count > 0 )
185
+ {
186
+ const bool pair = (count > 1 );
187
+ const uint16_t segment = (static_cast <uint16_t >(data[0 ]) << 8 )
188
+ | (pair ? static_cast <uint16_t >(data[1 ]) : 0 );
189
+ const std::array<char , 4 > remainder {
190
+ verifyToBase64 ((segment & 0xFC00 ) >> 10 ),
191
+ verifyToBase64 ((segment & 0x3F0 ) >> 4 ),
192
+ (pair ? verifyToBase64 ((segment & 0xF ) << 2 ) : padding),
193
+ padding
194
+ };
195
+
196
+ result.append (remainder.data (), remainder.size ());
167
197
}
168
198
169
199
return result;
@@ -229,7 +259,7 @@ rapidjson::Document ModifiedArgument<rapidjson::Document>::convert(const rapidjs
229
259
}
230
260
231
261
template <>
232
- std::vector<unsigned char > ModifiedArgument<std::vector<unsigned char >>::convert(const rapidjson::Value& value)
262
+ std::vector<uint8_t > ModifiedArgument<std::vector<uint8_t >>::convert(const rapidjson::Value& value)
233
263
{
234
264
if (!value.IsString ())
235
265
{
@@ -282,7 +312,7 @@ rapidjson::Document ModifiedResult<rapidjson::Document>::convert(rapidjson::Docu
282
312
}
283
313
284
314
template <>
285
- rapidjson::Document ModifiedResult<std::vector<unsigned char >>::convert(std::vector<unsigned char >&& result, ResolverParams&&)
315
+ rapidjson::Document ModifiedResult<std::vector<uint8_t >>::convert(std::vector<uint8_t >&& result, ResolverParams&&)
286
316
{
287
317
rapidjson::Document document (rapidjson::Type::kStringType );
288
318
0 commit comments