diff options
Diffstat (limited to 'src/looper.cc')
| -rw-r--r-- | src/looper.cc | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/src/looper.cc b/src/looper.cc new file mode 100644 index 0000000..0da851b --- /dev/null +++ b/src/looper.cc @@ -0,0 +1,287 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <list> +#include <poll.h> +#include <vector> + +#include "looper.hh" + +namespace { + +int const read_events = POLLIN | POLLPRI +#ifdef POLLRDHUP + | POLLRDHUP +#endif + ; +int const write_events = POLLOUT; +int const error_events = POLLERR | POLLNVAL; +int const hup_events = POLLHUP; + +class LooperImpl : public Looper { +public: + LooperImpl() + : fds_to_remove_(0), fds_protected_(0), quit_(false) { + } + + ~LooperImpl() override { + for (auto& timeout : timeouts_) { + delete timeout; + } + } + + void add(int fd, uint8_t events, FdCallback const& callback) override { + if (fd < 0) { + assert(false); + return; + } + for (auto it = fds_.begin(); it != fds_.end(); ++it) { + if (it->fd == fd) { + size_t index = it - fds_.begin(); + auto& entry = fdentries_[index]; + if (index >= fds_protected_) { + entry.callback = callback; + it->events = pollevents(events); + return; + } else { + // Don't call new callback this run, so add it at the end but + // remove the old callback as it would be replaced + if (!entry.removed) { + entry.removed = true; + fds_to_remove_++; + } + } + } + } + + fds_.emplace_back((struct pollfd) { fd, pollevents(events), 0 }); + fdentries_.emplace_back(callback); + } + + void modify(int fd, uint8_t events) override { + if (fd < 0) { + assert(false); + return; + } + for (auto it = fds_.begin(); it != fds_.end(); ++it) { + if (it->fd == fd) { + size_t index = it - fds_.begin(); + if (index < fds_protected_) { + auto entry = fdentries_.begin() + index; + // If entry is removed we need to write to the later one + if (entry->removed) { + continue; + } + } + it->events = pollevents(events); + return; + } + } + assert(false); + } + + void remove(int fd) override { + if (fd < 0) return; + for (auto it = fds_.begin(); it != fds_.end(); ++it) { + if (it->fd == fd) { + size_t index = it - fds_.begin(); + auto entry = fdentries_.begin() + index; + if (index < fds_protected_) { + if (!entry->removed) { + entry->removed = true; + fds_to_remove_++; + } + } else { + fds_.erase(it); + fdentries_.erase(entry); + } + return; + } + } + assert(false); + } + + void* schedule(float delay_s, ScheduleCallback const& callback) override { + clock::time_point target = clock::now() + + std::chrono::duration_cast<clock::duration>( + std::chrono::duration<float>(delay_s)); + auto timeout = new Timeout(target, callback); + for (auto it = timeouts_.begin(); it != timeouts_.end(); ++it) { + if (target < (*it)->target) { + timeouts_.insert(it, timeout); + return timeout; + } + } + timeouts_.push_back(timeout); + return timeout; + } + + void cancel(void* handle) override { + auto timeout = reinterpret_cast<Timeout*>(handle); + if (!timeout) { + assert(false); + return; + } + if (timeout->expired) return; + for (auto it = timeouts_.begin(); it != timeouts_.end(); ++it) { + if (*it == timeout) { + timeouts_.erase(it); + delete timeout; + return; + } + } + assert(false); + } + + void quit() override { + quit_ = true; + } + + bool run() override { + std::vector<Timeout*> expired; + + while (!quit_) { + int timeout = -1; + if (!timeouts_.empty()) { + auto dur = std::chrono::duration_cast<std::chrono::milliseconds>( + timeouts_.front()->target - clock::now()); + if (dur.count() <= 0) { + timeout = 0; + } else if (dur.count() < std::numeric_limits<int>::max()) { + timeout = dur.count(); + } else { + timeout = std::numeric_limits<int>::max(); + } + } + auto ret = poll(fds_.data(), fds_.size(), timeout); + if (ret < 0) { + if (errno == EINTR) continue; + return false; + } + now_ = clock::now(); + fds_protected_ = fds_.size(); + + if (!timeouts_.empty()) { + while (timeouts_.front()->target <= now_) { + auto timeout = timeouts_.front(); + timeouts_.pop_front(); + timeout->expired = true; + expired.push_back(timeout); + if (timeouts_.empty()) break; + } + + for (auto& timeout : expired) { + timeout->callback(timeout); + } + for (auto& timeout : expired) { + delete timeout; + } + expired.clear(); + } + + // Not using iterators here as that would be unsafe with + // add() and remove() modifying the vector outside protected range + // while callbacks are called + size_t i; + for (i = 0; ret > 0 && i < fds_protected_; ++i) { + if (fds_[i].revents) { + --ret; + if (!fdentries_[i].removed) { + fdentries_[i].callback(fds_[i].fd, unpollevents(fds_[i].revents)); + } + } + } + assert(ret == 0); + assert(fds_.size() >= fds_protected_); + assert(fdentries_.size() >= fds_protected_); + for (i = fds_protected_; fds_to_remove_ > 0 && i > 0; --i) { + if (fdentries_[i - 1].removed) { + --fds_to_remove_; + fds_.erase(fds_.begin() + i - 1); + fdentries_.erase(fdentries_.begin() + i - 1); + } + } + assert(fds_to_remove_ == 0); + fds_protected_ = 0; + } + return true; + } + + clock::time_point now() const override { + return now_; + } + +private: + struct FdEntry { + FdCallback callback; + bool removed; + + FdEntry(FdCallback const& callback) + : callback(callback), removed(false) { + } + }; + + struct Timeout { + clock::time_point target; + ScheduleCallback callback; + bool expired; + + Timeout(clock::time_point target, ScheduleCallback const& callback) + : target(target), callback(callback), expired(false) { + } + }; + + static uint8_t unpollevents(short events) { + uint8_t ret = 0; + if (events & read_events) { + ret |= EVENT_READ; + } + if (events & write_events) { + ret |= EVENT_WRITE; + } + if (events & error_events) { + ret |= EVENT_ERROR; + } + if (events & hup_events) { + ret |= EVENT_HUP; + } + return ret; + } + + static short pollevents(uint8_t events) { + int ret = 0; + if (events & EVENT_READ) { + ret |= read_events; + } + if (events & EVENT_WRITE) { + ret |= write_events; + } + return ret; + } + + std::vector<struct pollfd> fds_; + std::vector<FdEntry> fdentries_; + size_t fds_to_remove_; + size_t fds_protected_; + std::list<Timeout*> timeouts_; + bool quit_; + clock::time_point now_; +}; + +} // namespace + +// static +Looper* Looper::create() { + return new LooperImpl(); +} + +// static +const uint8_t Looper::EVENT_READ = 1 << 0; +// static +const uint8_t Looper::EVENT_WRITE = 1 << 1; +// static +const uint8_t Looper::EVENT_ERROR = 1 << 2; +// static +const uint8_t Looper::EVENT_HUP = 1 << 3; + |
