From c85b624d28564a6f785b25000e2b7825592a919d Mon Sep 17 00:00:00 2001 From: Joel Klinghed Date: Tue, 26 Sep 2017 20:09:31 +0200 Subject: Initial commit --- src/monitor.cc | 410 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 410 insertions(+) create mode 100644 src/monitor.cc (limited to 'src/monitor.cc') diff --git a/src/monitor.cc b/src/monitor.cc new file mode 100644 index 0000000..e5a231d --- /dev/null +++ b/src/monitor.cc @@ -0,0 +1,410 @@ +#include "common.hh" + +#include +#include +#include +#include + +#include "looper.hh" +#include "monitor.hh" + +namespace { + +class MonitorImpl : public Monitor { +public: + explicit MonitorImpl(std::shared_ptr const& looper) + : looper_(looper), discover_timer_(0), connect_(false), discover_fd_(-1), + port_(0) { + } + + ~MonitorImpl() override { + disconnect(); + } + + void connect(std::string const& netname, + std::string const& hostname, + uint16_t port) override { + netname_ = netname; + hostname_ = hostname; + port_ = port; + + if (channel_) { + // TODO(the_jk): Should we reconnect here? + return; + } + close_discover(); + schedule_discover((rand() % 1000) / 1000.0); + } + + void disconnect() override { + close_discover(); + disconnect_channel(); + } + + size_t machines() const override { + return machines_.size(); + } + + uint32_t id(size_t index) const override { + auto it = machines_.begin(); + while (index--) ++it; + if (it == machines_.end()) { + assert(false); + return 0xffffffff; + } + return it->first; + } + + Machine machine_at(size_t index) const override { + auto it = machines_.begin(); + while (index--) ++it; + if (it == machines_.end()) { + assert(false); + return EMPTY; + } + return it->second; + } + + Machine machine(uint32_t id) const override { + auto it = machines_.find(id); + if (it == machines_.end()) { + assert(false); + return EMPTY; + } + return it->second; + } + + void add_observer(Observer* observer) override { + observers_.push_back(observer); + } + +private: + static Machine const EMPTY; + + struct Job { + uint32_t source; + uint32_t target; + }; + + void schedule_discover(double delay) { + if (discover_timer_) return; + discover_timer_ = looper_->schedule(delay, + std::bind(&MonitorImpl::discover, + this, + std::placeholders::_1, + std::placeholders::_2)); + } + + void discover(Looper*, uint32_t timer) { + assert(discover_timer_ == timer); + discover_timer_ = 0; + check_discover(); + } + + void close_discover() { + if (discover_timer_) { + looper_->cancel(discover_timer_); + discover_timer_ = 0; + } + if (!discover_) return; + connect_ = false; + if (discover_fd_ >= 0) { + looper_->remove(discover_fd_); + discover_fd_ = -1; + } + discover_.reset(); + } + + void check_discover() { + if (channel_) return; + + if (!discover_) { + discover_.reset(new DiscoverSched(netname_, 2, hostname_, port_)); + } + + channel_.reset(discover_->try_get_scheduler()); + + if (channel_) { + connected(); + return; + } + + if (discover_->timed_out()) { + close_discover(); + check_discover(); + return; + } + + if (discover_fd_ < 0) { + auto fd = discover_->connect_fd(); + if (!connect_ && fd >= 0) { + discover_fd_ = fd; + connect_ = true; + looper_->add(fd, Looper::EV_WRITE, + std::bind(&MonitorImpl::discover_fd, this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3)); + // Use connect() timeout + return; + } + fd = discover_->listen_fd(); + if (fd >= 0) { + discover_fd_ = fd; + connect_ = false; + looper_->add(fd, Looper::EV_READ, + std::bind(&MonitorImpl::discover_fd, this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3)); + } + } + + schedule_discover(1.0 + (rand() % 1000) / 1000.0); + } + + void discover_fd(Looper*, int fd, uint8_t) { + assert(fd == discover_fd_); + looper_->remove(discover_fd_); + discover_fd_ = -1; + check_discover(); + } + + void connected() { + std::cerr << "connected" << std::endl; + current_netname_ = discover_->schedulerName(); + current_hostname_ = discover_->networkName(); + + channel_->setBulkTransfer(); + + close_discover(); + + looper_->add(channel_->fd, Looper::EV_READ, + std::bind(&MonitorImpl::msg, this, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3)); + + if (!channel_->send_msg(MonLoginMsg())) { + std::cerr << "login failed" << std::endl; + disconnect_channel(); + schedule_discover(0.0); + return; + } + + for (auto* observer : observers_) { + observer->state(this, CONNECTED); + } + } + + void disconnect_channel() { + if (!channel_) return; + std::cerr << "disconnect" << std::endl; + looper_->remove(channel_->fd); + channel_.reset(); + + for (auto* observer : observers_) { + observer->state(this, SEARCHING); + } + } + + void msg(Looper*, int, uint8_t event) { + if (event & Looper::EV_ERR) { + disconnect(); + schedule_discover(0.0); + return; + } + + while (true) { + if (!channel_->read_a_bit()) { + disconnect(); + schedule_discover(0.0); + return; + } + if (!channel_->has_msg()) break; + std::unique_ptr msg(channel_->get_msg()); + if (!msg) { + disconnect(); + schedule_discover(0.0); + return; + } + switch (msg->type) { + case M_END: + disconnect(); + schedule_discover(1.0); + return; + case M_MON_STATS: + handle_stats(static_cast(msg.get())); + break; + case M_MON_GET_CS: + handle_job(static_cast(msg.get())); + break; + case M_MON_JOB_BEGIN: + handle_job(static_cast(msg.get())); + break; + case M_MON_JOB_DONE: + handle_job(static_cast(msg.get())); + break; + case M_MON_LOCAL_JOB_BEGIN: + handle_local_job(static_cast(msg.get())); + break; + case M_JOB_LOCAL_DONE: + handle_local_job(static_cast(msg.get())); + break; + default: + break; + } + } + } + + void handle_job(MonGetCSMsg const* msg) { + pending_jobs_[msg->job_id] = msg->clientid; + } + + void handle_job(MonJobBeginMsg const* msg) { + auto it = pending_jobs_.find(msg->job_id); + if (it == pending_jobs_.end()) { + assert(false); + return; + } + auto source = it->second; + pending_jobs_.erase(it); + job_begin(msg->job_id, source, msg->hostid); + } + + void handle_job(MonJobDoneMsg const* msg) { + job_done(msg->job_id); + } + + void handle_local_job(MonLocalJobBeginMsg const* msg) { + job_begin(msg->job_id, msg->hostid, msg->hostid); + } + + void handle_local_job(JobLocalDoneMsg const* msg) { + job_done(msg->job_id); + } + + void job_begin(uint32_t job_id, uint32_t source, uint32_t target) { + auto& job = active_jobs_[job_id]; + job.source = source; + job.target = target; + for (auto* observer : observers_) { + observer->added_job(this, job.source, job.target); + } + } + + void job_done(uint32_t job_id) { + auto it = active_jobs_.find(job_id); + if (it == active_jobs_.end()) { + assert(false); + return; + } + auto source = it->second.source; + auto target = it->second.target; + active_jobs_.erase(it); + for (auto* observer : observers_) { + observer->removed_job(this, source, target); + } + } + + void handle_stats(MonStatsMsg const* msg) { + std::cerr << msg->hostid << " " << msg->statmsg << "***" << std::endl; + + auto& machine = machines_[msg->hostid]; + auto const known = !machine.name.empty(); + if (update(msg->statmsg, &machine)) { + if (machine.name.empty()) { + machines_.erase(msg->hostid); + if (known) { + for (auto* observer : observers_) { + observer->removed_machine(this, msg->hostid); + } + } + } else { + if (known) { + for (auto* observer : observers_) { + observer->updated_machine(this, msg->hostid); + } + } else { + for (auto* observer : observers_) { + observer->added_machine(this, msg->hostid); + } + } + } + } else { + if (!known) { + machines_.erase(msg->hostid); + } + } + } + + bool update(std::string const& msg, Machine* machine) { + std::string name, ip; + unsigned max_jobs = 0; + + size_t last = 0; + while (true) { + auto pos = msg.find(':', last); + if (pos == std::string::npos) break; + auto end = msg.find('\n', pos + 1); + if (end == std::string::npos) end = msg.size(); + auto key = msg.substr(last, pos - last); + if (key == "Name") { + name = msg.substr(pos + 1, end - pos - 1); + auto dot = name.find('.'); + if (dot != std::string::npos) name = name.substr(0, dot); + } else if (key == "IP") { + ip = msg.substr(pos + 1, end - pos - 1); + } else if (key == "MaxJobs") { + errno = 0; + char* end_ptr; + auto tmp = strtoul(msg.c_str() + pos + 1, &end_ptr, 10); + if (errno == 0 && tmp > 0 && end_ptr == msg.c_str() + end) { + max_jobs = static_cast(tmp); + } + } + last = end + 1; + } + + if (name.empty()) name = ip; + + bool changed = false; + if (name != machine->name) { + machine->name = name; + changed = true; + } + if (max_jobs != machine->max_jobs) { + machine->max_jobs = max_jobs; + changed = true; + } + return changed; + } + + std::shared_ptr looper_; + std::vector observers_; + std::unique_ptr channel_; + std::unique_ptr discover_; + uint32_t discover_timer_; + bool connect_; + int discover_fd_; + + // Requested netname, hostname and port + std::string netname_; + std::string hostname_; + uint16_t port_; + + // Actually connected netname and hostname + std::string current_netname_; + std::string current_hostname_; + + std::unordered_map machines_; + std::unordered_map active_jobs_; + std::unordered_map pending_jobs_; +}; + +Monitor::Machine const MonitorImpl::EMPTY; + +} // namespace + +// static +Monitor* Monitor::create(std::shared_ptr const& looper) { + return new MonitorImpl(looper); +} -- cgit v1.3