Skip to content

Commit 090abd8

Browse files
committed
Eliminate a bunch of branching, pass a temporary context instead
1 parent 585ec39 commit 090abd8

File tree

2 files changed

+79
-67
lines changed

2 files changed

+79
-67
lines changed

zasm/src/zasm/src/encoder/encoder.context.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,17 @@ namespace zasm
3333
Section::Attribs attribs{};
3434
};
3535

36+
enum class EncoderFlags : std::uint32_t
37+
{
38+
none = 0,
39+
temporary = 1U << 0,
40+
};
41+
ZASM_ENABLE_ENUM_OPERATORS(EncoderFlags);
42+
3643
struct EncoderContext
3744
{
3845
public:
46+
EncoderFlags flags{};
3947
detail::ProgramState* program{};
4048
bool needsExtraPass{};
4149
std::size_t nodeIndex{};
@@ -45,7 +53,6 @@ namespace zasm
4553
std::int64_t va{};
4654
std::int32_t offset{};
4755
std::int32_t instrSize{};
48-
4956

5057
struct LabelLink
5158
{
@@ -105,6 +112,11 @@ namespace zasm
105112
{
106113
assert(id != Label::Id::Invalid);
107114

115+
if ((flags & EncoderFlags::temporary) != EncoderFlags::none)
116+
{
117+
return std::nullopt;
118+
}
119+
108120
const auto& entry = getOrCreateLabelLink(id);
109121
if (entry.boundVA == -1)
110122
{
@@ -113,5 +125,16 @@ namespace zasm
113125

114126
return entry.boundVA;
115127
}
128+
129+
std::uint32_t getNodeSize(std::size_t nodeIndex) const
130+
{
131+
if ((flags & EncoderFlags::temporary) != EncoderFlags::none)
132+
{
133+
return 0;
134+
}
135+
136+
assert(nodeIndex < nodes.size());
137+
return nodes[nodeIndex].length;
138+
}
116139
};
117140
} // namespace zasm

zasm/src/zasm/src/encoder/encoder.cpp

Lines changed: 55 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@ namespace zasm
1818

1919
struct EncoderState
2020
{
21-
EncoderContext* ctx{};
21+
EncoderContext& ctx;
2222
ZydisEncoderRequest req{};
2323
std::size_t operandIndex{};
2424
RelocationType relocKind{};
2525
RelocationData relocData{};
2626
Label::Id relocLabel{ Label::Id::Invalid };
27+
28+
EncoderState(EncoderContext& ctx_) noexcept
29+
: ctx(ctx_)
30+
{
31+
}
2732
};
2833

2934
// NOTE: This value has to be at least larger than 0xFFFF to be used with imm32/rel32 displacement.
@@ -97,8 +102,14 @@ namespace zasm
97102
return encoderVariantData[mnemonic]; // NOLINT
98103
}
99104

