#include "common.hh" #include "buffer.hh" #include "io.hh" #include "looper.hh" #include "socket_test.hh" #include #include namespace { constexpr size_t kBaseSize = 512 * 1024; constexpr size_t kMaxSize = 10 * 1024 * 1024; class ClientImpl : public SocketTest::Client { public: ClientImpl(std::shared_ptr looper, unique_fd&& fd) : looper_(std::move(looper)), fd_(std::move(fd)), in_(Buffer::growing(kBaseSize, kMaxSize)), out_(Buffer::growing(kBaseSize, kMaxSize)) { looper_->add(fd_.get(), Looper::EVENT_READ, std::bind(&ClientImpl::event, this, std::placeholders::_1)); } ~ClientImpl() override { if (!closed_) looper_->remove(fd_.get()); } bool closed() const override { return closed_; } void write(std::function writer) override { ASSERT_FALSE(closed_); bool empty = out_->empty(); writer(out_.get()); if (empty) { size_t bytes; ASSERT_TRUE(io::drain(out_.get(), fd_.get(), &bytes)); if (out_->empty()) return; } update(); } void write(std::string_view data) override { ASSERT_FALSE(closed_); if (data.empty()) return; bool empty = out_->empty(); auto size = Buffer::write(out_.get(), data.data(), data.size()); ASSERT_EQ(data.size(), size); if (empty) { size_t bytes; ASSERT_TRUE(io::drain(out_.get(), fd_.get(), &bytes)); if (out_->empty()) return; } update(); } void read(std::function reader) override { bool full = in_->full(); reader(in_.get()); if (full && !in_->full()) update(); } std::string_view received() const override { size_t avail; auto* ptr = in_->rbuf(1, avail); return std::string_view(ptr, avail); } void forget(size_t bytes) override { bool full = in_->full(); in_->rcommit(bytes); if (full && bytes > 0) update(); } void wait(Logger* logger) override { ASSERT_FALSE(waiting_); ASSERT_FALSE(closed_); waiting_ = true; looper_->run(logger); waiting_ = false; } private: void event(uint8_t ev) { if (ev & Looper::EVENT_ERROR) { FAIL(); } bool need_update = false; if (ev & Looper::EVENT_READ) { switch (io::fill(fd_.get(), in_.get())) { case io::Return::OK: break; case io::Return::ERR: FAIL(); case io::Return::CLOSED: ASSERT_TRUE(out_->empty()); closed_ = true; looper_->remove(fd_.get()); fd_.reset(); if (waiting_) looper_->quit(); return; } if (in_->full()) need_update = true; } if (ev & Looper::EVENT_WRITE) { ASSERT_TRUE(io::drain(out_.get(), fd_.get())); if (out_->empty()) need_update = true; } if (waiting_) looper_->quit(); if (need_update) update(); } void update() { uint8_t events = 0; if (!in_->full()) events |= Looper::EVENT_READ; if (!out_->empty()) events |= Looper::EVENT_WRITE; looper_->update(fd_.get(), events); } std::shared_ptr looper_; unique_fd fd_; std::unique_ptr in_; std::unique_ptr out_; bool closed_{false}; bool waiting_{false}; }; } // namespace SocketTest::SocketTest() : looper_(Looper::create()) {} std::pair SocketTest::create_pair() { int ret[2]; if (socketpair(AF_UNIX, SOCK_STREAM, 0, ret) == 0) { if (io::make_nonblocking(ret[0]) && io::make_nonblocking(ret[1])) { return std::make_pair(unique_fd(ret[0]), unique_fd(ret[1])); } io::close(ret[0]); io::close(ret[1]); } return std::make_pair(nullptr, nullptr); } std::unique_ptr SocketTest::create_client(unique_fd&& fd) { return std::make_unique(looper_, std::move(fd)); }