Skip to content

Commit 8f8cc3d

Browse files
authored
Merge pull request #1943 from grumpycoders/psyqo-coroutine-fixes
Fixing a few issues in psyqo's coroutines.
2 parents 74a9e52 + 4b9b900 commit 8f8cc3d

File tree

1 file changed

+66
-31
lines changed

1 file changed

+66
-31
lines changed

src/mips/psyqo/coroutine.hh

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ SOFTWARE.
3434

3535
#include "common/psxlibc/ucontext.h"
3636
#include "common/syscalls/syscalls.h"
37+
#include "psyqo/kernel.hh"
3738

3839
namespace psyqo {
3940

@@ -43,7 +44,7 @@ namespace psyqo {
4344
* @details C++20 introduced the concept of coroutines in the language. This
4445
* type can be used to properly hold a coroutine and yield and resume it
4546
* within psyqo. An important caveat of using coroutines is that the language
46-
* insist on calling `new` and `delete` silently within the coroutine object.
47+
* insists on calling `new` and `delete` silently within the coroutine object.
4748
* This may be a problem for users who don't want to use the heap.
4849
*
4950
* @tparam T The type the coroutine returns. `void` by default.
@@ -55,16 +56,50 @@ struct Coroutine {
5556
typedef typename std::conditional<std::is_void<T>::value, Empty, T>::type SafeT;
5657

5758
Coroutine() = default;
58-
Coroutine(Coroutine &&other) = default;
59-
Coroutine &operator=(Coroutine &&other) = default;
59+
60+
Coroutine(Coroutine &&other) {
61+
if (m_handle) m_handle.destroy();
62+
m_handle = nullptr;
63+
m_handle = other.m_handle;
64+
m_value = eastl::move(other.m_value);
65+
m_suspended = other.m_suspended;
66+
m_earlyResume = other.m_earlyResume;
67+
68+
other.m_handle = nullptr;
69+
other.m_value = SafeT{};
70+
other.m_suspended = true;
71+
other.m_earlyResume = false;
72+
}
73+
74+
Coroutine &operator=(Coroutine &&other) {
75+
if (this != &other) {
76+
if (m_handle) m_handle.destroy();
77+
m_handle = nullptr;
78+
m_handle = other.m_handle;
79+
m_value = eastl::move(other.m_value);
80+
m_suspended = other.m_suspended;
81+
m_earlyResume = other.m_earlyResume;
82+
83+
other.m_handle = nullptr;
84+
other.m_value = SafeT{};
85+
other.m_suspended = true;
86+
other.m_earlyResume = false;
87+
}
88+
return *this;
89+
}
90+
6091
Coroutine(Coroutine const &) = delete;
6192
Coroutine &operator=(Coroutine const &) = delete;
93+
~Coroutine() {
94+
if (m_handle) m_handle.destroy();
95+
m_handle = nullptr;
96+
}
6297

6398
/**
6499
* @brief The awaiter type.
65100
*
66101
* @details The awaiter type is the type that is used to suspend the coroutine
67-
* after scheduling an asychronous operation. The keyword `co_await` can be used
102+
* after scheduling an asynchronous operation. The keyword `co_await` can be used
68103
* on an instance of the object to suspend the current coroutine. Creating an
69104
* instance of this object is done by calling `coroutine.awaiter()`.
70105
*/
@@ -158,16 +193,15 @@ struct Coroutine {
158193
std::suspend_always final_suspend() noexcept { return {}; }
159194
void unhandled_exception() {}
160195
void return_void() {
161-
auto awaitingCoroutine = m_awaitingCoroutine;
162-
if (awaitingCoroutine) {
163-
// This doesn't feel right, but I don't know how to do it otherwise,
164-
// since the coroutine is a template and I can't forward the type.
165-
__builtin_coro_resume(awaitingCoroutine);
196+
if (m_awaitingCoroutine) {
197+
Kernel::queueCallback([h = m_awaitingCoroutine]() { h.resume(); });
198+
m_awaitingCoroutine = nullptr;
166199
}
167200
}
168201
[[no_unique_address]] Empty m_value;
169-
void *m_awaitingCoroutine = nullptr;
202+
std::coroutine_handle<> m_awaitingCoroutine;
170203
};
204+
171205
struct PromiseValue {
172206
Coroutine<T> get_return_object() {
173207
return Coroutine{eastl::move(std::coroutine_handle<Promise>::from_promise(*this))};
@@ -177,21 +211,21 @@ struct Coroutine {
177211
void unhandled_exception() {}
178212
void return_value(T &&value) {
179213
m_value = eastl::move(value);
180-
auto awaitingCoroutine = m_awaitingCoroutine;
181-
if (awaitingCoroutine) {
182-
// This doesn't feel right, but I don't know how to do it otherwise,
183-
// since the coroutine is a template and I can't forward the type.
184-
__builtin_coro_resume(awaitingCoroutine);
214+
if (m_awaitingCoroutine) {
215+
Kernel::queueCallback([h = m_awaitingCoroutine]() { h.resume(); });
216+
m_awaitingCoroutine = nullptr;
185217
}
186218
}
187219
T m_value;
188-
void *m_awaitingCoroutine = nullptr;
220+
std::coroutine_handle<> m_awaitingCoroutine;
189221
};
222+
190223
typedef typename std::conditional<std::is_void<T>::value, PromiseVoid, PromiseValue>::type Promise;
224+
191225
Coroutine(std::coroutine_handle<Promise> &&handle) : m_handle(eastl::move(handle)) {}
226+
192227
std::coroutine_handle<Promise> m_handle;
193228
[[no_unique_address]] SafeT m_value;
194-
void *m_awaitingCoroutine = nullptr;
195229
bool m_suspended = true;
196230
bool m_earlyResume = false;
197231

@@ -201,31 +235,32 @@ struct Coroutine {
201235
constexpr bool await_ready() { return m_handle.done(); }
202236
template <typename U>
203237
constexpr void await_suspend(std::coroutine_handle<U> h) {
204-
auto &promise = m_handle.promise();
205-
promise.m_awaitingCoroutine = h.address();
238+
m_handle.promise().m_awaitingCoroutine = h;
206239
resume();
207240
}
208-
constexpr SafeT await_resume() {
209-
SafeT value = eastl::move(m_handle.promise().m_value);
210-
m_handle.destroy();
211-
return value;
241+
constexpr T await_resume() {
242+
if constexpr (std::is_void<T>::value) {
243+
return;
244+
} else {
245+
return eastl::move(m_handle.promise().m_value);
246+
}
212247
}
213248
};
214249

215250
class StackfulBase {
216251
protected:
217-
void initializeInternal(eastl::function<void()>&& func, void* ss_sp, unsigned ss_size);
252+
void initializeInternal(eastl::function<void()> &&func, void *ss_sp, unsigned ss_size);
218253
void resume();
219254
void yield();
220255
[[nodiscard]] bool isAlive() const { return m_isAlive; }
221256

222257
StackfulBase() = default;
223-
StackfulBase(const StackfulBase&) = delete;
224-
StackfulBase& operator=(const StackfulBase&) = delete;
258+
StackfulBase(const StackfulBase &) = delete;
259+
StackfulBase &operator=(const StackfulBase &) = delete;
225260

226261
private:
227-
static void trampoline(void* arg) {
228-
StackfulBase* self = static_cast<StackfulBase*>(arg);
262+
static void trampoline(void *arg) {
263+
StackfulBase *self = static_cast<StackfulBase *>(arg);
229264
self->trampoline();
230265
}
231266
void trampoline();
@@ -254,16 +289,16 @@ class Stackful : public StackfulBase {
254289
static constexpr unsigned c_stackSize = (StackSize + 7) & ~7;
255290

256291
Stackful() = default;
257-
Stackful(const Stackful&) = delete;
258-
Stackful& operator=(const Stackful&) = delete;
292+
Stackful(const Stackful &) = delete;
293+
Stackful &operator=(const Stackful &) = delete;
259294

260295
/**
261296
* @brief Initialize the coroutine with a function and an argument.
262297
*
263298
* @param func Function to be executed by the coroutine.
264299
* @param arg Argument to be passed to the function.
265300
*/
266-
void initialize(eastl::function<void()>&& func) {
301+
void initialize(eastl::function<void()> &&func) {
267302
initializeInternal(eastl::move(func), m_stack.data, c_stackSize);
268303
}
269304

0 commit comments

Comments
 (0)