100-
static bool isLabelExternal(detail::ProgramState* state, Label::Id labelId)
105+
static bool isLabelExternal(EncoderContext& ctx, Label::Id labelId)
101106
{
107+
if ((ctx.flags & EncoderFlags::temporary) != EncoderFlags::none)
108+
{
109+
return false;
110+
}
111+
112+
const auto state = ctx.program;
102113
const auto idx = static_cast<std::size_t>(labelId);
103114
if (idx >= state->labels.size())
104115
{
@@ -137,7 +148,8 @@ namespace zasm
137148
return res;
138149
}
139150

140-
static Error buildOperand_(ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state, const Reg& src) noexcept
151+
static Error buildOperand_(
152+
EncoderContext&, ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state, const Reg& src) noexcept
141153
{
142154
dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_REGISTER;
143155
dst.reg.value = static_cast<ZydisRegister>(src.getId());
@@ -147,8 +159,6 @@ namespace zasm
147159

148160
static int64_t getTemporaryRel(EncoderState& state, const EncodeVariantsInfo& encodeInfo) noexcept
149161
{
150-
auto* ctx = state.ctx;
151-
152162
std::int64_t tempRel = 0;
153163

154164
if (encodeInfo.canEncodeRel32())
@@ -163,22 +173,19 @@ namespace zasm
163173
return tempRel;
164174
}
165175

166-
static Error buildOperand_(ZydisEncoderOperand& dst, EncoderState& state, const Label& src)
176+
static Error buildOperand_(EncoderContext& ctx, ZydisEncoderOperand& dst, EncoderState& state, const Label& src)
167177
{
168-
auto* ctx = state.ctx;
169-
170-
const auto curVA = ctx != nullptr ? ctx->va : 0;
171178
const auto& encodeInfo = getEncodeVariantInfo(state.req.mnemonic);
172179

173180
// Initially a temporary placeholder.
174-
std::int64_t immValue = curVA + getTemporaryRel(state, encodeInfo);
181+
std::int64_t immValue = ctx.va + getTemporaryRel(state, encodeInfo);
175182

176-
if (ctx != nullptr && !isLabelExternal(ctx->program, src.getId()))
183+
if (!isLabelExternal(ctx, src.getId()))
177184
{
178-
auto labelVA = ctx->getLabelAddress(src.getId());
185+
auto labelVA = ctx.getLabelAddress(src.getId());
179186
if (!labelVA.has_value())
180187
{
181-
ctx->needsExtraPass = true;
188+
ctx.needsExtraPass = true;
182189
}
183190
else
184191
{
@@ -188,8 +195,7 @@ namespace zasm
188195

189196
if (encodeInfo.isControlFlow)
190197
{
191-
const auto instrSize = ctx != nullptr ? ctx->instrSize : 0;
192-
const auto rel = immValue - (curVA + instrSize);
198+
const auto rel = immValue - (ctx.va + ctx.instrSize);
193199

194200
if (!encodeInfo.canEncodeRel32())
195201
{
@@ -228,10 +234,8 @@ namespace zasm
228234
return ErrorCode::None;
229235
}
230236

231-
static Error buildOperand_(ZydisEncoderOperand& dst, EncoderState& state, const Imm& src)
237+
static Error buildOperand_(EncoderContext&, ZydisEncoderOperand& dst, EncoderState& state, const Imm& src)
232238
{
233-
auto* ctx = state.ctx;
234-
235239
auto desiredBranchType = ZydisBranchType::ZYDIS_BRANCH_TYPE_NONE;
236240

237241
dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_IMMEDIATE;
@@ -240,10 +244,8 @@ namespace zasm
240244
return ErrorCode::None;
241245
}
242246

243-
static Error buildOperand_(ZydisEncoderOperand& dst, EncoderState& state, const Mem& src)
247+
static Error buildOperand_(EncoderContext& ctx, ZydisEncoderOperand& dst, EncoderState& state, const Mem& src)
244248
{
245-
auto* ctx = state.ctx;
246-
247249
dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_MEMORY;
248250
dst.mem.base = static_cast<ZydisRegister>(src.getBase().getId());
249251
dst.mem.index = static_cast<ZydisRegister>(src.getIndex().getId());
@@ -258,38 +260,29 @@ namespace zasm
258260

259261
std::int64_t displacement = src.getDisplacement();
260262

261-
const auto address = ctx != nullptr ? ctx->va : 0;
262-
263263
bool usingLabel = false;
264264
bool externalLabel = false;
265265
bool isDisplacementValid = true;
266266

267267
if (const auto labelId = src.getLabelId(); labelId != Label::Id::Invalid)
268268
{
269-
if (ctx != nullptr)
270-
{
271-
externalLabel = isLabelExternal(ctx->program, labelId);
269+
externalLabel = isLabelExternal(ctx, labelId);
272270

273-
auto labelVA = ctx->getLabelAddress(labelId);
274-
if (labelVA.has_value())
275-
{
276-
displacement += *labelVA;
277-
}
278-
else
279-
{
280-
displacement += address + kTemporaryRel32Value;
281-
isDisplacementValid = false;
282-
if (!externalLabel)
283-
{
284-
ctx->needsExtraPass = true;
285-
}
286-
}
271+
auto labelVA = ctx.getLabelAddress(labelId);
272+
if (labelVA.has_value())
273+
{
274+
displacement += *labelVA;
287275
}
288276
else
289277
{
290-
displacement = kTemporaryRel32Value;
278+
displacement += ctx.va + kTemporaryRel32Value;
291279
isDisplacementValid = false;
280+
if (!externalLabel)
281+
{
282+
ctx.needsExtraPass = true;
283+
}
292284
}
285+
293286
usingLabel = true;
294287
}
295288

@@ -307,8 +300,7 @@ namespace zasm
307300
{
308301
if (isDisplacementValid)
309302
{
310-
const auto instrSize = ctx != nullptr ? ctx->instrSize : 0;
311-
const auto rel = displacement - (address + instrSize);
303+
const auto rel = displacement - (ctx.va + ctx.instrSize);
312304
if (std::abs(rel) > std::numeric_limits<std::int32_t>::max())
313305
{
314306
char msg[128];
@@ -343,15 +335,16 @@ namespace zasm
343335
}
344336

345337
static Error buildOperand_(
346-
ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state, [[maybe_unused]] const Operand::None& src) noexcept
338+
EncoderContext&, ZydisEncoderOperand& dst, [[maybe_unused]] EncoderState& state,
339+
[[maybe_unused]] const Operand::None& src) noexcept
347340
{
348341
dst.type = ZydisOperandType::ZYDIS_OPERAND_TYPE_UNUSED;
349342
return ErrorCode::None;
350343
}
351344

352-
static Error buildOperand(ZydisEncoderOperand& dst, EncoderState& state, const Operand& src)
345+
static Error buildOperand(EncoderContext& ctx, ZydisEncoderOperand& dst, EncoderState& state, const Operand& src)
353346
{
354-
return src.visit([&dst, &state](auto&& src2) { return buildOperand_(dst, state, src2); });
347+
return src.visit([&](auto&& src2) { return buildOperand_(ctx, dst, state, src2); });
355348
}
356349

357350
static void fixupIs4Operands(ZydisEncoderRequest& req) noexcept
@@ -432,7 +425,7 @@ namespace zasm
432425
}
433426

434427
static Error encode_(
435-
EncoderResult& res, EncoderContext* ctx, MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic,
428+
EncoderResult& res, EncoderContext& ctx, MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic,
436429
size_t numOps, const Operand* operands)
437430
{
438431
if (!validateMachineMode(mode))
@@ -442,8 +435,7 @@ namespace zasm
442435

443436
res.buffer.length = 0;
444437

445-
EncoderState state{};
446-
state.ctx = ctx;
438+
EncoderState state{ ctx };
447439

448440
ZydisEncoderRequest& req = state.req;
449441
if (mode == MachineMode::AMD64)
@@ -481,7 +473,7 @@ namespace zasm
481473
{
482474
auto& dstOp = req.operands[state.operandIndex]; // NOLINT
483475
const auto& srcOp = operands[state.operandIndex]; // NOLINT
484-
if (auto opStatus = buildOperand(dstOp, state, srcOp); opStatus != ErrorCode::None)
476+
if (auto opStatus = buildOperand(ctx, dstOp, state, srcOp); opStatus != ErrorCode::None)
485477
{
486478
return opStatus;
487479
}
@@ -491,8 +483,7 @@ namespace zasm
491483
fixupIs4Operands(req);
492484

493485
std::size_t bufLen = res.buffer.data.size();
494-
const auto curAddress = ctx != nullptr ? ctx->va : 0;
495-
switch (auto status = ZydisEncoderEncodeInstructionAbsolute(&req, res.buffer.data.data(), &bufLen, curAddress); status)
486+
switch (auto status = ZydisEncoderEncodeInstructionAbsolute(&req, res.buffer.data.data(), &bufLen, ctx.va); status)
496487
{
497488
case ZYAN_STATUS_SUCCESS:
498489
break;
@@ -509,26 +500,14 @@ namespace zasm
509500
return ErrorCode::None;
510501
}
511502

512-
Expected<EncoderResult, Error> encode(
513-
MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic, std::size_t numOps,
514-
const Operand* operands)
515-
{
516-
EncoderResult res;
517-
if (auto err = encode_(res, nullptr, mode, attribs, mnemonic, numOps, operands); err != ErrorCode::None)
518-
{
519-
return makeUnexpected(err);
520-
}
521-
return res;
522-
}
523-
524503
static Expected<EncoderResult, Error> encodeWithContext(
525504
EncoderContext& ctx, MachineMode mode, Instruction::Attribs prefixes, Instruction::Mnemonic mnemonic,
526505
std::size_t numOps, const Operand* operands)
527506
{
528507
EncoderResult res;
529-
ctx.instrSize = ctx.nodes[ctx.nodeIndex].length;
508+
ctx.instrSize = ctx.getNodeSize(ctx.nodeIndex);
530509

531-
if (const auto encodeError = encode_(res, &ctx, mode, prefixes, mnemonic, numOps, operands);
510+
if (const auto encodeError = encode_(res, ctx, mode, prefixes, mnemonic, numOps, operands);
532511
encodeError != ErrorCode::None)
533512
{
534513
return makeUnexpected(encodeError);
@@ -537,6 +516,16 @@ namespace zasm
537516
return res;
538517
}
539518

519+
Expected<EncoderResult, Error> encode(
520+
MachineMode mode, Instruction::Attribs attribs, Instruction::Mnemonic mnemonic, std::size_t numOps,
521+
const Operand* operands)
522+
{
523+
EncoderContext tempCtx{};
524+
tempCtx.flags |= EncoderFlags::temporary;
525+
526+
return encodeWithContext(tempCtx, mode, attribs, mnemonic, numOps, operands);
527+
}
528+
540529
Expected<EncoderResult, Error> encode(EncoderContext& ctx, MachineMode mode, const Instruction& instr)
541530
{
542531
const auto& ops = instr.getOperands();

0 commit comments

Comments
 (0)