summaryrefslogtreecommitdiff
path: root/test/socket_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'test/socket_test.cc')
-rw-r--r--test/socket_test.cc168
1 files changed, 168 insertions, 0 deletions
diff --git a/test/socket_test.cc b/test/socket_test.cc
new file mode 100644
index 0000000..f50d306
--- /dev/null
+++ b/test/socket_test.cc
@@ -0,0 +1,168 @@
+#include "common.hh"
+
+#include "buffer.hh"
+#include "io.hh"
+#include "looper.hh"
+#include "socket_test.hh"
+
+#include <sys/socket.h>
+#include <utility>
+
+namespace {
+
+constexpr size_t kBaseSize = 512 * 1024;
+constexpr size_t kMaxSize = 10 * 1024 * 1024;
+
+class ClientImpl : public SocketTest::Client {
+public:
+ ClientImpl(std::shared_ptr<Looper> looper, unique_fd&& fd)
+ : looper_(std::move(looper)), fd_(std::move(fd)),
+ in_(Buffer::growing(kBaseSize, kMaxSize)),
+ out_(Buffer::growing(kBaseSize, kMaxSize)) {
+ looper_->add(fd_.get(), Looper::EVENT_READ,
+ std::bind(&ClientImpl::event, this, std::placeholders::_1));
+ }
+
+ ~ClientImpl() override {
+ if (!closed_)
+ looper_->remove(fd_.get());
+ }
+
+ bool closed() const override {
+ return closed_;
+ }
+
+ void write(std::function<void(Buffer*)> writer) override {
+ ASSERT_FALSE(closed_);
+ bool empty = out_->empty();
+ writer(out_.get());
+ if (empty) {
+ size_t bytes;
+ ASSERT_TRUE(io::drain(out_.get(), fd_.get(), &bytes));
+ if (out_->empty())
+ return;
+ }
+ update();
+ }
+
+ void write(std::string_view data) override {
+ ASSERT_FALSE(closed_);
+ if (data.empty())
+ return;
+ bool empty = out_->empty();
+ auto size = Buffer::write(out_.get(), data.data(), data.size());
+ ASSERT_EQ(data.size(), size);
+ if (empty) {
+ size_t bytes;
+ ASSERT_TRUE(io::drain(out_.get(), fd_.get(), &bytes));
+ if (out_->empty())
+ return;
+ }
+ update();
+ }
+
+ void read(std::function<void(RoBuffer*)> reader) override {
+ bool full = in_->full();
+
+ reader(in_.get());
+
+ if (full && !in_->full())
+ update();
+ }
+
+ std::string_view received() const override {
+ size_t avail;
+ auto* ptr = in_->rbuf(1, avail);
+ return std::string_view(ptr, avail);
+ }
+
+ void forget(size_t bytes) override {
+ bool full = in_->full();
+ in_->rcommit(bytes);
+
+ if (full && bytes > 0)
+ update();
+ }
+
+ void wait(Logger* logger) override {
+ ASSERT_FALSE(waiting_);
+ ASSERT_FALSE(closed_);
+ waiting_ = true;
+ looper_->run(logger);
+ waiting_ = false;
+ }
+
+private:
+ void event(uint8_t ev) {
+ if (ev & Looper::EVENT_ERROR) {
+ FAIL();
+ }
+ bool need_update = false;
+ if (ev & Looper::EVENT_READ) {
+ switch (io::fill(fd_.get(), in_.get())) {
+ case io::Return::OK:
+ break;
+ case io::Return::ERR:
+ FAIL();
+ case io::Return::CLOSED:
+ ASSERT_TRUE(out_->empty());
+ closed_ = true;
+ looper_->remove(fd_.get());
+ fd_.reset();
+ if (waiting_)
+ looper_->quit();
+ return;
+ }
+ if (in_->full())
+ need_update = true;
+ }
+ if (ev & Looper::EVENT_WRITE) {
+ ASSERT_TRUE(io::drain(out_.get(), fd_.get()));
+ if (out_->empty())
+ need_update = true;
+ }
+
+ if (waiting_)
+ looper_->quit();
+
+ if (need_update)
+ update();
+ }
+
+ void update() {
+ uint8_t events = 0;
+ if (!in_->full())
+ events |= Looper::EVENT_READ;
+ if (!out_->empty())
+ events |= Looper::EVENT_WRITE;
+ looper_->update(fd_.get(), events);
+ }
+
+ std::shared_ptr<Looper> looper_;
+ unique_fd fd_;
+ std::unique_ptr<Buffer> in_;
+ std::unique_ptr<Buffer> out_;
+ bool closed_{false};
+ bool waiting_{false};
+};
+
+} // namespace
+
+SocketTest::SocketTest()
+ : looper_(Looper::create()) {}
+
+std::pair<unique_fd, unique_fd> SocketTest::create_pair() {
+ int ret[2];
+ if (socketpair(AF_UNIX, SOCK_STREAM, 0, ret) == 0) {
+ if (io::make_nonblocking(ret[0]) && io::make_nonblocking(ret[1])) {
+ return std::make_pair(unique_fd(ret[0]), unique_fd(ret[1]));
+ }
+ io::close(ret[0]);
+ io::close(ret[1]);
+ }
+ return std::make_pair(nullptr, nullptr);
+}
+
+std::unique_ptr<SocketTest::Client> SocketTest::create_client(unique_fd&& fd) {
+ return std::make_unique<ClientImpl>(looper_, std::move(fd));
+}