diff options
Diffstat (limited to 'test/socket_test.cc')
| -rw-r--r-- | test/socket_test.cc | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/test/socket_test.cc b/test/socket_test.cc new file mode 100644 index 0000000..f50d306 --- /dev/null +++ b/test/socket_test.cc @@ -0,0 +1,168 @@ +#include "common.hh" + +#include "buffer.hh" +#include "io.hh" +#include "looper.hh" +#include "socket_test.hh" + +#include <sys/socket.h> +#include <utility> + +namespace { + +constexpr size_t kBaseSize = 512 * 1024; +constexpr size_t kMaxSize = 10 * 1024 * 1024; + +class ClientImpl : public SocketTest::Client { +public: + ClientImpl(std::shared_ptr<Looper> looper, unique_fd&& fd) + : looper_(std::move(looper)), fd_(std::move(fd)), + in_(Buffer::growing(kBaseSize, kMaxSize)), + out_(Buffer::growing(kBaseSize, kMaxSize)) { + looper_->add(fd_.get(), Looper::EVENT_READ, + std::bind(&ClientImpl::event, this, std::placeholders::_1)); + } + + ~ClientImpl() override { + if (!closed_) + looper_->remove(fd_.get()); + } + + bool closed() const override { + return closed_; + } + + void write(std::function<void(Buffer*)> writer) override { + ASSERT_FALSE(closed_); + bool empty = out_->empty(); + writer(out_.get()); + if (empty) { + size_t bytes; + ASSERT_TRUE(io::drain(out_.get(), fd_.get(), &bytes)); + if (out_->empty()) + return; + } + update(); + } + + void write(std::string_view data) override { + ASSERT_FALSE(closed_); + if (data.empty()) + return; + bool empty = out_->empty(); + auto size = Buffer::write(out_.get(), data.data(), data.size()); + ASSERT_EQ(data.size(), size); + if (empty) { + size_t bytes; + ASSERT_TRUE(io::drain(out_.get(), fd_.get(), &bytes)); + if (out_->empty()) + return; + } + update(); + } + + void read(std::function<void(RoBuffer*)> reader) override { + bool full = in_->full(); + + reader(in_.get()); + + if (full && !in_->full()) + update(); + } + + std::string_view received() const override { + size_t avail; + auto* ptr = in_->rbuf(1, avail); + return std::string_view(ptr, avail); + } + + void forget(size_t bytes) override { + bool full = in_->full(); + in_->rcommit(bytes); + + if (full && bytes > 0) + update(); + } + + void wait(Logger* logger) override { + ASSERT_FALSE(waiting_); + ASSERT_FALSE(closed_); + waiting_ = true; + looper_->run(logger); + waiting_ = false; + } + +private: + void event(uint8_t ev) { + if (ev & Looper::EVENT_ERROR) { + FAIL(); + } + bool need_update = false; + if (ev & Looper::EVENT_READ) { + switch (io::fill(fd_.get(), in_.get())) { + case io::Return::OK: + break; + case io::Return::ERR: + FAIL(); + case io::Return::CLOSED: + ASSERT_TRUE(out_->empty()); + closed_ = true; + looper_->remove(fd_.get()); + fd_.reset(); + if (waiting_) + looper_->quit(); + return; + } + if (in_->full()) + need_update = true; + } + if (ev & Looper::EVENT_WRITE) { + ASSERT_TRUE(io::drain(out_.get(), fd_.get())); + if (out_->empty()) + need_update = true; + } + + if (waiting_) + looper_->quit(); + + if (need_update) + update(); + } + + void update() { + uint8_t events = 0; + if (!in_->full()) + events |= Looper::EVENT_READ; + if (!out_->empty()) + events |= Looper::EVENT_WRITE; + looper_->update(fd_.get(), events); + } + + std::shared_ptr<Looper> looper_; + unique_fd fd_; + std::unique_ptr<Buffer> in_; + std::unique_ptr<Buffer> out_; + bool closed_{false}; + bool waiting_{false}; +}; + +} // namespace + +SocketTest::SocketTest() + : looper_(Looper::create()) {} + +std::pair<unique_fd, unique_fd> SocketTest::create_pair() { + int ret[2]; + if (socketpair(AF_UNIX, SOCK_STREAM, 0, ret) == 0) { + if (io::make_nonblocking(ret[0]) && io::make_nonblocking(ret[1])) { + return std::make_pair(unique_fd(ret[0]), unique_fd(ret[1])); + } + io::close(ret[0]); + io::close(ret[1]); + } + return std::make_pair(nullptr, nullptr); +} + +std::unique_ptr<SocketTest::Client> SocketTest::create_client(unique_fd&& fd) { + return std::make_unique<ClientImpl>(looper_, std::move(fd)); +} |
