From 87774d8981ae7a079492d8949e205065ba72a8e4 Mon Sep 17 00:00:00 2001 From: Joel Klinghed Date: Thu, 16 Mar 2017 23:28:09 +0100 Subject: Add basic console monitor and implement monitor support --- src/.gitignore | 2 + src/Makefile.am | 11 +- src/buffer.cc | 11 +- src/buffer.hh | 2 + src/chunked.cc | 44 ++++- src/chunked.hh | 4 + src/http.cc | 6 + src/http.hh | 1 + src/ios_save.hh | 31 ++++ src/monitor-cmd.cc | 271 ++++++++++++++++++++++++++++++ src/monitor.cc | 476 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/monitor.hh | 68 ++++++++ src/proxy.cc | 471 ++++++++++++++++++++++++++++++++++++++++++++++++---- 13 files changed, 1362 insertions(+), 36 deletions(-) create mode 100644 src/ios_save.hh create mode 100644 src/monitor-cmd.cc create mode 100644 src/monitor.cc create mode 100644 src/monitor.hh diff --git a/src/.gitignore b/src/.gitignore index 7066278..e4fbae2 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -1,4 +1,6 @@ /config.h /config.h.in~ +/libmonitor.a /libtp.a /tp +/tp-monitor diff --git a/src/Makefile.am b/src/Makefile.am index 502c82d..84c4401 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -5,8 +5,8 @@ AM_CXXFLAGS = @DEFINES@ # Remove ar: `u' modifier ignored since `D' is the default (see `U') ARFLAGS = cr -bin_PROGRAMS = tp -noinst_LIBRARIES = libtp.a +bin_PROGRAMS = tp tp-monitor +noinst_LIBRARIES = libtp.a libmonitor.a tp_SOURCES = main.cc proxy.cc logger.cc resolver.cc tp_LDADD = libtp.a @THREAD_LIBS@ @@ -16,3 +16,10 @@ libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \ character.cc config.cc strings.cc io.cc looper.cc \ buffer.cc chunked.cc libtp_a_CXXFLAGS = $(AM_CXXFLAGS) -DSYSCONFDIR='"@SYSCONFDIR@"' + +libmonitor_a_SOURCES = monitor.cc resolver.cc +libmonitor_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ + +tp_monitor_SOURCES = monitor-cmd.cc +tp_monitor_LDADD = libmonitor.a libtp.a @THREAD_LIBS@ +tp_monitor_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ diff --git a/src/buffer.cc b/src/buffer.cc index d0c5fbb..4bde195 100644 --- a/src/buffer.cc +++ b/src/buffer.cc @@ -12,7 +12,8 @@ namespace { class BufferImpl : public Buffer { public: BufferImpl(size_t capacity, size_t min_avail) - : capacity_(capacity), min_avail_(min_avail) { + : capacity_(capacity), min_avail_( + std::max(min_avail, static_cast(1))) { data_ = capacity_ > 0 ? reinterpret_cast(malloc(capacity_)) : nullptr; end_ = data_; @@ -117,3 +118,11 @@ void Buffer::write(void const* data, size_t size) { } } +void Buffer::clear() { + while (true) { + size_t avail; + read_ptr(&avail); + if (avail == 0) return; + consume(avail); + } +} diff --git a/src/buffer.hh b/src/buffer.hh index 92a7566..a648924 100644 --- a/src/buffer.hh +++ b/src/buffer.hh @@ -22,6 +22,8 @@ public: size_t read(void* data, size_t max); void write(void const* data, size_t size); + void clear(); + protected: Buffer() {} Buffer(Buffer const&) = delete; diff --git a/src/chunked.cc b/src/chunked.cc index 99aea0d..dfafd89 100644 --- a/src/chunked.cc +++ b/src/chunked.cc @@ -13,6 +13,7 @@ namespace { enum State { CHUNK, IN_CHUNK, + END_CHUNK, TRAILER, DONE, }; @@ -51,13 +52,29 @@ public: } break; } - case IN_CHUNK: - if (static_cast(end - d) < size_) { + case IN_CHUNK: { + uint64_t left = end - d; + if (left < size_) { + this->data(d, left); + size_ -= left; return avail; } + this->data(d, size_); d += size_; + state_ = END_CHUNK; + break; + } + case END_CHUNK: { + auto p = find_crlf(d, end); + if (!p) return d - start; + if (p != d) { + good_ = false; + return d - start; + } + d = p + 2; state_ = CHUNK; break; + } case TRAILER: { auto p = find_crlf(d, end); if (!p) return d - start; @@ -81,6 +98,10 @@ public: return state_ == DONE; } +protected: + virtual void data(void const* UNUSED(data), size_t UNUSED(size)) { + } + private: char const* find_crlf(char const* start, char const* end) { for (; start != end; ++start) { @@ -97,6 +118,20 @@ private: uint64_t size_; }; +class ChunkedCallbackImpl : public ChunkedImpl { +public: + ChunkedCallbackImpl(DataCallback const& callback) + : callback_(callback) { + } + + void data(void const* data, size_t size) override { + callback_(data, size); + } + +private: + DataCallback const callback_; +}; + } // namespace // static @@ -104,3 +139,8 @@ Chunked* Chunked::create() { return new ChunkedImpl(); } +// static +Chunked* Chunked::create(DataCallback const& callback) { + return new ChunkedCallbackImpl(callback); +} + diff --git a/src/chunked.hh b/src/chunked.hh index 66d3ae7..511ae55 100644 --- a/src/chunked.hh +++ b/src/chunked.hh @@ -4,12 +4,16 @@ #define CHUNKED_HH #include +#include class Chunked { public: virtual ~Chunked() { } + typedef std::function DataCallback; + static Chunked* create(); + static Chunked* create(DataCallback const& callback); virtual size_t add(void const* data, size_t avail) = 0; virtual bool good() const = 0; diff --git a/src/http.cc b/src/http.cc index c043c87..26911cb 100644 --- a/src/http.cc +++ b/src/http.cc @@ -377,10 +377,16 @@ public: bool method_equal(std::string const& method) const override { return method.compare(0, method.size(), data_, method_end_) == 0; } + std::string url() const override { return make_string(data_, url_start_, url_end_); } + bool url_equal(std::string const& url) const override { + return url.compare(0, url.size(), + data_ + url_start_, url_end_ - url_start_) == 0; + } + ParseResult parse() { good_ = false; proto_end_ = find_newline(0, &content_start_); diff --git a/src/http.hh b/src/http.hh index 091d3d4..9a084d2 100644 --- a/src/http.hh +++ b/src/http.hh @@ -92,6 +92,7 @@ public: virtual std::string method() const = 0; virtual bool method_equal(std::string const& method) const = 0; virtual std::string url() const = 0; + virtual bool url_equal(std::string const& url) const = 0; protected: HttpRequest() {} diff --git a/src/ios_save.hh b/src/ios_save.hh new file mode 100644 index 0000000..1abcf53 --- /dev/null +++ b/src/ios_save.hh @@ -0,0 +1,31 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef IOS_SAVE_HH +#define IOS_SAVE_HH + +#include + +class ios_save { +public: + ios_save(std::ostream& ostream) + : ostream_(ostream), save_(nullptr) { + save_.copyfmt(ostream); + } + + ios_save(ios_save const& save) + : ostream_(save.ostream_), save_(nullptr) { + save_.copyfmt(save.save_); + } + + ~ios_save() { + ostream_.copyfmt(save_); + } + + ios_save& operator=(ios_save const&) = delete; + +private: + std::ostream& ostream_; + std::ios save_; +}; + +#endif // IOS_SAVE_HH diff --git a/src/monitor-cmd.cc b/src/monitor-cmd.cc new file mode 100644 index 0000000..5221db0 --- /dev/null +++ b/src/monitor-cmd.cc @@ -0,0 +1,271 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include +#include +#include +#include +#include +#include + +#include "args.hh" +#include "buffer.hh" +#include "io.hh" +#include "ios_save.hh" +#include "looper.hh" +#include "resolver.hh" +#include "monitor.hh" + +namespace { + +class Delegate : public Monitor::Delegate { +public: + Delegate(std::ostream& out, bool interleave, Looper* looper) + : out_(out), interleave_(interleave), looper_(looper), attached_(false) { + } + + void state(Monitor* monitor, Monitor::State state) override { + switch (state) { + case Monitor::DISCONNECTED: + std::cout << "# Disconnected" << std::endl; + looper_->quit(); + break; + case Monitor::CONNECTING: + break; + case Monitor::CONNECTED: + if (attached_) { + std::cout << "# Detached" << std::endl; + attached_ = false; + } else { + std::cout << "# Connected" << std::endl; + monitor->attach(); + } + break; + case Monitor::ATTACHED: + std::cout << "# Attached" << std::endl; + attached_ = true; + break; + } + } + + void error(Monitor* UNUSED(monitor), std::string const& error) override { + std::cerr << "# Error: " << error << std::endl; + } + + void package( + Monitor* UNUSED(monitor), Monitor::Package const& package) override { + packages_.insert(std::make_pair(package.id, package)); + } + + void package_data(Monitor* UNUSED(monitor), uint32_t id, + char const* data, size_t size, bool last) override { + auto it = packages_.find(id); + if (it == packages_.end()) { + assert(false); + return; + } + if (interleave_) { + print_package(it->second, last, data, size); + if (last) packages_.erase(it); + return; + } + auto buf = data_.find(id); + if (last) { + if (buf == data_.end()) { + print_package(it->second, true, data, size); + } else { + buf->second->write(data, size); + size_t avail; + auto ptr = buf->second->read_ptr(&avail); + print_package(it->second, true, + reinterpret_cast(ptr), avail); + data_.erase(buf); + } + packages_.erase(it); + } else { + if (buf == data_.end()) { + buf = data_.insert(std::make_pair(id, Buffer::create(8192, 0))).first; + } + buf->second->write(data, size); + } + } + +private: + void print_package( + Monitor::Package& pkg, bool last, char const* data, size_t size) { + if (size == 0 && !last) return; + { + ios_save save(out_); + out_ << "*** " << pkg.timestamp.tv_sec << '.' + << std::setfill('0') << std::setw(9) << pkg.timestamp.tv_nsec + << '\n'; + } + out_ << "* Source: " << pkg.source_host << ':' << pkg.source_port << '\n' + << "* Target: " << pkg.target_host << ':' << pkg.target_port << '\n'; + if (interleave_) { + auto offset = offset_.find(pkg.id); + if (offset != offset_.end()) { + out_ << "* Bytes: " << offset->second << "-"; + if (last) { + out_ << (offset->second + size); + offset_.erase(offset); + } else { + offset->second += size; + } + out_ << '\n'; + } else if (!last) { + offset_.insert(std::make_pair(pkg.id, size)); + out_ << "* Bytes: 0-"; + } + } + out_ << "* Size: " << size << '\n'; + if (size > 0) { + ios_save save(out_); + auto d = reinterpret_cast(data); + out_.flags(std::ios::hex); + out_.fill('0'); + for (size_t i = 0; i < size; i += 16) { + out_ << std::setw(8) << i << ' '; + unsigned j = 0; + for (; j < 8; ++j) { + auto k = i + j; + if (k >= size) break; + out_ << ' ' << std::setw(2) << static_cast(d[k]); + } + for (; j < 8; ++j) { + out_ << " "; + } + out_ << ' '; + for (; j < 16; ++j) { + auto k = i + j; + if (k >= size) break; + out_ << ' ' << std::setw(2) << static_cast(d[k]); + } + for (; j < 16; ++j) { + out_ << " "; + } + out_ << " |"; + j = 0; + for (; j < 16; ++j) { + auto k = i + j; + if (k >= size) break; + out_ << printable(data[k]); + } + for (; j < 16; ++j) { + out_ << ' '; + } + out_ << "|\n"; + } + } + out_ << std::endl; + } + + static char printable(char c) { + return (c & 0x80 || c < ' ' || c >= 0x7f) ? '.' : c; + } + + std::ostream& out_; + bool interleave_; + Looper* looper_; + bool attached_; + std::unordered_map packages_; + // Used when interleaving + std::unordered_map offset_; + // Used when not interleaving + std::unordered_map> data_; +}; + +io::auto_pipe signal_pipe; + +void signal(int UNUSED(signum)) { + io::write_all(signal_pipe.write(), "", 1); + std::cerr << "# Caught signal" << std::endl; +} + +void quit_loop(Looper* looper, int UNUSED(fd), uint8_t UNUSED(events)) { + looper->quit(); +} + +bool run(std::ostream& out, bool interleave, + std::string const& host, uint16_t port) { + std::unique_ptr looper(Looper::create()); + std::unique_ptr resolver(Resolver::create(looper.get())); + std::unique_ptr delegate( + new Delegate(out, interleave, looper.get())); + std::unique_ptr monitor( + Monitor::create(looper.get(), resolver.get(), delegate.get())); + std::cout << "# Connecting to " << host << ':' << port << std::endl; + monitor->connect(host, port); + if (signal_pipe.open()) { + struct sigaction action; + memset(&action, 0, sizeof(action)); + action.sa_handler = signal; + sigaction(SIGINT, &action, nullptr); + sigaction(SIGTERM, &action, nullptr); + looper->add(signal_pipe.read(), Looper::EVENT_READ, + std::bind(&quit_loop, looper.get(), std::placeholders::_1, + std::placeholders::_2)); + } + return looper->run(); +} + +} // namespace + +int main(int argc, char** argv) { + std::unique_ptr args(Args::create()); + args->add('i', "interleave", + "unless set packages are not output until they are complete"); + args->add('o', "output", "FILE", "output packages to FILE instead of stdout"); + args->add('h', "help", "display this text and exit"); + args->add('V', "version", "display version and exit"); + if (!args->run(argc, argv)) { + std::cerr << "Try `tp-monitor --help` for usage." << std::endl; + return EXIT_FAILURE; + } + if (args->is_set('h')) { + std::cout << "Usage: `tp-monitor [OPTIONS...] HOST[:PORT]`\n" + << "Connect to transparent proxy running on HOST[:PORT] and " + << "print logged packages\n" + << std::endl; + args->print_help(); + return EXIT_SUCCESS; + } + if (args->is_set('V')) { + std::cout << "TransparentProxy monitor version " VERSION + << " written by Joel Klinghed " << std::endl; + return EXIT_SUCCESS; + } + if (args->arguments().size() != 1) { + std::cerr << "Unexpected number of arguments, expects one argument." + << std::endl; + return EXIT_FAILURE; + } + std::string host = args->arguments().front(); + uint16_t port; + auto i = host.rfind(':'); + if (i != std::string::npos && i > 0 && host[i - 1] != ':') { + errno = 0; + char* end = nullptr; + auto tmp = strtoul(host.c_str() + i + 1, &end, 10); + if (errno || !end || *end) { + std::cerr << "Invalid port number: " << host.substr(i + 1) << std::endl; + return EXIT_FAILURE; + } + port = tmp & 0xffff; + } else { + port = 9000; + } + auto interleave = args->is_set('i'); + auto output = args->arg('o', nullptr); + if (output) { + std::ofstream out(output, std::ofstream::trunc); + if (!out.good()) { + std::cerr << "Unable to open " << output << " for writing." << std::endl; + return EXIT_FAILURE; + } + return run(out, interleave, host, port); + } else { + return run(std::cout, interleave, host, port); + } +} diff --git a/src/monitor.cc b/src/monitor.cc new file mode 100644 index 0000000..673ad14 --- /dev/null +++ b/src/monitor.cc @@ -0,0 +1,476 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include + +#include "buffer.hh" +#include "chunked.hh" +#include "http.hh" +#include "io.hh" +#include "looper.hh" +#include "monitor.hh" +#include "resolver.hh" + +namespace { + +Version const http_version = { 1, 1 }; + +class MonitorImpl : public Monitor { +public: + MonitorImpl(Looper* looper, Resolver* resolver, Delegate* delegate) + : looper_(looper), resolver_(resolver), delegate_(delegate), + state_(DISCONNECTED) { + } + + ~MonitorImpl() { + do_disconnect(true); + } + + void connect(std::string const& host, uint16_t port) override { + do_disconnect(); + new_state(CONNECTING); + if (state_ == CONNECTING) { + resolv_ = resolver_->request( + host, port, std::bind(&MonitorImpl::resolved, this, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3)); + } + } + + void disconnect() override { + do_disconnect(); + new_state(DISCONNECTED); + } + + void attach() override { + if (state_ != CONNECTED) return; + new_state(ATTACHED); + send_attach(); + } + + void detach() override { + if (state_ != ATTACHED) return; + new_state(CONNECTED); + send_detach(); + } + + State state() const override { + return state_; + } + +private: + void do_disconnect(bool skip_detach = false) { + switch (state_) { + case ATTACHED: + if (!skip_detach) detach(); + // Fallthrough + case CONNECTED: + looper_->remove(sock_.get()); + sock_.reset(); + break; + case CONNECTING: + if (sock_) { + looper_->remove(sock_.get()); + sock_.reset(); + } + if (resolv_) { + resolver_->cancel(resolv_); + resolv_ = nullptr; + } + break; + case DISCONNECTED: + break; + } + } + + void new_state(State state) { + if (state_ == state) return; + state_ = state; + delegate_->state(this, state_); + } + + void resolved(int fd, bool connected, char const* error) { + resolv_ = nullptr; + if (fd < 0) { + delegate_->error(this, error); + new_state(DISCONNECTED); + return; + } + sock_.reset(fd); + if (!in_) { + in_.reset(Buffer::create(65536, 4096)); + } else { + in_->clear(); + } + if (!out_) { + out_.reset(Buffer::create(8192, 1024)); + } else { + out_->clear(); + } + looper_->add(sock_.get(), Looper::EVENT_WRITE, + std::bind(&MonitorImpl::event, this, std::placeholders::_1, + std::placeholders::_2)); + sent_hello_ = false; + if (connected) { + event(sock_.get(), Looper::EVENT_WRITE); + } + } + + bool parse_content_length(HttpResponse const* resp, uint64_t* bytes) { + std::string len = resp->first_header("content-length"); + char* end = nullptr; + errno = 0; + auto tmp = strtoull(len.c_str(), &end, 10); + if (errno || !end || *end) { + return false; + } + if (bytes) *bytes = tmp; + return true; + } + + bool setup_attach(HttpResponse const* resp) { + assert(!active_attach_); + std::string te = resp->first_header("transfer-encoding"); + if (te != "chunked") return false; + package_fill_ = 0; + active_attach_.reset( + Chunked::create( + std::bind(&MonitorImpl::package, this, std::placeholders::_1, + std::placeholders::_2))); + return true; + } + + static uint64_t read_u64(uint8_t const* data) { + return static_cast(read_u32(data)) << 32 | read_u32(data + 4); + } + + static uint32_t read_u32(uint8_t const* data) { + return static_cast(read_u16(data)) << 16 | read_u16(data + 2); + } + + static uint16_t read_u16(uint8_t const* data) { + return data[0] << 8 | data[1]; + } + + void package(void const* data, size_t size) { + auto d = reinterpret_cast(data); + auto const end = d + size; + while (d < end) { + size_t avail = sizeof(package_) - package_fill_; + if (avail == 0) break; + if (d + avail > end) avail = end - d; + memcpy(package_ + package_fill_, d, avail); + package_fill_ += avail; + d += avail; + size_t offset = 0; + while (offset + 5 < package_fill_) { + uint16_t size = read_u16(package_ + offset + 3); + if (offset + size > package_fill_) break; + size_t o = 5; + if (size >= 29 && memcmp(package_ + offset, "PKG", 3) == 0) { + Package pkg; + pkg.id = read_u32(package_ + offset + o); + o += 4; + pkg.timestamp.tv_sec = read_u64(package_ + offset + o); + o += 8; + pkg.timestamp.tv_nsec = read_u32(package_ + offset + o); + o += 4; + pkg.flags = read_u16(package_ + offset + o); + o += 2; + pkg.source_port = read_u16(package_ + offset + o); + o += 2; + pkg.target_port = read_u16(package_ + offset + o); + o += 2; + auto len = read_u16(package_ + offset + o); + o += 2; + if (o + len + 2 <= size) { + pkg.source_host.assign( + reinterpret_cast(package_) + offset + o, len); + o += len; + len = read_u16(package_ + offset + o); + o += 2; + if (o + len <= size) { + pkg.target_host.assign( + reinterpret_cast(package_) + offset + o, len); + o += len; + bool last = !(pkg.flags & 0x01); + pkg.flags >>= 1; + delegate_->package(this, pkg); + if (o < size || last) { + delegate_->package_data( + this, pkg.id, + reinterpret_cast(package_) + offset + o, size - o, + last); + } + } + } + } else if (size >= 10 && memcmp(package_ + offset, "DAT", 3) == 0) { + uint32_t id = read_u32(package_ + offset + o); + o += 4; + uint8_t flags = package_[offset + o]; + ++o; + delegate_->package_data( + this, id, + reinterpret_cast(package_) + offset + o, size - o, + !(flags & 0x01)); + } + offset += size; + } + if (offset > 0) { + package_fill_ -= offset; + memmove(package_, package_ + offset, package_fill_); + } + } + } + + void consume_attach() { + assert(active_attach_); + while (true) { + size_t avail; + auto ptr = in_->read_ptr(&avail); + if (avail == 0) return; + auto used = active_attach_->add(ptr, avail); + if (!active_attach_->good()) { + delegate_->error(this, "Bad chunked data"); + disconnect(); + return; + } + if (used == 0) return; + in_->consume(used); + if (active_attach_->eof()) { + active_attach_.reset(); + new_state(CONNECTED); + return; + } + } + } + + void event(int fd, uint8_t events) { + if (fd != sock_.get() || events == 0) { + assert(false); + return; + } + + if (events & (Looper::EVENT_HUP | Looper::EVENT_ERROR)) { + if (state_ == CONNECTING && !sent_hello_) { + delegate_->error(this, "Connection denied"); + } else { + delegate_->error(this, "Connection lost"); + } + disconnect(); + return; + } + + if (events & Looper::EVENT_READ) { + while (true) { + size_t avail; + auto ptr = in_->write_ptr(&avail); + if (avail == 0) { + assert(false); + break; + } + auto ret = io::read(sock_.get(), ptr, avail); + if (ret == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) break; + delegate_->error(this, "Read error"); + disconnect(); + return; + } + if (ret == 0) { + delegate_->error(this, "Connection closed"); + disconnect(); + return; + } + in_->commit(ret); + if (static_cast(ret) < avail) break; + } + } + if (events & Looper::EVENT_WRITE) { + if (state_ == CONNECTING && !sent_hello_) { + sent_hello_ = true; + send_hello(); + return; + } + + size_t avail; + while (true) { + auto ptr = out_->read_ptr(&avail); + if (avail == 0) break; + auto ret = io::write(sock_.get(), ptr, avail); + if (ret == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) break; + delegate_->error(this, "Write error"); + disconnect(); + return; + } + if (ret == 0) { + delegate_->error(this, "Connection lost"); + disconnect(); + return; + } + out_->consume(ret); + if (static_cast(ret) < avail) break; + } + looper_->modify(sock_.get(), Looper::EVENT_READ + | (avail == 0 ? 0 : Looper::EVENT_WRITE)); + } + + if (active_attach_) { + consume_attach(); + if (active_attach_) return; + } + + while (true) { + size_t avail; + auto ptr = in_->read_ptr(&avail); + if (avail == 0) return; + if (content_skip_ > 0) { + if (content_skip_ > avail) { + in_->consume(avail); + content_skip_ -= avail; + return; + } + in_->consume(content_skip_); + content_skip_ = 0; + ptr = in_->read_ptr(&avail); + } + auto resp = std::unique_ptr( + HttpResponse::parse( + reinterpret_cast(ptr), avail, false)); + if (!resp) { + if (avail > 1024 * 1024) { + delegate_->error(this, "Server sending too much unexpected data"); + disconnect(); + } + return; + } + if (!resp->good()) { + delegate_->error(this, "Server sent invalid response"); + disconnect(); + return; + } + switch (state_) { + case CONNECTING: + if (resp->status_code() != 200) { + delegate_->error( + this, "Unexpected server response: " + resp->status_message()); + disconnect(); + return; + } + if (!parse_content_length(resp.get(), &content_skip_)) { + delegate_->error( + this, "Invalid server response, bad content length"); + disconnect(); + return; + } + in_->consume(resp->size()); + new_state(CONNECTED); + active_attach_.reset(); + continue; + case ATTACHED: + if (resp->status_code() != 200) { + delegate_->error( + this, "Unexpected server response: " + resp->status_message()); + new_state(CONNECTED); + if (!parse_content_length(resp.get(), &content_skip_)) { + delegate_->error( + this, "Invalid server response, bad content length"); + disconnect(); + return; + } + in_->consume(resp->size()); + } else { + if (!setup_attach(resp.get())) { + delegate_->error( + this, "Invalid server response, bad chunked attach"); + disconnect(); + return; + } + in_->consume(resp->size()); + consume_attach(); + if (active_attach_) return; + } + continue; + default: + delegate_->error( + this, "Unexpected server response: " + resp->status_message()); + disconnect(); + return; + } + } + } + + void send_hello() { + auto request = std::unique_ptr( + HttpRequestBuilder::create("GET", "/hello", "HTTP", http_version)); + request->add_header("X-TP-Monitor-Version", VERSION); + request->add_header("Content-Length", "0"); + auto data = request->build(); + send(data.data(), data.size()); + } + + void send_attach() { + auto request = std::unique_ptr( + HttpRequestBuilder::create("GET", "/attach", "HTTP", http_version)); + request->add_header("Content-Length", "0"); + auto data = request->build(); + send(data.data(), data.size()); + } + + void send_detach() { + auto request = std::unique_ptr( + HttpRequestBuilder::create("GET", "/detach", "HTTP", http_version)); + request->add_header("Content-Length", "0"); + auto data = request->build(); + send(data.data(), data.size()); + } + + void send(void const* data, size_t size) { + if (size == 0) return; + + if (out_->empty()) { + ssize_t ret = io::write(sock_.get(), data, size); + if (ret == -1) { + if (errno != EWOULDBLOCK && errno != EAGAIN) { + delegate_->error(this, "Write error"); + disconnect(); + return; + } + } else if (ret == 0) { + delegate_->error(this, "Write error"); + disconnect(); + return; + } else if (static_cast(ret) == size) { + return; + } + out_->write(reinterpret_cast(data) + ret, size - ret); + looper_->modify(sock_.get(), Looper::EVENT_READ | Looper::EVENT_WRITE); + } else { + out_->write(data, size); + } + } + + Looper* const looper_; + Resolver* const resolver_; + Delegate* const delegate_; + State state_; + void* resolv_; + io::auto_fd sock_; + std::unique_ptr in_; + std::unique_ptr out_; + bool sent_hello_; + std::unique_ptr active_attach_; + uint64_t content_skip_; + uint8_t package_[65535]; + size_t package_fill_; +}; + +} // namespace + +// static +Monitor* Monitor::create( + Looper* looper, Resolver* resolver, Delegate* delegate) { + return new MonitorImpl(looper, resolver, delegate); +} diff --git a/src/monitor.hh b/src/monitor.hh new file mode 100644 index 0000000..83652ee --- /dev/null +++ b/src/monitor.hh @@ -0,0 +1,68 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef MONITOR_HH +#define MONITOR_HH + +#include +#include +#include + +class Looper; +class Resolver; + +class Monitor { +public: + enum State { + // Default state, if a connection is broken this is the new state + DISCONNECTED = 0, + // After connect() is called but before the connection is truly setup + CONNECTING, + // Connect() succeeded, not attached + CONNECTED, + // Connected and attached. + ATTACHED, + }; + + struct Package { + uint32_t id; + struct timespec timestamp; + uint16_t flags; + std::string source_host; + uint16_t source_port; + std::string target_host; + uint16_t target_port; + }; + + class Delegate { + public: + virtual ~Delegate() {} + + virtual void state(Monitor* monitor, State state) = 0; + virtual void error(Monitor* monitor, std::string const& error) = 0; + virtual void package(Monitor* monitor, Package const& package) = 0; + virtual void package_data(Monitor* monitor, uint32_t id, + char const* data, size_t size, bool last) = 0; + + protected: + Delegate() {} + }; + + virtual ~Monitor() {} + + static Monitor* create( + Looper* looper, Resolver* resolver, Delegate* delegate); + + virtual void connect(std::string const& host, uint16_t port) = 0; + virtual void disconnect() = 0; + + virtual void attach() = 0; + virtual void detach() = 0; + + virtual State state() const = 0; + +protected: + Monitor() {} + Monitor(Monitor const&) = delete; +}; + +#endif // MONITOR_HH diff --git a/src/proxy.cc b/src/proxy.cc index dcc52df..5abe257 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -6,6 +6,7 @@ #define _GNU_SOURCE #endif +#include #include #include #include @@ -17,6 +18,7 @@ #include #include #include +#include #include "buffer.hh" #include "chunked.hh" @@ -229,6 +231,7 @@ struct RemoteClient : public BaseClient { Content content; std::string host; uint16_t port; + uint32_t pkg_id; }; struct Client : public BaseClient { @@ -242,9 +245,13 @@ struct Client : public BaseClient { RemoteState remote_state; void* resolve; RemoteClient remote; + std::string source_host; + uint16_t source_port; + uint32_t pkg_id; }; struct Monitor : public BaseClient { + bool got_hello; }; class ProxyImpl : public Proxy { @@ -254,7 +261,8 @@ public: : config_(config), cwd_(cwd), configfile_(configfile), logfile_(logfile), logger_(logger), accept_socket_(accept_fd), monitor_socket_(monitor_fd), looper_(Looper::create()), resolver_(Resolver::create(looper_.get())), - new_timeout_(nullptr), timeout_(nullptr) { + new_timeout_(nullptr), timeout_(nullptr), next_package_id_(1), + monitor_send_proxied_(false) { setup(); } ~ProxyImpl() override { @@ -307,11 +315,33 @@ private: void client_error(size_t index, uint16_t status_code, std::string const& status); bool client_request(size_t index); - bool client_send(size_t index, void const* ptr, size_t avail); void client_remote_error(size_t index, uint16_t error); void close_client_when_done(size_t index); void client_remote_resolved(size_t index, int fd, bool connected, char const* error); + void monitor_error(size_t index, + uint16_t status_code, std::string const& status); + bool support_monitor_version(size_t index, std::string const& version); + bool monitor_send_chunked(size_t index, + void const* header, size_t header_size, + void const* data, size_t data_size); + uint32_t get_next_package_id(); + void send_attached_package(uint32_t id, uint16_t flags, + std::string const& source_host, + uint16_t source_port, + std::string const& target_host, + uint16_t target_port, + bool last); + void send_attached_package2(uint8_t* buffer, size_t size, + uint32_t id, uint16_t flags, + std::string const& source_host, + uint16_t source_port, + std::string const& target_host, + uint16_t target_port, + bool last); + void send_attached_data(uint32_t id, void const* ptr, size_t size, bool last); + void send_attached(void const* header, size_t header_size, + void const* data, size_t data_size); Config* const config_; std::string cwd_; @@ -327,9 +357,12 @@ private: bool good_; void* new_timeout_; void* timeout_; + uint32_t next_package_id_; + bool monitor_send_proxied_; clients clients_; clients monitors_; + std::unordered_set attached_; }; size_t get_size(Config* config, Logger* logger, std::string const& name, @@ -349,6 +382,7 @@ size_t get_size(Config* config, Logger* logger, std::string const& name, } void ProxyImpl::setup() { + monitor_send_proxied_ = !config_->get("monitor_proxy_request", false); clients_.resize(get_size(config_, logger_, "max_clients", 1024)); monitors_.resize(get_size(config_, logger_, "max_monitors", 2)); looper_->add(accept_socket_.get(), @@ -380,6 +414,7 @@ bool ProxyImpl::reload_config() { logger_->out(Logger::WARN, "New config invalid, ignored."); return true; } + monitor_send_proxied_ = !config_->get("monitor_proxy_request", false); if (!logfile_) { auto logfile = config_->get("logfile", nullptr); if (logfile) { @@ -461,6 +496,16 @@ void ProxyImpl::close_base(BaseClient* client) { void ProxyImpl::close_client(size_t index) { bool was_full = clients_.full(); auto& client = clients_[index]; + if (client.pkg_id != 0) { + size_t avail; + auto ptr = client.in->read_ptr(&avail); + if (avail) { + send_attached_data(client.pkg_id, ptr, avail, true); + } else { + send_attached_data(client.pkg_id, nullptr, 0, true); + } + client.pkg_id = 0; + } client.request.reset(); client.url.reset(); client.connect.reset(); @@ -471,6 +516,16 @@ void ProxyImpl::close_client(size_t index) { resolver_->cancel(client.resolve); client.resolve = nullptr; } + if (client.remote.pkg_id) { + size_t avail; + auto ptr = client.remote.in->read_ptr(&avail); + if (avail) { + send_attached_data(client.remote.pkg_id, ptr, avail, true); + } else { + send_attached_data(client.remote.pkg_id, nullptr, 0, true); + } + client.remote.pkg_id = 0; + } close_base(&client.remote); close_base(&client); clients_.erase(index); @@ -482,6 +537,10 @@ void ProxyImpl::close_client(size_t index) { void ProxyImpl::close_monitor(size_t index) { bool was_full = monitors_.full(); auto& monitor = monitors_[index]; + auto it = attached_.find(index); + if (it != attached_.end()) { + attached_.erase(it); + } close_base(&monitor); monitors_.erase(index); if (was_full && !monitors_.full() && monitor_socket_) { @@ -530,6 +589,8 @@ float ProxyImpl::handle_timeout(bool new_conn, close.clear(); for (auto i = monitors_.begin(); i != monitors_.end(); ++i) { if (i->new_connection != new_conn) continue; + // Monitors are safe from timeout after hello + if (i->got_hello) continue; auto diff = std::chrono::duration_cast>( ((i->last + timeout) - now)).count(); if (diff < 0.0f) { @@ -728,21 +789,9 @@ void ProxyImpl::client_error(size_t index, uint16_t status_code, client.url.reset(); client.connect.reset(); - std::string proto; - Version version; - - if (!client.request) { - proto = "HTTP"; - version.major = 1; - version.minor = 1; - } else { - proto = client.request->proto(); - version = client.request->proto_version(); - } - auto resp = std::unique_ptr( HttpResponseBuilder::create( - proto, version, status_code, status_message)); + "HTTP", Version(1, 0), status_code, status_message)); resp->add_header("Content-Length", "0"); resp->add_header("Connection", "close"); @@ -781,6 +830,13 @@ bool ProxyImpl::client_request(size_t index) { host = client.url->host(); port = client.url->port(); } + if (port == 0) port = 80; + client.pkg_id = get_next_package_id(); + if (client.pkg_id != 0) { + send_attached_package(client.pkg_id, 0, + client.source_host, client.source_port, + host, port, false); + } if (client.remote_state == WAITING) { if (client.connect || host != client.remote.host || port != client.remote.port) { @@ -803,7 +859,6 @@ bool ProxyImpl::client_request(size_t index) { client.remote.host = host; client.remote.port = port; - if (port == 0) port = 80; client.resolve = resolver_->request( host, port, std::bind(&ProxyImpl::client_remote_resolved, this, index, @@ -909,7 +964,10 @@ void ProxyImpl::client_empty_input(size_t index) { // falltrough case CONTENT_NONE: { if (client.connect && client.remote_state == CONNECTED) { - if (!client_send(index, ptr, avail)) { + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, ptr, avail, false); + } + if (!base_send(&client.remote, ptr, avail, index, "Client remote")) { return; } client.in->consume(avail); @@ -919,6 +977,10 @@ void ProxyImpl::client_empty_input(size_t index) { // Still working on the last request, wait return; } + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, nullptr, 0, true); + client.pkg_id = 0; + } client.request.reset( HttpRequest::parse( reinterpret_cast(ptr), avail, false)); @@ -968,14 +1030,22 @@ void ProxyImpl::client_empty_input(size_t index) { return; } if (avail < client.content.len) { - if (!client_send(index, ptr, avail)) { + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, ptr, avail, false); + } + if (!base_send(&client.remote, ptr, avail, index, "Client remote")) { return; } client.in->consume(avail); client.content.len -= avail; return; } - if (!client_send(index, ptr, client.content.len)) { + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, ptr, client.content.len, true); + client.pkg_id = 0; + } + if (!base_send(&client.remote, ptr, client.content.len, index, + "Client remote")) { return; } client.in->consume(client.content.len); @@ -992,11 +1062,17 @@ void ProxyImpl::client_empty_input(size_t index) { client_error(index, 400, "Bad request"); return; } - if (!client_send(index, ptr, used)) { + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, ptr, used, false); + } + if (!base_send(&client.remote, ptr, used, index, "Client remote")) { return; } client.in->consume(used); if (client.content.chunked->eof()) { + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, nullptr, 0, true); + } client.content.chunked.reset(); client.content.type = CONTENT_NONE; } @@ -1017,6 +1093,10 @@ void ProxyImpl::close_client_when_done(size_t index) { client.read_flag = 0; client.request.reset(); client.content.type = CONTENT_NONE; + if (client.pkg_id != 0) { + send_attached_data(client.pkg_id, nullptr, 0, true); + client.pkg_id = 0; + } looper_->modify(client.fd.get(), client.write_flag); } @@ -1045,6 +1125,12 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { client.in->consume(client.request->size()); client.request.reset(); client.remote.content.type = CONTENT_CLOSE; + client.remote.pkg_id = get_next_package_id(); + if (client.remote.pkg_id) { + send_attached_package(client.remote.pkg_id, 0, + client.remote.host, client.remote.port, + client.source_host, client.source_port, false); + } if (!base_send(&client, data.data(), data.size(), index, "Client")) { return; } @@ -1080,10 +1166,30 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { req->add_header("host", client.url->host()); } auto data = req->build(); + if (client.pkg_id != 0) { + bool const last = client.content.type == CONTENT_NONE; + if (monitor_send_proxied_) { + send_attached_data(client.pkg_id, data.data(), data.size(), last); + } else { + auto ptr = client.in->read_ptr(nullptr); + send_attached_data(client.pkg_id, + ptr, client.request->size(), last); + } + if (last) { + client.pkg_id = 0; + } + } client.in->consume(client.request->size()); client.request.reset(); client.url.reset(); client.remote.out->write(data.data(), data.size()); + + client.remote.pkg_id = get_next_package_id(); + if (client.remote.pkg_id) { + send_attached_package(client.remote.pkg_id, 0, + client.remote.host, client.remote.port, + client.source_host, client.source_port, false); + } } client.remote_state = CONNECTED; client.remote.read_flag = Looper::EVENT_READ; @@ -1117,6 +1223,10 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { client_remote_error(index, 502); break; case WAITING: + if (client.remote.pkg_id != 0) { + send_attached_data(client.remote.pkg_id, nullptr, 0, true); + client.remote.pkg_id = 0; + } close_base(&client.remote); client.remote_state = CLOSED; break; @@ -1159,6 +1269,16 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { looper_->modify(client.remote.fd.get(), 0); } } + if (client.remote.pkg_id != 0) { + if (client.remote.content.type == CONTENT_NONE) { + send_attached_data(client.remote.pkg_id, + ptr, response->size(), true); + client.remote.pkg_id = 0; + } else { + send_attached_data(client.remote.pkg_id, + ptr, response->size(), false); + } + } if (!base_send(&client, ptr, response->size(), index, "Client")) { return; } @@ -1167,6 +1287,9 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { } case CONTENT_LEN: if (avail < client.remote.content.len) { + if (client.remote.pkg_id != 0) { + send_attached_data(client.remote.pkg_id, ptr, avail, false); + } if (!base_send(&client, ptr, avail, index, "Client")) { return; } @@ -1174,6 +1297,11 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { client.remote.content.len -= avail; return; } + if (client.remote.pkg_id != 0) { + send_attached_data(client.remote.pkg_id, ptr, + client.remote.content.len, true); + client.remote.pkg_id = 0; + } if (!base_send(&client, ptr, client.remote.content.len, index, "Client")) { return; @@ -1194,14 +1322,20 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { client.remote.content.type = CONTENT_CLOSE; break; } + if (client.remote.pkg_id != 0) { + send_attached_data(client.remote.pkg_id, ptr, used, false); + } if (!base_send(&client, ptr, used, index, "Client")) { return; } client.remote.in->consume(used); if (client.remote.content.chunked->eof()) { + if (client.remote.pkg_id != 0) { + send_attached_data(client.remote.pkg_id, nullptr, 0, true); + client.remote.pkg_id = 0; + } client.remote.content.chunked.reset(); client.remote.content.type = CONTENT_NONE; - logger_->out(Logger::INFO, "%zu: chunked -> waiting", index); client.remote_state = WAITING; client.remote.read_flag = 0; client.remote.write_flag = 0; @@ -1210,6 +1344,9 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { break; } case CONTENT_CLOSE: + if (client.remote.pkg_id != 0) { + send_attached_data(client.remote.pkg_id, ptr, avail, false); + } if (!base_send(&client, ptr, avail, index, "Client")) { return; } @@ -1228,12 +1365,51 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { } } -bool ProxyImpl::client_send(size_t index, void const* ptr, size_t size) { - auto& client = clients_[index]; - assert(!client.request); - assert(client.remote_state >= CONNECTED); +void ProxyImpl::monitor_error(size_t index, uint16_t status_code, + std::string const& status_message) { + auto& monitor = monitors_[index]; + // No more input + monitor.read_flag = 0; + looper_->modify(monitor.fd.get(), monitor.write_flag); + + auto resp = std::unique_ptr( + HttpResponseBuilder::create( + "HTTP", Version(1, 0), status_code, status_message)); + resp->add_header("Content-Length", "0"); + resp->add_header("Connection", "close"); - return base_send(&client.remote, ptr, size, index, "Client remote"); + monitor.in.reset(); + + auto data = resp->build(); + + if (!base_send(&monitor, data.data(), data.size(), index, "Monitor")) { + close_monitor(index); + } +} + +bool ProxyImpl::support_monitor_version( + size_t index, std::string const& version) { + if (version.empty()) return false; + // TODO: Actually do some version check here + return version.compare(VERSION) == 0; +} + +bool ProxyImpl::monitor_send_chunked( + size_t index, void const* header, size_t header_size, + void const* data, size_t data_size) { + auto& monitor = monitors_[index]; + char chunked_header[10]; + auto len = snprintf(chunked_header, 10, "%zx\r\n", header_size + data_size); + if (!base_send(&monitor, chunked_header, len, index, "Monitor")) return false; + if (header_size > 0 && !base_send(&monitor, + header, header_size, index, "Monitor")) { + return false; + } + if (data_size > 0 + && !base_send(&monitor, data, data_size, index, "Monitor")) { + return false; + } + return base_send(&monitor, "\r\n", 2, index, "Monitor"); } void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) { @@ -1252,6 +1428,87 @@ void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) { close_monitor(index); return; } + + while (true) { + size_t avail; + auto ptr = monitor.in->read_ptr(&avail); + if (avail == 0) return; + auto request = std::unique_ptr( + HttpRequest::parse( + reinterpret_cast(ptr), avail, false)); + if (!request) { + if (avail >= 1024 * 1024) { + logger_->out(Logger::INFO, "%zu: Monitor too large request %zu", + index, avail); + close_monitor(index); + } + return; + } + if (!request->good() + || !request->proto_equal("HTTP") + || !request->method_equal("GET") + // Only support 1.1 or above + || request->proto_version().major == 0 + || (request->proto_version().major == 1 + && request->proto_version().minor == 0)) { + monitor_error(index, 400, "Bad request"); + return; + } + Content content; + content.type = CONTENT_NONE; + if (!setup_content(request.get(), &content)) { + monitor_error(index, 400, "Bad request"); + return; + } + if (content.type != CONTENT_NONE) { + monitor_error(index, 400, "Bad request"); + return; + } + std::unique_ptr resp; + if (request->url_equal("/hello")) { + auto version = request->first_header("x-tp-monitor-version"); + if (support_monitor_version(index, version)) { + resp.reset(HttpResponseBuilder::create( + "HTTP", Version(1, 1), 200, "OK")); + monitor.got_hello = true; + } else { + resp.reset(HttpResponseBuilder::create( + "HTTP", Version(1, 1), 500, "Unsupported version")); + } + resp->add_header("X-TP-Version", VERSION); + resp->add_header("Content-Length", "0"); + } else if (!monitor.got_hello) { + monitor_error(index, 500, "Unexpected request"); + return; + } else if (request->url_equal("/attach")) { + attached_.insert(index); + resp.reset(HttpResponseBuilder::create( + "HTTP", Version(1, 1), 200, "OK")); + resp->add_header("Transfer-Encoding", "chunked"); + } else if (request->url_equal("/detach")) { + auto it = attached_.find(index); + if (it != attached_.end()) { + if (!monitor_send_chunked(index, nullptr, 0, nullptr, 0)) { + close_monitor(index); + return; + } + attached_.erase(it); + } + resp.reset(HttpResponseBuilder::create( + "HTTP", Version(1, 1), 200, "OK")); + resp->add_header("Content-Length", "0"); + } else { + resp.reset(HttpResponseBuilder::create( + "HTTP", Version(1, 1), 404, "Not found")); + resp->add_header("Content-Length", "0"); + } + auto data = resp->build(); + monitor.in->consume(request->size()); + if (!base_send(&monitor, data.data(), data.size(), index, "Monitor")) { + close_monitor(index); + return; + } + } } void ProxyImpl::new_base(BaseClient* client, int fd) { @@ -1286,12 +1543,20 @@ int my_accept4(int sockfd, struct sockaddr* addr, socklen_t* addrlen, } #endif // HAVE_ACCEPT4 +union big_addr { + struct sockaddr_in addr_in; + struct sockaddr_in6 addr_in6; +}; + void ProxyImpl::new_client(int fd, uint8_t events) { assert(fd == accept_socket_.get()); if (events == Looper::EVENT_READ) { assert(!clients_.full()); while (true) { - int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK); + big_addr addr; + socklen_t len = sizeof(addr); + auto a = reinterpret_cast(&addr); + int ret = accept4(fd, a, &len, SOCK_NONBLOCK); if (ret < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) return; if (errno == EINTR) continue; @@ -1300,9 +1565,28 @@ void ProxyImpl::new_client(int fd, uint8_t events) { } auto index = clients_.new_client(); new_base(&clients_[index], ret); + clients_[index].source_host.clear(); + if (len == sizeof(struct sockaddr_in) && a->sa_family == AF_INET) { + char tmp[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &addr.addr_in.sin_addr, tmp, sizeof(tmp))) { + clients_[index].source_host = tmp; + } + clients_[index].source_port = ntohs(addr.addr_in.sin_port); + } else if (len == sizeof(struct sockaddr_in6) + && a->sa_family == AF_INET6) { + char tmp[INET6_ADDRSTRLEN]; + if (inet_ntop(AF_INET6, &addr.addr_in6.sin6_addr, tmp, sizeof(tmp))) { + clients_[index].source_host = tmp; + } + clients_[index].source_port = ntohs(addr.addr_in6.sin6_port); + } else { + clients_[index].source_port = 0; + } clients_[index].content.type = CONTENT_NONE; clients_[index].remote_state = CLOSED; clients_[index].remote.content.type = CONTENT_NONE; + clients_[index].pkg_id = 0; + clients_[index].remote.pkg_id = 0; looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag, std::bind(&ProxyImpl::client_event, this, index, @@ -1331,16 +1615,16 @@ void ProxyImpl::new_monitor(int fd, uint8_t events) { } auto index = monitors_.new_client(); new_base(&monitors_[index], ret); - looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag, + monitors_[index].got_hello = false; + looper_->add(ret, monitors_[index].read_flag + | monitors_[index].write_flag, std::bind(&ProxyImpl::monitor_event, this, index, std::placeholders::_1, std::placeholders::_2)); break; } - if (monitors_.full()) { - looper_->modify(fd, 0); - } + if (monitors_.full()) looper_->modify(fd, 0); } else { logger_->out(Logger::ERR, "Monitor socket died"); fatal_error(); @@ -1383,6 +1667,131 @@ bool ProxyImpl::run() { return good_; } +uint32_t ProxyImpl::get_next_package_id() { + if (attached_.empty()) return 0; + while (!next_package_id_) { + ++next_package_id_; + } + return next_package_id_++; +} + +void write_u16(uint8_t* dst, uint16_t value) { + dst[0] = value >> 8; + dst[1] = value & 0xff; +} + +void write_u32(uint8_t* dst, uint32_t value) { + write_u16(dst, value >> 16); + write_u16(dst + 2, value & 0xffff); +} + +void write_u64(uint8_t* dst, uint64_t value) { + write_u32(dst, value >> 32); + write_u32(dst + 4, value & 0xffffffff); +} + +void ProxyImpl::send_attached_package(uint32_t id, uint16_t flags, + std::string const& source_host, + uint16_t source_port, + std::string const& target_host, + uint16_t target_port, + bool last) { + if (id == 0) { + assert(false); + return; + } + if (attached_.empty()) return; + uint8_t data[256]; + size_t need = 2 + 3 + 4 + 8 + 4 + 2 + 2 + 2 + 2 + 2 + source_host.size() + + target_host.size(); + if (need <= sizeof(data)) { + send_attached_package2(data, need, id, flags, source_host, source_port, + target_host, target_port, last); + } else { + // TODO: Might need better handling of really long source_host/target_host + auto p = std::unique_ptr(new uint8_t[need]); + send_attached_package2(p.get(), need, id, flags, source_host, source_port, + target_host, target_port, last); + } +} + +void ProxyImpl::send_attached_package2(uint8_t* buffer, size_t size, + uint32_t id, uint16_t flags, + std::string const& source_host, + uint16_t source_port, + std::string const& target_host, + uint16_t target_port, + bool last) { + buffer[0] = 'P'; + buffer[1] = 'K'; + buffer[2] = 'G'; + write_u16(buffer + 3, size); + write_u32(buffer + 5, id); + auto dur = looper_->now().time_since_epoch(); + auto sec = std::chrono::duration_cast(dur); + auto nsec = std::chrono::duration_cast(dur - sec); + write_u64(buffer + 9, sec.count()); + write_u32(buffer + 17, nsec.count()); + write_u16(buffer + 21, (last ? 0 : 1) | (flags << 1)); + write_u16(buffer + 23, source_port); + write_u16(buffer + 25, target_port); + write_u16(buffer + 27, source_host.size()); + memcpy(buffer + 29, source_host.data(), source_host.size()); + write_u16(buffer + 29 + source_host.size(), target_host.size()); + memcpy(buffer + 31 + source_host.size(), + target_host.data(), target_host.size()); + send_attached(buffer, size, nullptr, 0); +} + +void ProxyImpl::send_attached_data(uint32_t id, void const* ptr, size_t size, + bool last) { + if (id == 0) { + assert(false); + return; + } + if (attached_.empty()) return; + if (size == 0 && !last) return; + uint8_t data[10]; + data[0] = 'D'; + data[1] = 'A'; + data[2] = 'T'; + write_u32(data + 5, id); + if (size == 0) { + assert(last); + assert(ptr == nullptr); + write_u16(data + 3, 10); + data[9] = 0; + send_attached(data, 10, nullptr, 0); + } else { + size_t max = 0xffff - 10; + auto p = reinterpret_cast(ptr); + data[9] = 1; + while (size > max) { + write_u16(data + 3, 0xffff); + send_attached(data, 10, p, max); + p += max; + size -= max; + } + data[9] = last ? 0 : 1; + write_u16(data + 3, size + 10); + send_attached(data, 10, p, size); + } +} + +void ProxyImpl::send_attached(void const* header, size_t header_size, + void const* data, size_t data_size) { + auto it = attached_.begin(); + while (it != attached_.end()) { + if (monitor_send_chunked(*it, header, header_size, data, data_size)) { + ++it; + } else { + auto index = *it; + it = attached_.erase(it); + close_monitor(index); + } + } +} + int setup_socket(char const* host, std::string const& port, Logger* logger) { io::auto_fd ret; struct addrinfo hints; -- cgit v1.2.3-70-g09d2