Skip to content

Add write, read, drained #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ coverage:
status:
patch:
default:
target: 90%
target: 80%
threshold: 0%
if_ci_failed: error
project:
Expand Down
9 changes: 1 addition & 8 deletions include/msd/blocking_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include <cstddef>
#include <iterator>
#include <mutex>

namespace msd {

Expand Down Expand Up @@ -76,13 +75,7 @@ class blocking_iterator {
* @return true if the channel is not closed or not empty (continue iterating).
* @return false if the channel is closed and empty (stop iterating).
*/
bool operator!=(blocking_iterator<Channel>) const
{
std::unique_lock<std::mutex> lock{chan_.mtx_};
chan_.waitBeforeRead(lock);

return !(chan_.closed() && chan_.empty());
}
bool operator!=(blocking_iterator<Channel>) const { return !chan_.drained(); }

private:
Channel& chan_;
Expand Down
138 changes: 98 additions & 40 deletions include/msd/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#ifndef MSD_CHANNEL_HPP_
#define MSD_CHANNEL_HPP_

#include <atomic>
#include <condition_variable>
#include <cstdlib>
#include <mutex>
Expand Down Expand Up @@ -82,34 +81,102 @@ class channel {
/**
* @brief Pops an element from the channel.
*
* @tparam Type The type of the elements
* @tparam Type The type of the elements.
*/
template <typename Type>
friend channel<Type>& operator>>(channel<Type>&, Type&);

/**
* @brief Pushes an element into the channel.
*
* @tparam Type The type of the elements.
*
* @param value The element to be pushed into the channel.
*
* @return true If an element was successfully pushed into the channel.
* @return false If the channel is closed.
*/
template <typename Type>
bool write(Type&& value)
{
{
std::unique_lock<std::mutex> lock{mtx_};
waitBeforeWrite(lock);

if (is_closed_) {
return false;
}

queue_.push(std::forward<Type>(value));
++size_;
}

cnd_.notify_one();

return true;
}

/**
* @brief Pops an element from the channel.
*
* @param out Reference to the variable where the popped element will be stored.
*
* @return true If an element was successfully read from the channel.
* @return false If the channel is closed and empty.
*/
bool read(T& out)
{
{
std::unique_lock<std::mutex> lock{mtx_};
waitBeforeRead(lock);

if (is_closed_ && size_ == 0) {
return false;
}

if (!(size_ == 0)) {
out = std::move(queue_.front());
queue_.pop();
--size_;
}
}

cnd_.notify_one();

return true;
}

/**
* @brief Returns the current size of the channel.
*
* @return The number of elements in the channel.
*/
NODISCARD size_type constexpr size() const noexcept { return size_; }
NODISCARD size_type size() const noexcept
{
std::unique_lock<std::mutex> lock{mtx_};
return size_;
}

/**
* @brief Checks if the channel is empty.
*
* @return true If the channel contains no elements.
* @return false Otherwise.
*/
NODISCARD bool constexpr empty() const noexcept { return size_ == 0; }
NODISCARD bool empty() const noexcept
{
std::unique_lock<std::mutex> lock{mtx_};
return size_ == 0;
}

/**
* @brief Closes the channel.
* @brief Closes the channel, no longer accepting new elements.
*/
void close() noexcept
{
{
std::unique_lock<std::mutex> lock{mtx_};
is_closed_.store(true, std::memory_order_seq_cst);
is_closed_ = true;
}
cnd_.notify_all();
}
Expand All @@ -120,7 +187,23 @@ class channel {
* @return true If no more elements can be added to the channel.
* @return false Otherwise.
*/
NODISCARD bool closed() const noexcept { return is_closed_.load(std::memory_order_seq_cst); }
NODISCARD bool closed() const noexcept
{
std::unique_lock<std::mutex> lock{mtx_};
return is_closed_;
}

/**
* @brief Checks if the channel has been closed and is empty.
*
* @return true If nothing can be read anymore from the channel.
* @return false Otherwise.
*/
NODISCARD bool drained() noexcept
{
std::unique_lock<std::mutex> lock{mtx_};
return is_closed_ && size_ == 0;
}

/**
* @brief Returns an iterator to the beginning of the channel.
Expand All @@ -146,16 +229,16 @@ class channel {
virtual ~channel() = default;

private:
const size_type cap_{0};
std::queue<T> queue_;
std::atomic<std::size_t> size_{0};
std::mutex mtx_;
std::size_t size_{0};
const size_type cap_{0};
mutable std::mutex mtx_;
std::condition_variable cnd_;
std::atomic<bool> is_closed_{false};
bool is_closed_{false};

void waitBeforeRead(std::unique_lock<std::mutex>& lock)
{
cnd_.wait(lock, [this]() { return !empty() || closed(); });
cnd_.wait(lock, [this]() { return !(size_ == 0) || is_closed_; });
};

void waitBeforeWrite(std::unique_lock<std::mutex>& lock)
Expand All @@ -171,42 +254,17 @@ class channel {
template <typename T>
channel<typename std::decay<T>::type>& operator<<(channel<typename std::decay<T>::type>& chan, T&& value)
{
{
std::unique_lock<std::mutex> lock{chan.mtx_};
chan.waitBeforeWrite(lock);

if (chan.closed()) {
throw closed_channel{"cannot write on closed channel"};
}

chan.queue_.push(std::forward<T>(value));
++chan.size_;
if (!chan.write(std::forward<T>(value))) {
throw closed_channel{"cannot write on closed channel"};
}

chan.cnd_.notify_one();

return chan;
}

template <typename T>
channel<T>& operator>>(channel<T>& chan, T& out)
{
{
std::unique_lock<std::mutex> lock{chan.mtx_};
chan.waitBeforeRead(lock);

if (chan.closed() && chan.empty()) {
return chan;
}

if (!chan.empty()) {
out = std::move(chan.queue_.front());
chan.queue_.pop();
--chan.size_;
}
}

chan.cnd_.notify_one();
chan.read(out);

return chan;
}
Expand Down
83 changes: 83 additions & 0 deletions tests/channel_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,30 @@ TEST(ChannelTest, PushAndFetch)
EXPECT_EQ(4, out);
}

TEST(ChannelTest, WriteAndRead)
{
msd::channel<int> channel;

int in = 1;
EXPECT_TRUE(channel.write(in));

const int cin = 3;
EXPECT_TRUE(channel.write(cin));

channel.close();
EXPECT_FALSE(channel.write(2));

int out = 0;

EXPECT_TRUE(channel.read(out));
EXPECT_EQ(1, out);

EXPECT_TRUE(channel.read(out));
EXPECT_EQ(3, out);

EXPECT_FALSE(channel.read(out));
}

TEST(ChannelTest, PushAndFetchWithBufferedChannel)
{
msd::channel<int> channel{2};
Expand Down Expand Up @@ -167,6 +191,23 @@ TEST(ChannelTest, close)
EXPECT_THROW(channel << std::move(in), msd::closed_channel);
}

TEST(ChannelTest, drained)
{
msd::channel<int> channel;
EXPECT_FALSE(channel.drained());

int in = 1;
channel << in;

channel.close();
EXPECT_FALSE(channel.drained());

int out = 0;
channel >> out;
EXPECT_EQ(1, out);
EXPECT_TRUE(channel.drained());
}

TEST(ChannelTest, Iterator)
{
msd::channel<int> channel;
Expand Down Expand Up @@ -238,3 +279,45 @@ TEST(ChannelTest, Multithreading)

EXPECT_EQ(expected, sum_numbers);
}

TEST(ChannelTest, ReadWriteClose)
{
const int numbers = 10000;
const std::int64_t expected_sum = 50005000;
constexpr std::size_t kThreadsToReadFrom = 20;

msd::channel<int> channel{kThreadsToReadFrom};
std::atomic<std::int64_t> sum{0};
std::atomic<std::int64_t> nums{0};

std::thread writer([&channel]() {
for (int i = 1; i <= numbers; ++i) {
channel << i;
}
channel.close();
});

std::vector<std::thread> readers;
for (std::size_t i = 0; i < kThreadsToReadFrom; ++i) {
readers.emplace_back([&channel, &sum, &nums]() {
while (true) {
int value = 0;

if (!channel.read(value)) {
return;
}

sum += value;
++nums;
}
});
}

writer.join();
for (auto& reader : readers) {
reader.join();
}

EXPECT_EQ(sum, expected_sum);
EXPECT_EQ(nums, numbers);
}