summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJoel Klinghed <the_jk@yahoo.com>2017-03-16 23:28:09 +0100
committerJoel Klinghed <the_jk@yahoo.com>2017-03-16 23:38:19 +0100
commit87774d8981ae7a079492d8949e205065ba72a8e4 (patch)
treef056ffbdfb436143db1d968ffc7c82b1cb3d79a3 /src
parent719d90a40e83e870be19f8d46cc55caed618aa35 (diff)
Add basic console monitor and implement monitor support
Diffstat (limited to 'src')
-rw-r--r--src/.gitignore2
-rw-r--r--src/Makefile.am11
-rw-r--r--src/buffer.cc11
-rw-r--r--src/buffer.hh2
-rw-r--r--src/chunked.cc44
-rw-r--r--src/chunked.hh4
-rw-r--r--src/http.cc6
-rw-r--r--src/http.hh1
-rw-r--r--src/ios_save.hh31
-rw-r--r--src/monitor-cmd.cc271
-rw-r--r--src/monitor.cc476
-rw-r--r--src/monitor.hh68
-rw-r--r--src/proxy.cc471
13 files changed, 1362 insertions, 36 deletions
diff --git a/src/.gitignore b/src/.gitignore
index 7066278..e4fbae2 100644
--- a/src/.gitignore
+++ b/src/.gitignore
@@ -1,4 +1,6 @@
/config.h
/config.h.in~
+/libmonitor.a
/libtp.a
/tp
+/tp-monitor
diff --git a/src/Makefile.am b/src/Makefile.am
index 502c82d..84c4401 100644
--- a/src/Makefile.am
+++ b/src/Makefile.am
@@ -5,8 +5,8 @@ AM_CXXFLAGS = @DEFINES@
# Remove ar: `u' modifier ignored since `D' is the default (see `U')
ARFLAGS = cr
-bin_PROGRAMS = tp
-noinst_LIBRARIES = libtp.a
+bin_PROGRAMS = tp tp-monitor
+noinst_LIBRARIES = libtp.a libmonitor.a
tp_SOURCES = main.cc proxy.cc logger.cc resolver.cc
tp_LDADD = libtp.a @THREAD_LIBS@
@@ -16,3 +16,10 @@ libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \
character.cc config.cc strings.cc io.cc looper.cc \
buffer.cc chunked.cc
libtp_a_CXXFLAGS = $(AM_CXXFLAGS) -DSYSCONFDIR='"@SYSCONFDIR@"'
+
+libmonitor_a_SOURCES = monitor.cc resolver.cc
+libmonitor_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@
+
+tp_monitor_SOURCES = monitor-cmd.cc
+tp_monitor_LDADD = libmonitor.a libtp.a @THREAD_LIBS@
+tp_monitor_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@
diff --git a/src/buffer.cc b/src/buffer.cc
index d0c5fbb..4bde195 100644
--- a/src/buffer.cc
+++ b/src/buffer.cc
@@ -12,7 +12,8 @@ namespace {
class BufferImpl : public Buffer {
public:
BufferImpl(size_t capacity, size_t min_avail)
- : capacity_(capacity), min_avail_(min_avail) {
+ : capacity_(capacity), min_avail_(
+ std::max(min_avail, static_cast<size_t>(1))) {
data_ = capacity_ > 0 ?
reinterpret_cast<char*>(malloc(capacity_)) : nullptr;
end_ = data_;
@@ -117,3 +118,11 @@ void Buffer::write(void const* data, size_t size) {
}
}
+void Buffer::clear() {
+ while (true) {
+ size_t avail;
+ read_ptr(&avail);
+ if (avail == 0) return;
+ consume(avail);
+ }
+}
diff --git a/src/buffer.hh b/src/buffer.hh
index 92a7566..a648924 100644
--- a/src/buffer.hh
+++ b/src/buffer.hh
@@ -22,6 +22,8 @@ public:
size_t read(void* data, size_t max);
void write(void const* data, size_t size);
+ void clear();
+
protected:
Buffer() {}
Buffer(Buffer const&) = delete;
diff --git a/src/chunked.cc b/src/chunked.cc
index 99aea0d..dfafd89 100644
--- a/src/chunked.cc
+++ b/src/chunked.cc
@@ -13,6 +13,7 @@ namespace {
enum State {
CHUNK,
IN_CHUNK,
+ END_CHUNK,
TRAILER,
DONE,
};
@@ -51,13 +52,29 @@ public:
}
break;
}
- case IN_CHUNK:
- if (static_cast<uint64_t>(end - d) < size_) {
+ case IN_CHUNK: {
+ uint64_t left = end - d;
+ if (left < size_) {
+ this->data(d, left);
+ size_ -= left;
return avail;
}
+ this->data(d, size_);
d += size_;
+ state_ = END_CHUNK;
+ break;
+ }
+ case END_CHUNK: {
+ auto p = find_crlf(d, end);
+ if (!p) return d - start;
+ if (p != d) {
+ good_ = false;
+ return d - start;
+ }
+ d = p + 2;
state_ = CHUNK;
break;
+ }
case TRAILER: {
auto p = find_crlf(d, end);
if (!p) return d - start;
@@ -81,6 +98,10 @@ public:
return state_ == DONE;
}
+protected:
+ virtual void data(void const* UNUSED(data), size_t UNUSED(size)) {
+ }
+
private:
char const* find_crlf(char const* start, char const* end) {
for (; start != end; ++start) {
@@ -97,6 +118,20 @@ private:
uint64_t size_;
};
+class ChunkedCallbackImpl : public ChunkedImpl {
+public:
+ ChunkedCallbackImpl(DataCallback const& callback)
+ : callback_(callback) {
+ }
+
+ void data(void const* data, size_t size) override {
+ callback_(data, size);
+ }
+
+private:
+ DataCallback const callback_;
+};
+
} // namespace
// static
@@ -104,3 +139,8 @@ Chunked* Chunked::create() {
return new ChunkedImpl();
}
+// static
+Chunked* Chunked::create(DataCallback const& callback) {
+ return new ChunkedCallbackImpl(callback);
+}
+
diff --git a/src/chunked.hh b/src/chunked.hh
index 66d3ae7..511ae55 100644
--- a/src/chunked.hh
+++ b/src/chunked.hh
@@ -4,12 +4,16 @@
#define CHUNKED_HH
#include <cstddef>
+#include <functional>
class Chunked {
public:
virtual ~Chunked() { }
+ typedef std::function<void(void const* data, size_t avail)> DataCallback;
+
static Chunked* create();
+ static Chunked* create(DataCallback const& callback);
virtual size_t add(void const* data, size_t avail) = 0;
virtual bool good() const = 0;
diff --git a/src/http.cc b/src/http.cc
index c043c87..26911cb 100644
--- a/src/http.cc
+++ b/src/http.cc
@@ -377,10 +377,16 @@ public:
bool method_equal(std::string const& method) const override {
return method.compare(0, method.size(), data_, method_end_) == 0;
}
+
std::string url() const override {
return make_string(data_, url_start_, url_end_);
}
+ bool url_equal(std::string const& url) const override {
+ return url.compare(0, url.size(),
+ data_ + url_start_, url_end_ - url_start_) == 0;
+ }
+
ParseResult parse() {
good_ = false;
proto_end_ = find_newline(0, &content_start_);
diff --git a/src/http.hh b/src/http.hh
index 091d3d4..9a084d2 100644
--- a/src/http.hh
+++ b/src/http.hh
@@ -92,6 +92,7 @@ public:
virtual std::string method() const = 0;
virtual bool method_equal(std::string const& method) const = 0;
virtual std::string url() const = 0;
+ virtual bool url_equal(std::string const& url) const = 0;
protected:
HttpRequest() {}
diff --git a/src/ios_save.hh b/src/ios_save.hh
new file mode 100644
index 0000000..1abcf53
--- /dev/null
+++ b/src/ios_save.hh
@@ -0,0 +1,31 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef IOS_SAVE_HH
+#define IOS_SAVE_HH
+
+#include <ostream>
+
+class ios_save {
+public:
+ ios_save(std::ostream& ostream)
+ : ostream_(ostream), save_(nullptr) {
+ save_.copyfmt(ostream);
+ }
+
+ ios_save(ios_save const& save)
+ : ostream_(save.ostream_), save_(nullptr) {
+ save_.copyfmt(save.save_);
+ }
+
+ ~ios_save() {
+ ostream_.copyfmt(save_);
+ }
+
+ ios_save& operator=(ios_save const&) = delete;
+
+private:
+ std::ostream& ostream_;
+ std::ios save_;
+};
+
+#endif // IOS_SAVE_HH
diff --git a/src/monitor-cmd.cc b/src/monitor-cmd.cc
new file mode 100644
index 0000000..5221db0
--- /dev/null
+++ b/src/monitor-cmd.cc
@@ -0,0 +1,271 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <cstring>
+#include <fstream>
+#include <iomanip>
+#include <iostream>
+#include <signal.h>
+#include <unordered_map>
+
+#include "args.hh"
+#include "buffer.hh"
+#include "io.hh"
+#include "ios_save.hh"
+#include "looper.hh"
+#include "resolver.hh"
+#include "monitor.hh"
+
+namespace {
+
+class Delegate : public Monitor::Delegate {
+public:
+ Delegate(std::ostream& out, bool interleave, Looper* looper)
+ : out_(out), interleave_(interleave), looper_(looper), attached_(false) {
+ }
+
+ void state(Monitor* monitor, Monitor::State state) override {
+ switch (state) {
+ case Monitor::DISCONNECTED:
+ std::cout << "# Disconnected" << std::endl;
+ looper_->quit();
+ break;
+ case Monitor::CONNECTING:
+ break;
+ case Monitor::CONNECTED:
+ if (attached_) {
+ std::cout << "# Detached" << std::endl;
+ attached_ = false;
+ } else {
+ std::cout << "# Connected" << std::endl;
+ monitor->attach();
+ }
+ break;
+ case Monitor::ATTACHED:
+ std::cout << "# Attached" << std::endl;
+ attached_ = true;
+ break;
+ }
+ }
+
+ void error(Monitor* UNUSED(monitor), std::string const& error) override {
+ std::cerr << "# Error: " << error << std::endl;
+ }
+
+ void package(
+ Monitor* UNUSED(monitor), Monitor::Package const& package) override {
+ packages_.insert(std::make_pair(package.id, package));
+ }
+
+ void package_data(Monitor* UNUSED(monitor), uint32_t id,
+ char const* data, size_t size, bool last) override {
+ auto it = packages_.find(id);
+ if (it == packages_.end()) {
+ assert(false);
+ return;
+ }
+ if (interleave_) {
+ print_package(it->second, last, data, size);
+ if (last) packages_.erase(it);
+ return;
+ }
+ auto buf = data_.find(id);
+ if (last) {
+ if (buf == data_.end()) {
+ print_package(it->second, true, data, size);
+ } else {
+ buf->second->write(data, size);
+ size_t avail;
+ auto ptr = buf->second->read_ptr(&avail);
+ print_package(it->second, true,
+ reinterpret_cast<char const*>(ptr), avail);
+ data_.erase(buf);
+ }
+ packages_.erase(it);
+ } else {
+ if (buf == data_.end()) {
+ buf = data_.insert(std::make_pair(id, Buffer::create(8192, 0))).first;
+ }
+ buf->second->write(data, size);
+ }
+ }
+
+private:
+ void print_package(
+ Monitor::Package& pkg, bool last, char const* data, size_t size) {
+ if (size == 0 && !last) return;
+ {
+ ios_save save(out_);
+ out_ << "*** " << pkg.timestamp.tv_sec << '.'
+ << std::setfill('0') << std::setw(9) << pkg.timestamp.tv_nsec
+ << '\n';
+ }
+ out_ << "* Source: " << pkg.source_host << ':' << pkg.source_port << '\n'
+ << "* Target: " << pkg.target_host << ':' << pkg.target_port << '\n';
+ if (interleave_) {
+ auto offset = offset_.find(pkg.id);
+ if (offset != offset_.end()) {
+ out_ << "* Bytes: " << offset->second << "-";
+ if (last) {
+ out_ << (offset->second + size);
+ offset_.erase(offset);
+ } else {
+ offset->second += size;
+ }
+ out_ << '\n';
+ } else if (!last) {
+ offset_.insert(std::make_pair(pkg.id, size));
+ out_ << "* Bytes: 0-";
+ }
+ }
+ out_ << "* Size: " << size << '\n';
+ if (size > 0) {
+ ios_save save(out_);
+ auto d = reinterpret_cast<uint8_t const*>(data);
+ out_.flags(std::ios::hex);
+ out_.fill('0');
+ for (size_t i = 0; i < size; i += 16) {
+ out_ << std::setw(8) << i << ' ';
+ unsigned j = 0;
+ for (; j < 8; ++j) {
+ auto k = i + j;
+ if (k >= size) break;
+ out_ << ' ' << std::setw(2) << static_cast<int>(d[k]);
+ }
+ for (; j < 8; ++j) {
+ out_ << " ";
+ }
+ out_ << ' ';
+ for (; j < 16; ++j) {
+ auto k = i + j;
+ if (k >= size) break;
+ out_ << ' ' << std::setw(2) << static_cast<int>(d[k]);
+ }
+ for (; j < 16; ++j) {
+ out_ << " ";
+ }
+ out_ << " |";
+ j = 0;
+ for (; j < 16; ++j) {
+ auto k = i + j;
+ if (k >= size) break;
+ out_ << printable(data[k]);
+ }
+ for (; j < 16; ++j) {
+ out_ << ' ';
+ }
+ out_ << "|\n";
+ }
+ }
+ out_ << std::endl;
+ }
+
+ static char printable(char c) {
+ return (c & 0x80 || c < ' ' || c >= 0x7f) ? '.' : c;
+ }
+
+ std::ostream& out_;
+ bool interleave_;
+ Looper* looper_;
+ bool attached_;
+ std::unordered_map<uint32_t, Monitor::Package> packages_;
+ // Used when interleaving
+ std::unordered_map<uint32_t, uint64_t> offset_;
+ // Used when not interleaving
+ std::unordered_map<uint32_t, std::unique_ptr<Buffer>> data_;
+};
+
+io::auto_pipe signal_pipe;
+
+void signal(int UNUSED(signum)) {
+ io::write_all(signal_pipe.write(), "", 1);
+ std::cerr << "# Caught signal" << std::endl;
+}
+
+void quit_loop(Looper* looper, int UNUSED(fd), uint8_t UNUSED(events)) {
+ looper->quit();
+}
+
+bool run(std::ostream& out, bool interleave,
+ std::string const& host, uint16_t port) {
+ std::unique_ptr<Looper> looper(Looper::create());
+ std::unique_ptr<Resolver> resolver(Resolver::create(looper.get()));
+ std::unique_ptr<Delegate> delegate(
+ new Delegate(out, interleave, looper.get()));
+ std::unique_ptr<Monitor> monitor(
+ Monitor::create(looper.get(), resolver.get(), delegate.get()));
+ std::cout << "# Connecting to " << host << ':' << port << std::endl;
+ monitor->connect(host, port);
+ if (signal_pipe.open()) {
+ struct sigaction action;
+ memset(&action, 0, sizeof(action));
+ action.sa_handler = signal;
+ sigaction(SIGINT, &action, nullptr);
+ sigaction(SIGTERM, &action, nullptr);
+ looper->add(signal_pipe.read(), Looper::EVENT_READ,
+ std::bind(&quit_loop, looper.get(), std::placeholders::_1,
+ std::placeholders::_2));
+ }
+ return looper->run();
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ std::unique_ptr<Args> args(Args::create());
+ args->add('i', "interleave",
+ "unless set packages are not output until they are complete");
+ args->add('o', "output", "FILE", "output packages to FILE instead of stdout");
+ args->add('h', "help", "display this text and exit");
+ args->add('V', "version", "display version and exit");
+ if (!args->run(argc, argv)) {
+ std::cerr << "Try `tp-monitor --help` for usage." << std::endl;
+ return EXIT_FAILURE;
+ }
+ if (args->is_set('h')) {
+ std::cout << "Usage: `tp-monitor [OPTIONS...] HOST[:PORT]`\n"
+ << "Connect to transparent proxy running on HOST[:PORT] and "
+ << "print logged packages\n"
+ << std::endl;
+ args->print_help();
+ return EXIT_SUCCESS;
+ }
+ if (args->is_set('V')) {
+ std::cout << "TransparentProxy monitor version " VERSION
+ << " written by Joel Klinghed <the_jk@yahoo.com>" << std::endl;
+ return EXIT_SUCCESS;
+ }
+ if (args->arguments().size() != 1) {
+ std::cerr << "Unexpected number of arguments, expects one argument."
+ << std::endl;
+ return EXIT_FAILURE;
+ }
+ std::string host = args->arguments().front();
+ uint16_t port;
+ auto i = host.rfind(':');
+ if (i != std::string::npos && i > 0 && host[i - 1] != ':') {
+ errno = 0;
+ char* end = nullptr;
+ auto tmp = strtoul(host.c_str() + i + 1, &end, 10);
+ if (errno || !end || *end) {
+ std::cerr << "Invalid port number: " << host.substr(i + 1) << std::endl;
+ return EXIT_FAILURE;
+ }
+ port = tmp & 0xffff;
+ } else {
+ port = 9000;
+ }
+ auto interleave = args->is_set('i');
+ auto output = args->arg('o', nullptr);
+ if (output) {
+ std::ofstream out(output, std::ofstream::trunc);
+ if (!out.good()) {
+ std::cerr << "Unable to open " << output << " for writing." << std::endl;
+ return EXIT_FAILURE;
+ }
+ return run(out, interleave, host, port);
+ } else {
+ return run(std::cout, interleave, host, port);
+ }
+}
diff --git a/src/monitor.cc b/src/monitor.cc
new file mode 100644
index 0000000..673ad14
--- /dev/null
+++ b/src/monitor.cc
@@ -0,0 +1,476 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <cstring>
+
+#include "buffer.hh"
+#include "chunked.hh"
+#include "http.hh"
+#include "io.hh"
+#include "looper.hh"
+#include "monitor.hh"
+#include "resolver.hh"
+
+namespace {
+
+Version const http_version = { 1, 1 };
+
+class MonitorImpl : public Monitor {
+public:
+ MonitorImpl(Looper* looper, Resolver* resolver, Delegate* delegate)
+ : looper_(looper), resolver_(resolver), delegate_(delegate),
+ state_(DISCONNECTED) {
+ }
+
+ ~MonitorImpl() {
+ do_disconnect(true);
+ }
+
+ void connect(std::string const& host, uint16_t port) override {
+ do_disconnect();
+ new_state(CONNECTING);
+ if (state_ == CONNECTING) {
+ resolv_ = resolver_->request(
+ host, port, std::bind(&MonitorImpl::resolved, this,
+ std::placeholders::_1,
+ std::placeholders::_2,
+ std::placeholders::_3));
+ }
+ }
+
+ void disconnect() override {
+ do_disconnect();
+ new_state(DISCONNECTED);
+ }
+
+ void attach() override {
+ if (state_ != CONNECTED) return;
+ new_state(ATTACHED);
+ send_attach();
+ }
+
+ void detach() override {
+ if (state_ != ATTACHED) return;
+ new_state(CONNECTED);
+ send_detach();
+ }
+
+ State state() const override {
+ return state_;
+ }
+
+private:
+ void do_disconnect(bool skip_detach = false) {
+ switch (state_) {
+ case ATTACHED:
+ if (!skip_detach) detach();
+ // Fallthrough
+ case CONNECTED:
+ looper_->remove(sock_.get());
+ sock_.reset();
+ break;
+ case CONNECTING:
+ if (sock_) {
+ looper_->remove(sock_.get());
+ sock_.reset();
+ }
+ if (resolv_) {
+ resolver_->cancel(resolv_);
+ resolv_ = nullptr;
+ }
+ break;
+ case DISCONNECTED:
+ break;
+ }
+ }
+
+ void new_state(State state) {
+ if (state_ == state) return;
+ state_ = state;
+ delegate_->state(this, state_);
+ }
+
+ void resolved(int fd, bool connected, char const* error) {
+ resolv_ = nullptr;
+ if (fd < 0) {
+ delegate_->error(this, error);
+ new_state(DISCONNECTED);
+ return;
+ }
+ sock_.reset(fd);
+ if (!in_) {
+ in_.reset(Buffer::create(65536, 4096));
+ } else {
+ in_->clear();
+ }
+ if (!out_) {
+ out_.reset(Buffer::create(8192, 1024));
+ } else {
+ out_->clear();
+ }
+ looper_->add(sock_.get(), Looper::EVENT_WRITE,
+ std::bind(&MonitorImpl::event, this, std::placeholders::_1,
+ std::placeholders::_2));
+ sent_hello_ = false;
+ if (connected) {
+ event(sock_.get(), Looper::EVENT_WRITE);
+ }
+ }
+
+ bool parse_content_length(HttpResponse const* resp, uint64_t* bytes) {
+ std::string len = resp->first_header("content-length");
+ char* end = nullptr;
+ errno = 0;
+ auto tmp = strtoull(len.c_str(), &end, 10);
+ if (errno || !end || *end) {
+ return false;
+ }
+ if (bytes) *bytes = tmp;
+ return true;
+ }
+
+ bool setup_attach(HttpResponse const* resp) {
+ assert(!active_attach_);
+ std::string te = resp->first_header("transfer-encoding");
+ if (te != "chunked") return false;
+ package_fill_ = 0;
+ active_attach_.reset(
+ Chunked::create(
+ std::bind(&MonitorImpl::package, this, std::placeholders::_1,
+ std::placeholders::_2)));
+ return true;
+ }
+
+ static uint64_t read_u64(uint8_t const* data) {
+ return static_cast<uint64_t>(read_u32(data)) << 32 | read_u32(data + 4);
+ }
+
+ static uint32_t read_u32(uint8_t const* data) {
+ return static_cast<uint32_t>(read_u16(data)) << 16 | read_u16(data + 2);
+ }
+
+ static uint16_t read_u16(uint8_t const* data) {
+ return data[0] << 8 | data[1];
+ }
+
+ void package(void const* data, size_t size) {
+ auto d = reinterpret_cast<char const*>(data);
+ auto const end = d + size;
+ while (d < end) {
+ size_t avail = sizeof(package_) - package_fill_;
+ if (avail == 0) break;
+ if (d + avail > end) avail = end - d;
+ memcpy(package_ + package_fill_, d, avail);
+ package_fill_ += avail;
+ d += avail;
+ size_t offset = 0;
+ while (offset + 5 < package_fill_) {
+ uint16_t size = read_u16(package_ + offset + 3);
+ if (offset + size > package_fill_) break;
+ size_t o = 5;
+ if (size >= 29 && memcmp(package_ + offset, "PKG", 3) == 0) {
+ Package pkg;
+ pkg.id = read_u32(package_ + offset + o);
+ o += 4;
+ pkg.timestamp.tv_sec = read_u64(package_ + offset + o);
+ o += 8;
+ pkg.timestamp.tv_nsec = read_u32(package_ + offset + o);
+ o += 4;
+ pkg.flags = read_u16(package_ + offset + o);
+ o += 2;
+ pkg.source_port = read_u16(package_ + offset + o);
+ o += 2;
+ pkg.target_port = read_u16(package_ + offset + o);
+ o += 2;
+ auto len = read_u16(package_ + offset + o);
+ o += 2;
+ if (o + len + 2 <= size) {
+ pkg.source_host.assign(
+ reinterpret_cast<char*>(package_) + offset + o, len);
+ o += len;
+ len = read_u16(package_ + offset + o);
+ o += 2;
+ if (o + len <= size) {
+ pkg.target_host.assign(
+ reinterpret_cast<char*>(package_) + offset + o, len);
+ o += len;
+ bool last = !(pkg.flags & 0x01);
+ pkg.flags >>= 1;
+ delegate_->package(this, pkg);
+ if (o < size || last) {
+ delegate_->package_data(
+ this, pkg.id,
+ reinterpret_cast<char*>(package_) + offset + o, size - o,
+ last);
+ }
+ }
+ }
+ } else if (size >= 10 && memcmp(package_ + offset, "DAT", 3) == 0) {
+ uint32_t id = read_u32(package_ + offset + o);
+ o += 4;
+ uint8_t flags = package_[offset + o];
+ ++o;
+ delegate_->package_data(
+ this, id,
+ reinterpret_cast<char*>(package_) + offset + o, size - o,
+ !(flags & 0x01));
+ }
+ offset += size;
+ }
+ if (offset > 0) {
+ package_fill_ -= offset;
+ memmove(package_, package_ + offset, package_fill_);
+ }
+ }
+ }
+
+ void consume_attach() {
+ assert(active_attach_);
+ while (true) {
+ size_t avail;
+ auto ptr = in_->read_ptr(&avail);
+ if (avail == 0) return;
+ auto used = active_attach_->add(ptr, avail);
+ if (!active_attach_->good()) {
+ delegate_->error(this, "Bad chunked data");
+ disconnect();
+ return;
+ }
+ if (used == 0) return;
+ in_->consume(used);
+ if (active_attach_->eof()) {
+ active_attach_.reset();
+ new_state(CONNECTED);
+ return;
+ }
+ }
+ }
+
+ void event(int fd, uint8_t events) {
+ if (fd != sock_.get() || events == 0) {
+ assert(false);
+ return;
+ }
+
+ if (events & (Looper::EVENT_HUP | Looper::EVENT_ERROR)) {
+ if (state_ == CONNECTING && !sent_hello_) {
+ delegate_->error(this, "Connection denied");
+ } else {
+ delegate_->error(this, "Connection lost");
+ }
+ disconnect();
+ return;
+ }
+
+ if (events & Looper::EVENT_READ) {
+ while (true) {
+ size_t avail;
+ auto ptr = in_->write_ptr(&avail);
+ if (avail == 0) {
+ assert(false);
+ break;
+ }
+ auto ret = io::read(sock_.get(), ptr, avail);
+ if (ret == -1) {
+ if (errno == EWOULDBLOCK || errno == EAGAIN) break;
+ delegate_->error(this, "Read error");
+ disconnect();
+ return;
+ }
+ if (ret == 0) {
+ delegate_->error(this, "Connection closed");
+ disconnect();
+ return;
+ }
+ in_->commit(ret);
+ if (static_cast<size_t>(ret) < avail) break;
+ }
+ }
+ if (events & Looper::EVENT_WRITE) {
+ if (state_ == CONNECTING && !sent_hello_) {
+ sent_hello_ = true;
+ send_hello();
+ return;
+ }
+
+ size_t avail;
+ while (true) {
+ auto ptr = out_->read_ptr(&avail);
+ if (avail == 0) break;
+ auto ret = io::write(sock_.get(), ptr, avail);
+ if (ret == -1) {
+ if (errno == EWOULDBLOCK || errno == EAGAIN) break;
+ delegate_->error(this, "Write error");
+ disconnect();
+ return;
+ }
+ if (ret == 0) {
+ delegate_->error(this, "Connection lost");
+ disconnect();
+ return;
+ }
+ out_->consume(ret);
+ if (static_cast<size_t>(ret) < avail) break;
+ }
+ looper_->modify(sock_.get(), Looper::EVENT_READ
+ | (avail == 0 ? 0 : Looper::EVENT_WRITE));
+ }
+
+ if (active_attach_) {
+ consume_attach();
+ if (active_attach_) return;
+ }
+
+ while (true) {
+ size_t avail;
+ auto ptr = in_->read_ptr(&avail);
+ if (avail == 0) return;
+ if (content_skip_ > 0) {
+ if (content_skip_ > avail) {
+ in_->consume(avail);
+ content_skip_ -= avail;
+ return;
+ }
+ in_->consume(content_skip_);
+ content_skip_ = 0;
+ ptr = in_->read_ptr(&avail);
+ }
+ auto resp = std::unique_ptr<HttpResponse>(
+ HttpResponse::parse(
+ reinterpret_cast<char const*>(ptr), avail, false));
+ if (!resp) {
+ if (avail > 1024 * 1024) {
+ delegate_->error(this, "Server sending too much unexpected data");
+ disconnect();
+ }
+ return;
+ }
+ if (!resp->good()) {
+ delegate_->error(this, "Server sent invalid response");
+ disconnect();
+ return;
+ }
+ switch (state_) {
+ case CONNECTING:
+ if (resp->status_code() != 200) {
+ delegate_->error(
+ this, "Unexpected server response: " + resp->status_message());
+ disconnect();
+ return;
+ }
+ if (!parse_content_length(resp.get(), &content_skip_)) {
+ delegate_->error(
+ this, "Invalid server response, bad content length");
+ disconnect();
+ return;
+ }
+ in_->consume(resp->size());
+ new_state(CONNECTED);
+ active_attach_.reset();
+ continue;
+ case ATTACHED:
+ if (resp->status_code() != 200) {
+ delegate_->error(
+ this, "Unexpected server response: " + resp->status_message());
+ new_state(CONNECTED);
+ if (!parse_content_length(resp.get(), &content_skip_)) {
+ delegate_->error(
+ this, "Invalid server response, bad content length");
+ disconnect();
+ return;
+ }
+ in_->consume(resp->size());
+ } else {
+ if (!setup_attach(resp.get())) {
+ delegate_->error(
+ this, "Invalid server response, bad chunked attach");
+ disconnect();
+ return;
+ }
+ in_->consume(resp->size());
+ consume_attach();
+ if (active_attach_) return;
+ }
+ continue;
+ default:
+ delegate_->error(
+ this, "Unexpected server response: " + resp->status_message());
+ disconnect();
+ return;
+ }
+ }
+ }
+
+ void send_hello() {
+ auto request = std::unique_ptr<HttpRequestBuilder>(
+ HttpRequestBuilder::create("GET", "/hello", "HTTP", http_version));
+ request->add_header("X-TP-Monitor-Version", VERSION);
+ request->add_header("Content-Length", "0");
+ auto data = request->build();
+ send(data.data(), data.size());
+ }
+
+ void send_attach() {
+ auto request = std::unique_ptr<HttpRequestBuilder>(
+ HttpRequestBuilder::create("GET", "/attach", "HTTP", http_version));
+ request->add_header("Content-Length", "0");
+ auto data = request->build();
+ send(data.data(), data.size());
+ }
+
+ void send_detach() {
+ auto request = std::unique_ptr<HttpRequestBuilder>(
+ HttpRequestBuilder::create("GET", "/detach", "HTTP", http_version));
+ request->add_header("Content-Length", "0");
+ auto data = request->build();
+ send(data.data(), data.size());
+ }
+
+ void send(void const* data, size_t size) {
+ if (size == 0) return;
+
+ if (out_->empty()) {
+ ssize_t ret = io::write(sock_.get(), data, size);
+ if (ret == -1) {
+ if (errno != EWOULDBLOCK && errno != EAGAIN) {
+ delegate_->error(this, "Write error");
+ disconnect();
+ return;
+ }
+ } else if (ret == 0) {
+ delegate_->error(this, "Write error");
+ disconnect();
+ return;
+ } else if (static_cast<size_t>(ret) == size) {
+ return;
+ }
+ out_->write(reinterpret_cast<char const*>(data) + ret, size - ret);
+ looper_->modify(sock_.get(), Looper::EVENT_READ | Looper::EVENT_WRITE);
+ } else {
+ out_->write(data, size);
+ }
+ }
+
+ Looper* const looper_;
+ Resolver* const resolver_;
+ Delegate* const delegate_;
+ State state_;
+ void* resolv_;
+ io::auto_fd sock_;
+ std::unique_ptr<Buffer> in_;
+ std::unique_ptr<Buffer> out_;
+ bool sent_hello_;
+ std::unique_ptr<Chunked> active_attach_;
+ uint64_t content_skip_;
+ uint8_t package_[65535];
+ size_t package_fill_;
+};
+
+} // namespace
+
+// static
+Monitor* Monitor::create(
+ Looper* looper, Resolver* resolver, Delegate* delegate) {
+ return new MonitorImpl(looper, resolver, delegate);
+}
diff --git a/src/monitor.hh b/src/monitor.hh
new file mode 100644
index 0000000..83652ee
--- /dev/null
+++ b/src/monitor.hh
@@ -0,0 +1,68 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef MONITOR_HH
+#define MONITOR_HH
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+class Looper;
+class Resolver;
+
+class Monitor {
+public:
+ enum State {
+ // Default state, if a connection is broken this is the new state
+ DISCONNECTED = 0,
+ // After connect() is called but before the connection is truly setup
+ CONNECTING,
+ // Connect() succeeded, not attached
+ CONNECTED,
+ // Connected and attached.
+ ATTACHED,
+ };
+
+ struct Package {
+ uint32_t id;
+ struct timespec timestamp;
+ uint16_t flags;
+ std::string source_host;
+ uint16_t source_port;
+ std::string target_host;
+ uint16_t target_port;
+ };
+
+ class Delegate {
+ public:
+ virtual ~Delegate() {}
+
+ virtual void state(Monitor* monitor, State state) = 0;
+ virtual void error(Monitor* monitor, std::string const& error) = 0;
+ virtual void package(Monitor* monitor, Package const& package) = 0;
+ virtual void package_data(Monitor* monitor, uint32_t id,
+ char const* data, size_t size, bool last) = 0;
+
+ protected:
+ Delegate() {}
+ };
+
+ virtual ~Monitor() {}
+
+ static Monitor* create(
+ Looper* looper, Resolver* resolver, Delegate* delegate);
+
+ virtual void connect(std::string const& host, uint16_t port) = 0;
+ virtual void disconnect() = 0;
+
+ virtual void attach() = 0;
+ virtual void detach() = 0;
+
+ virtual State state() const = 0;
+
+protected:
+ Monitor() {}
+ Monitor(Monitor const&) = delete;
+};
+
+#endif // MONITOR_HH
diff --git a/src/proxy.cc b/src/proxy.cc
index dcc52df..5abe257 100644
--- a/src/proxy.cc
+++ b/src/proxy.cc
@@ -6,6 +6,7 @@
#define _GNU_SOURCE
#endif
+#include <arpa/inet.h>
#include <cerrno>
#include <chrono>
#include <cstring>
@@ -17,6 +18,7 @@
#include <sys/types.h>
#include <unistd.h>
#include <vector>
+#include <unordered_set>
#include "buffer.hh"
#include "chunked.hh"
@@ -229,6 +231,7 @@ struct RemoteClient : public BaseClient {
Content content;
std::string host;
uint16_t port;
+ uint32_t pkg_id;
};
struct Client : public BaseClient {
@@ -242,9 +245,13 @@ struct Client : public BaseClient {
RemoteState remote_state;
void* resolve;
RemoteClient remote;
+ std::string source_host;
+ uint16_t source_port;
+ uint32_t pkg_id;
};
struct Monitor : public BaseClient {
+ bool got_hello;
};
class ProxyImpl : public Proxy {
@@ -254,7 +261,8 @@ public:
: config_(config), cwd_(cwd), configfile_(configfile), logfile_(logfile),
logger_(logger), accept_socket_(accept_fd), monitor_socket_(monitor_fd),
looper_(Looper::create()), resolver_(Resolver::create(looper_.get())),
- new_timeout_(nullptr), timeout_(nullptr) {
+ new_timeout_(nullptr), timeout_(nullptr), next_package_id_(1),
+ monitor_send_proxied_(false) {
setup();
}
~ProxyImpl() override {
@@ -307,11 +315,33 @@ private:
void client_error(size_t index,
uint16_t status_code, std::string const& status);
bool client_request(size_t index);
- bool client_send(size_t index, void const* ptr, size_t avail);
void client_remote_error(size_t index, uint16_t error);
void close_client_when_done(size_t index);
void client_remote_resolved(size_t index, int fd, bool connected,
char const* error);
+ void monitor_error(size_t index,
+ uint16_t status_code, std::string const& status);
+ bool support_monitor_version(size_t index, std::string const& version);
+ bool monitor_send_chunked(size_t index,
+ void const* header, size_t header_size,
+ void const* data, size_t data_size);
+ uint32_t get_next_package_id();
+ void send_attached_package(uint32_t id, uint16_t flags,
+ std::string const& source_host,
+ uint16_t source_port,
+ std::string const& target_host,
+ uint16_t target_port,
+ bool last);
+ void send_attached_package2(uint8_t* buffer, size_t size,
+ uint32_t id, uint16_t flags,
+ std::string const& source_host,
+ uint16_t source_port,
+ std::string const& target_host,
+ uint16_t target_port,
+ bool last);
+ void send_attached_data(uint32_t id, void const* ptr, size_t size, bool last);
+ void send_attached(void const* header, size_t header_size,
+ void const* data, size_t data_size);
Config* const config_;
std::string cwd_;
@@ -327,9 +357,12 @@ private:
bool good_;
void* new_timeout_;
void* timeout_;
+ uint32_t next_package_id_;
+ bool monitor_send_proxied_;
clients<Client> clients_;
clients<Monitor> monitors_;
+ std::unordered_set<size_t> attached_;
};
size_t get_size(Config* config, Logger* logger, std::string const& name,
@@ -349,6 +382,7 @@ size_t get_size(Config* config, Logger* logger, std::string const& name,
}
void ProxyImpl::setup() {
+ monitor_send_proxied_ = !config_->get("monitor_proxy_request", false);
clients_.resize(get_size(config_, logger_, "max_clients", 1024));
monitors_.resize(get_size(config_, logger_, "max_monitors", 2));
looper_->add(accept_socket_.get(),
@@ -380,6 +414,7 @@ bool ProxyImpl::reload_config() {
logger_->out(Logger::WARN, "New config invalid, ignored.");
return true;
}
+ monitor_send_proxied_ = !config_->get("monitor_proxy_request", false);
if (!logfile_) {
auto logfile = config_->get("logfile", nullptr);
if (logfile) {
@@ -461,6 +496,16 @@ void ProxyImpl::close_base(BaseClient* client) {
void ProxyImpl::close_client(size_t index) {
bool was_full = clients_.full();
auto& client = clients_[index];
+ if (client.pkg_id != 0) {
+ size_t avail;
+ auto ptr = client.in->read_ptr(&avail);
+ if (avail) {
+ send_attached_data(client.pkg_id, ptr, avail, true);
+ } else {
+ send_attached_data(client.pkg_id, nullptr, 0, true);
+ }
+ client.pkg_id = 0;
+ }
client.request.reset();
client.url.reset();
client.connect.reset();
@@ -471,6 +516,16 @@ void ProxyImpl::close_client(size_t index) {
resolver_->cancel(client.resolve);
client.resolve = nullptr;
}
+ if (client.remote.pkg_id) {
+ size_t avail;
+ auto ptr = client.remote.in->read_ptr(&avail);
+ if (avail) {
+ send_attached_data(client.remote.pkg_id, ptr, avail, true);
+ } else {
+ send_attached_data(client.remote.pkg_id, nullptr, 0, true);
+ }
+ client.remote.pkg_id = 0;
+ }
close_base(&client.remote);
close_base(&client);
clients_.erase(index);
@@ -482,6 +537,10 @@ void ProxyImpl::close_client(size_t index) {
void ProxyImpl::close_monitor(size_t index) {
bool was_full = monitors_.full();
auto& monitor = monitors_[index];
+ auto it = attached_.find(index);
+ if (it != attached_.end()) {
+ attached_.erase(it);
+ }
close_base(&monitor);
monitors_.erase(index);
if (was_full && !monitors_.full() && monitor_socket_) {
@@ -530,6 +589,8 @@ float ProxyImpl::handle_timeout(bool new_conn,
close.clear();
for (auto i = monitors_.begin(); i != monitors_.end(); ++i) {
if (i->new_connection != new_conn) continue;
+ // Monitors are safe from timeout after hello
+ if (i->got_hello) continue;
auto diff = std::chrono::duration_cast<std::chrono::duration<float>>(
((i->last + timeout) - now)).count();
if (diff < 0.0f) {
@@ -728,21 +789,9 @@ void ProxyImpl::client_error(size_t index, uint16_t status_code,
client.url.reset();
client.connect.reset();
- std::string proto;
- Version version;
-
- if (!client.request) {
- proto = "HTTP";
- version.major = 1;
- version.minor = 1;
- } else {
- proto = client.request->proto();
- version = client.request->proto_version();
- }
-
auto resp = std::unique_ptr<HttpResponseBuilder>(
HttpResponseBuilder::create(
- proto, version, status_code, status_message));
+ "HTTP", Version(1, 0), status_code, status_message));
resp->add_header("Content-Length", "0");
resp->add_header("Connection", "close");
@@ -781,6 +830,13 @@ bool ProxyImpl::client_request(size_t index) {
host = client.url->host();
port = client.url->port();
}
+ if (port == 0) port = 80;
+ client.pkg_id = get_next_package_id();
+ if (client.pkg_id != 0) {
+ send_attached_package(client.pkg_id, 0,
+ client.source_host, client.source_port,
+ host, port, false);
+ }
if (client.remote_state == WAITING) {
if (client.connect || host != client.remote.host
|| port != client.remote.port) {
@@ -803,7 +859,6 @@ bool ProxyImpl::client_request(size_t index) {
client.remote.host = host;
client.remote.port = port;
- if (port == 0) port = 80;
client.resolve = resolver_->request(
host, port,
std::bind(&ProxyImpl::client_remote_resolved, this, index,
@@ -909,7 +964,10 @@ void ProxyImpl::client_empty_input(size_t index) {
// falltrough
case CONTENT_NONE: {
if (client.connect && client.remote_state == CONNECTED) {
- if (!client_send(index, ptr, avail)) {
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, ptr, avail, false);
+ }
+ if (!base_send(&client.remote, ptr, avail, index, "Client remote")) {
return;
}
client.in->consume(avail);
@@ -919,6 +977,10 @@ void ProxyImpl::client_empty_input(size_t index) {
// Still working on the last request, wait
return;
}
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, nullptr, 0, true);
+ client.pkg_id = 0;
+ }
client.request.reset(
HttpRequest::parse(
reinterpret_cast<char const*>(ptr), avail, false));
@@ -968,14 +1030,22 @@ void ProxyImpl::client_empty_input(size_t index) {
return;
}
if (avail < client.content.len) {
- if (!client_send(index, ptr, avail)) {
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, ptr, avail, false);
+ }
+ if (!base_send(&client.remote, ptr, avail, index, "Client remote")) {
return;
}
client.in->consume(avail);
client.content.len -= avail;
return;
}
- if (!client_send(index, ptr, client.content.len)) {
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, ptr, client.content.len, true);
+ client.pkg_id = 0;
+ }
+ if (!base_send(&client.remote, ptr, client.content.len, index,
+ "Client remote")) {
return;
}
client.in->consume(client.content.len);
@@ -992,11 +1062,17 @@ void ProxyImpl::client_empty_input(size_t index) {
client_error(index, 400, "Bad request");
return;
}
- if (!client_send(index, ptr, used)) {
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, ptr, used, false);
+ }
+ if (!base_send(&client.remote, ptr, used, index, "Client remote")) {
return;
}
client.in->consume(used);
if (client.content.chunked->eof()) {
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, nullptr, 0, true);
+ }
client.content.chunked.reset();
client.content.type = CONTENT_NONE;
}
@@ -1017,6 +1093,10 @@ void ProxyImpl::close_client_when_done(size_t index) {
client.read_flag = 0;
client.request.reset();
client.content.type = CONTENT_NONE;
+ if (client.pkg_id != 0) {
+ send_attached_data(client.pkg_id, nullptr, 0, true);
+ client.pkg_id = 0;
+ }
looper_->modify(client.fd.get(), client.write_flag);
}
@@ -1045,6 +1125,12 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
client.in->consume(client.request->size());
client.request.reset();
client.remote.content.type = CONTENT_CLOSE;
+ client.remote.pkg_id = get_next_package_id();
+ if (client.remote.pkg_id) {
+ send_attached_package(client.remote.pkg_id, 0,
+ client.remote.host, client.remote.port,
+ client.source_host, client.source_port, false);
+ }
if (!base_send(&client, data.data(), data.size(), index, "Client")) {
return;
}
@@ -1080,10 +1166,30 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
req->add_header("host", client.url->host());
}
auto data = req->build();
+ if (client.pkg_id != 0) {
+ bool const last = client.content.type == CONTENT_NONE;
+ if (monitor_send_proxied_) {
+ send_attached_data(client.pkg_id, data.data(), data.size(), last);
+ } else {
+ auto ptr = client.in->read_ptr(nullptr);
+ send_attached_data(client.pkg_id,
+ ptr, client.request->size(), last);
+ }
+ if (last) {
+ client.pkg_id = 0;
+ }
+ }
client.in->consume(client.request->size());
client.request.reset();
client.url.reset();
client.remote.out->write(data.data(), data.size());
+
+ client.remote.pkg_id = get_next_package_id();
+ if (client.remote.pkg_id) {
+ send_attached_package(client.remote.pkg_id, 0,
+ client.remote.host, client.remote.port,
+ client.source_host, client.source_port, false);
+ }
}
client.remote_state = CONNECTED;
client.remote.read_flag = Looper::EVENT_READ;
@@ -1117,6 +1223,10 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
client_remote_error(index, 502);
break;
case WAITING:
+ if (client.remote.pkg_id != 0) {
+ send_attached_data(client.remote.pkg_id, nullptr, 0, true);
+ client.remote.pkg_id = 0;
+ }
close_base(&client.remote);
client.remote_state = CLOSED;
break;
@@ -1159,6 +1269,16 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
looper_->modify(client.remote.fd.get(), 0);
}
}
+ if (client.remote.pkg_id != 0) {
+ if (client.remote.content.type == CONTENT_NONE) {
+ send_attached_data(client.remote.pkg_id,
+ ptr, response->size(), true);
+ client.remote.pkg_id = 0;
+ } else {
+ send_attached_data(client.remote.pkg_id,
+ ptr, response->size(), false);
+ }
+ }
if (!base_send(&client, ptr, response->size(), index, "Client")) {
return;
}
@@ -1167,6 +1287,9 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
}
case CONTENT_LEN:
if (avail < client.remote.content.len) {
+ if (client.remote.pkg_id != 0) {
+ send_attached_data(client.remote.pkg_id, ptr, avail, false);
+ }
if (!base_send(&client, ptr, avail, index, "Client")) {
return;
}
@@ -1174,6 +1297,11 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
client.remote.content.len -= avail;
return;
}
+ if (client.remote.pkg_id != 0) {
+ send_attached_data(client.remote.pkg_id, ptr,
+ client.remote.content.len, true);
+ client.remote.pkg_id = 0;
+ }
if (!base_send(&client, ptr, client.remote.content.len,
index, "Client")) {
return;
@@ -1194,14 +1322,20 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
client.remote.content.type = CONTENT_CLOSE;
break;
}
+ if (client.remote.pkg_id != 0) {
+ send_attached_data(client.remote.pkg_id, ptr, used, false);
+ }
if (!base_send(&client, ptr, used, index, "Client")) {
return;
}
client.remote.in->consume(used);
if (client.remote.content.chunked->eof()) {
+ if (client.remote.pkg_id != 0) {
+ send_attached_data(client.remote.pkg_id, nullptr, 0, true);
+ client.remote.pkg_id = 0;
+ }
client.remote.content.chunked.reset();
client.remote.content.type = CONTENT_NONE;
- logger_->out(Logger::INFO, "%zu: chunked -> waiting", index);
client.remote_state = WAITING;
client.remote.read_flag = 0;
client.remote.write_flag = 0;
@@ -1210,6 +1344,9 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
break;
}
case CONTENT_CLOSE:
+ if (client.remote.pkg_id != 0) {
+ send_attached_data(client.remote.pkg_id, ptr, avail, false);
+ }
if (!base_send(&client, ptr, avail, index, "Client")) {
return;
}
@@ -1228,12 +1365,51 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
}
}
-bool ProxyImpl::client_send(size_t index, void const* ptr, size_t size) {
- auto& client = clients_[index];
- assert(!client.request);
- assert(client.remote_state >= CONNECTED);
+void ProxyImpl::monitor_error(size_t index, uint16_t status_code,
+ std::string const& status_message) {
+ auto& monitor = monitors_[index];
+ // No more input
+ monitor.read_flag = 0;
+ looper_->modify(monitor.fd.get(), monitor.write_flag);
+
+ auto resp = std::unique_ptr<HttpResponseBuilder>(
+ HttpResponseBuilder::create(
+ "HTTP", Version(1, 0), status_code, status_message));
+ resp->add_header("Content-Length", "0");
+ resp->add_header("Connection", "close");
- return base_send(&client.remote, ptr, size, index, "Client remote");
+ monitor.in.reset();
+
+ auto data = resp->build();
+
+ if (!base_send(&monitor, data.data(), data.size(), index, "Monitor")) {
+ close_monitor(index);
+ }
+}
+
+bool ProxyImpl::support_monitor_version(
+ size_t index, std::string const& version) {
+ if (version.empty()) return false;
+ // TODO: Actually do some version check here
+ return version.compare(VERSION) == 0;
+}
+
+bool ProxyImpl::monitor_send_chunked(
+ size_t index, void const* header, size_t header_size,
+ void const* data, size_t data_size) {
+ auto& monitor = monitors_[index];
+ char chunked_header[10];
+ auto len = snprintf(chunked_header, 10, "%zx\r\n", header_size + data_size);
+ if (!base_send(&monitor, chunked_header, len, index, "Monitor")) return false;
+ if (header_size > 0 && !base_send(&monitor,
+ header, header_size, index, "Monitor")) {
+ return false;
+ }
+ if (data_size > 0
+ && !base_send(&monitor, data, data_size, index, "Monitor")) {
+ return false;
+ }
+ return base_send(&monitor, "\r\n", 2, index, "Monitor");
}
void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) {
@@ -1252,6 +1428,87 @@ void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) {
close_monitor(index);
return;
}
+
+ while (true) {
+ size_t avail;
+ auto ptr = monitor.in->read_ptr(&avail);
+ if (avail == 0) return;
+ auto request = std::unique_ptr<HttpRequest>(
+ HttpRequest::parse(
+ reinterpret_cast<char const*>(ptr), avail, false));
+ if (!request) {
+ if (avail >= 1024 * 1024) {
+ logger_->out(Logger::INFO, "%zu: Monitor too large request %zu",
+ index, avail);
+ close_monitor(index);
+ }
+ return;
+ }
+ if (!request->good()
+ || !request->proto_equal("HTTP")
+ || !request->method_equal("GET")
+ // Only support 1.1 or above
+ || request->proto_version().major == 0
+ || (request->proto_version().major == 1
+ && request->proto_version().minor == 0)) {
+ monitor_error(index, 400, "Bad request");
+ return;
+ }
+ Content content;
+ content.type = CONTENT_NONE;
+ if (!setup_content(request.get(), &content)) {
+ monitor_error(index, 400, "Bad request");
+ return;
+ }
+ if (content.type != CONTENT_NONE) {
+ monitor_error(index, 400, "Bad request");
+ return;
+ }
+ std::unique_ptr<HttpResponseBuilder> resp;
+ if (request->url_equal("/hello")) {
+ auto version = request->first_header("x-tp-monitor-version");
+ if (support_monitor_version(index, version)) {
+ resp.reset(HttpResponseBuilder::create(
+ "HTTP", Version(1, 1), 200, "OK"));
+ monitor.got_hello = true;
+ } else {
+ resp.reset(HttpResponseBuilder::create(
+ "HTTP", Version(1, 1), 500, "Unsupported version"));
+ }
+ resp->add_header("X-TP-Version", VERSION);
+ resp->add_header("Content-Length", "0");
+ } else if (!monitor.got_hello) {
+ monitor_error(index, 500, "Unexpected request");
+ return;
+ } else if (request->url_equal("/attach")) {
+ attached_.insert(index);
+ resp.reset(HttpResponseBuilder::create(
+ "HTTP", Version(1, 1), 200, "OK"));
+ resp->add_header("Transfer-Encoding", "chunked");
+ } else if (request->url_equal("/detach")) {
+ auto it = attached_.find(index);
+ if (it != attached_.end()) {
+ if (!monitor_send_chunked(index, nullptr, 0, nullptr, 0)) {
+ close_monitor(index);
+ return;
+ }
+ attached_.erase(it);
+ }
+ resp.reset(HttpResponseBuilder::create(
+ "HTTP", Version(1, 1), 200, "OK"));
+ resp->add_header("Content-Length", "0");
+ } else {
+ resp.reset(HttpResponseBuilder::create(
+ "HTTP", Version(1, 1), 404, "Not found"));
+ resp->add_header("Content-Length", "0");
+ }
+ auto data = resp->build();
+ monitor.in->consume(request->size());
+ if (!base_send(&monitor, data.data(), data.size(), index, "Monitor")) {
+ close_monitor(index);
+ return;
+ }
+ }
}
void ProxyImpl::new_base(BaseClient* client, int fd) {
@@ -1286,12 +1543,20 @@ int my_accept4(int sockfd, struct sockaddr* addr, socklen_t* addrlen,
}
#endif // HAVE_ACCEPT4
+union big_addr {
+ struct sockaddr_in addr_in;
+ struct sockaddr_in6 addr_in6;
+};
+
void ProxyImpl::new_client(int fd, uint8_t events) {
assert(fd == accept_socket_.get());
if (events == Looper::EVENT_READ) {
assert(!clients_.full());
while (true) {
- int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK);
+ big_addr addr;
+ socklen_t len = sizeof(addr);
+ auto a = reinterpret_cast<struct sockaddr*>(&addr);
+ int ret = accept4(fd, a, &len, SOCK_NONBLOCK);
if (ret < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) return;
if (errno == EINTR) continue;
@@ -1300,9 +1565,28 @@ void ProxyImpl::new_client(int fd, uint8_t events) {
}
auto index = clients_.new_client();
new_base(&clients_[index], ret);
+ clients_[index].source_host.clear();
+ if (len == sizeof(struct sockaddr_in) && a->sa_family == AF_INET) {
+ char tmp[INET_ADDRSTRLEN];
+ if (inet_ntop(AF_INET, &addr.addr_in.sin_addr, tmp, sizeof(tmp))) {
+ clients_[index].source_host = tmp;
+ }
+ clients_[index].source_port = ntohs(addr.addr_in.sin_port);
+ } else if (len == sizeof(struct sockaddr_in6)
+ && a->sa_family == AF_INET6) {
+ char tmp[INET6_ADDRSTRLEN];
+ if (inet_ntop(AF_INET6, &addr.addr_in6.sin6_addr, tmp, sizeof(tmp))) {
+ clients_[index].source_host = tmp;
+ }
+ clients_[index].source_port = ntohs(addr.addr_in6.sin6_port);
+ } else {
+ clients_[index].source_port = 0;
+ }
clients_[index].content.type = CONTENT_NONE;
clients_[index].remote_state = CLOSED;
clients_[index].remote.content.type = CONTENT_NONE;
+ clients_[index].pkg_id = 0;
+ clients_[index].remote.pkg_id = 0;
looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag,
std::bind(&ProxyImpl::client_event, this,
index,
@@ -1331,16 +1615,16 @@ void ProxyImpl::new_monitor(int fd, uint8_t events) {
}
auto index = monitors_.new_client();
new_base(&monitors_[index], ret);
- looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag,
+ monitors_[index].got_hello = false;
+ looper_->add(ret, monitors_[index].read_flag
+ | monitors_[index].write_flag,
std::bind(&ProxyImpl::monitor_event, this,
index,
std::placeholders::_1,
std::placeholders::_2));
break;
}
- if (monitors_.full()) {
- looper_->modify(fd, 0);
- }
+ if (monitors_.full()) looper_->modify(fd, 0);
} else {
logger_->out(Logger::ERR, "Monitor socket died");
fatal_error();
@@ -1383,6 +1667,131 @@ bool ProxyImpl::run() {
return good_;
}
+uint32_t ProxyImpl::get_next_package_id() {
+ if (attached_.empty()) return 0;
+ while (!next_package_id_) {
+ ++next_package_id_;
+ }
+ return next_package_id_++;
+}
+
+void write_u16(uint8_t* dst, uint16_t value) {
+ dst[0] = value >> 8;
+ dst[1] = value & 0xff;
+}
+
+void write_u32(uint8_t* dst, uint32_t value) {
+ write_u16(dst, value >> 16);
+ write_u16(dst + 2, value & 0xffff);
+}
+
+void write_u64(uint8_t* dst, uint64_t value) {
+ write_u32(dst, value >> 32);
+ write_u32(dst + 4, value & 0xffffffff);
+}
+
+void ProxyImpl::send_attached_package(uint32_t id, uint16_t flags,
+ std::string const& source_host,
+ uint16_t source_port,
+ std::string const& target_host,
+ uint16_t target_port,
+ bool last) {
+ if (id == 0) {
+ assert(false);
+ return;
+ }
+ if (attached_.empty()) return;
+ uint8_t data[256];
+ size_t need = 2 + 3 + 4 + 8 + 4 + 2 + 2 + 2 + 2 + 2 + source_host.size() +
+ target_host.size();
+ if (need <= sizeof(data)) {
+ send_attached_package2(data, need, id, flags, source_host, source_port,
+ target_host, target_port, last);
+ } else {
+ // TODO: Might need better handling of really long source_host/target_host
+ auto p = std::unique_ptr<uint8_t[]>(new uint8_t[need]);
+ send_attached_package2(p.get(), need, id, flags, source_host, source_port,
+ target_host, target_port, last);
+ }
+}
+
+void ProxyImpl::send_attached_package2(uint8_t* buffer, size_t size,
+ uint32_t id, uint16_t flags,
+ std::string const& source_host,
+ uint16_t source_port,
+ std::string const& target_host,
+ uint16_t target_port,
+ bool last) {
+ buffer[0] = 'P';
+ buffer[1] = 'K';
+ buffer[2] = 'G';
+ write_u16(buffer + 3, size);
+ write_u32(buffer + 5, id);
+ auto dur = looper_->now().time_since_epoch();
+ auto sec = std::chrono::duration_cast<std::chrono::seconds>(dur);
+ auto nsec = std::chrono::duration_cast<std::chrono::nanoseconds>(dur - sec);
+ write_u64(buffer + 9, sec.count());
+ write_u32(buffer + 17, nsec.count());
+ write_u16(buffer + 21, (last ? 0 : 1) | (flags << 1));
+ write_u16(buffer + 23, source_port);
+ write_u16(buffer + 25, target_port);
+ write_u16(buffer + 27, source_host.size());
+ memcpy(buffer + 29, source_host.data(), source_host.size());
+ write_u16(buffer + 29 + source_host.size(), target_host.size());
+ memcpy(buffer + 31 + source_host.size(),
+ target_host.data(), target_host.size());
+ send_attached(buffer, size, nullptr, 0);
+}
+
+void ProxyImpl::send_attached_data(uint32_t id, void const* ptr, size_t size,
+ bool last) {
+ if (id == 0) {
+ assert(false);
+ return;
+ }
+ if (attached_.empty()) return;
+ if (size == 0 && !last) return;
+ uint8_t data[10];
+ data[0] = 'D';
+ data[1] = 'A';
+ data[2] = 'T';
+ write_u32(data + 5, id);
+ if (size == 0) {
+ assert(last);
+ assert(ptr == nullptr);
+ write_u16(data + 3, 10);
+ data[9] = 0;
+ send_attached(data, 10, nullptr, 0);
+ } else {
+ size_t max = 0xffff - 10;
+ auto p = reinterpret_cast<char const*>(ptr);
+ data[9] = 1;
+ while (size > max) {
+ write_u16(data + 3, 0xffff);
+ send_attached(data, 10, p, max);
+ p += max;
+ size -= max;
+ }
+ data[9] = last ? 0 : 1;
+ write_u16(data + 3, size + 10);
+ send_attached(data, 10, p, size);
+ }
+}
+
+void ProxyImpl::send_attached(void const* header, size_t header_size,
+ void const* data, size_t data_size) {
+ auto it = attached_.begin();
+ while (it != attached_.end()) {
+ if (monitor_send_chunked(*it, header, header_size, data, data_size)) {
+ ++it;
+ } else {
+ auto index = *it;
+ it = attached_.erase(it);
+ close_monitor(index);
+ }
+ }
+}
+
int setup_socket(char const* host, std::string const& port, Logger* logger) {
io::auto_fd ret;
struct addrinfo hints;