diff options
Diffstat (limited to 'src/resolver.cc')
| -rw-r--r-- | src/resolver.cc | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/src/resolver.cc b/src/resolver.cc new file mode 100644 index 0000000..2623089 --- /dev/null +++ b/src/resolver.cc @@ -0,0 +1,207 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <condition_variable> +#include <cstring> +#include <fcntl.h> +#include <mutex> +#include <netdb.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <thread> +#include <vector> + +#include "io.hh" +#include "looper.hh" +#include "resolver.hh" + +namespace { + +size_t const WORKERS = 4; + +class ResolverImpl : public Resolver { +public: + ResolverImpl(Looper* looper) + : looper_(looper), request_(nullptr), buf_(new char[sizeof(Request*)]), + fill_(0), quit_(false) { + while (threads_.size() < WORKERS) { + threads_.emplace_back(std::bind(&ResolverImpl::worker, this)); + } + if (pipe_.open() && fcntl(pipe_.read(), F_SETFL, O_NONBLOCK) == 0) { + looper_->add(pipe_.read(), + Looper::EVENT_READ, + std::bind(&ResolverImpl::event, this, + std::placeholders::_1, std::placeholders::_2)); + } else { + assert(false); + } + } + + ~ResolverImpl() override { + quit_ = true; + cond_.notify_all(); + for (auto& thread : threads_) { + thread.join(); + } + } + + void* request(std::string const& host, uint16_t port, + Callback const& callback) override { + auto req = new Request(); + req->host = host; + req->port = port; + req->callback = callback; + req->canceled = false; + std::unique_lock<std::mutex> lock(mutex_); + req->next = request_; + request_ = req; + cond_.notify_one(); + return req; + } + + void cancel(void* ptr) override { + auto req = reinterpret_cast<Request*>(ptr); + req->canceled = true; + std::unique_lock<std::mutex> lock(mutex_); + if (request_ == req) { + request_ = req->next; + delete req; + } else { + for (auto r = request_; r->next; r = r->next) { + if (r->next == req) { + r->next = req->next; + delete req; + return; + } + } + } + } + +protected: + struct Request { + Request* next; + std::string host; + uint16_t port; + Callback callback; + bool canceled; + io::auto_fd fd; + bool connected; + std::string error; + }; + + void event(int fd, uint8_t event) { + assert(fd == pipe_.read()); + if (event & Looper::EVENT_READ) { + while (true) { + auto ret = io::read(fd, buf_.get() + fill_, sizeof(Request*) - fill_); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return; + assert(false); + return; + } else if (ret == 0) { + assert(false); + return; + } + fill_ += ret; + if (fill_ == sizeof(Request*)) { + fill_ = 0; + auto req = *reinterpret_cast<Request**>(buf_.get()); + if (!req->canceled) { + auto err = req->fd ? nullptr : req->error.c_str(); + req->callback(req->fd.release(), req->connected, err); + } + delete req; + } else { + break; + } + } + } else { + assert(false); + } + } + + void report(Request* req, int fd, bool connected, char const* errmsg) { + req->fd.reset(fd); + req->connected = connected; + if (errmsg) req->error = errmsg; + io::write_all(pipe_.write(), &req, sizeof(Request*)); + } + + void worker() { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG | AI_NUMERICSERV; + char tmp[10]; + while (true) { + Request* req; + { + std::unique_lock<std::mutex> lock(mutex_); + while (!quit_ && !request_) { + cond_.wait(lock); + } + if (quit_) return; + auto pr = &request_; + while ((*pr)->next) { + pr = &((*pr)->next); + } + req = *pr; + *pr = nullptr; + } + snprintf(tmp, sizeof(tmp), "%u", static_cast<unsigned int>(req->port)); + struct addrinfo* result; + auto ret = getaddrinfo(req->host.c_str(), tmp, &hints, &result); + if (ret != 0) { + report(req, -1, false, gai_strerror(ret)); + continue; + } + auto rp = result; + for (; rp; rp = rp->ai_next) { + io::auto_fd fd(socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)); + if (!fd) continue; + + fcntl(fd.get(), F_SETFL, O_NONBLOCK); + + while (true) { + ret = connect(fd.get(), rp->ai_addr, rp->ai_addrlen); + if (ret == 0 || errno != EINTR) break; + } + if (ret == 0) { + report(req, fd.release(), true, nullptr); + break; + } + if (errno == EINPROGRESS) { + report(req, fd.release(), false, nullptr); + break; + } + } + + if (!rp) { + freeaddrinfo(result); + report(req, -1, false, strerror(errno)); + continue; + } + + freeaddrinfo(result); + } + } + + Looper* const looper_; + Request* request_; + io::auto_pipe pipe_; + std::mutex mutex_; + std::condition_variable cond_; + std::unique_ptr<char[]> buf_; + size_t fill_; + bool quit_; + std::vector<std::thread> threads_; +}; + +} // namespace + +// static +Resolver* Resolver::create(Looper* looper) { + return new ResolverImpl(looper); +} |
