summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/.gitignore4
-rw-r--r--src/Makefile.am18
-rw-r--r--src/args.cc281
-rw-r--r--src/args.hh60
-rw-r--r--src/buffer.cc119
-rw-r--r--src/buffer.hh30
-rw-r--r--src/character.cc31
-rw-r--r--src/character.hh18
-rw-r--r--src/chunked.cc106
-rw-r--r--src/chunked.hh23
-rw-r--r--src/common.hh18
-rw-r--r--src/config.cc176
-rw-r--r--src/config.hh33
-rw-r--r--src/http.cc657
-rw-r--r--src/http.hh141
-rw-r--r--src/io.cc90
-rw-r--r--src/io.hh124
-rw-r--r--src/logger.cc130
-rw-r--r--src/logger.hh33
-rw-r--r--src/looper.cc287
-rw-r--r--src/looper.hh41
-rw-r--r--src/main.cc187
-rw-r--r--src/paths.cc98
-rw-r--r--src/paths.hh18
-rw-r--r--src/proxy.cc1373
-rw-r--r--src/proxy.hh37
-rw-r--r--src/resolver.cc207
-rw-r--r--src/resolver.hh28
-rw-r--r--src/strings.cc94
-rw-r--r--src/strings.hh24
-rw-r--r--src/terminal.cc20
-rw-r--r--src/terminal.hh22
-rw-r--r--src/url.cc901
-rw-r--r--src/url.hh62
-rw-r--r--src/xdg.cc135
-rw-r--r--src/xdg.hh25
36 files changed, 5651 insertions, 0 deletions
diff --git a/src/.gitignore b/src/.gitignore
new file mode 100644
index 0000000..7066278
--- /dev/null
+++ b/src/.gitignore
@@ -0,0 +1,4 @@
+/config.h
+/config.h.in~
+/libtp.a
+/tp
diff --git a/src/Makefile.am b/src/Makefile.am
new file mode 100644
index 0000000..502c82d
--- /dev/null
+++ b/src/Makefile.am
@@ -0,0 +1,18 @@
+MAINTAINERCLEANFILES = Makefile.in config.h.in
+
+AM_CXXFLAGS = @DEFINES@
+
+# Remove ar: `u' modifier ignored since `D' is the default (see `U')
+ARFLAGS = cr
+
+bin_PROGRAMS = tp
+noinst_LIBRARIES = libtp.a
+
+tp_SOURCES = main.cc proxy.cc logger.cc resolver.cc
+tp_LDADD = libtp.a @THREAD_LIBS@
+tp_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@
+
+libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \
+ character.cc config.cc strings.cc io.cc looper.cc \
+ buffer.cc chunked.cc
+libtp_a_CXXFLAGS = $(AM_CXXFLAGS) -DSYSCONFDIR='"@SYSCONFDIR@"'
diff --git a/src/args.cc b/src/args.cc
new file mode 100644
index 0000000..cde77b0
--- /dev/null
+++ b/src/args.cc
@@ -0,0 +1,281 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <cstring>
+#include <string>
+#include <vector>
+#include <unordered_map>
+
+#include "args.hh"
+#include "character.hh"
+#include "terminal.hh"
+
+namespace {
+
+class ArgsImpl : public Args {
+public:
+ ArgsImpl()
+ : good_(true) {
+ }
+ void add(char short_opt, std::string const& long_opt,
+ std::string const& argument, std::string const& help) override {
+ assert(short_opt == '\0' || short_opts_.count(short_opt) == 0);
+ assert(long_opt.empty() || long_opts_.count(long_opt) == 0);
+ assert(long_opt.find('=') == std::string::npos);
+ auto const index = opts_.size();
+ opts_.push_back(Option(short_opt, long_opt, argument, help));
+ if (short_opt != '\0') {
+ short_opts_.insert(std::make_pair(short_opt, index));
+ }
+ if (!long_opt.empty()) {
+ long_opts_.insert(std::make_pair(long_opt, index));
+ }
+ }
+
+ bool run(int argc, char** argv, std::ostream& out) override {
+ if (argc == 0) {
+ assert(false);
+ good_ = false;
+ return false;
+ }
+ auto start = strrchr(argv[0], '/');
+ if (!start) {
+ start = argv[0];
+ } else {
+ start++;
+ }
+ return run(start, argc, argv, out);
+ }
+
+ bool run(std::string const& prg, int argc, char** argv, std::ostream& out)
+ override {
+ reset();
+
+ std::string opt;
+ for (int a = 1; a < argc; ++a) {
+ if (argv[a][0] == '-') {
+ if (argv[a][1] == '-') {
+ if (argv[a][2] == '\0') {
+ for (++a; a < argc; ++a) {
+ args_.push_back(argv[a]);
+ }
+ return good_;
+ }
+ size_t len = 2;
+ while (argv[a][len] && argv[a][len] != '=') ++len;
+ opt.assign(argv[a] + 2, len - 2);
+ auto i = long_opts_.find(opt);
+ if (i == long_opts_.end()) {
+ out << prg << ": unrecognized option '--" << opt << "'\n";
+ good_ = false;
+ continue;
+ }
+ if (argv[a][len] == '=') {
+ if (opts_[i->second].argument.empty()) {
+ out << prg << ": option '--" << opt << "'"
+ << " doesn't allow an argument\n";
+ good_ = false;
+ continue;
+ } else {
+ opts_[i->second].value = argv[a] + len + 1;
+ opts_[i->second].is_set = true;
+ }
+ } else {
+ if (opts_[i->second].argument.empty()) {
+ opts_[i->second].is_set = true;
+ } else if (a + 1 == argc) {
+ out << prg << ": option '--" << opt << "'"
+ << " requires an argument\n";
+ good_ = false;
+ continue;
+ } else {
+ opts_[i->second].value = argv[++a];
+ opts_[i->second].is_set = true;
+ }
+ }
+ } else {
+ for (auto opt = argv[a] + 1; *opt; ++opt) {
+ auto i = short_opts_.find(*opt);
+ if (i == short_opts_.end()) {
+ out << prg << ": invalid option -- '" << *opt << "'\n";
+ good_ = false;
+ continue;
+ }
+ if (opts_[i->second].argument.empty()) {
+ opts_[i->second].is_set = true;
+ } else if (a + 1 == argc) {
+ out << prg << ": option requires an argument "
+ << " -- '" << *opt << "'\n";
+ good_ = false;
+ continue;
+ } else {
+ opts_[i->second].value = argv[++a];
+ opts_[i->second].is_set = true;
+ }
+ }
+ }
+ } else {
+ args_.push_back(argv[a]);
+ }
+ }
+
+ return good_;
+ }
+ bool good() const override {
+ return good_;
+ }
+
+ bool is_set(char short_opt) const override {
+ auto i = short_opts_.find(short_opt);
+ if (i == short_opts_.end()) return false;
+ return opts_[i->second].is_set;
+ }
+ bool is_set(std::string const& long_opt) const override {
+ auto i = long_opts_.find(long_opt);
+ if (i == long_opts_.end()) return false;
+ return opts_[i->second].is_set;
+ }
+ char const* arg(char short_opt, char const* fallback) const override {
+ auto i = short_opts_.find(short_opt);
+ if (i == short_opts_.end()) return fallback;
+ if (!opts_[i->second].is_set) return fallback;
+ if (opts_[i->second].argument.empty()) return fallback;
+ return opts_[i->second].value.c_str();
+ }
+ char const* arg(std::string const& long_opt,
+ char const* fallback) const override {
+ auto i = long_opts_.find(long_opt);
+ if (i == long_opts_.end()) return fallback;
+ if (!opts_[i->second].is_set) return fallback;
+ if (opts_[i->second].argument.empty()) return fallback;
+ return opts_[i->second].value.c_str();
+ }
+
+ std::vector<std::string> const& arguments() const override {
+ return args_;
+ }
+
+ void print_help(std::ostream& out) const override {
+ print_help(out, Terminal::size().width);
+ }
+
+ void print_help(std::ostream& out, size_t width) const override {
+ size_t left = 0;
+ for (auto const& opt : opts_) {
+ size_t l = 0;
+ if (!opt.long_opt.empty()) {
+ l += 6 + opt.long_opt.size();
+ } else if (opt.short_opt != '\0') {
+ l += 2;
+ } else {
+ continue;
+ }
+ if (!opt.argument.empty()) {
+ l += 1 + opt.argument.size();
+ }
+ if (l > left) left = l;
+ }
+
+ size_t const need = 2 + 2 + left;
+ if (need + 10 > width) {
+ width = need + 10;
+ }
+ size_t const right = width - need;
+
+ for (auto const& opt : opts_) {
+ size_t i = 0;
+ if (!opt.long_opt.empty()) {
+ if (opt.short_opt != '\0') {
+ out << " -" << opt.short_opt << ", ";
+ } else {
+ out << " ";
+ }
+ out << "--" << opt.long_opt;
+ i += 8 + opt.long_opt.size();
+ } else if (opt.short_opt != '\0') {
+ out << " -" << opt.short_opt;
+ i += 4;
+ } else {
+ continue;
+ }
+ if (!opt.argument.empty()) {
+ out << '=' << opt.argument;
+ i += 1 + opt.argument.size();
+ }
+ pad(out, need - i);
+ if (opt.help.size() < right) {
+ out << opt.help << '\n';
+ } else {
+ i = right;
+ while (i > 0 && !Character::isseparator(opt.help, i)) --i;
+ if (i == 0) i = right;
+ out << opt.help.substr(0, i) << '\n';
+ while (true) {
+ while (i < opt.help.size() && Character::isspace(opt.help, i)) ++i;
+ if (i == opt.help.size()) break;
+ size_t j = right - 2;
+ pad(out, width - j);
+ if (i + j >= opt.help.size()) {
+ out << opt.help.substr(i) << '\n';
+ break;
+ }
+ while (j > 0 && !Character::isseparator(opt.help, i + j)) --j;
+ if (j == 0) j = right - 2;
+ out << opt.help.substr(i, j) << '\n';
+ i += j;
+ }
+ }
+ }
+ }
+
+private:
+ struct Option {
+ char const short_opt;
+ std::string const long_opt;
+ std::string const argument;
+ std::string const help;
+
+ bool is_set;
+ std::string value;
+
+ Option(char short_opt, std::string const& long_opt,
+ std::string const& argument, std::string const& help)
+ : short_opt(short_opt), long_opt(long_opt), argument(argument),
+ help(help), is_set(false) {
+ }
+ };
+ bool good_;
+ std::unordered_map<char, size_t> short_opts_;
+ std::unordered_map<std::string, size_t> long_opts_;
+ std::vector<Option> opts_;
+ std::vector<std::string> args_;
+
+ void reset() {
+ good_ = true;
+ args_.clear();
+ for (auto& opt : opts_) {
+ opt.is_set = false;
+ opt.value.clear();
+ }
+ }
+
+ static void pad(std::ostream& out, size_t count) {
+ while (count > 4) {
+ out << " ";
+ count -= 4;
+ }
+ while (count) {
+ out << ' ';
+ --count;
+ }
+ }
+};
+
+} // namespace
+
+// static
+Args* Args::create() {
+ return new ArgsImpl();
+}
+
diff --git a/src/args.hh b/src/args.hh
new file mode 100644
index 0000000..10c56b3
--- /dev/null
+++ b/src/args.hh
@@ -0,0 +1,60 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef ARGS_HH
+#define ARGS_HH
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+class Args {
+public:
+ virtual ~Args() {}
+
+ static Args* create();
+
+ virtual void add(char short_opt, std::string const& long_opt,
+ std::string const& argument, std::string const& help) = 0;
+ void add(char short_opt, std::string const& long_opt,
+ std::string const& help) {
+ add(short_opt, long_opt, "", help);
+ }
+ void add(std::string const& long_opt, std::string const& help) {
+ add('\0', long_opt, "", help);
+ }
+ void add(std::string const& long_opt, std::string const& argument,
+ std::string const& help) {
+ add('\0', long_opt, argument, help);
+ }
+
+ bool run(int argc, char** argv) {
+ return run(argc, argv, std::cerr);
+ }
+ virtual bool run(int argc, char** argv, std::ostream& out) = 0;
+ bool run(std::string const& prg, int argc, char** argv) {
+ return run(prg, argc, argv, std::cerr);
+ }
+ virtual bool run(
+ std::string const& prg, int argc, char** argv, std::ostream& out) = 0;
+ virtual bool good() const = 0;
+
+ virtual bool is_set(char short_opt) const = 0;
+ virtual bool is_set(std::string const& long_opt) const = 0;
+ virtual char const* arg(char short_opt, char const* fallback) const = 0;
+ virtual char const* arg(std::string const& long_opt,
+ char const* fallback) const = 0;
+
+ virtual std::vector<std::string> const& arguments() const = 0;
+
+ void print_help() const {
+ print_help(std::cout);
+ }
+ virtual void print_help(std::ostream& out) const = 0;
+ virtual void print_help(std::ostream& out, size_t width) const = 0;
+
+protected:
+ Args() {}
+ Args(Args const&) = delete;
+};
+
+#endif // ARGS_HH
diff --git a/src/buffer.cc b/src/buffer.cc
new file mode 100644
index 0000000..d0c5fbb
--- /dev/null
+++ b/src/buffer.cc
@@ -0,0 +1,119 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <algorithm>
+#include <cstring>
+
+#include "buffer.hh"
+
+namespace {
+
+class BufferImpl : public Buffer {
+public:
+ BufferImpl(size_t capacity, size_t min_avail)
+ : capacity_(capacity), min_avail_(min_avail) {
+ data_ = capacity_ > 0 ?
+ reinterpret_cast<char*>(malloc(capacity_)) : nullptr;
+ end_ = data_;
+ rptr_ = data_;
+ wptr_ = data_;
+ }
+ ~BufferImpl() {
+ free(data_);
+ }
+
+ void const* read_ptr(size_t* avail) const override {
+ if (avail) *avail = wptr_ - rptr_;
+ return rptr_;
+ }
+
+ void consume(size_t bytes) override {
+ if (bytes == 0) return;
+ assert(rptr_ + bytes <= wptr_);
+ rptr_ += bytes;
+ if (rptr_ == wptr_) {
+ rptr_ = wptr_ = data_;
+ }
+ }
+
+ void* write_ptr(size_t* avail) override {
+ if (wptr_ + min_avail_ > end_) {
+ if (rptr_ > data_) {
+ memmove(data_, rptr_, wptr_ - rptr_);
+ wptr_ -= rptr_ - data_;
+ rptr_ = data_;
+ }
+ if (wptr_ + min_avail_ > end_) {
+ auto new_size = (end_ - data_) + std::max(capacity_, min_avail_);
+ auto tmp = reinterpret_cast<char*>(realloc(data_, new_size));
+ if (tmp) {
+ end_ = tmp + new_size;
+ rptr_ = tmp + (rptr_ - data_);
+ wptr_ = tmp + (wptr_ - data_);
+ data_ = tmp;
+ }
+ }
+ }
+ if (avail) *avail = end_ - wptr_;
+ return wptr_;
+ }
+
+ void commit(size_t bytes) override {
+ if (bytes == 0) return;
+ assert(wptr_ + bytes <= end_);
+ wptr_ += bytes;
+ }
+
+private:
+ size_t capacity_;
+ size_t min_avail_;
+ char* data_;
+ char* end_;
+ char* rptr_;
+ char* wptr_;
+};
+
+} // namespace
+
+// static
+Buffer* Buffer::create(size_t size, size_t min_avail) {
+ return new BufferImpl(std::max(size, min_avail), min_avail);
+}
+
+bool Buffer::empty() const {
+ size_t avail;
+ read_ptr(&avail);
+ return avail == 0;
+}
+
+size_t Buffer::read(void* data, size_t max) {
+ if (max == 0) return 0;
+ size_t avail;
+ auto ptr = read_ptr(&avail);
+ if (avail == 0) return 0;
+ avail = std::min(avail, max);
+ memcpy(data, ptr, avail);
+ commit(avail);
+ return avail;
+}
+
+void Buffer::write(void const* data, size_t size) {
+ if (size == 0) return;
+ auto d = reinterpret_cast<char const*>(data);
+ size_t pos = 0;
+ while (true) {
+ size_t avail;
+ auto ptr = write_ptr(&avail);
+ if (pos + avail < size) {
+ memcpy(ptr, d + pos, avail);
+ pos += avail;
+ commit(avail);
+ } else {
+ memcpy(ptr, d + pos, size - pos);
+ commit(size - pos);
+ return;
+ }
+ }
+}
+
diff --git a/src/buffer.hh b/src/buffer.hh
new file mode 100644
index 0000000..92a7566
--- /dev/null
+++ b/src/buffer.hh
@@ -0,0 +1,30 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef BUFFER_HH
+#define BUFFER_HH
+
+#include <cstddef>
+
+class Buffer {
+public:
+ virtual ~Buffer() {}
+
+ static Buffer* create(size_t capacity, size_t min_avail);
+
+ bool empty() const;
+
+ virtual void const* read_ptr(size_t* avail) const = 0;
+ virtual void consume(size_t bytes) = 0;
+
+ virtual void* write_ptr(size_t* avail) = 0;
+ virtual void commit(size_t bytes) = 0;
+
+ size_t read(void* data, size_t max);
+ void write(void const* data, size_t size);
+
+protected:
+ Buffer() {}
+ Buffer(Buffer const&) = delete;
+};
+
+#endif // BUFFER_HH
diff --git a/src/character.cc b/src/character.cc
new file mode 100644
index 0000000..98166ee
--- /dev/null
+++ b/src/character.cc
@@ -0,0 +1,31 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include "character.hh"
+
+// static
+bool Character::isspace(std::string const& str, size_t pos) {
+ switch (str[pos]) {
+ case ' ':
+ case '\t':
+ case '\r':
+ case '\n':
+ return true;
+ }
+ return false;
+}
+
+// static
+bool Character::isseparator(std::string const& str, size_t pos) {
+ if (isspace(str, pos)) return true;
+ switch (str[pos]) {
+ case '.':
+ case ':':
+ case '-':
+ case ',':
+ case ';':
+ return true;
+ }
+ return false;
+}
diff --git a/src/character.hh b/src/character.hh
new file mode 100644
index 0000000..18f98d8
--- /dev/null
+++ b/src/character.hh
@@ -0,0 +1,18 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef CHARACTER_HH
+#define CHARACTER_HH
+
+#include <string>
+
+class Character {
+public:
+ static bool isspace(std::string const& str, size_t pos);
+ static bool isseparator(std::string const& str, size_t pos);
+
+private:
+ ~Character() {}
+ Character() {}
+};
+
+#endif // CHARACTER_HH
diff --git a/src/chunked.cc b/src/chunked.cc
new file mode 100644
index 0000000..99aea0d
--- /dev/null
+++ b/src/chunked.cc
@@ -0,0 +1,106 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+
+#include "chunked.hh"
+
+namespace {
+
+enum State {
+ CHUNK,
+ IN_CHUNK,
+ TRAILER,
+ DONE,
+};
+
+class ChunkedImpl : public Chunked {
+public:
+ ChunkedImpl()
+ : state_(CHUNK), good_(true) {
+ }
+
+ size_t add(void const* data, size_t avail) override {
+ if (!good_) return 0;
+ auto const start = reinterpret_cast<char const*>(data);
+ auto const end = start + avail;
+ auto d = start;
+ while (true) {
+ if (d == end) return avail;
+ switch (state_) {
+ case CHUNK: {
+ auto p = find_crlf(d, end);
+ if (!p) return d - start;
+ char* x = nullptr;
+ errno = 0;
+ auto tmp = strtoull(d, &x, 16);
+ if (errno || (x != p && (!x || *x != ';'))) {
+ good_ = false;
+ return d - start;
+ }
+ size_ = tmp;
+ d = p + 2;
+ if (size_ == 0) {
+ // Last chunk
+ state_ = TRAILER;
+ } else {
+ state_ = IN_CHUNK;
+ }
+ break;
+ }
+ case IN_CHUNK:
+ if (static_cast<uint64_t>(end - d) < size_) {
+ return avail;
+ }
+ d += size_;
+ state_ = CHUNK;
+ break;
+ case TRAILER: {
+ auto p = find_crlf(d, end);
+ if (!p) return d - start;
+ if (p == d) {
+ state_ = DONE;
+ }
+ d = p + 2;
+ break;
+ }
+ case DONE:
+ return d - start;
+ }
+ }
+ }
+
+ bool good() const override {
+ return good_;
+ }
+
+ bool eof() const override {
+ return state_ == DONE;
+ }
+
+private:
+ char const* find_crlf(char const* start, char const* end) {
+ for (; start != end; ++start) {
+ if (*start == '\r') {
+ if (start + 1 == end) break;
+ if (start[1] == '\n') return start;
+ }
+ }
+ return nullptr;
+ }
+
+ State state_;
+ bool good_;
+ uint64_t size_;
+};
+
+} // namespace
+
+// static
+Chunked* Chunked::create() {
+ return new ChunkedImpl();
+}
+
diff --git a/src/chunked.hh b/src/chunked.hh
new file mode 100644
index 0000000..66d3ae7
--- /dev/null
+++ b/src/chunked.hh
@@ -0,0 +1,23 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef CHUNKED_HH
+#define CHUNKED_HH
+
+#include <cstddef>
+
+class Chunked {
+public:
+ virtual ~Chunked() { }
+
+ static Chunked* create();
+
+ virtual size_t add(void const* data, size_t avail) = 0;
+ virtual bool good() const = 0;
+ virtual bool eof() const = 0;
+
+protected:
+ Chunked() {}
+ Chunked(Chunked const&) = delete;
+};
+
+#endif // CHUNKED_HH
diff --git a/src/common.hh b/src/common.hh
new file mode 100644
index 0000000..67c8fa5
--- /dev/null
+++ b/src/common.hh
@@ -0,0 +1,18 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef COMMON_HH
+#define COMMON_HH
+
+#ifdef HAVE_CONFIG_H
+# include "config.h"
+#endif
+
+#if HAVE_VAR_ATTRIBUTE_UNUSED
+# define UNUSED(x) __attribute__((unused)) x ## _unused
+#else
+# define UNUSED(x) x ## _unused
+#endif
+
+#include <cassert>
+
+#endif // COMMON_HH
diff --git a/src/config.cc b/src/config.cc
new file mode 100644
index 0000000..9824044
--- /dev/null
+++ b/src/config.cc
@@ -0,0 +1,176 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include "config.hh"
+#include "paths.hh"
+#include "strings.hh"
+#include "xdg.hh"
+
+#include <cstring>
+#include <fstream>
+#include <sstream>
+#include <unordered_map>
+
+namespace {
+
+class ConfigImpl : public Config {
+public:
+ ConfigImpl()
+ : good_(true) {
+ }
+
+ bool load_name(std::string const& name) override {
+ name_ = name;
+
+ auto dirs = XDG::config_dirs();
+ bool good = true;
+ std::string error;
+ std::unordered_map<std::string,std::string> data;
+ for (auto& dir : dirs) {
+ std::string tmp_error;
+ std::unordered_map<std::string,std::string> tmp_data;
+ if (load_file(Paths::join(dir, name_), &tmp_data, &tmp_error)) {
+ data.insert(tmp_data.begin(), tmp_data.end());
+ } else {
+ if (good) {
+ // We want the most import error
+ error = tmp_error;
+ good = false;
+ }
+ }
+ }
+ return update(good, error, data);
+ }
+
+ bool load_file(std::string const& filename) override {
+ std::string error;
+ std::unordered_map<std::string,std::string> data;
+ bool good = load_file(filename, &data, &error);
+ return update(good, error, data);
+ }
+
+ bool good() const override {
+ return good_;
+ }
+
+ std::string const& last_error() const override {
+ return last_error_;
+ }
+
+ std::string const& get(std::string const& key,
+ std::string const& fallback) override {
+ auto i = override_.find(key);
+ if (i != override_.end()) return i->second;
+ i = data_.find(key);
+ if (i != data_.end()) return i->second;
+ return fallback;
+ }
+ char const* get(std::string const& key, char const* fallback) override {
+ auto i = override_.find(key);
+ if (i != override_.end()) return i->second.c_str();
+ i = data_.find(key);
+ if (i != data_.end()) return i->second.c_str();
+ return fallback;
+ }
+ bool is_set(std::string const& key) override {
+ return override_.count(key) > 0 || data_.count(key) > 0;
+ }
+
+ bool get(std::string const& key, bool fallback) override {
+ auto ret = get(key, nullptr);
+ if (!ret) return fallback;
+ return strcmp(ret, "true") == 0;
+ }
+
+ void set(std::string const& key, std::string const& value) override {
+ auto ret = override_.insert(std::make_pair(key, value));
+ if (!ret.second) {
+ ret.first->second = value;
+ }
+ }
+
+private:
+ bool update(bool good, std::string const& last_error,
+ std::unordered_map<std::string, std::string>& data) {
+ good_ = good;
+ if (!good_) {
+ last_error_ = last_error;
+ } else {
+ data_.clear();
+ data_.swap(data);
+ }
+ return good_;
+ }
+
+ static bool load_file(std::string const& filename,
+ std::unordered_map<std::string, std::string>* data,
+ std::string* error) {
+ bool good = true;
+ data->clear();
+ error->clear();
+
+ std::ifstream in(filename);
+ // Non existent file is not considered an error
+ if (in) {
+ std::string line;
+ uint32_t count = 0;
+ std::string key, value;
+ while (std::getline(in, line)) {
+ count++;
+ if (line.empty() || line[0] == '#') continue;
+ auto idx = line.find('=');
+ if (idx == 0 || idx == std::string::npos) {
+ std::stringstream ss;
+ if (idx == 0) {
+ ss << filename << ':' << count << ": Invalid line, starts with '='";
+ } else {
+ ss << filename << ':' << count << ": Invalid line, no '=' found";
+ }
+ *error = ss.str();
+ good = false;
+ break;
+ }
+ size_t start = 0, end = idx;
+ key.assign(Strings::trim(line, start, end));
+ if (data->count(key) > 0) {
+ std::stringstream ss;
+ ss << filename << ':' << count << ": '" << key << "' is already set";
+ *error = ss.str();
+ good = false;
+ break;
+ }
+ start = idx + 1;
+ end = line.size();
+ Strings::trim(line, &start, &end);
+ if (line[start] == '"' && line[end - 1] == '"') {
+ value.assign(Strings::unquote(line, start, end));
+ } else {
+ value.assign(line.substr(start, end - start));
+ }
+ (*data)[key] = value;
+ }
+ if (good && in.bad()) {
+ std::stringstream ss;
+ ss << filename << ": I/O error: " << strerror(errno);
+ *error = ss.str();
+ good = false;
+ }
+ if (!good) data->clear();
+ }
+ return good;
+ }
+
+ bool good_;
+ std::string name_;
+ std::string last_error_;
+ std::unordered_map<std::string, std::string> data_;
+ std::unordered_map<std::string, std::string> override_;
+};
+
+} // namespace
+
+// static
+Config* Config::create() {
+ return new ConfigImpl();
+}
diff --git a/src/config.hh b/src/config.hh
new file mode 100644
index 0000000..5262e91
--- /dev/null
+++ b/src/config.hh
@@ -0,0 +1,33 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef CONFIG_HH
+#define CONFIG_HH
+
+#include <string>
+
+class Config {
+public:
+ virtual ~Config() { }
+
+ static Config* create();
+
+ virtual bool load_name(std::string const& name) = 0;
+ virtual bool load_file(std::string const& filename) = 0;
+
+ virtual bool good() const = 0;
+ virtual std::string const& last_error() const = 0;
+
+ virtual std::string const& get(std::string const& key,
+ std::string const& fallback) = 0;
+ virtual char const* get(std::string const& key, char const* fallback) = 0;
+ virtual bool get(std::string const& key, bool fallback) = 0;
+ virtual bool is_set(std::string const& key) = 0;
+
+ virtual void set(std::string const& key, std::string const& value) = 0;
+
+protected:
+ Config() { }
+ Config(Config const&) = delete;
+};
+
+#endif // CONFIG_HH
diff --git a/src/http.cc b/src/http.cc
new file mode 100644
index 0000000..c043c87
--- /dev/null
+++ b/src/http.cc
@@ -0,0 +1,657 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <cstring>
+#include <memory>
+#include <sstream>
+#include <vector>
+
+#include "http.hh"
+
+namespace {
+
+std::string make_string(char const* data, size_t start, size_t end) {
+ assert(start <= end);
+ return std::string(data + start, end - start);
+}
+
+uint16_t number(char const* data, size_t start, size_t end) {
+ uint16_t ret = 0;
+ assert(start < end);
+ for (; start < end; ++start) {
+ ret *= 10;
+ ret += data[start] - '0';
+ }
+ return ret;
+}
+
+inline char lower_ascii(char c) {
+ return (c >= 'A' && c <= 'Z') ? (c - 'A' + 'a') : c;
+}
+
+inline char upper_ascii(char c) {
+ return (c >= 'a' && c <= 'z') ? (c - 'a' + 'A') : c;
+}
+
+bool lower_equal(char const* data, size_t start, size_t end,
+ std::string const& str) {
+ assert(start <= end);
+ if (str.size() != end - start) return false;
+ for (auto i = str.begin(); start < end; ++start, ++i) {
+ if (lower_ascii(*i) != lower_ascii(data[start])) return false;
+ }
+ return true;
+}
+
+bool lower_equal2(char const* data, size_t start, size_t end,
+ char const* str, size_t size) {
+ assert(start <= end);
+ if (size != end - start) return false;
+ for (auto i = str; start < end; ++start, ++i) {
+ if (*i != lower_ascii(data[start])) return false;
+ }
+ return true;
+}
+
+bool allow_header_append(char const* data, size_t start, size_t end) {
+ // These headers doesn't handle being merged with ',' even if the standard
+ // say they must
+ return !lower_equal2(data, start, end, "set-cookie", 10) &&
+ !lower_equal2(data, start, end, "set-cookie2", 11);
+}
+
+enum ParseResult {
+ GOOD,
+ BAD,
+ INCOMPLETE,
+};
+
+class HeaderIteratorImpl : public HeaderIterator {
+public:
+ HeaderIteratorImpl(char const* data, std::vector<size_t> const* headers)
+ : data_(data), headers_(headers), iter_(headers_->begin()) {
+ }
+
+ bool valid() const override {
+ return iter_ != headers_->end();
+ }
+ std::string name() const override {
+ return make_string(data_, iter_[0], iter_[1]);
+ }
+ bool name_equal(std::string const& name) const override {
+ return lower_equal(data_, iter_[0], iter_[1], name);
+ }
+ std::string value() const override {
+ std::string ret = make_string(data_, iter_[2], iter_[3]);
+ if (allow_header_append(data_, iter_[0], iter_[1])) {
+ auto i = iter_ + 4;
+ while (i != headers_->end()) {
+ if (i[0] != i[1]) break;
+ ret.push_back(',');
+ ret.append(data_ + i[2], i[3] - i[2]);
+ i += 4;
+ }
+ }
+ return ret;
+ }
+ void next() override {
+ if (iter_ != headers_->end()) {
+ while (true) {
+ iter_ += 4;
+ if (iter_ == headers_->end() || iter_[0] != iter_[1]) break;
+ }
+ }
+ }
+
+private:
+ char const* const data_;
+ std::vector<size_t> const* const headers_;
+ std::vector<size_t>::const_iterator iter_;
+};
+
+class FilterHeaderIteratorImpl : public HeaderIteratorImpl {
+public:
+ FilterHeaderIteratorImpl(char const* data, std::vector<size_t> const* headers,
+ std::string const& filter)
+ : HeaderIteratorImpl(data, headers), filter_(filter) {
+ check_filter();
+ }
+
+ void next() override {
+ HeaderIteratorImpl::next();
+ check_filter();
+ }
+
+private:
+ void check_filter() {
+ while (true) {
+ if (!valid() || name_equal(filter_)) return;
+ next();
+ }
+ }
+
+ std::string const filter_;
+};
+
+class AbstractHttp : public virtual Http {
+public:
+ AbstractHttp(char const* data, size_t size)
+ : data_(data), size_(size) {
+ }
+
+ std::string proto() const override {
+ return make_string(data_, proto_start_, proto_slash_);
+ }
+
+ bool proto_equal(std::string const& proto) const override {
+ return proto.compare(0, proto.size(),
+ data_ + proto_start_, proto_slash_ - proto_start_) == 0;
+ }
+
+ Version proto_version() const override {
+ Version ret;
+ ret.major = number(data_, proto_slash_ + 1, proto_dot_);
+ ret.minor = number(data_, proto_dot_ + 1, proto_end_);
+ return ret;
+ }
+
+ std::unique_ptr<HeaderIterator> header() const override {
+ return std::unique_ptr<HeaderIterator>(
+ new HeaderIteratorImpl(data_, &headers_));
+ }
+ std::unique_ptr<HeaderIterator> header(
+ std::string const& name) const override {
+ return std::unique_ptr<HeaderIterator>(
+ new FilterHeaderIteratorImpl(data_, &headers_, name));
+ }
+
+ size_t size() const override {
+ return content_start_;
+ }
+
+protected:
+ ParseResult parse_headers() {
+ headers_.clear();
+ while (true) {
+ auto start = content_start_;
+ auto end = find_newline(start, &content_start_);
+ if (end == std::string::npos) return INCOMPLETE;
+ if (end == start) {
+ // The final newline can only be a alone '\r' if the one in front of
+ // it is also '\r', otherwise we expect a missing '\n'
+ if (data_[start - 1] == '\n' && data_[content_start_ - 1] == '\r') {
+ return INCOMPLETE;
+ }
+ break;
+ }
+ if (is_lws(data_[start])) {
+ if (headers_.empty()) return BAD;
+ headers_.push_back(start);
+ headers_.push_back(start);
+ headers_.push_back(start + 1);
+ headers_.push_back(end);
+ } else {
+ auto colon = find(start, ':', end);
+ if (colon == std::string::npos) return BAD;
+ auto value_start = skip_lws(colon + 1, end);
+ while (colon > start && is_lws(data_[colon - 1])) --colon;
+ headers_.push_back(start);
+ headers_.push_back(colon);
+ headers_.push_back(value_start);
+ headers_.push_back(end);
+ }
+ }
+ return GOOD;
+ }
+
+ void new_data(char const* data) {
+ data_ = data;
+ }
+
+ char const* data() const {
+ return data_;
+ }
+
+ size_t data_size() const {
+ return size_;
+ }
+
+ size_t find(size_t start, char c, size_t end) const {
+ assert(start <= end);
+ for (; start < end; ++start) {
+ if (data_[start] == c) return start;
+ }
+ return std::string::npos;
+ }
+
+ static bool is_lws(char c) {
+ return c == ' ' || c == '\t';
+ }
+
+ size_t skip_lws(size_t start, size_t end) const {
+ assert(start <= end);
+ while (start < end && is_lws(data_[start])) ++start;
+ return start;
+ }
+
+ size_t find_newline(size_t start, size_t* next) const {
+ assert(start <= size_);
+ for (; start < size_; ++start) {
+ if (data_[start] == '\r') {
+ if (start + 1 < size_ && data_[start + 1] == '\n') {
+ if (next) *next = start + 2;
+ } else {
+ if (next) *next = start + 1;
+ }
+ return start;
+ } else if (data_[start] == '\n') {
+ if (next) *next = start + 1;
+ return start;
+ }
+ }
+ return std::string::npos;
+ }
+
+ size_t valid_number(size_t start, size_t end) const {
+ assert(start <= end);
+ if (start == end) return std::string::npos;
+ if (data_[start] == '0') {
+ return start + 1;
+ }
+ if (data_[start] < '0' || data_[start] > '9') return std::string::npos;
+ for (++start; start < end; ++start) {
+ if (data_[start] < '0' || data_[start] > '9') break;
+ }
+ return start;
+ }
+
+ char const* data_;
+ size_t const size_;
+ size_t proto_start_;
+ size_t proto_slash_;
+ size_t proto_dot_;
+ size_t proto_end_;
+ std::vector<size_t> headers_;
+ size_t content_start_;
+};
+
+class AbstractHttpResponse : public HttpResponse, protected AbstractHttp {
+public:
+ AbstractHttpResponse(char const* data, size_t size)
+ : AbstractHttp(data, size), good_(false) {
+ }
+
+ bool good() const override {
+ return good_;
+ }
+
+ uint16_t status_code() const override {
+ return number(data_, status_start_, status_end_);
+ }
+
+ std::string status_message() const override {
+ return make_string(data_, status_msg_start_, status_msg_end_);
+ }
+
+ ParseResult parse() {
+ good_ = false;
+ status_msg_end_ = find_newline(0, &content_start_);
+ if (status_msg_end_ == std::string::npos) return INCOMPLETE;
+ proto_start_ = 0;
+ proto_slash_ = find(0, '/', status_msg_end_);
+ if (proto_slash_ == std::string::npos) return BAD;
+ proto_dot_ = valid_number(proto_slash_ + 1, status_msg_end_);
+ if (proto_dot_ == std::string::npos || data_[proto_dot_] != '.') {
+ return BAD;
+ }
+ proto_end_ = valid_number(proto_dot_ + 1, status_msg_end_);
+ if (proto_end_ == std::string::npos || !is_lws(data_[proto_end_])) {
+ return BAD;
+ }
+ status_start_ = skip_lws(proto_end_ + 1, status_msg_end_);
+ status_end_ = valid_number(status_start_, status_msg_end_);
+ if (status_end_ == std::string::npos) return BAD;
+ if (is_lws(data_[status_end_])) {
+ status_msg_start_ = skip_lws(status_end_ + 1, status_msg_end_);
+ } else {
+ status_msg_start_ = status_end_;
+ if (status_msg_start_ != status_msg_end_) return BAD;
+ }
+
+ auto ret = parse_headers();
+ if (ret == GOOD) good_ = true;
+ return ret;
+ }
+
+protected:
+ bool good_;
+ size_t status_start_;
+ size_t status_end_;
+ size_t status_msg_start_;
+ size_t status_msg_end_;
+};
+
+class UniqueHttpResponse : public AbstractHttpResponse {
+public:
+ UniqueHttpResponse(char const* data, size_t size)
+ : AbstractHttpResponse(data, size) {
+ }
+
+ void copy() {
+ assert(!data_);
+ auto tmp = new char[data_size()];
+ memcpy(tmp, data(), data_size());
+ new_data(tmp);
+ data_.reset(tmp);
+ }
+
+private:
+ std::unique_ptr<char[]> data_;
+};
+
+class SharedHttpResponse : public AbstractHttpResponse {
+public:
+ SharedHttpResponse(std::shared_ptr<char> data, size_t offset, size_t len)
+ : AbstractHttpResponse(data.get() + offset, len), data_(data) {
+ }
+
+private:
+ std::shared_ptr<char> data_;
+};
+
+class AbstractHttpRequest : public HttpRequest, protected AbstractHttp {
+public:
+ AbstractHttpRequest(char const* data, size_t size)
+ : AbstractHttp(data, size), good_(false) {
+ }
+
+ bool good() const override {
+ return good_;
+ }
+
+ std::string method() const override {
+ return make_string(data_, 0, method_end_);
+ }
+
+ bool method_equal(std::string const& method) const override {
+ return method.compare(0, method.size(), data_, method_end_) == 0;
+ }
+ std::string url() const override {
+ return make_string(data_, url_start_, url_end_);
+ }
+
+ ParseResult parse() {
+ good_ = false;
+ proto_end_ = find_newline(0, &content_start_);
+ if (proto_end_ == std::string::npos) return INCOMPLETE;
+ method_end_ = 0;
+ while (method_end_ < proto_end_ && !is_lws(data_[method_end_])) {
+ ++method_end_;
+ }
+ if (method_end_ == 0 || method_end_ == proto_end_) return BAD;
+ url_start_ = skip_lws(method_end_ + 1, proto_end_);
+ url_end_ = url_start_;
+ while (url_end_ < proto_end_ && !is_lws(data_[url_end_])) {
+ ++url_end_;
+ }
+ if (url_end_ == url_start_ || url_end_ == proto_end_) return BAD;
+ proto_start_ = skip_lws(url_end_ + 1, proto_end_);
+ proto_slash_ = find(proto_start_, '/', proto_end_);
+ if (proto_slash_ == std::string::npos) return BAD;
+ proto_dot_ = valid_number(proto_slash_ + 1, proto_end_);
+ if (proto_dot_ == std::string::npos || data_[proto_dot_] != '.') {
+ return BAD;
+ }
+ auto tmp = valid_number(proto_dot_ + 1, proto_end_);
+ if (tmp != proto_end_) return BAD;
+
+ auto ret = parse_headers();
+ if (ret == GOOD) good_ = true;
+ return ret;
+ }
+
+protected:
+ bool good_;
+ size_t method_end_;
+ size_t url_start_;
+ size_t url_end_;
+};
+
+class UniqueHttpRequest : public AbstractHttpRequest {
+public:
+ UniqueHttpRequest(char const* data, size_t size)
+ : AbstractHttpRequest(data, size) {
+ }
+
+ void copy() {
+ assert(!data_);
+ auto tmp = new char[data_size()];
+ memcpy(tmp, data(), data_size());
+ new_data(tmp);
+ data_.reset(tmp);
+ }
+
+private:
+ std::unique_ptr<char[]> data_;
+};
+
+class SharedHttpRequest : public AbstractHttpRequest {
+public:
+ SharedHttpRequest(std::shared_ptr<char> data, size_t offset, size_t len)
+ : AbstractHttpRequest(data.get() + offset, len), data_(data) {
+ }
+
+private:
+ std::shared_ptr<char> data_;
+};
+
+class AbstractHttpBuilder {
+public:
+ AbstractHttpBuilder()
+ : set_content_length_(false), set_content_(false) {
+ }
+
+ void add_header(std::string const& name,
+ std::string const& value) {
+ if (name.empty()) {
+ data_.push_back(' ');
+ data_.append(value);
+ data_.append("\r\n", 2);
+ return;
+ }
+ size_t pos = 0;
+ auto c = upper_ascii(name[pos]);
+ if (c != name[pos]) {
+ data_.push_back(c);
+ ++pos;
+ }
+ auto last = pos;
+ for (; pos < name.size(); ++pos) {
+ if (name[pos] == '-') {
+ ++pos;
+ if (pos == name.size()) break;
+ data_.append(name.data() + last, pos - last);
+ auto c = upper_ascii(name[pos]);
+ if (c != name[pos]) {
+ data_.push_back(c);
+ ++pos;
+ }
+ last = pos;
+ }
+ }
+ if (!set_content_length_
+ && lower_equal2(name.data(), 0, name.size(), "content-length", 14)) {
+ set_content_length_ = true;
+ }
+ if (last < pos) {
+ data_.append(name.data() + last, pos - last);
+ }
+ data_.append(": ", 2);
+ data_.append(value);
+ data_.append("\r\n", 2);
+ }
+
+ void set_content(std::string const& content) {
+ set_content_ = true;
+ content_ = content;
+ }
+
+ std::string build() const {
+ std::string ret(data_);
+ if (!set_content_length_ && set_content_) {
+ char tmp[50];
+ auto len = snprintf(tmp, sizeof(tmp),
+ "Content-Length: %zu\r\n", content_.size());
+ ret.append(tmp, len);
+ }
+ ret.append("\r\n", 2);
+ if (set_content_) {
+ ret.append(content_);
+ }
+ return ret;
+ }
+
+protected:
+ std::string data_;
+ std::string content_;
+ bool set_content_length_;
+ bool set_content_;
+};
+
+class HttpRequestBuilderImpl : public HttpRequestBuilder, AbstractHttpBuilder {
+public:
+ HttpRequestBuilderImpl(std::string const& method, std::string const& url,
+ std::string const& proto, Version version) {
+ data_.append(method);
+ data_.push_back(' ');
+ data_.append(url);
+ data_.push_back(' ');
+ data_.append(proto);
+ data_.push_back('/');
+ char tmp[10];
+ data_.append(tmp, snprintf(
+ tmp, 10, "%u", static_cast<unsigned int>(version.major)));
+ data_.push_back('.');
+ data_.append(tmp, snprintf(
+ tmp, 10, "%u", static_cast<unsigned int>(version.minor)));
+ data_.append("\r\n", 2);
+ }
+
+ void add_header(std::string const& name,
+ std::string const& value) override {
+ AbstractHttpBuilder::add_header(name, value);
+ }
+ void set_content(std::string const& content) override {
+ AbstractHttpBuilder::set_content(content);
+ }
+ std::string build() const override {
+ return AbstractHttpBuilder::build();
+ }
+};
+
+class HttpResponseBuilderImpl : public HttpResponseBuilder,
+ AbstractHttpBuilder {
+public:
+ HttpResponseBuilderImpl(std::string const& proto, Version version,
+ uint16_t status_code, std::string const& status) {
+ data_.append(proto);
+ data_.push_back('/');
+ char tmp[10];
+ data_.append(tmp, snprintf(
+ tmp, 10, "%u", static_cast<unsigned int>(version.major)));
+ data_.push_back('.');
+ data_.append(tmp, snprintf(
+ tmp, 10, "%u", static_cast<unsigned int>(version.minor)));
+ data_.push_back(' ');
+ data_.append(tmp, snprintf(
+ tmp, 10, "%u", static_cast<unsigned int>(status_code)));
+ data_.push_back(' ');
+ data_.append(status);
+ data_.append("\r\n", 2);
+ }
+
+ void add_header(std::string const& name,
+ std::string const& value) override {
+ AbstractHttpBuilder::add_header(name, value);
+ }
+ void set_content(std::string const& content) override {
+ AbstractHttpBuilder::set_content(content);
+ }
+ std::string build() const override {
+ return AbstractHttpBuilder::build();
+ }
+};
+
+} // namespace
+
+// static
+HttpResponse* HttpResponse::parse(std::string const& data) {
+ return parse(data.data(), data.size(), true);
+}
+
+// static
+HttpResponse* HttpResponse::parse(char const* data, size_t len, bool copy) {
+ auto ret = std::unique_ptr<UniqueHttpResponse>(
+ new UniqueHttpResponse(data, len));
+ if (ret->parse() == INCOMPLETE) return nullptr;
+ if (copy) ret->copy();
+ return ret.release();
+}
+
+// static
+HttpResponse* HttpResponse::parse(std::shared_ptr<char> data,
+ size_t offset, size_t len) {
+ auto ret = std::unique_ptr<SharedHttpResponse>(
+ new SharedHttpResponse(data, offset, len));
+ if (ret->parse() == INCOMPLETE) return nullptr;
+ return ret.release();
+}
+
+std::string Http::first_header(std::string const& name) const {
+ static std::string empty_str;
+ auto iter = header(name);
+ if (iter->valid()) {
+ return iter->value();
+ }
+ return empty_str;
+}
+
+// static
+HttpRequest* HttpRequest::parse(std::string const& data) {
+ return parse(data.data(), data.size(), true);
+}
+
+// static
+HttpRequest* HttpRequest::parse(char const* data, size_t len, bool copy) {
+ auto ret = std::unique_ptr<UniqueHttpRequest>(
+ new UniqueHttpRequest(data, len));
+ if (ret->parse() == INCOMPLETE) return nullptr;
+ if (copy) ret->copy();
+ return ret.release();
+}
+
+// static
+HttpRequest* HttpRequest::parse(std::shared_ptr<char> data,
+ size_t offset, size_t len) {
+ auto ret = std::unique_ptr<SharedHttpRequest>(
+ new SharedHttpRequest(data, offset, len));
+ if (ret->parse() == INCOMPLETE) return nullptr;
+ return ret.release();
+}
+
+// static
+HttpRequestBuilder* HttpRequestBuilder::create(std::string const& method,
+ std::string const& url,
+ std::string const& proto,
+ Version version) {
+ return new HttpRequestBuilderImpl(method, url, proto, version);
+}
+
+// static
+HttpResponseBuilder* HttpResponseBuilder::create(std::string const& proto,
+ Version version,
+ uint16_t status_code,
+ std::string const& status) {
+ return new HttpResponseBuilderImpl(proto, version, status_code, status);
+}
diff --git a/src/http.hh b/src/http.hh
new file mode 100644
index 0000000..091d3d4
--- /dev/null
+++ b/src/http.hh
@@ -0,0 +1,141 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef HTTP_HH
+#define HTTP_HH
+
+#include <cstdint>
+#include <memory>
+#include <string>
+
+// glibc defines these even tho it shouldn't if you ask posix
+#ifdef major
+# undef major
+#endif
+#ifdef minor
+# undef minor
+#endif
+
+struct Version {
+ uint16_t major;
+ uint16_t minor;
+
+ Version()
+ : major(0), minor(0) {
+ }
+ Version(uint16_t major, uint16_t minor)
+ : major(major), minor(minor) {
+ }
+};
+
+class HeaderIterator {
+public:
+ virtual ~HeaderIterator() {}
+
+ virtual bool valid() const = 0;
+ virtual std::string name() const = 0;
+ virtual bool name_equal(std::string const& name) const = 0;
+ virtual std::string value() const = 0;
+ virtual void next() = 0;
+
+protected:
+ HeaderIterator() {}
+ HeaderIterator(HeaderIterator const&) = delete;
+};
+
+class Http {
+public:
+ virtual ~Http() {}
+
+ virtual bool good() const = 0;
+
+ virtual std::string proto() const = 0;
+ virtual bool proto_equal(std::string const& proto) const = 0;
+ virtual Version proto_version() const = 0;
+ virtual std::unique_ptr<HeaderIterator> header() const = 0;
+ virtual std::unique_ptr<HeaderIterator> header(
+ std::string const& name) const = 0;
+ std::string first_header(std::string const& name) const;
+
+ virtual size_t size() const = 0;
+
+protected:
+ Http() {}
+ Http(Http const&) = delete;
+};
+
+class HttpResponse : public virtual Http {
+public:
+ virtual ~HttpResponse() {}
+
+ static HttpResponse* parse(std::string const& data);
+ static HttpResponse* parse(char const* data, size_t len, bool copy = true);
+ static HttpResponse* parse(std::shared_ptr<char> data,
+ size_t offset, size_t len);
+
+ virtual uint16_t status_code() const = 0;
+ virtual std::string status_message() const = 0;
+
+protected:
+ HttpResponse() {}
+ HttpResponse(HttpResponse const&) = delete;
+};
+
+class HttpRequest : public virtual Http {
+public:
+ virtual ~HttpRequest() {}
+
+ static HttpRequest* parse(std::string const& data);
+ static HttpRequest* parse(char const* data, size_t len, bool copy = true);
+ static HttpRequest* parse(std::shared_ptr<char> data,
+ size_t offset, size_t len);
+
+ virtual std::string method() const = 0;
+ virtual bool method_equal(std::string const& method) const = 0;
+ virtual std::string url() const = 0;
+
+protected:
+ HttpRequest() {}
+ HttpRequest(HttpRequest const&) = delete;
+};
+
+class HttpResponseBuilder {
+public:
+ virtual ~HttpResponseBuilder() {}
+
+ static HttpResponseBuilder* create(
+ std::string const& proto, Version version,
+ uint16_t status_code, std::string const& status_message);
+
+ virtual void add_header(std::string const& name,
+ std::string const& value) = 0;
+ // This will add a content-length header unless there already is one
+ virtual void set_content(std::string const& content) = 0;
+
+ virtual std::string build() const = 0;
+
+protected:
+ HttpResponseBuilder() {}
+ HttpResponseBuilder(HttpResponseBuilder const&) = delete;
+};
+
+class HttpRequestBuilder {
+public:
+ virtual ~HttpRequestBuilder() {}
+
+ static HttpRequestBuilder* create(
+ std::string const& method, std::string const& url,
+ std::string const& proto, Version version);
+
+ virtual void add_header(std::string const& name,
+ std::string const& value) = 0;
+ // This will add a content-length header unless there already is one
+ virtual void set_content(std::string const& content) = 0;
+
+ virtual std::string build() const = 0;
+
+protected:
+ HttpRequestBuilder() {}
+ HttpRequestBuilder(HttpRequestBuilder const&) = delete;
+};
+
+#endif // HTTP_HH
diff --git a/src/io.cc b/src/io.cc
new file mode 100644
index 0000000..5c33ea2
--- /dev/null
+++ b/src/io.cc
@@ -0,0 +1,90 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <algorithm>
+#include <cerrno>
+#include <limits>
+#include <unistd.h>
+
+#include "io.hh"
+
+namespace io {
+
+ssize_t read(int fd, void* buf, size_t max) {
+ while (true) {
+ ssize_t ret = ::read(fd, buf, max);
+ if (ret != -1 || errno != EINTR) return ret;
+ }
+}
+
+ssize_t write(int fd, void const* buf, size_t max) {
+ while (true) {
+ ssize_t ret = ::write(fd, buf, max);
+ if (ret != -1 || errno != EINTR) return ret;
+ }
+}
+
+bool read_all(int fd, void* buf, size_t size) {
+ auto data = reinterpret_cast<char*>(buf);
+ auto end = data + size;
+ while (data < end) {
+ ssize_t ret = ::read(fd, data, std::min(
+ std::numeric_limits<ssize_t>::max(), end - data));
+ if (ret < 0) {
+ if (errno == EINTR) continue;
+ return false;
+ }
+ if (ret == 0) return false;
+ data += ret;
+ }
+ return true;
+}
+
+bool write_all(int fd, void const* buf, size_t size) {
+ auto data = reinterpret_cast<char const*>(buf);
+ auto end = data + size;
+ while (data < end) {
+ ssize_t ret = ::write(fd, data, std::min(
+ std::numeric_limits<ssize_t>::max(), end - data));
+ if (ret < 0) {
+ if (errno == EINTR) continue;
+ return false;
+ }
+ if (ret == 0) return false;
+ data += ret;
+ }
+ return true;
+}
+
+bool auto_fd::reset() {
+ if (fd_ >= 0) {
+ auto tmp = fd_;
+ fd_ = -1;
+ return close(tmp) != 0;
+ }
+ return true;
+}
+
+void auto_pipe::reset() {
+ if (read_ >= 0) {
+ assert(write_ >= 0);
+ close(read_);
+ close(write_);
+ read_ = -1;
+ write_ = -1;
+ }
+}
+
+bool auto_pipe::open() {
+ reset();
+ int fd[2];
+ if (pipe(fd) != 0) {
+ return false;
+ }
+ read_ = fd[0];
+ write_ = fd[1];
+ return true;
+}
+
+} // namespace io
diff --git a/src/io.hh b/src/io.hh
new file mode 100644
index 0000000..3bfccec
--- /dev/null
+++ b/src/io.hh
@@ -0,0 +1,124 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef IO_HH
+#define IO_HH
+
+#include <sys/types.h>
+
+namespace io {
+
+ssize_t read(int fd, void* buf, size_t max);
+ssize_t write(int fd, void const* buf, size_t max);
+
+bool read_all(int fd, void* buf, size_t size);
+bool write_all(int fd, void const* buf, size_t size);
+
+class auto_fd {
+public:
+ auto_fd()
+ : fd_(-1) {
+ }
+ explicit auto_fd(int fd)
+ : fd_(fd) {
+ }
+ auto_fd(auto_fd&& fd)
+ : fd_(fd.fd_) {
+ fd.fd_ = -1;
+ }
+ auto_fd(auto_fd const&) = delete;
+ ~auto_fd() {
+ reset();
+ }
+
+ auto_fd& operator=(auto_fd&& fd) {
+ reset(fd.release());
+ return *this;
+ }
+ auto_fd& operator=(auto_fd&) = delete;
+
+ int get() const {
+ return fd_;
+ }
+
+ explicit operator bool() const {
+ return fd_ >= 0;
+ }
+
+ int release() {
+ auto ret = fd_;
+ fd_ = -1;
+ return ret;
+ }
+
+ bool reset();
+
+ bool reset(int fd) {
+ auto ret = reset();
+ fd_ = fd;
+ return ret;
+ }
+
+ void swap(auto_fd& fd) {
+ auto tmp = fd_;
+ fd_ = fd.fd_;
+ fd.fd_ = tmp;
+ }
+private:
+ int fd_;
+};
+
+class auto_pipe {
+public:
+ auto_pipe()
+ : read_(-1), write_(-1) {
+ }
+ auto_pipe(auto_pipe&& pipe)
+ : read_(pipe.read_), write_(pipe.write_) {
+ pipe.read_ = -1;
+ pipe.write_ = -1;
+ }
+ auto_pipe(auto_pipe const&) = delete;
+ ~auto_pipe() {
+ reset();
+ }
+
+ auto_pipe& operator=(auto_pipe&& pipe) {
+ reset();
+ read_ = pipe.read_;
+ write_ = pipe.write_;
+ pipe.read_ = -1;
+ pipe.write_ = -1;
+ return *this;
+ }
+ auto_pipe& operator=(auto_pipe&) = delete;
+
+ int read() const {
+ return read_;
+ }
+ int write() const {
+ return write_;
+ }
+
+ explicit operator bool() const {
+ return read_ >= 0 /* && write_ >= 0 */;
+ }
+
+ bool open();
+
+ void reset();
+
+ void swap(auto_pipe& pipe) {
+ auto r = read_;
+ auto w = write_;
+ read_ = pipe.read_;
+ write_ = pipe.write_;
+ pipe.read_ = r;
+ pipe.write_ = w;
+ }
+private:
+ int read_, write_;
+};
+
+} // namespace io
+
+#endif // IO_HH
diff --git a/src/logger.cc b/src/logger.cc
new file mode 100644
index 0000000..29dcb38
--- /dev/null
+++ b/src/logger.cc
@@ -0,0 +1,130 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#ifndef _BSD_SOURCE
+#define _BSD_SOURCE
+#endif
+#ifndef _DEFAULT_SOURCE
+#define _DEFAULT_SOURCE
+#endif
+
+#include <cstdarg>
+#include <cstdio>
+#include <iostream>
+#include <memory>
+#include <syslog.h>
+
+#include "logger.hh"
+
+namespace {
+
+class LoggerStdErr : public Logger {
+public:
+ void out(Level UNUSED(lvl), char const* format, ...) override {
+ char* tmp;
+ va_list args;
+ va_start(args, format);
+ auto ret = vasprintf(&tmp, format, args);
+ va_end(args);
+ if (ret == -1) return;
+ std::cerr << tmp << std::endl;
+ free(tmp);
+ }
+};
+
+class LoggerSyslog : public Logger {
+public:
+ LoggerSyslog(std::string const& name) {
+ openlog(name.c_str(), LOG_PID, LOG_DAEMON);
+ }
+
+ ~LoggerSyslog() override {
+ closelog();
+ }
+
+ void out(Level lvl, char const* format, ...) override {
+ va_list args;
+ va_start(args, format);
+ vsyslog(lvl2prio(lvl), format, args);
+ va_end(args);
+ }
+
+private:
+ static int lvl2prio(Level lvl) {
+ switch (lvl) {
+ case ERR:
+ return LOG_ERR;
+ case WARN:
+ return LOG_WARNING;
+ case INFO:
+ return LOG_INFO;
+ }
+ assert(false);
+ return LOG_INFO;
+ }
+};
+
+class LoggerFile : public Logger {
+public:
+ LoggerFile()
+ : fh_(nullptr) {
+ }
+
+ bool open(std::string const& path) {
+ if (fh_) fclose(fh_);
+ fh_ = fopen(path.c_str(), "a");
+ return fh_ != NULL;
+ }
+
+ ~LoggerFile() override {
+ if (fh_) fclose(fh_);
+ }
+
+ void out(Level lvl, char const* format, ...) override {
+ fputs(lvl2str(lvl), fh_);
+ fwrite(": ", 1, 2, fh_);
+ va_list args;
+ va_start(args, format);
+ vfprintf(fh_, format, args);
+ va_end(args);
+ fputc('\n', fh_);
+ }
+
+private:
+ static char const* lvl2str(Level lvl) {
+ switch (lvl) {
+ case ERR:
+ return "Error";
+ case WARN:
+ return "Warning";
+ case INFO:
+ return "Info";
+ }
+ assert(false);
+ return "Info";
+ }
+
+ FILE* fh_;
+};
+
+} // namespace
+
+// static
+Logger* Logger::create_stderr() {
+ return new LoggerStdErr();
+}
+
+// static
+Logger* Logger::create_syslog(std::string const& name) {
+ return new LoggerSyslog(name);
+}
+
+// static
+Logger* Logger::create_file(std::string const& path) {
+ std::unique_ptr<LoggerFile> ret(new LoggerFile());
+ if (ret->open(path)) {
+ return ret.release();
+ }
+ return nullptr;
+}
diff --git a/src/logger.hh b/src/logger.hh
new file mode 100644
index 0000000..8b0db05
--- /dev/null
+++ b/src/logger.hh
@@ -0,0 +1,33 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef LOGGER_HH
+#define LOGGER_HH
+
+#include <string>
+
+class Logger {
+public:
+ virtual ~Logger() {}
+
+ enum Level {
+ ERR,
+ WARN,
+ INFO
+ };
+
+ static Logger* create_stderr();
+ static Logger* create_syslog(std::string const& name);
+ static Logger* create_file(std::string const& path);
+
+ virtual void out(Level lvl, char const* format, ...)
+#ifdef HAVE_FUNC_ATTRIBUTE_FORMAT
+ __attribute__ ((format (printf, 3, 4)))
+#endif
+ = 0;
+
+protected:
+ Logger() {}
+ Logger(Logger const&) = delete;
+};
+
+#endif // LOGGER_HH
diff --git a/src/looper.cc b/src/looper.cc
new file mode 100644
index 0000000..0da851b
--- /dev/null
+++ b/src/looper.cc
@@ -0,0 +1,287 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <list>
+#include <poll.h>
+#include <vector>
+
+#include "looper.hh"
+
+namespace {
+
+int const read_events = POLLIN | POLLPRI
+#ifdef POLLRDHUP
+ | POLLRDHUP
+#endif
+ ;
+int const write_events = POLLOUT;
+int const error_events = POLLERR | POLLNVAL;
+int const hup_events = POLLHUP;
+
+class LooperImpl : public Looper {
+public:
+ LooperImpl()
+ : fds_to_remove_(0), fds_protected_(0), quit_(false) {
+ }
+
+ ~LooperImpl() override {
+ for (auto& timeout : timeouts_) {
+ delete timeout;
+ }
+ }
+
+ void add(int fd, uint8_t events, FdCallback const& callback) override {
+ if (fd < 0) {
+ assert(false);
+ return;
+ }
+ for (auto it = fds_.begin(); it != fds_.end(); ++it) {
+ if (it->fd == fd) {
+ size_t index = it - fds_.begin();
+ auto& entry = fdentries_[index];
+ if (index >= fds_protected_) {
+ entry.callback = callback;
+ it->events = pollevents(events);
+ return;
+ } else {
+ // Don't call new callback this run, so add it at the end but
+ // remove the old callback as it would be replaced
+ if (!entry.removed) {
+ entry.removed = true;
+ fds_to_remove_++;
+ }
+ }
+ }
+ }
+
+ fds_.emplace_back((struct pollfd) { fd, pollevents(events), 0 });
+ fdentries_.emplace_back(callback);
+ }
+
+ void modify(int fd, uint8_t events) override {
+ if (fd < 0) {
+ assert(false);
+ return;
+ }
+ for (auto it = fds_.begin(); it != fds_.end(); ++it) {
+ if (it->fd == fd) {
+ size_t index = it - fds_.begin();
+ if (index < fds_protected_) {
+ auto entry = fdentries_.begin() + index;
+ // If entry is removed we need to write to the later one
+ if (entry->removed) {
+ continue;
+ }
+ }
+ it->events = pollevents(events);
+ return;
+ }
+ }
+ assert(false);
+ }
+
+ void remove(int fd) override {
+ if (fd < 0) return;
+ for (auto it = fds_.begin(); it != fds_.end(); ++it) {
+ if (it->fd == fd) {
+ size_t index = it - fds_.begin();
+ auto entry = fdentries_.begin() + index;
+ if (index < fds_protected_) {
+ if (!entry->removed) {
+ entry->removed = true;
+ fds_to_remove_++;
+ }
+ } else {
+ fds_.erase(it);
+ fdentries_.erase(entry);
+ }
+ return;
+ }
+ }
+ assert(false);
+ }
+
+ void* schedule(float delay_s, ScheduleCallback const& callback) override {
+ clock::time_point target = clock::now()
+ + std::chrono::duration_cast<clock::duration>(
+ std::chrono::duration<float>(delay_s));
+ auto timeout = new Timeout(target, callback);
+ for (auto it = timeouts_.begin(); it != timeouts_.end(); ++it) {
+ if (target < (*it)->target) {
+ timeouts_.insert(it, timeout);
+ return timeout;
+ }
+ }
+ timeouts_.push_back(timeout);
+ return timeout;
+ }
+
+ void cancel(void* handle) override {
+ auto timeout = reinterpret_cast<Timeout*>(handle);
+ if (!timeout) {
+ assert(false);
+ return;
+ }
+ if (timeout->expired) return;
+ for (auto it = timeouts_.begin(); it != timeouts_.end(); ++it) {
+ if (*it == timeout) {
+ timeouts_.erase(it);
+ delete timeout;
+ return;
+ }
+ }
+ assert(false);
+ }
+
+ void quit() override {
+ quit_ = true;
+ }
+
+ bool run() override {
+ std::vector<Timeout*> expired;
+
+ while (!quit_) {
+ int timeout = -1;
+ if (!timeouts_.empty()) {
+ auto dur = std::chrono::duration_cast<std::chrono::milliseconds>(
+ timeouts_.front()->target - clock::now());
+ if (dur.count() <= 0) {
+ timeout = 0;
+ } else if (dur.count() < std::numeric_limits<int>::max()) {
+ timeout = dur.count();
+ } else {
+ timeout = std::numeric_limits<int>::max();
+ }
+ }
+ auto ret = poll(fds_.data(), fds_.size(), timeout);
+ if (ret < 0) {
+ if (errno == EINTR) continue;
+ return false;
+ }
+ now_ = clock::now();
+ fds_protected_ = fds_.size();
+
+ if (!timeouts_.empty()) {
+ while (timeouts_.front()->target <= now_) {
+ auto timeout = timeouts_.front();
+ timeouts_.pop_front();
+ timeout->expired = true;
+ expired.push_back(timeout);
+ if (timeouts_.empty()) break;
+ }
+
+ for (auto& timeout : expired) {
+ timeout->callback(timeout);
+ }
+ for (auto& timeout : expired) {
+ delete timeout;
+ }
+ expired.clear();
+ }
+
+ // Not using iterators here as that would be unsafe with
+ // add() and remove() modifying the vector outside protected range
+ // while callbacks are called
+ size_t i;
+ for (i = 0; ret > 0 && i < fds_protected_; ++i) {
+ if (fds_[i].revents) {
+ --ret;
+ if (!fdentries_[i].removed) {
+ fdentries_[i].callback(fds_[i].fd, unpollevents(fds_[i].revents));
+ }
+ }
+ }
+ assert(ret == 0);
+ assert(fds_.size() >= fds_protected_);
+ assert(fdentries_.size() >= fds_protected_);
+ for (i = fds_protected_; fds_to_remove_ > 0 && i > 0; --i) {
+ if (fdentries_[i - 1].removed) {
+ --fds_to_remove_;
+ fds_.erase(fds_.begin() + i - 1);
+ fdentries_.erase(fdentries_.begin() + i - 1);
+ }
+ }
+ assert(fds_to_remove_ == 0);
+ fds_protected_ = 0;
+ }
+ return true;
+ }
+
+ clock::time_point now() const override {
+ return now_;
+ }
+
+private:
+ struct FdEntry {
+ FdCallback callback;
+ bool removed;
+
+ FdEntry(FdCallback const& callback)
+ : callback(callback), removed(false) {
+ }
+ };
+
+ struct Timeout {
+ clock::time_point target;
+ ScheduleCallback callback;
+ bool expired;
+
+ Timeout(clock::time_point target, ScheduleCallback const& callback)
+ : target(target), callback(callback), expired(false) {
+ }
+ };
+
+ static uint8_t unpollevents(short events) {
+ uint8_t ret = 0;
+ if (events & read_events) {
+ ret |= EVENT_READ;
+ }
+ if (events & write_events) {
+ ret |= EVENT_WRITE;
+ }
+ if (events & error_events) {
+ ret |= EVENT_ERROR;
+ }
+ if (events & hup_events) {
+ ret |= EVENT_HUP;
+ }
+ return ret;
+ }
+
+ static short pollevents(uint8_t events) {
+ int ret = 0;
+ if (events & EVENT_READ) {
+ ret |= read_events;
+ }
+ if (events & EVENT_WRITE) {
+ ret |= write_events;
+ }
+ return ret;
+ }
+
+ std::vector<struct pollfd> fds_;
+ std::vector<FdEntry> fdentries_;
+ size_t fds_to_remove_;
+ size_t fds_protected_;
+ std::list<Timeout*> timeouts_;
+ bool quit_;
+ clock::time_point now_;
+};
+
+} // namespace
+
+// static
+Looper* Looper::create() {
+ return new LooperImpl();
+}
+
+// static
+const uint8_t Looper::EVENT_READ = 1 << 0;
+// static
+const uint8_t Looper::EVENT_WRITE = 1 << 1;
+// static
+const uint8_t Looper::EVENT_ERROR = 1 << 2;
+// static
+const uint8_t Looper::EVENT_HUP = 1 << 3;
+
diff --git a/src/looper.hh b/src/looper.hh
new file mode 100644
index 0000000..7315220
--- /dev/null
+++ b/src/looper.hh
@@ -0,0 +1,41 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef LOOPER_HH
+#define LOOPER_HH
+
+#include <chrono>
+#include <functional>
+
+class Looper {
+public:
+ typedef std::chrono::steady_clock clock;
+ typedef std::function<void(int, uint8_t)> FdCallback;
+ typedef std::function<void(void*)> ScheduleCallback;
+ virtual ~Looper() { }
+
+ static const uint8_t EVENT_READ;
+ static const uint8_t EVENT_WRITE;
+ // Always reported, need not be set
+ static const uint8_t EVENT_HUP;
+ static const uint8_t EVENT_ERROR;
+
+ static Looper* create();
+
+ virtual void add(int fd, uint8_t events, FdCallback const& callback) = 0;
+ virtual void modify(int fd, uint8_t events) = 0;
+ virtual void remove(int fd) = 0;
+
+ virtual void* schedule(float delay_s, ScheduleCallback const& callback) = 0;
+ virtual void cancel(void* handle) = 0;
+
+ virtual bool run() = 0;
+ virtual void quit() = 0;
+
+ virtual clock::time_point now() const = 0;
+
+protected:
+ Looper() { }
+ Looper(Looper const&) = delete;
+};
+
+#endif // LOOPER_HH
diff --git a/src/main.cc b/src/main.cc
new file mode 100644
index 0000000..f9e1af6
--- /dev/null
+++ b/src/main.cc
@@ -0,0 +1,187 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <cstring>
+#include <memory>
+#include <signal.h>
+#include <unistd.h>
+
+#include "args.hh"
+#include "config.hh"
+#include "io.hh"
+#include "logger.hh"
+#include "proxy.hh"
+
+namespace {
+
+Proxy* g_proxy;
+
+void proxy_quit(int UNUSED(sig)) {
+ g_proxy->quit();
+}
+
+void proxy_reload(int UNUSED(sig)) {
+ g_proxy->reload();
+}
+
+void setup_signals(Proxy* proxy) {
+ g_proxy = proxy;
+ struct sigaction action;
+ memset(&action, 0, sizeof(action));
+ action.sa_handler = proxy_quit;
+ action.sa_flags = SA_RESTART;
+ sigaction(SIGINT, &action, nullptr);
+ sigaction(SIGQUIT, &action, nullptr);
+ action.sa_handler = proxy_reload;
+ sigaction(SIGHUP, &action, nullptr);
+}
+
+std::string get_cwd() {
+ std::string ret;
+ char buf[128];
+ auto ptr = getcwd(buf, 128);
+ if (ptr) {
+ ret.assign(ptr, strlen(ptr));
+ return ret;
+ }
+ if (errno != ERANGE) return ret;
+ ret.resize(1024);
+ while (true) {
+ ptr = getcwd(&ret[0], ret.size());
+ if (ptr) {
+ ret.resize(strlen(ret.data()));
+ break;
+ }
+ if (errno != ERANGE) {
+ ret.resize(0);
+ break;
+ }
+ ret.resize(ret.size() * 2);
+ }
+ return ret;
+}
+
+} // namespace
+
+int main(int argc, char** argv) {
+ std::unique_ptr<Args> args(Args::create());
+ args->add('C', "config", "FILE", "load config from FILE instead of default");
+ args->add('F', "foreground", "do not fork into background and log to stdout");
+ args->add('m', "monitor", "HOST:PORT", "accept monitors on HOST:PORT");
+ args->add('h', "help", "display this text and exit.");
+ args->add('V', "version", "display version and exit.");
+ if (!args->run(argc, argv)) {
+ std::cerr << "Try `tp --help` for usage." << std::endl;
+ return EXIT_FAILURE;
+ }
+ if (args->is_set('h')) {
+ std::cout << "Usage: `tp [OPTIONS...] [BIND][:PORT]`\n"
+ << "Transparent proxy.\n"
+ << '\n';
+ args->print_help();
+ return EXIT_SUCCESS;
+ }
+ if (args->is_set('V')) {
+ std::cout << "TransparentProxy version " VERSION
+ << " written by Joel Klinghed <the_jk@yahoo.com>" << std::endl;
+ return EXIT_SUCCESS;
+ }
+ std::unique_ptr<Config> config(Config::create());
+ switch (args->arguments().size()) {
+ case 0:
+ break;
+ case 1: {
+ auto arg = args->arguments().front();
+ auto colon = arg.find(':');
+ if (colon == std::string::npos) {
+ config->set("proxy_bind", arg);
+ } else {
+ if (colon > 0) {
+ config->set("proxy_bind", arg.substr(0, colon));
+ }
+ config->set("proxy_port", arg.substr(colon + 1));
+ }
+ break;
+ }
+ default:
+ std::cerr << "Too many arguments.\n"
+ << "Try `tp --help` for usage." << std::endl;
+ return EXIT_FAILURE;
+ }
+ if (args->is_set('F')) {
+ config->set("foreground", "true");
+ }
+ auto monitor = args->arg('m', nullptr);
+ if (monitor) {
+ auto str = std::string(monitor);
+ auto colon = str.find(':');
+ if (colon == 0 || colon == std::string::npos) {
+ std::cerr << "Invalid argument to monitor, expected HOST:PORT not: "
+ << str << std::endl;
+ return EXIT_FAILURE;
+ }
+ config->set("monitor", "true");
+ config->set("monitor_bind", str.substr(0, colon));
+ config->set("monitor_port", str.substr(colon + 1));
+ }
+ auto configfile = args->arg('C', nullptr);
+ if (configfile) {
+ config->load_file(configfile);
+ } else {
+ config->load_name("tp");
+ }
+ if (!config->good()) {
+ std::cerr << "Error loading config\n"
+ << config->last_error() << std::endl;
+ return EXIT_FAILURE;
+ }
+ std::unique_ptr<Logger> logger(Logger::create_stderr());
+ std::unique_ptr<Logger> file_logger;
+ auto logfile = args->arg('l', nullptr);
+ bool logfile_from_argument;
+ if (!logfile) {
+ logfile = config->get("logfile", nullptr);
+ logfile_from_argument = false;
+ } else {
+ logfile_from_argument = true;
+ }
+ if (logfile) {
+ if (logfile[0] != '/') {
+ logger->out(Logger::ERR, "Logfile need to be an absolute path, not: %s",
+ logfile);
+ return EXIT_FAILURE;
+ }
+ file_logger.reset(Logger::create_file(logfile));
+ if (!file_logger) {
+ logger->out(Logger::ERR, "Unable to open %s for logging: %s",
+ logfile, strerror(errno));
+ return EXIT_FAILURE;
+ }
+ }
+ io::auto_fd accept_socket(Proxy::setup_accept_socket(config.get(),
+ logger.get()));
+ if (!accept_socket) return EXIT_FAILURE;
+ io::auto_fd monitor_socket;
+ if (config->get("monitor", false)) {
+ monitor_socket.reset(Proxy::setup_monitor_socket(config.get(),
+ logger.get()));
+ }
+ auto foreground = config->get("foreground", false);
+ auto cwd = get_cwd();
+ std::unique_ptr<Proxy> proxy(
+ Proxy::create(config.get(), cwd, configfile,
+ foreground ? "bogus" :
+ (logfile_from_argument ? logfile : nullptr),
+ foreground ? logger.get() : file_logger.get(),
+ accept_socket.release(),
+ monitor_socket.release()));
+ if (!foreground) {
+ if (daemon(0, 0)) {
+ logger->out(Logger::ERR, "Failed to fork: %s", strerror(errno));
+ return EXIT_FAILURE;
+ }
+ }
+ setup_signals(proxy.get());
+ return proxy->run() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
diff --git a/src/paths.cc b/src/paths.cc
new file mode 100644
index 0000000..095076f
--- /dev/null
+++ b/src/paths.cc
@@ -0,0 +1,98 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include "paths.hh"
+
+// static
+std::string Paths::join(std::string const& base, std::string const& path) {
+ if (path.empty()) {
+ return cleanup(base);
+ } else if (path.front() == '/') {
+ return cleanup(path);
+ }
+ std::string ret(cleanup(base));
+ if (ret == ".") {
+ return cleanup(path);
+ }
+ if (ret.empty() || ret.back() != '/') {
+ ret.push_back('/');
+ }
+ ret.append(cleanup(path));
+ return ret;
+}
+
+namespace {
+
+std::string do_cleanup(std::string const& path) {
+ size_t i = 0, j = 0;
+ std::string ret;
+ bool add_slash = false;
+ if (path.front() == '/') {
+ ret.push_back('/');
+ i = j = 1;
+ }
+ while (true) {
+ if (j == path.size() || path[j] == '/') {
+ auto len = j - i;
+ if (len > 0) {
+ if (len == 1 && path[i] == '.') {
+ i = ++j;
+ continue;
+ }
+ if (add_slash && len == 2 && path[i] == '.' && path[i + 1] == '.') {
+ auto x = ret.find_last_of('/');
+ if (x == std::string::npos) x = 0;
+ ret.erase(x);
+ add_slash = !(ret.empty() || ret == "/");
+ i = ++j;
+ continue;
+ }
+ if (add_slash) {
+ ret.push_back('/');
+ } else {
+ add_slash = true;
+ }
+ ret.append(path.substr(i, len));
+ }
+ if (j == path.size()) {
+ return ret;
+ }
+ i = ++j;
+ } else {
+ ++j;
+ }
+ }
+}
+
+} // namespace
+
+// static
+std::string Paths::cleanup(std::string const& path) {
+ if (path.empty()) return ".";
+ if (path[0] == '.') {
+ if (path.size() == 1) return path;
+ if (path[1] == '/') return do_cleanup(path); // found ./ at beginning
+ }
+ for (size_t pos = 0; pos < path.size(); ++pos) {
+ if (path[pos] == '/') {
+ if (pos > 0 && pos + 1 == path.size()) {
+ // found slash at end
+ return path.substr(0, path.size() - 1);
+ }
+ if (path[pos + 1] == '/') return do_cleanup(path); // found double slash
+ if (path[pos + 1] == '.') {
+ if (pos + 2 == path.size()) return do_cleanup(path); // found /. at end
+ if (path[pos + 2] == '/') return do_cleanup(path); // found /./
+ if (path[pos + 2] == '.') {
+ if (pos + 3 == path.size()) {
+ // found /.. at end
+ return do_cleanup(path);
+ }
+ if (path[pos + 3] == '/') do_cleanup(path); // found /../
+ }
+ }
+ }
+ }
+ return path;
+}
diff --git a/src/paths.hh b/src/paths.hh
new file mode 100644
index 0000000..adec361
--- /dev/null
+++ b/src/paths.hh
@@ -0,0 +1,18 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef PATHS_HH
+#define PATHS_HH
+
+#include <string>
+
+class Paths {
+public:
+ static std::string join(std::string const& base, std::string const& path);
+ static std::string cleanup(std::string const& path);
+
+private:
+ Paths() {}
+ ~Paths() {}
+};
+
+#endif // PATHS_HH
diff --git a/src/proxy.cc b/src/proxy.cc
new file mode 100644
index 0000000..fcd02f6
--- /dev/null
+++ b/src/proxy.cc
@@ -0,0 +1,1373 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include <chrono>
+#include <cstring>
+#include <fcntl.h>
+#include <memory>
+#include <netdb.h>
+#include <signal.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <vector>
+
+#include "buffer.hh"
+#include "chunked.hh"
+#include "config.hh"
+#include "http.hh"
+#include "io.hh"
+#include "logger.hh"
+#include "looper.hh"
+#include "resolver.hh"
+#include "paths.hh"
+#include "proxy.hh"
+#include "url.hh"
+
+namespace {
+
+auto const NEW_CONNECTION_TIMEOUT = std::chrono::duration<float>(5.0f);
+auto const CONNECTION_TIMEOUT = std::chrono::duration<float>(30.0f);
+
+template<typename T>
+class clients {
+public:
+ class iterator {
+ public:
+ iterator(clients* clients, size_t active)
+ : clients_(clients), index_(0), active_(active) {
+ if (active_) {
+ while (clients_->is_free(index_)) {
+ ++index_;
+ }
+ }
+ }
+ iterator(iterator const& i)
+ : clients_(i.clients_), index_(i.index_), active_(i.active_) {
+ }
+
+ iterator& operator=(iterator const& i) {
+ clients_ = i.clients_;
+ index_ = i.index_;
+ active_ = i.active_;
+ return *this;
+ }
+
+ bool operator==(iterator const& i) const {
+ return active_ == i.active_;
+ }
+ bool operator!=(iterator const& i) const {
+ return !(*this == i);
+ }
+ bool operator<(iterator const& i) const {
+ return active_ < i.active_;
+ }
+ bool operator<=(iterator const& i) const {
+ return (*this < i) || (*this == i);
+ }
+ bool operator>=(iterator const& i) const {
+ return !(*this < i);
+ }
+ bool operator>(iterator const& i) const {
+ return !(*this <= i);
+ }
+
+ iterator& operator++() {
+ next();
+ return *this;
+ }
+
+ iterator operator++(int UNUSED(dummy)) {
+ iterator ret(*this);
+ ++(*this);
+ return ret;
+ }
+
+ T& operator*() {
+ return (*clients_)[index_];
+ }
+ T* operator->() {
+ return &(*clients_)[index_];
+ }
+
+ size_t index() const {
+ return index_;
+ }
+
+ private:
+ void next() {
+ if (active_ == 0) return;
+ --active_;
+ if (active_ == 0) return;
+ ++index_;
+ while (clients_->is_free(index_)) {
+ ++index_;
+ }
+ }
+
+ clients* clients_;
+ size_t index_;
+ size_t active_;
+ };
+
+ clients()
+ : active_(0), max_(0) {
+ }
+
+ void resize(size_t max) {
+ max_ = max;
+ if (max > data_.size()) {
+ data_.resize(max);
+ }
+ }
+
+ bool full() const {
+ return active_ >= max_;
+ }
+
+ iterator begin() {
+ return iterator(this, active_);
+ }
+
+ iterator end() {
+ return iterator(this, 0);
+ }
+
+ T& operator[] (size_t index) {
+ return data_[index];
+ }
+
+ size_t new_client() {
+ size_t index;
+ if (active_ == data_.size()) {
+ index = data_.size();
+ data_.emplace_back();
+ } else {
+ index = rand() % data_.size();
+ while (!is_free(index)) {
+ if (++index == data_.size()) {
+ index = 0;
+ }
+ }
+ }
+ ++active_;
+ return index;
+ }
+
+ void erase(size_t index) {
+ assert(active_ > 0);
+ --active_;
+ assert(is_free(index));
+ if (data_.size() > max_ && index >= max_) {
+ size_t size = data_.size();
+ while (size > max_ && is_free(size - 1)) --size;
+ data_.resize(size);
+ }
+ }
+
+ void clear() {
+ if (active_ > 0) {
+ data_.clear();
+ active_ = 0;
+ }
+ data_.resize(max_);
+ }
+
+private:
+ bool is_free(size_t index) const {
+ return !data_[index].fd;
+ }
+
+ size_t active_;
+ size_t max_;
+ std::vector<T> data_;
+};
+
+struct BaseClient {
+ io::auto_fd fd;
+ bool new_connection;
+ Looper::clock::time_point last;
+ std::unique_ptr<Buffer> in;
+ std::unique_ptr<Buffer> out;
+ uint8_t read_flag;
+ uint8_t write_flag;
+};
+
+enum ContentType {
+ CONTENT_NONE,
+ CONTENT_LEN,
+ CONTENT_CHUNKED,
+ CONTENT_CLOSE
+};
+
+struct Content {
+ ContentType type;
+ uint64_t len;
+ std::unique_ptr<Chunked> chunked;
+};
+
+enum RemoteState {
+ CLOSED,
+ RESOLVING,
+ CONNECTING,
+ CONNECTED,
+ WAITING,
+};
+
+struct RemoteClient : public BaseClient {
+ Content content;
+};
+
+struct Client : public BaseClient {
+ Client()
+ : resolve(nullptr) {
+ }
+ std::unique_ptr<HttpRequest> request;
+ std::unique_ptr<Url> url;
+ Content content;
+ RemoteState remote_state;
+ void* resolve;
+ RemoteClient remote;
+};
+
+struct Monitor : public BaseClient {
+};
+
+class ProxyImpl : public Proxy {
+public:
+ ProxyImpl(Config* config, std::string const& cwd, char const* configfile,
+ char const* logfile, Logger* logger, int accept_fd, int monitor_fd)
+ : config_(config), cwd_(cwd), configfile_(configfile), logfile_(logfile),
+ logger_(logger), accept_socket_(accept_fd), monitor_socket_(monitor_fd),
+ looper_(Looper::create()), resolver_(Resolver::create(looper_.get())),
+ new_timeout_(nullptr), timeout_(nullptr) {
+ setup();
+ }
+ ~ProxyImpl() override {
+ if (accept_socket_) {
+ looper_->remove(accept_socket_.get());
+ }
+ if (monitor_socket_) {
+ looper_->remove(monitor_socket_.get());
+ }
+ if (signal_pipe_) {
+ looper_->remove(signal_pipe_.read());
+ }
+ }
+
+ void quit() override {
+ char cmd = 'q';
+ io::write_all(signal_pipe_.write(), &cmd, 1);
+ }
+
+ void reload() override {
+ char cmd = 'r';
+ io::write_all(signal_pipe_.write(), &cmd, 1);
+ }
+
+ bool run() override;
+
+private:
+ void setup();
+ bool reload_config();
+ void fatal_error();
+ void new_client(int fd, uint8_t events);
+ void new_monitor(int fd, uint8_t events);
+ void new_base(BaseClient* client, int fd);
+ void signal_event(int fd, uint8_t events);
+ bool base_event(BaseClient* client, uint8_t events,
+ size_t index, char const* name);
+ void client_event(size_t index, int fd, uint8_t events);
+ void client_remote_event(size_t index, int fd, uint8_t events);
+ void client_empty_input(size_t index);
+ void monitor_event(size_t index, int fd, uint8_t events);
+ void close_client(size_t index);
+ void close_monitor(size_t index);
+ void close_base(BaseClient* client);
+ void new_timeout();
+ void timeout();
+ float handle_timeout(bool new_conn,
+ std::chrono::duration<float> const& timeout);
+ bool base_send(BaseClient* client, void const* data, size_t size,
+ size_t index, char const* name);
+ void client_error(size_t index,
+ uint16_t status_code, std::string const& status);
+ bool client_request(size_t index);
+ bool client_send(size_t index, void const* ptr, size_t avail);
+ void client_remote_error(size_t index, uint16_t error);
+ void close_client_when_done(size_t index);
+ void client_remote_resolved(size_t index, int fd, bool connected,
+ char const* error);
+
+ Config* const config_;
+ std::string cwd_;
+ char const* const configfile_;
+ char const* const logfile_;
+ Logger* logger_;
+ std::unique_ptr<Logger> priv_logger_;
+ io::auto_fd accept_socket_;
+ io::auto_fd monitor_socket_;
+ io::auto_pipe signal_pipe_;
+ std::unique_ptr<Looper> looper_;
+ std::unique_ptr<Resolver> resolver_;
+ bool good_;
+ void* new_timeout_;
+ void* timeout_;
+
+ clients<Client> clients_;
+ clients<Monitor> monitors_;
+};
+
+size_t get_size(Config* config, Logger* logger, std::string const& name,
+ size_t fallback) {
+ auto value = config->get(name, nullptr);
+ if (!value) return fallback;
+ char* end = nullptr;
+ errno = 0;
+ auto tmp = strtoul(value, &end, 10);
+ if (errno || !end || *end) {
+ logger->out(Logger::WARN,
+ "Invalid value given for %s: %s, using fallback %zu instead",
+ name.c_str(), value, fallback);
+ return fallback;
+ }
+ return static_cast<size_t>(tmp);
+}
+
+void ProxyImpl::setup() {
+ clients_.resize(get_size(config_, logger_, "max_clients", 1024));
+ monitors_.resize(get_size(config_, logger_, "max_monitors", 2));
+ looper_->add(accept_socket_.get(),
+ clients_.full() ? 0 : Looper::EVENT_READ,
+ std::bind(&ProxyImpl::new_client, this,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ if (monitor_socket_) {
+ looper_->add(monitor_socket_.get(),
+ monitors_.full() ? 0 :Looper::EVENT_READ,
+ std::bind(&ProxyImpl::new_monitor, this,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ } else {
+ monitors_.clear();
+ }
+}
+
+bool ProxyImpl::reload_config() {
+ if (configfile_) {
+ auto file = Paths::join(cwd_, configfile_);
+ logger_->out(Logger::INFO, "Reloading config file %s", file.c_str());
+ config_->load_file(file);
+ } else {
+ logger_->out(Logger::INFO, "Reloading config");
+ config_->load_name("tp");
+ }
+ if (!config_->good()) {
+ logger_->out(Logger::WARN, "New config invalid, ignored.");
+ return true;
+ }
+ if (!logfile_) {
+ auto logfile = config_->get("logfile", nullptr);
+ if (logfile) {
+ if (logfile[0] != '/') {
+ logger_->out(Logger::WARN,
+ "Logfile need to be an absolute path, not: %s",
+ logfile);
+ } else {
+ std::unique_ptr<Logger> tmp(Logger::create_file(logfile_));
+ if (tmp) {
+ priv_logger_.swap(tmp);
+ logger_ = priv_logger_.get();
+ } else {
+ logger_->out(Logger::WARN,
+ "Unable to open %s for logging", logfile);
+ }
+ }
+ } else {
+ priv_logger_.reset(Logger::create_syslog("tp"));
+ logger_ = priv_logger_.get();
+ }
+ }
+ auto const old_accept = accept_socket_.get();
+ auto const old_monitor = monitor_socket_.get();
+ looper_->remove(old_accept);
+ if (monitor_socket_) looper_->remove(old_monitor);
+ accept_socket_.reset(setup_accept_socket(config_, logger_));
+ monitor_socket_.reset();
+ if (!accept_socket_) {
+ logger_->out(Logger::ERR, "Unable to bind to new configuration, abort");
+ return false;
+ }
+ if (config_->get("monitor", false)) {
+ monitor_socket_.reset(setup_monitor_socket(config_, logger_));
+ if (!monitor_socket_) {
+ logger_->out(Logger::ERR, "Unable to bind to new configuration, abort");
+ return false;
+ }
+ }
+ setup();
+ return true;
+}
+
+void ProxyImpl::signal_event(int fd, uint8_t events) {
+ assert(fd == signal_pipe_.read());
+ if (events == Looper::EVENT_READ) {
+ char cmd;
+ auto ret = io::read(fd, &cmd, 1);
+ if (ret == 1) {
+ switch (cmd) {
+ case 'q':
+ logger_->out(Logger::INFO, "Exiting ...");
+ looper_->quit();
+ return;
+ case 'r':
+ if (!reload_config()) {
+ fatal_error();
+ return;
+ }
+ break;
+ }
+ }
+ } else {
+ logger_->out(Logger::WARN, "Signal pipes have crashed");
+ signal_pipe_.reset();
+ looper_->remove(fd);
+ }
+}
+
+void ProxyImpl::close_base(BaseClient* client) {
+ if (client->fd) {
+ looper_->remove(client->fd.get());
+ client->fd.reset();
+ }
+ client->in.reset();
+ client->out.reset();
+}
+
+void ProxyImpl::close_client(size_t index) {
+ bool was_full = clients_.full();
+ auto& client = clients_[index];
+ client.request.reset();
+ client.url.reset();
+ client.content.type = CONTENT_NONE;
+ client.content.chunked.reset();
+ client.remote_state = CLOSED;
+ if (client.resolve) {
+ resolver_->cancel(client.resolve);
+ client.resolve = nullptr;
+ }
+ close_base(&client.remote);
+ close_base(&client);
+ clients_.erase(index);
+ if (was_full && !clients_.full() && accept_socket_) {
+ looper_->modify(accept_socket_.get(), Looper::EVENT_READ);
+ }
+}
+
+void ProxyImpl::close_monitor(size_t index) {
+ bool was_full = monitors_.full();
+ auto& monitor = monitors_[index];
+ close_base(&monitor);
+ monitors_.erase(index);
+ if (was_full && !monitors_.full() && monitor_socket_) {
+ looper_->modify(monitor_socket_.get(), Looper::EVENT_READ);
+ }
+}
+
+float ProxyImpl::handle_timeout(bool new_conn,
+ std::chrono::duration<float> const& timeout) {
+ auto now = looper_->now();
+ std::vector<size_t> close;
+ std::vector<bool> remote;
+ float next = -1.0f;
+ for (auto i = clients_.begin(); i != clients_.end(); ++i) {
+ if (i->new_connection != new_conn) continue;
+ auto diff = std::chrono::duration_cast<std::chrono::duration<float>>(
+ ((i->last + timeout) - now)).count();
+ if (diff < 0.0f) {
+ close.push_back(i.index());
+ remote.push_back(false);
+ } else {
+ if (!new_conn && i->remote_state > CLOSED) {
+ auto diff2 = std::chrono::duration_cast<std::chrono::duration<float>>(
+ ((i->remote.last + timeout) - now)).count();
+ if (diff2 < 0.0f) {
+ close.push_back(i.index());
+ remote.push_back(true);
+ } else if (diff2 < diff) {
+ diff = diff2;
+ }
+ }
+ if (next < 0.0f || diff < next) {
+ next = diff;
+ }
+ }
+ }
+ assert(close.size() == remote.size());
+ auto j = remote.rbegin();
+ for (auto i = close.rbegin(); i != close.rend(); ++i, ++j) {
+ if (*j) {
+ client_remote_error(*i, 504);
+ } else {
+ close_client(*i);
+ }
+ }
+ close.clear();
+ for (auto i = monitors_.begin(); i != monitors_.end(); ++i) {
+ if (i->new_connection != new_conn) continue;
+ auto diff = std::chrono::duration_cast<std::chrono::duration<float>>(
+ ((i->last + timeout) - now)).count();
+ if (diff < 0.0f) {
+ close.push_back(i.index());
+ } else if (next < 0.0f || diff < next) {
+ next = diff;
+ }
+ }
+ for (auto i = close.rbegin(); i != close.rend(); ++i) {
+ close_monitor(*i);
+ }
+ return next;
+}
+
+void ProxyImpl::timeout() {
+ assert(timeout_);
+ timeout_ = nullptr;
+ float next = handle_timeout(false, CONNECTION_TIMEOUT);
+ if (next < 0.0f) return;
+ timeout_ = looper_->schedule(next, std::bind(&ProxyImpl::timeout, this));
+}
+
+void ProxyImpl::new_timeout() {
+ assert(new_timeout_);
+ new_timeout_ = nullptr;
+ float next = handle_timeout(true, NEW_CONNECTION_TIMEOUT);
+ if (next < 0.0f) return;
+ new_timeout_ =
+ looper_->schedule(next, std::bind(&ProxyImpl::new_timeout, this));
+}
+
+bool ProxyImpl::base_send(BaseClient* client, void const* data, size_t size,
+ size_t index, char const* name) {
+ if (size == 0) return true;
+ if (!client->out->empty()) {
+ // Already waiting for write event
+ client->out->write(data, size);
+ return true;
+ }
+ auto ret = io::write(client->fd.get(), data, size);
+ if (ret == -1) {
+ if (errno != EAGAIN && errno != EWOULDBLOCK) {
+ logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ client->out->write(data, size);
+ } else {
+ if (static_cast<size_t>(ret) == size) {
+ if (!client->in) {
+ // If input is closed, close after sending all data
+ return false;
+ }
+ return true;
+ }
+ client->out->write(reinterpret_cast<char const*>(data) + ret, size - ret);
+ }
+ client->write_flag = Looper::EVENT_WRITE;
+ looper_->modify(client->fd.get(), client->read_flag | client->write_flag);
+ return true;
+}
+
+bool ProxyImpl::base_event(BaseClient* client, uint8_t events,
+ size_t index, char const* name) {
+ if (events & Looper::EVENT_READ) {
+ if (client->new_connection) {
+ char tmp[1];
+ auto ret = io::read(client->fd.get(), &tmp, 1);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return true;
+ logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ if (ret == 0) {
+ return false;
+ }
+ client->last = looper_->now();
+ client->new_connection = false;
+ client->in.reset(Buffer::create(8192, 1024));
+ client->out.reset(Buffer::create(8192, 1024));
+ client->in->write(tmp, 1);
+ if (!timeout_) {
+ timeout_ = looper_->schedule(CONNECTION_TIMEOUT.count(),
+ std::bind(&ProxyImpl::timeout, this));
+ }
+ }
+ if (!client->in) {
+ assert(false);
+ return false;
+ }
+ size_t avail;
+ auto ptr = client->in->write_ptr(&avail);
+ assert(avail > 0);
+ auto ret = io::read(client->fd.get(), ptr, avail);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return true;
+ logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ if (ret == 0) {
+ return false;
+ }
+ client->last = looper_->now();
+ client->in->commit(ret);
+ }
+ if (events & Looper::EVENT_WRITE) {
+ if (client->new_connection) {
+ assert(false);
+ return true;
+ }
+ size_t avail;
+ auto ptr = client->out->read_ptr(&avail);
+ if (avail == 0) {
+ assert(false);
+ if (!client->in) return false;
+ client->write_flag = 0;
+ looper_->modify(client->fd.get(), client->read_flag);
+ return true;
+ }
+ auto ret = io::write(client->fd.get(), ptr, avail);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return true;
+ logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name,
+ strerror(errno));
+ return false;
+ }
+ if (ret == 0) {
+ return false;
+ }
+ client->last = looper_->now();
+ client->out->consume(ret);
+ if (client->out->empty()) {
+ if (!client->in) return false;
+ client->write_flag = 0;
+ looper_->modify(client->fd.get(), client->read_flag);
+ }
+ }
+ return true;
+}
+
+inline char lower_ascii(char c) {
+ return (c >= 'A' && c <= 'Z') ? (c - 'A' + 'a') : c;
+}
+
+bool lower_equal(char const* data, size_t start, size_t end,
+ std::string const& str) {
+ assert(start <= end);
+ if (str.size() != end - start) return false;
+ for (auto i = str.begin(); start < end; ++start, ++i) {
+ if (lower_ascii(*i) != lower_ascii(data[start])) return false;
+ }
+ return true;
+}
+
+bool header_token_eq(std::string const& value, std::string const& token) {
+ if (value.empty()) return false;
+ auto pos = value.find(';');
+ if (pos == std::string::npos) pos = value.size();
+ return lower_equal(value.data(), 0, pos, token);
+}
+
+void ProxyImpl::client_remote_error(size_t index, uint16_t error) {
+ assert(false);
+ auto& client = clients_[index];
+ if (client.remote_state > CONNECTED) {
+ // Already started to return response, too late to do anything
+ close_client(index);
+ return;
+ }
+ char const* status;
+ switch (error) {
+ case 502:
+ status = "Bad Gateway";
+ break;
+ case 504:
+ status = "Gateway Timeout";
+ break;
+ default:
+ assert(false);
+ error = 500;
+ status = "Internal Server Error";
+ break;
+ }
+
+ client_error(index, error, status);
+}
+
+void ProxyImpl::client_error(size_t index, uint16_t status_code,
+ std::string const& status_message) {
+ auto& client = clients_[index];
+ // No more input
+ client.read_flag = 0;
+ looper_->modify(client.fd.get(), client.write_flag);
+ client.url.reset();
+
+ std::string proto;
+ Version version;
+
+ if (!client.request) {
+ proto = "HTTP";
+ version.major = 1;
+ version.minor = 1;
+ } else {
+ proto = client.request->proto();
+ version = client.request->proto_version();
+ }
+
+ auto resp = std::unique_ptr<HttpResponseBuilder>(
+ HttpResponseBuilder::create(
+ proto, version, status_code, status_message));
+ resp->add_header("Content-Length", "0");
+ resp->add_header("Connection", "close");
+
+ client.in.reset();
+ client.request.reset();
+
+ auto data = resp->build();
+
+ if (!base_send(&client, data.data(), data.size(), index, "Client")) {
+ close_client(index);
+ return;
+ }
+}
+
+bool ProxyImpl::client_request(size_t index) {
+ auto& client = clients_[index];
+ auto version = client.request->proto_version();
+ if (version.major != 1 || version.minor > 1) {
+ client_error(index, 505, "HTTP Version Not Supported");
+ return false;
+ }
+ if (client.url->scheme() != "http") {
+ client_error(index, 501, "Not Implemented");
+ return false;
+ }
+ if (client.url->userinfo()) {
+ client_error(index, 400, "Bad request");
+ return false;
+ }
+ if (client.remote_state == WAITING) {
+ client.remote_state = CONNECTING;
+ client.remote.read_flag = Looper::EVENT_READ;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(),
+ client.remote.read_flag | client.remote.write_flag);
+ client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE);
+ } else {
+ assert(client.remote_state == CLOSED);
+ client.remote.last = looper_->now();
+ client.remote.new_connection = true;
+ client.remote_state = RESOLVING;
+
+ auto port = client.url->port();
+ if (port == 0) port = 80;
+ client.resolve = resolver_->request(
+ client.url->host(), port,
+ std::bind(&ProxyImpl::client_remote_resolved, this, index,
+ std::placeholders::_1,
+ std::placeholders::_2,
+ std::placeholders::_3));
+ }
+ return true;
+}
+
+void ProxyImpl::client_remote_resolved(size_t index, int fd, bool connected,
+ char const* error) {
+ auto& client = clients_[index];
+ assert(client.resolve);
+ client.resolve = nullptr;
+ assert(client.remote_state == RESOLVING);
+ if (fd < 0) {
+ logger_->out(Logger::INFO, "%zu: Client unable to resolve remote: %s",
+ index, error);
+ client_remote_error(index, 502);
+ return;
+ }
+ client.remote_state = CONNECTING;
+ client.remote.fd.reset(fd);
+ client.remote.last = looper_->now();
+ client.remote.new_connection = false;
+ client.remote.in.reset(Buffer::create(8192, 1024));
+ client.remote.out.reset(Buffer::create(8192, 1024));
+ if (connected) {
+ client.remote.read_flag = Looper::EVENT_READ;
+ client.remote.write_flag = 0;
+ } else {
+ client.remote.read_flag = 0;
+ client.remote.write_flag = Looper::EVENT_WRITE;
+ }
+ looper_->add(client.remote.fd.get(),
+ client.remote.read_flag | client.remote.write_flag,
+ std::bind(&ProxyImpl::client_remote_event, this, index,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ assert(timeout_);
+ if (connected) {
+ client_remote_event(index, fd, Looper::EVENT_WRITE);
+ }
+}
+
+void ProxyImpl::client_event(size_t index, int fd, uint8_t events) {
+ auto& client = clients_[index];
+ assert(client.fd.get() == fd);
+ if (events & Looper::EVENT_HUP) {
+ close_client(index);
+ return;
+ }
+ if (events & Looper::EVENT_ERROR) {
+ logger_->out(Logger::INFO, "%zu: Client connection error", index);
+ close_client(index);
+ return;
+ }
+ if (!base_event(&client, events, index, "Client")) {
+ close_client(index);
+ return;
+ }
+ if (client.new_connection) return;
+ client_empty_input(index);
+}
+
+bool setup_content(Http const* http, Content* content) {
+ assert(content->type == CONTENT_NONE);
+ std::string te = http->first_header("transfer-encoding");
+ if (te.empty() || header_token_eq(te, "identity")) {
+ std::string len = http->first_header("content-length");
+ if (len.empty()) {
+ content->type = CONTENT_CLOSE;
+ return true;
+ }
+ char* end = nullptr;
+ errno = 0;
+ auto tmp = strtoull(len.c_str(), &end, 10);
+ if (errno || !end || *end) {
+ return false;
+ }
+ if (tmp == 0) {
+ content->type = CONTENT_NONE;
+ return true;
+ }
+ content->len = tmp;
+ content->type = CONTENT_LEN;
+ } else {
+ content->chunked.reset(Chunked::create());
+ content->type = CONTENT_CHUNKED;
+ }
+ return true;
+}
+
+void ProxyImpl::client_empty_input(size_t index) {
+ auto& client = clients_[index];
+ while (true) {
+ size_t avail;
+ auto ptr = client.in->read_ptr(&avail);
+ if (avail == 0) return;
+ switch (client.content.type) {
+ case CONTENT_CLOSE:
+ assert(false);
+ // falltrough
+ case CONTENT_NONE: {
+ if (client.remote_state != CLOSED && client.remote_state != WAITING) {
+ // Still working on the last request, wait
+ return;
+ }
+ client.request.reset(
+ HttpRequest::parse(
+ reinterpret_cast<char const*>(ptr), avail, false));
+ if (!client.request) {
+ if (avail >= 1024 * 1024) {
+ logger_->out(Logger::INFO, "%zu: Client too large request %zu",
+ index, avail);
+ close_client(index);
+ }
+ return;
+ }
+ if (!client.request->good()) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (client.request->method_equal("CONNECT")) {
+ client_error(index, 501, "Not Implemented");
+ return;
+ }
+ client.url.reset(Url::parse(client.request->url()));
+ if (!client.url) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (!setup_content(client.request.get(), &client.content)) {
+ logger_->out(Logger::INFO, "%zu: Client bad content-length", index);
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (client.content.type == CONTENT_CLOSE) {
+ client.content.type = CONTENT_NONE;
+ }
+ if (!client_request(index)) {
+ client.content.type = CONTENT_NONE;
+ return;
+ }
+ break;
+ }
+ case CONTENT_LEN:
+ if (client.remote_state < CONNECTED) {
+ // Request hasn't been sent yet, still collecting data
+ return;
+ }
+ if (avail < client.content.len) {
+ if (!client_send(index, ptr, avail)) {
+ return;
+ }
+ client.in->consume(avail);
+ client.content.len -= avail;
+ return;
+ }
+ if (!client_send(index, ptr, client.content.len)) {
+ return;
+ }
+ client.in->consume(client.content.len);
+ client.content.len = 0;
+ client.content.type = CONTENT_NONE;
+ break;
+ case CONTENT_CHUNKED:
+ if (client.remote_state < CONNECTED) {
+ // Request hasn't been sent yet, still collecting data
+ return;
+ }
+ auto used = client.content.chunked->add(ptr, avail);
+ if (!client.content.chunked->good()) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (!client_send(index, ptr, used)) {
+ return;
+ }
+ client.in->consume(used);
+ if (client.content.chunked->eof()) {
+ client.content.chunked.reset();
+ client.content.type = CONTENT_NONE;
+ }
+ break;
+ }
+ }
+}
+
+void ProxyImpl::close_client_when_done(size_t index) {
+ auto& client = clients_[index];
+ if (client.out->empty()) {
+ close_client(index);
+ return;
+ }
+ client.remote_state = CLOSED;
+ close_base(&client.remote);
+ client.in.reset();
+ client.read_flag = 0;
+ client.request.reset();
+ client.content.type = CONTENT_NONE;
+ looper_->modify(client.fd.get(), client.write_flag);
+}
+
+void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
+ auto& client = clients_[index];
+ assert(client.remote.fd.get() == fd);
+ if (events & Looper::EVENT_HUP) {
+ logger_->out(Logger::INFO, "%zu: Client remote connection closed", index);
+ client_remote_error(index, 502);
+ return;
+ }
+ if (events & Looper::EVENT_ERROR) {
+ logger_->out(Logger::INFO, "%zu: Client remote connection error", index);
+ client_remote_error(index, 502);
+ return;
+ }
+ if (client.remote_state == CONNECTING) {
+ if (events & Looper::EVENT_WRITE) {
+ std::string url(client.url->path_escaped());
+ if (url.empty()) url.push_back('/');
+ auto query = client.url->full_query_escaped();
+ if (query) {
+ url.push_back('?');
+ url.append(query);
+ }
+ auto req = std::unique_ptr<HttpRequestBuilder>(
+ HttpRequestBuilder::create(
+ client.request->method(),
+ url,
+ client.request->proto(),
+ client.request->proto_version()));
+ auto iter = client.request->header();
+ bool have_host = false;
+ for (; iter->valid(); iter->next()) {
+ if (!have_host && iter->name_equal("host")) have_host = true;
+ if (iter->name_equal("proxy-connection") ||
+ iter->name_equal("proxy-authenticate") ||
+ iter->name_equal("proxy-authorization")) {
+ continue;
+ }
+ req->add_header(iter->name(), iter->value());
+ }
+ if (!have_host &&
+ (client.request->proto_version().major == 1 &&
+ client.request->proto_version().minor == 1)) {
+ req->add_header("host", client.url->host());
+ }
+ auto data = req->build();
+ client.in->consume(client.request->size());
+ client.request.reset();
+ client.url.reset();
+ client.remote.out->write(data.data(), data.size());
+ client.remote_state = CONNECTED;
+ client.remote.read_flag = Looper::EVENT_READ;
+ client.remote.write_flag = Looper::EVENT_WRITE;
+ looper_->modify(client.remote.fd.get(),
+ client.remote.read_flag | client.remote.write_flag);
+ client_empty_input(index);
+ } else {
+ return;
+ }
+ }
+ if (!base_event(&client.remote, events, index, "Client remote")) {
+ switch (client.remote_state) {
+ case CONNECTED:
+ switch (client.remote.content.type) {
+ case CONTENT_CLOSE:
+ close_client_when_done(index);
+ break;
+ case CONTENT_CHUNKED:
+ case CONTENT_LEN:
+ case CONTENT_NONE:
+ client_remote_error(index, 502);
+ break;
+ }
+ break;
+ case CONNECTING:
+ case RESOLVING:
+ case CLOSED:
+ assert(false);
+ client_remote_error(index, 502);
+ break;
+ case WAITING:
+ close_base(&client.remote);
+ client.remote_state = CLOSED;
+ break;
+ }
+ return;
+ }
+ while (true) {
+ size_t avail;
+ auto ptr = client.remote.in->read_ptr(&avail);
+ if (avail == 0) return;
+ switch (client.remote_state) {
+ case CONNECTED:
+ switch (client.remote.content.type) {
+ case CONTENT_NONE: {
+ auto response = std::unique_ptr<HttpResponse>(
+ HttpResponse::parse(
+ reinterpret_cast<char const*>(ptr), avail, false));
+ if (!response) {
+ if (avail >= 1024 * 1024) {
+ logger_->out(Logger::INFO,
+ "%zu: Client remote too large request %zu",
+ index, avail);
+ client_remote_error(index, 502);
+ }
+ return;
+ }
+ if (!response->good()) {
+ client_remote_error(index, 502);
+ return;
+ }
+ if (!setup_content(response.get(), &client.remote.content)) {
+ logger_->out(Logger::INFO, "%zu: Client remote bad content-length",
+ index);
+ client.remote.content.type = CONTENT_CLOSE;
+ } else {
+ if (client.remote.content.type == CONTENT_NONE) {
+ client.remote_state = WAITING;
+ client.remote.read_flag = 0;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(), 0);
+ }
+ }
+ if (!base_send(&client, ptr, response->size(), index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(response->size());
+ break;
+ }
+ case CONTENT_LEN:
+ if (avail < client.remote.content.len) {
+ if (!base_send(&client, ptr, avail, index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(avail);
+ client.remote.content.len -= avail;
+ return;
+ }
+ if (!base_send(&client, ptr, client.remote.content.len,
+ index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(client.remote.content.len);
+ client.remote.content.len = 0;
+ client.remote.content.type = CONTENT_NONE;
+ client.remote_state = WAITING;
+ client.remote.read_flag = 0;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(), 0);
+ return;
+ case CONTENT_CHUNKED: {
+ auto used = client.remote.content.chunked->add(ptr, avail);
+ if (!client.remote.content.chunked->good()) {
+ logger_->out(Logger::INFO, "%zu: Client remote bad chunked",
+ index);
+ client.remote.content.type = CONTENT_CLOSE;
+ break;
+ }
+ if (!base_send(&client, ptr, used, index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(used);
+ if (client.remote.content.chunked->eof()) {
+ client.remote.content.chunked.reset();
+ client.remote.content.type = CONTENT_NONE;
+ logger_->out(Logger::INFO, "%zu: chunked -> waiting", index);
+ client.remote_state = WAITING;
+ client.remote.read_flag = 0;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(), 0);
+ }
+ break;
+ }
+ case CONTENT_CLOSE:
+ if (!base_send(&client, ptr, avail, index, "Client")) {
+ return;
+ }
+ client.remote.in->consume(avail);
+ break;
+ }
+ break;
+ case CONNECTING:
+ case RESOLVING:
+ case CLOSED:
+ assert(false);
+ return;
+ case WAITING:
+ return;
+ }
+ }
+}
+
+bool ProxyImpl::client_send(size_t index, void const* ptr, size_t size) {
+ auto& client = clients_[index];
+ assert(!client.request);
+ assert(client.remote_state >= CONNECTED);
+
+ return base_send(&client.remote, ptr, size, index, "Client remote");
+}
+
+void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) {
+ auto& monitor = monitors_[index];
+ assert(monitor.fd.get() == fd);
+ if (events & Looper::EVENT_HUP) {
+ close_monitor(index);
+ return;
+ }
+ if (events & Looper::EVENT_ERROR) {
+ logger_->out(Logger::INFO, "%zu: Monitor connection error", index);
+ close_monitor(index);
+ return;
+ }
+ if (!base_event(&monitor, events, index, "Monitor")) {
+ close_monitor(index);
+ return;
+ }
+}
+
+void ProxyImpl::new_base(BaseClient* client, int fd) {
+ client->fd.reset(fd);
+ client->new_connection = true;
+ client->read_flag = Looper::EVENT_READ;
+ client->write_flag = 0;
+ client->last = looper_->now();
+ if (!new_timeout_) {
+ new_timeout_ = looper_->schedule(
+ NEW_CONNECTION_TIMEOUT.count(),
+ std::bind(&ProxyImpl::new_timeout, this));
+ }
+}
+
+void ProxyImpl::new_client(int fd, uint8_t events) {
+ assert(fd == accept_socket_.get());
+ if (events == Looper::EVENT_READ) {
+ assert(!clients_.full());
+ while (true) {
+ int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return;
+ if (errno == EINTR) continue;
+ logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno));
+ return;
+ }
+ auto index = clients_.new_client();
+ new_base(&clients_[index], ret);
+ clients_[index].content.type = CONTENT_NONE;
+ clients_[index].remote_state = CLOSED;
+ clients_[index].remote.content.type = CONTENT_NONE;
+ looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag,
+ std::bind(&ProxyImpl::client_event, this,
+ index,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ break;
+ }
+ if (clients_.full()) looper_->modify(fd, 0);
+ } else {
+ logger_->out(Logger::ERR, "Accept socket died");
+ fatal_error();
+ }
+}
+
+void ProxyImpl::new_monitor(int fd, uint8_t events) {
+ assert(fd == monitor_socket_.get());
+ if (events == Looper::EVENT_READ) {
+ assert(!monitors_.full());
+ while (true) {
+ int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK);
+ if (ret < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return;
+ if (errno == EINTR) continue;
+ logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno));
+ return;
+ }
+ auto index = monitors_.new_client();
+ new_base(&monitors_[index], ret);
+ looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag,
+ std::bind(&ProxyImpl::monitor_event, this,
+ index,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ break;
+ }
+ if (monitors_.full()) {
+ looper_->modify(fd, 0);
+ }
+ } else {
+ logger_->out(Logger::ERR, "Monitor socket died");
+ fatal_error();
+ }
+}
+
+void ProxyImpl::fatal_error() {
+ looper_->quit();
+ good_ = false;
+}
+
+bool ProxyImpl::run() {
+ good_ = true;
+ if (!logger_) {
+ priv_logger_.reset(Logger::create_syslog("tp"));
+ logger_ = priv_logger_.get();
+ }
+ {
+ struct sigaction action;
+ memset(&action, 0, sizeof(action));
+ action.sa_handler = SIG_IGN;
+ action.sa_flags = SA_RESTART;
+ sigaction(SIGPIPE, &action, nullptr);
+ }
+ if (!signal_pipe_.open()) {
+ logger_->out(Logger::WARN,
+ "Failed to create pipes, signals wont work: %s",
+ strerror(errno));
+ }
+ if (signal_pipe_) {
+ looper_->add(signal_pipe_.read(), Looper::EVENT_READ,
+ std::bind(&ProxyImpl::signal_event, this,
+ std::placeholders::_1,
+ std::placeholders::_2));
+ }
+ if (!looper_->run()) {
+ logger_->out(Logger::ERR, "poll() failed: %s", strerror(errno));
+ return false;
+ }
+ return good_;
+}
+
+int setup_socket(char const* host, std::string const& port, Logger* logger) {
+ io::auto_fd ret;
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG | AI_PASSIVE;
+ struct addrinfo* result;
+ auto retval = getaddrinfo(host, port.c_str(), &hints, &result);
+ if (retval) {
+ logger->out(Logger::ERR, "getaddrinfo failed for %s:%s: %s",
+ host ? host : "*", port.c_str(), gai_strerror(retval));
+ return -1;
+ }
+ auto retp = result;
+ for (; retp; retp = retp->ai_next) {
+ ret.reset(socket(retp->ai_family, retp->ai_socktype,
+ retp->ai_protocol));
+ if (!ret) continue;
+ if (bind(ret.get(), retp->ai_addr, retp->ai_addrlen) == 0)
+ break;
+ }
+ freeaddrinfo(result);
+ if (!retp) {
+ logger->out(Logger::ERR, "Failed to bind %s:%s: %s",
+ host ? host : "*", port.c_str(), strerror(errno));
+ return -1;
+ }
+ if (listen(ret.get(), SOMAXCONN)) {
+ logger->out(Logger::ERR, "Failed to listen: %s", strerror(errno));
+ return -1;
+ }
+ if (fcntl(ret.get(), F_SETFL, O_NONBLOCK)) {
+ logger->out(Logger::ERR, "fcntl(O_NONBLOCK) failed: %s", strerror(errno));
+ return -1;
+ }
+ return ret.release();
+}
+
+} // namespace
+
+// static
+Proxy* Proxy::create(Config* config, std::string const& cwd,
+ char const* configfile,
+ char const* logfile,
+ Logger* logger,
+ int accept_fd,
+ int monitor_fd) {
+ return new ProxyImpl(config, cwd, configfile, logfile, logger,
+ accept_fd, monitor_fd);
+}
+
+// static
+int Proxy::setup_accept_socket(Config* config, Logger* logger) {
+ return setup_socket(config->get("proxy_bind", nullptr),
+ config->get("proxy_port", "8080"), logger);
+}
+
+// static
+int Proxy::setup_monitor_socket(Config* config, Logger* logger) {
+ assert(config->get("monitor", false));
+ return setup_socket(config->get("monitor_bind", "localhost"),
+ config->get("monitor_port", "9000"), logger);
+}
diff --git a/src/proxy.hh b/src/proxy.hh
new file mode 100644
index 0000000..6545152
--- /dev/null
+++ b/src/proxy.hh
@@ -0,0 +1,37 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef PROXY_HH
+#define PROXY_HH
+
+#include <string>
+
+class Config;
+class Logger;
+
+class Proxy {
+public:
+ virtual ~Proxy() {}
+
+ static Proxy* create(Config* config, std::string const& cwd,
+ char const* configfile,
+ char const* logfile,
+ Logger* logger,
+ int accept_fd, int monitor_fd);
+ static int setup_accept_socket(Config* config, Logger* logger);
+ static int setup_monitor_socket(Config* config, Logger* logger);
+
+ // Called from signal handler, ask the proxy to exit as soon as possible
+ virtual void quit() = 0;
+
+ // Called from signal handler, ask the proxy to reload config as soon
+ // as possible
+ virtual void reload() = 0;
+
+ virtual bool run() = 0;
+
+protected:
+ Proxy() {}
+ Proxy(Proxy const&) = delete;
+};
+
+#endif // PROXY_HH
diff --git a/src/resolver.cc b/src/resolver.cc
new file mode 100644
index 0000000..2623089
--- /dev/null
+++ b/src/resolver.cc
@@ -0,0 +1,207 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <condition_variable>
+#include <cstring>
+#include <fcntl.h>
+#include <mutex>
+#include <netdb.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <thread>
+#include <vector>
+
+#include "io.hh"
+#include "looper.hh"
+#include "resolver.hh"
+
+namespace {
+
+size_t const WORKERS = 4;
+
+class ResolverImpl : public Resolver {
+public:
+ ResolverImpl(Looper* looper)
+ : looper_(looper), request_(nullptr), buf_(new char[sizeof(Request*)]),
+ fill_(0), quit_(false) {
+ while (threads_.size() < WORKERS) {
+ threads_.emplace_back(std::bind(&ResolverImpl::worker, this));
+ }
+ if (pipe_.open() && fcntl(pipe_.read(), F_SETFL, O_NONBLOCK) == 0) {
+ looper_->add(pipe_.read(),
+ Looper::EVENT_READ,
+ std::bind(&ResolverImpl::event, this,
+ std::placeholders::_1, std::placeholders::_2));
+ } else {
+ assert(false);
+ }
+ }
+
+ ~ResolverImpl() override {
+ quit_ = true;
+ cond_.notify_all();
+ for (auto& thread : threads_) {
+ thread.join();
+ }
+ }
+
+ void* request(std::string const& host, uint16_t port,
+ Callback const& callback) override {
+ auto req = new Request();
+ req->host = host;
+ req->port = port;
+ req->callback = callback;
+ req->canceled = false;
+ std::unique_lock<std::mutex> lock(mutex_);
+ req->next = request_;
+ request_ = req;
+ cond_.notify_one();
+ return req;
+ }
+
+ void cancel(void* ptr) override {
+ auto req = reinterpret_cast<Request*>(ptr);
+ req->canceled = true;
+ std::unique_lock<std::mutex> lock(mutex_);
+ if (request_ == req) {
+ request_ = req->next;
+ delete req;
+ } else {
+ for (auto r = request_; r->next; r = r->next) {
+ if (r->next == req) {
+ r->next = req->next;
+ delete req;
+ return;
+ }
+ }
+ }
+ }
+
+protected:
+ struct Request {
+ Request* next;
+ std::string host;
+ uint16_t port;
+ Callback callback;
+ bool canceled;
+ io::auto_fd fd;
+ bool connected;
+ std::string error;
+ };
+
+ void event(int fd, uint8_t event) {
+ assert(fd == pipe_.read());
+ if (event & Looper::EVENT_READ) {
+ while (true) {
+ auto ret = io::read(fd, buf_.get() + fill_, sizeof(Request*) - fill_);
+ if (ret == -1) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK) return;
+ assert(false);
+ return;
+ } else if (ret == 0) {
+ assert(false);
+ return;
+ }
+ fill_ += ret;
+ if (fill_ == sizeof(Request*)) {
+ fill_ = 0;
+ auto req = *reinterpret_cast<Request**>(buf_.get());
+ if (!req->canceled) {
+ auto err = req->fd ? nullptr : req->error.c_str();
+ req->callback(req->fd.release(), req->connected, err);
+ }
+ delete req;
+ } else {
+ break;
+ }
+ }
+ } else {
+ assert(false);
+ }
+ }
+
+ void report(Request* req, int fd, bool connected, char const* errmsg) {
+ req->fd.reset(fd);
+ req->connected = connected;
+ if (errmsg) req->error = errmsg;
+ io::write_all(pipe_.write(), &req, sizeof(Request*));
+ }
+
+ void worker() {
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG | AI_NUMERICSERV;
+ char tmp[10];
+ while (true) {
+ Request* req;
+ {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (!quit_ && !request_) {
+ cond_.wait(lock);
+ }
+ if (quit_) return;
+ auto pr = &request_;
+ while ((*pr)->next) {
+ pr = &((*pr)->next);
+ }
+ req = *pr;
+ *pr = nullptr;
+ }
+ snprintf(tmp, sizeof(tmp), "%u", static_cast<unsigned int>(req->port));
+ struct addrinfo* result;
+ auto ret = getaddrinfo(req->host.c_str(), tmp, &hints, &result);
+ if (ret != 0) {
+ report(req, -1, false, gai_strerror(ret));
+ continue;
+ }
+ auto rp = result;
+ for (; rp; rp = rp->ai_next) {
+ io::auto_fd fd(socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol));
+ if (!fd) continue;
+
+ fcntl(fd.get(), F_SETFL, O_NONBLOCK);
+
+ while (true) {
+ ret = connect(fd.get(), rp->ai_addr, rp->ai_addrlen);
+ if (ret == 0 || errno != EINTR) break;
+ }
+ if (ret == 0) {
+ report(req, fd.release(), true, nullptr);
+ break;
+ }
+ if (errno == EINPROGRESS) {
+ report(req, fd.release(), false, nullptr);
+ break;
+ }
+ }
+
+ if (!rp) {
+ freeaddrinfo(result);
+ report(req, -1, false, strerror(errno));
+ continue;
+ }
+
+ freeaddrinfo(result);
+ }
+ }
+
+ Looper* const looper_;
+ Request* request_;
+ io::auto_pipe pipe_;
+ std::mutex mutex_;
+ std::condition_variable cond_;
+ std::unique_ptr<char[]> buf_;
+ size_t fill_;
+ bool quit_;
+ std::vector<std::thread> threads_;
+};
+
+} // namespace
+
+// static
+Resolver* Resolver::create(Looper* looper) {
+ return new ResolverImpl(looper);
+}
diff --git a/src/resolver.hh b/src/resolver.hh
new file mode 100644
index 0000000..d98458f
--- /dev/null
+++ b/src/resolver.hh
@@ -0,0 +1,28 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef RESOLVER_HH
+#define RESOLVER_HH
+
+#include <functional>
+
+class Looper;
+
+class Resolver {
+public:
+ typedef std::function<void(int fd, bool connected,
+ char const* error)> Callback;
+
+ virtual ~Resolver() {}
+
+ static Resolver* create(Looper* looper);
+
+ virtual void* request(std::string const& host, uint16_t port,
+ Callback const& callback) = 0;
+ virtual void cancel(void* ptr) = 0;
+
+protected:
+ Resolver() {}
+ Resolver(Resolver const&) = delete;
+};
+
+#endif // RESOLVER_HH
diff --git a/src/strings.cc b/src/strings.cc
new file mode 100644
index 0000000..1d74dee
--- /dev/null
+++ b/src/strings.cc
@@ -0,0 +1,94 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include "character.hh"
+#include "strings.hh"
+
+// static
+void Strings::trim(std::string const& str, size_t* start, size_t* end) {
+ assert(start && end);
+ assert(*start <= *end);
+ assert(*end <= str.size());
+ while (*start < *end && Character::isspace(str, *start)) ++(*start);
+ while (*end > *start && Character::isspace(str, *end - 1)) --(*end);
+}
+
+// static
+std::string Strings::trim(std::string const& str) {
+ return trim(str, 0, str.size());
+}
+
+// static
+std::string Strings::trim(std::string const& str, size_t start, size_t end) {
+ auto s = start, e = end;
+ trim(str, &s, &e);
+ return str.substr(s, e - s);
+}
+
+// static
+std::string Strings::quote(std::string const& str) {
+ return quote(str, 0, str.size());
+}
+
+// static
+std::string Strings::quote(std::string const& str, size_t start, size_t end) {
+ assert(start <= end);
+ assert(end <= str.size());
+ auto i = start;
+ while (i < end) {
+ auto c = str[i];
+ if (c == '"' || c == '\\') break;
+ ++i;
+ }
+ std::string ret("\"");
+ if (i == end) {
+ ret.append(str.substr(start, end - start));
+ } else {
+ ret.append(str.substr(start, i - start));
+ while (true) {
+ ret.push_back('\\');
+ auto j = i++;
+ while (i < end) {
+ auto c = str[i];
+ if (c == '"' || c == '\\') break;
+ ++i;
+ }
+ ret.append(str.substr(j, i - j));
+ if (i == end) break;
+ }
+ }
+ ret.push_back('"');
+ return ret;
+}
+
+// static
+std::string Strings::unquote(std::string const& str) {
+ return unquote(str, 0, str.size());
+}
+
+// static
+std::string Strings::unquote(std::string const& str, size_t start, size_t end) {
+ assert(start < end - 1);
+ assert(end <= str.size());
+ assert(str[start] == '"' && str[end - 1] == '"');
+ ++start;
+ --end;
+ auto i = start;
+ while (i < end && str[i] != '\\') ++i;
+ if (i == end) return str.substr(start, end - start);
+ std::string ret(str.substr(start, i - start));
+ while (true) {
+ ++i; // skip backslash
+ if (i == end) {
+ assert(false);
+ ret.push_back('\\');
+ break;
+ }
+ auto j = i++;
+ while (i < end && str[i] != '\\') ++i;
+ ret.append(str.substr(j, i - j));
+ if (i == end) break;
+ }
+ return ret;
+}
diff --git a/src/strings.hh b/src/strings.hh
new file mode 100644
index 0000000..10e3edf
--- /dev/null
+++ b/src/strings.hh
@@ -0,0 +1,24 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef STRINGS_HH
+#define STRINGS_HH
+
+#include <string>
+
+class Strings {
+public:
+ static void trim(std::string const& str, size_t* start, size_t* end);
+ static std::string trim(std::string const& str);
+ static std::string trim(std::string const& str, size_t start, size_t end);
+
+ static std::string quote(std::string const& str);
+ static std::string quote(std::string const& str, size_t start, size_t end);
+ static std::string unquote(std::string const& str);
+ static std::string unquote(std::string const& str, size_t start, size_t end);
+
+private:
+ ~Strings() {}
+ Strings() {}
+};
+
+#endif // STRINGS_HH
diff --git a/src/terminal.cc b/src/terminal.cc
new file mode 100644
index 0000000..fd7bf98
--- /dev/null
+++ b/src/terminal.cc
@@ -0,0 +1,20 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <cstring>
+#include <sys/ioctl.h>
+#include <unistd.h>
+
+#include "terminal.hh"
+
+// static
+Terminal::Size Terminal::size() {
+ struct winsize size;
+ memset(&size, 0, sizeof(size));
+ ioctl(STDOUT_FILENO, TIOCGWINSZ, &size);
+ if (size.ws_col == 0 || size.ws_row == 0) {
+ return { .width = 80, .height = 25 };
+ }
+ return { .width = size.ws_col, .height = size.ws_row };
+}
diff --git a/src/terminal.hh b/src/terminal.hh
new file mode 100644
index 0000000..414ee4b
--- /dev/null
+++ b/src/terminal.hh
@@ -0,0 +1,22 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef TERMINAL_HH
+#define TERMINAL_HH
+
+#include <cstdint>
+
+class Terminal {
+public:
+ struct Size {
+ uint32_t width;
+ uint32_t height;
+ };
+
+ static Size size();
+
+private:
+ Terminal() { }
+ ~Terminal() { }
+};
+
+#endif // TERMINAL_HH
diff --git a/src/url.cc b/src/url.cc
new file mode 100644
index 0000000..b4e6c0f
--- /dev/null
+++ b/src/url.cc
@@ -0,0 +1,901 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <algorithm>
+#include <cstring>
+#include <iostream>
+#include <memory>
+
+#include "paths.hh"
+#include "url.hh"
+
+namespace {
+
+class UrlImpl : public Url {
+public:
+ UrlImpl()
+ : userinfo_(nullptr), userinfo_unescaped_(nullptr), port_(0),
+ path_(nullptr), path_unescaped_(nullptr), query_(nullptr),
+ query_unescaped_(nullptr), fragment_(nullptr) {
+ }
+ ~UrlImpl() override {
+ delete[] userinfo_;
+ if (userinfo_unescaped_ != userinfo_) delete[] userinfo_unescaped_;
+ delete[] path_;
+ if (path_unescaped_ != path_) delete[] path_unescaped_;
+ delete[] query_;
+ if (query_unescaped_ != query_) delete[] query_unescaped_;
+ delete[] fragment_;
+ }
+ UrlImpl(UrlImpl const& url);
+
+ bool parse(std::string const& url, Url const* base);
+
+ Url* copy() const override {
+ return new UrlImpl(*this);
+ }
+
+ std::string const& scheme() const override {
+ return scheme_;
+ }
+
+ char const* userinfo() const override;
+
+ char const* userinfo_escaped() const override {
+ return userinfo_;
+ }
+
+ std::string const& host() const override {
+ return host_;
+ }
+
+ uint16_t port() const override {
+ return port_;
+ }
+
+ std::string path() const override;
+
+ std::string path_escaped() const override {
+ return path_;
+ }
+
+ bool query(std::string const& name, std::string* value) const override;
+
+ char const* full_query() const override;
+
+ char const* full_query_escaped() const override {
+ return query_;
+ }
+
+ char const* fragment() const override {
+ return fragment_;
+ }
+
+ void print(std::ostream& out, bool path = true, bool query = true,
+ bool fragment = true) const override;
+
+private:
+ char const* parse_authority(char const* pos);
+ char const* parse_query(char const* pos);
+ char const* parse_fragment(char const* pos);
+ bool relative(std::string const& url, Url const* base);
+
+ std::string scheme_;
+ char* userinfo_;
+ mutable char* userinfo_unescaped_;
+ std::string host_;
+ uint16_t port_;
+ char* path_;
+ mutable char* path_unescaped_;
+ char* query_;
+ mutable char* query_unescaped_;
+ char* fragment_;
+};
+
+char* unescape(char* start, char* end, bool query);
+char* escape(char const* str, char const* safe);
+char* dup(char const* start, char const* end);
+void lower(std::string& str);
+
+bool is_ipv4(char const* start, char const* end);
+bool is_ipv6(char const* start, char const* end);
+
+char const subdelims[] = "!$&'()*+,;=";
+
+UrlImpl::UrlImpl(UrlImpl const& url)
+ : Url(), scheme_(url.scheme_), host_(url.host_), port_(url.port_) {
+ userinfo_ = dup(url.userinfo_, nullptr);
+ userinfo_unescaped_ = dup(url.userinfo_unescaped_, nullptr);
+ path_ = dup(url.path_, nullptr);
+ path_unescaped_ = dup(url.path_unescaped_, nullptr);
+ query_ = dup(url.query_, nullptr);
+ query_unescaped_ = dup(url.query_unescaped_, nullptr);
+ fragment_ = dup(url.fragment_, nullptr);
+}
+
+inline bool is_alpha(char c) {
+ return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
+}
+
+inline bool is_digit(char c) {
+ return c >= '0' && c <= '9';
+}
+
+inline bool is_unreserved(char c) {
+ return is_alpha(c) || is_digit(c) || c == '-' || c == '.' || c == '_' ||
+ c == '~';
+}
+
+inline bool is_hex(char c) {
+ return is_digit(c) || (c >= 'A' && c <= 'F') || (c >= 'a' && c <= 'f');
+}
+
+char* emptystr() {
+ char* ret = new char[1];
+ ret[0] = '\0';
+ return ret;
+}
+
+bool UrlImpl::parse(std::string const& url, Url const* base) {
+ char const* pos = url.c_str(), *start;
+
+ // URI = scheme ":" hier-part [ "?" query ] [ "#" fragment ]
+
+ // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." )
+ start = pos;
+ if (!is_alpha(*pos)) {
+ return relative(url, base);
+ }
+ pos++;
+ while (is_alpha(*pos) || is_digit(*pos) || *pos == '+' ||
+ *pos == '-' || *pos == '.') {
+ pos++;
+ }
+ if (*pos != ':') {
+ return relative(url, base);
+ }
+ scheme_.assign(start, pos);
+ lower(scheme_);
+ pos++;
+
+ /*
+ hier-part = "//" authority path-abempty
+ / path-absolute
+ / path-rootless
+ / path-empty
+ */
+ if (memcmp(pos, "//", 2)) {
+ if (pos[0] == '/' && pos[1] != '/') {
+ // path-absolute = "/" [ segment-nz *( "/" segment ) ]
+ start = pos++;
+ } else if (pos[0] != '/') {
+ /*
+ path-rootless = segment-nz *( "/" segment )
+ path-empty = 0<pchar>
+ */
+ start = pos;
+ }
+ while (true) {
+ if (is_unreserved(*pos)) {
+ pos++;
+ } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) {
+ pos += 3;
+ } else if (*pos && strchr(subdelims, *pos)) {
+ pos++;
+ } else if (*pos == ':' || *pos == '@' || *pos == '/') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+
+ path_ = dup(start, pos);
+ } else {
+ pos += 2;
+ pos = parse_authority(pos);
+ if (!pos) {
+ return relative(url, base);
+ }
+
+ if (*pos == '/') {
+ /*
+ path-abempty = *( "/" segment )
+ segment = *pchar
+ pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
+ */
+ start = pos++;
+ while (true) {
+ if (is_unreserved(*pos)) {
+ pos++;
+ } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) {
+ pos += 3;
+ } else if (*pos && strchr(subdelims, *pos)) {
+ pos++;
+ } else if (*pos == ':' || *pos == '@' || *pos == '/') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+
+ path_ = dup(start, pos);
+ } else {
+ path_ = emptystr();
+ }
+ }
+
+ if (*pos == '?') {
+ pos = parse_query(pos);
+ if (!pos) {
+ return relative(url, base);
+ }
+ }
+
+ if (*pos == '#') {
+ pos = parse_fragment(pos);
+ if (!pos) {
+ return relative(url, base);
+ }
+ }
+
+ if (*pos != '\0') {
+ return relative(url, base);
+ }
+
+ return true;
+}
+
+bool UrlImpl::relative(std::string const& url, Url const* base) {
+ char const* pos = url.c_str();
+ char const* start;
+ if (!base) return false;
+
+ scheme_.clear();
+ if (userinfo_unescaped_ != userinfo_) {
+ delete[] userinfo_unescaped_; userinfo_unescaped_ = nullptr;
+ }
+ delete[] userinfo_; userinfo_ = nullptr;
+ host_.clear();
+ port_ = 0;
+ if (path_unescaped_ != path_) {
+ delete[] path_unescaped_; path_unescaped_ = nullptr;
+ }
+ delete[] path_; path_ = nullptr;
+ if (query_unescaped_ != query_) {
+ delete[] query_unescaped_; query_unescaped_ = nullptr;
+ }
+ delete[] query_; query_ = nullptr;
+ delete[] fragment_; fragment_ = nullptr;
+
+ /*
+ relative-part = "//" authority path-abempty
+ / path-absolute
+ / path-noscheme
+ / path-empty
+ */
+ if (memcmp(pos, "//", 2)) {
+ if (pos[0] == '/' && pos[1] != '/') {
+ // path-absolute = "/" [ segment-nz *( "/" segment ) ]
+ start = pos++;
+ } else if (pos[0] != '/') {
+ /*
+ path-noscheme = segment-nz-nc *( "/" segment )
+ path-empty = 0<pchar>
+ */
+ start = pos;
+ }
+ while (true) {
+ if (is_unreserved(*pos)) {
+ pos++;
+ } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) {
+ pos += 3;
+ } else if (*pos && strchr(subdelims, *pos)) {
+ pos++;
+ } else if (*pos == ':' || *pos == '@' || *pos == '/') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+
+ path_ = dup(start, pos);
+ } else {
+ pos += 2;
+ pos = parse_authority(pos);
+ if (!pos) return false;
+
+ if (*pos == '/') {
+ /*
+ path-abempty = *( "/" segment )
+ segment = *pchar
+ pchar = unreserved / pct-encoded / sub-delims / ":" / "@"
+ */
+ start = pos++;
+ while (true) {
+ if (is_unreserved(*pos)) {
+ pos++;
+ } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) {
+ pos += 3;
+ } else if (*pos && strchr(subdelims, *pos)) {
+ pos++;
+ } else if (*pos == ':' || *pos == '@' || *pos == '/') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+
+ path_ = dup(start, pos);
+ } else {
+ path_ = emptystr();
+ }
+ }
+
+ if (*pos == '?') {
+ pos = parse_query(pos);
+ if (!pos) return false;
+ }
+
+ if (*pos == '#') {
+ pos = parse_fragment(pos);
+ if (!pos) return false;
+ }
+
+ if (*pos != '\0') return false;
+
+ scheme_ = base->scheme();
+ if (host_.empty()) {
+ auto userinfo = base->userinfo();
+ if (userinfo) userinfo_ = dup(userinfo, userinfo + strlen(userinfo));
+ host_ = base->host();
+ port_ = base->port();
+ }
+ if (path_[0] != '/') {
+ std::string tmp(
+ Paths::join(!base->path().empty() ? base->path() : "/", path_));
+ delete[] path_;
+ path_ = dup(tmp.data(), tmp.data() + tmp.size());
+ }
+ return true;
+}
+
+char const* UrlImpl::parse_authority(char const* pos) {
+ /* authority = [ userinfo "@" ] host [ ":" port ]
+ userinfo = *( unreserved / pct-encoded / sub-delims / ":" )
+ host = IP-literal / IPv4address / reg-name
+ IP-literal = "[" ( IPv6address / IPvFuture ) "]"
+ IPvFuture = "v" 1*HEXDIG "." 1*( unreserved / sub-delims / ":" )
+ IPv6address = 6( h16 ":" ) ls32
+ / "::" 5( h16 ":" ) ls32
+ / [ h16 ] "::" 4( h16 ":" ) ls32
+ / [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32
+ / [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32
+ / [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32
+ / [ *4( h16 ":" ) h16 ] "::" ls32
+ / [ *5( h16 ":" ) h16 ] "::" h16
+ / [ *6( h16 ":" ) h16 ] "::"
+ ls32 = ( h16 ":" h16 ) / IPv4address
+ h16 = 1*4HEXDIG
+ IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet
+ dec-octet = DIGIT ; 0-9
+ / %x31-39 DIGIT ; 10-99
+ / "1" 2DIGIT ; 100-199
+ / "2" %x30-34 DIGIT ; 200-249
+ / "25" %x30-35 ; 250-255
+ reg-name = *( unreserved / pct-encoded / sub-delims )
+ port = *DIGIT
+ */
+ char const* start = pos;
+ char const* at = nullptr, *colon = nullptr;
+ char const* host_start, *host_end, *tmp;
+ while (true) {
+ if (is_unreserved(*pos)) {
+ pos++;
+ } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) {
+ pos += 3;
+ } else if (*pos && strchr(subdelims, *pos)) {
+ pos++;
+ } else if (*pos == '[' || *pos == ']') {
+ pos++;
+ } else if (*pos == ':') {
+ colon = pos++;
+ } else if (*pos == '@') {
+ if (at) {
+ return nullptr;
+ }
+ colon = nullptr;
+ at = pos++;
+ } else {
+ break;
+ }
+ }
+
+ // userinfo?
+ if (at) {
+ host_start = at + 1;
+ userinfo_ = dup(start, at);
+ if (strchr(userinfo_, '[') || strchr(userinfo_, ']')) {
+ return nullptr;
+ }
+ } else {
+ host_start = start;
+ }
+
+ // port?
+ if (colon) {
+ tmp = colon + 1;
+ if (tmp < pos && is_digit(*tmp)) {
+ uint16_t v = *tmp - '0';
+ tmp++;
+ host_end = colon;
+ while (tmp < pos) {
+ uint16_t x;
+ if (!is_digit(*tmp)) {
+ host_end = pos;
+ break;
+ }
+ x = v;
+ v *= 10;
+ if (v < x) {
+ return nullptr;
+ }
+ v += *tmp - '0';
+ tmp++;
+ }
+ if (host_end == colon) {
+ port_ = v;
+ }
+ } else {
+ host_end = pos;
+ }
+ } else {
+ host_end = pos;
+ }
+
+ if (*host_start == '[') {
+ if (host_end[-1] != ']' || host_start + 1 >= host_end - 1) {
+ return nullptr;
+ }
+ host_start++;
+ host_end--;
+ if (*host_start == 'v') {
+ if (!is_hex(host_start[1]) || host_start[2] != '.') {
+ return nullptr;
+ }
+ host_.assign(host_start, host_end - host_start);
+ lower(host_);
+ } else {
+ if (!is_ipv6(host_start, host_end)) {
+ return nullptr;
+ }
+ host_.assign(host_start, host_end - host_start);
+ lower(host_);
+ }
+ if (host_.find('[') != std::string::npos ||
+ host_.find(']') != std::string::npos) {
+ return nullptr;
+ }
+ } else {
+ tmp = host_start;
+ while (tmp < host_end) {
+ if (*tmp == '[' || *tmp == ']') {
+ return nullptr;
+ }
+ tmp++;
+ }
+ tmp = unescape(const_cast<char*>(host_start), const_cast<char*>(host_end),
+ false);
+ host_ = tmp;
+ lower(host_);
+ delete[] const_cast<char*>(tmp);
+ }
+ if (host_.empty()) return nullptr;
+ return pos;
+}
+
+char const* UrlImpl::parse_query(char const* pos) {
+ // query = *( pchar / "/" / "?" )
+ char const* start = ++pos;
+ while (true) {
+ if (is_unreserved(*pos)) {
+ pos++;
+ } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) {
+ pos += 3;
+ } else if (*pos && strchr(subdelims, *pos)) {
+ pos++;
+ } else if (*pos == ':' || *pos == '@' || *pos == '/' ||
+ *pos == '?') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+
+ query_ = dup(start, pos);
+ return pos;
+}
+
+char const* UrlImpl::parse_fragment(char const* pos) {
+ // fragment = *( pchar / "/" / "?" )
+ char const* start = ++pos;
+ while (true) {
+ if (is_unreserved(*pos)) {
+ pos++;
+ } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) {
+ pos += 3;
+ } else if (*pos && strchr(subdelims, *pos)) {
+ pos++;
+ } else if (*pos == ':' || *pos == '@' || *pos == '/' ||
+ *pos == '?') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+
+ fragment_ = unescape((char*) start, (char*) pos, false);
+ return pos;
+}
+
+char const* UrlImpl::userinfo() const {
+ if (userinfo_ && !userinfo_unescaped_) {
+ userinfo_unescaped_ = unescape(userinfo_, nullptr, false);
+ }
+ return userinfo_unescaped_;
+}
+
+std::string UrlImpl::path() const {
+ if (!path_unescaped_) {
+ path_unescaped_ = unescape(path_, nullptr, false);
+ }
+ return path_unescaped_;
+}
+
+bool UrlImpl::query(std::string const& name, std::string* value) const {
+ char* pos;
+ if (!query_ || !*query_) {
+ return false;
+ }
+ pos = query_;
+ while (true) {
+ char* next = pos, *eq;
+ char* tmp;
+ while (*next && *next != '=' && *next != '&') {
+ next++;
+ }
+ if (*next == '=') {
+ eq = next++;
+ while (*next && *next != '&') {
+ next++;
+ }
+ } else {
+ eq = next;
+ }
+ tmp = unescape(pos, eq, true);
+ if (name.compare(tmp) == 0) {
+ delete[] tmp;
+ if (eq != next) {
+ std::unique_ptr<char[]> tmp(unescape(eq + 1, next, true));
+ if (value) value->assign(tmp.get());
+ return true;
+ } else {
+ if (value) value->clear();
+ return true;
+ }
+ }
+ delete[] tmp;
+ if (!*next) {
+ return false;
+ }
+ pos = next + 1;
+ }
+}
+
+char const* UrlImpl::full_query() const {
+ if (query_ && !query_unescaped_) {
+ query_unescaped_ = unescape(query_, nullptr, true);
+ }
+ return query_unescaped_;
+}
+
+bool unhex(char c, uint8_t* ret) {
+ if (is_digit(c)) {
+ *ret = c - '0';
+ return true;
+ } else if (c >= 'A' && c <= 'F') {
+ *ret = c - 'A' + 10;
+ return true;
+ } else if (c >= 'a' && c <= 'f') {
+ *ret = c - 'a' + 10;
+ return true;
+ }
+ return false;
+}
+
+char* unescape(char* start, char* end, bool query) {
+ char* pos = start;
+ char* ret;
+ size_t o;
+ if (!start) {
+ return nullptr;
+ }
+ if (end) {
+ for (; pos < end; pos++) {
+ if (*pos == '%' || (query && *pos == '+')) {
+ break;
+ }
+ }
+ if (pos == end) {
+ return dup(start, end);
+ }
+ } else {
+ for (; *pos; pos++) {
+ if (*pos == '%' || (query && *pos == '+')) {
+ break;
+ }
+ }
+ if (!*pos) {
+ return start;
+ }
+ end = pos + strlen(pos);
+ }
+ ret = new char[end - start + 1];
+ o = 0;
+ while (true) {
+ uint8_t h, l;
+ memcpy(ret + o, start, pos - start);
+ o += pos - start;
+ if (query && *pos == '+') {
+ ret[o++] = ' ';
+ pos++;
+ } else if (pos + 3 <= end && unhex(pos[1], &h) && unhex(pos[2], &l)) {
+ ret[o++] = h << 4 | l;
+ pos += 3;
+ } else {
+ ret[o++] = *(pos++);
+ }
+ start = pos;
+ while (pos < end && *pos != '%' && !(query && *pos == '+')) {
+ pos++;
+ }
+ if (pos == end) {
+ memcpy(ret + o, start, pos - start);
+ o += pos - start;
+ break;
+ }
+ }
+ ret[o] = '\0';
+ return ret;
+}
+
+bool unsafe(char c, char const* safe) {
+ if (is_unreserved(c) || (c && strchr(safe, c))) {
+ return false;
+ }
+ return true;
+}
+
+char hex(uint8_t c) {
+ return c < 10 ? '0' + c : 'A' + c - 10;
+}
+
+char* escape(char const* str, char const* safe) {
+ char const* pos = str;
+ size_t len;
+ char* ret;
+ size_t o;
+ while (*pos && !unsafe(*pos, safe)) {
+ pos++;
+ }
+ if (!*pos) {
+ return (char*) str;
+ }
+ len = strlen(pos);
+ ret = new char[(pos - str) + len * 3 + 1];
+ o = 0;
+ while (true) {
+ memcpy(ret + o, str, pos - str);
+ o += pos - str;
+ ret[o++] = '%';
+ ret[o++] = hex(*((const uint8_t*)pos) >> 4);
+ ret[o++] = hex(*((const uint8_t*)pos) & 0xf);
+ str = ++pos;
+ while (*pos && !unsafe(*pos, safe)) {
+ pos++;
+ }
+ if (!*pos) {
+ memcpy(ret + o, str, pos - str);
+ o += pos - str;
+ break;
+ }
+ }
+ ret[o] = '\0';
+ return ret;
+}
+
+char* dup(char const* start, char const* end) {
+ if (!start) return nullptr;
+ if (!end) end = start + strlen(start);
+ size_t len = end - start;
+ char* ret;
+ assert(start <= end);
+ ret = new char[len + 1];
+ memcpy(ret, start, len);
+ ret[len] = '\0';
+ return ret;
+}
+
+inline char ascii_tolower(char c) {
+ return (c >= 'A' && c <= 'Z') ? ('a' + c - 'A') : c;
+}
+
+void lower(std::string& str) {
+ std::transform(str.begin(), str.end(), str.begin(), ascii_tolower);
+}
+
+bool is_ipv4(char const* start, char const* end) {
+ char const* pos = start;
+ size_t i = 0;
+ while (true) {
+ if (pos[0] == '2' && pos + 2 <= end) {
+ if (pos[1] == '5') {
+ if (pos[2] < '0' || pos[2] > '5') {
+ break;
+ }
+ } else {
+ if (pos[1] < '0' || pos[1] > '4' || !is_digit(pos[2])) {
+ break;
+ }
+ }
+ pos += 3;
+ } else if (pos[0] == '1' && pos + 2 <= end &&
+ is_digit(pos[1]) && is_digit(pos[2])) {
+ pos += 3;
+ } else if ((pos[0] >= '1' && pos[0] <= '9') && pos + 1 <= end &&
+ is_digit(pos[1])) {
+ pos += 2;
+ } else if (is_digit(pos[0])) {
+ pos++;
+ } else {
+ break;
+ }
+
+ i++;
+ if (pos == end || *pos != '.') {
+ break;
+ }
+ pos++;
+ if (pos == end) {
+ return false;
+ }
+ }
+ return i == 4 && pos == end;
+}
+
+size_t walk_hex(char const* start, char const* end, size_t max) {
+ size_t i = 0;
+ while (max-- && start < end) {
+ if (!is_hex(*start)) break;
+ start++;
+ i++;
+ }
+ return i;
+}
+
+bool is_ipv6(char const* start, char const* end) {
+ size_t i = 0, j = 0, x;
+ char const* pos = start;
+ bool empty = false;
+ while (pos < end) {
+ if (*pos == ':') {
+ pos++;
+ if (*pos == ':') {
+ empty = true;
+ pos++;
+ break;
+ }
+ if (i == 0) {
+ return false;
+ }
+ }
+ x = walk_hex(pos, end, 4);
+ if (x == 0) {
+ return false;
+ }
+ pos += x;
+ i++;
+ }
+
+ if (pos < end) {
+ while (true) {
+ x = walk_hex(pos, end, 4);
+ if (x == 0) {
+ return false;
+ }
+ pos += x;
+ if (pos == end) {
+ j++;
+ break;
+ }
+ if (*pos != ':') {
+ pos -= x;
+ break;
+ }
+ pos++;
+ j++;
+ }
+ }
+
+ if (pos != end) {
+ if (!is_ipv4(pos, end)) {
+ return false;
+ }
+ j += 2;
+ }
+
+ if (!empty) {
+ return i == 8;
+ }
+
+ if (i + j > 7) {
+ return false;
+ }
+ return true;
+}
+
+void UrlImpl::print(std::ostream& out, bool path, bool query, bool fragment)
+ const {
+ out << scheme_;
+ if (!host_.empty()) {
+ out << "://";
+ if (userinfo_) {
+ out << userinfo_ << '@';
+ }
+ out << host_;
+ if (port_) {
+ out << ':' << port_;
+ }
+ } else {
+ out << ':';
+ }
+ if (path) {
+ out << path_;
+ }
+ if (query && query_ && *query_) {
+ out << '?' << query_;
+ }
+ if (fragment && fragment_ && *fragment_) {
+ out << '#';
+ char* tmp = escape(fragment_, "/?");
+ out << tmp;
+ if (tmp != fragment_) delete[] tmp;
+ }
+}
+
+bool null_eq(char const* a1, char const* a2) {
+ if (a1 == a2) return true;
+ return a1 && a2 && strcmp(a1, a2) == 0;
+}
+
+} // namespace
+
+// static
+Url* Url::parse(std::string const& url, Url const* base) {
+ UrlImpl* ret = new UrlImpl();
+ if (ret->parse(url, base)) return ret;
+ delete ret;
+ return nullptr;
+}
+
+bool Url::operator==(Url const& url) const {
+ if (scheme() != url.scheme()) return false;
+ if (host() != url.host()) return false;
+ if (port() != url.port()) return false;
+ if (path_escaped() != url.path_escaped()) return false;
+ if (!null_eq(userinfo_escaped(), url.userinfo_escaped())) return false;
+ if (!null_eq(full_query_escaped(), url.full_query_escaped())) return false;
+ if (!null_eq(fragment(), url.fragment())) return false;
+ return true;
+}
+
diff --git a/src/url.hh b/src/url.hh
new file mode 100644
index 0000000..d3b69b7
--- /dev/null
+++ b/src/url.hh
@@ -0,0 +1,62 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef URL_HH
+#define URL_HH
+
+#include <cstdint>
+#include <string>
+
+// scheme://[userinfo@]host[:port]/path[?query][#fragment]
+class Url {
+public:
+ virtual ~Url() {}
+
+ static Url* parse(std::string const& url, Url const* base = nullptr);
+
+ virtual Url* copy() const = 0;
+
+ bool operator==(Url const& url) const;
+
+ // never empty, unescaped
+ virtual std::string const& scheme() const = 0;
+
+ // may be null, not empty, unescaped
+ virtual char const* userinfo() const = 0;
+
+ // may be null, not empty
+ virtual char const* userinfo_escaped() const = 0;
+
+ // may be empty, unescaped
+ virtual std::string const& host() const = 0;
+
+ // 0 is used for no port set
+ virtual uint16_t port() const = 0;
+
+ // may be empty, unescaped
+ virtual std::string path() const = 0;
+
+ // may be empty
+ virtual std::string path_escaped() const = 0;
+
+ // return true if name was found and value updated,
+ // value may be empty, unescaped
+ virtual bool query(std::string const& name, std::string* value) const = 0;
+
+ // may be null or empty, unescaped
+ virtual char const* full_query() const = 0;
+
+ // may be null or empty
+ virtual char const* full_query_escaped() const = 0;
+
+ // may be null or empty, unescaped
+ virtual char const* fragment() const = 0;
+
+ virtual void print(std::ostream& out, bool path = true, bool query = true,
+ bool fragment = true) const = 0;
+
+protected:
+ Url() {}
+ Url(Url const&) = delete;
+};
+
+#endif // URL_HH
diff --git a/src/xdg.cc b/src/xdg.cc
new file mode 100644
index 0000000..d230e07
--- /dev/null
+++ b/src/xdg.cc
@@ -0,0 +1,135 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include "paths.hh"
+#include "xdg.hh"
+
+#include <map>
+#include <memory>
+#include <pwd.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+namespace {
+
+std::string home_based_dir(char const* env_name, char const* fallback) {
+ auto env = getenv(env_name);
+ if (!env || !*env) {
+ return Paths::join(XDG::home(), fallback);
+ }
+ return Paths::cleanup(env);
+}
+
+void dirs(char const* dirs, std::vector<std::string>* ret) {
+ auto start = dirs;
+ auto i = dirs;
+ while (true) {
+ if (*i == ':' || !*i) {
+ if (i > start) {
+ auto tmp = Paths::cleanup(std::string(start, i - start));
+ if (tmp[0] == '/') {
+ ret->push_back(tmp);
+ }
+ }
+ if (!*i) break;
+ start = ++i;
+ } else {
+ ++i;
+ }
+ }
+}
+
+void dirs(char const* env_name, char const* fallback,
+ std::vector<std::string>* ret) {
+ auto env = getenv(env_name);
+ if (!env || !*env) {
+ dirs(fallback, ret);
+ } else {
+ dirs(env, ret);
+ }
+}
+
+void remove_dupes(std::vector<std::string>* lst) {
+ std::multimap<size_t, std::string*> mem;
+ bool modifying = false;
+ auto out = lst->begin();
+ for (auto it = lst->begin(); it != lst->end(); ++it) {
+ auto len = it->size();
+ auto pair = mem.equal_range(len);
+ bool remove = false;
+ for (auto j = pair.first; j != pair.second; ++j) {
+ if (*(j->second) == *it) {
+ remove = true;
+ break;
+ }
+ }
+ if (remove) {
+ if (!modifying) {
+ modifying = true;
+ out = it;
+ }
+ } else {
+ if (modifying) {
+ *(out++) = *it;
+ }
+ mem.insert(pair.first, std::make_pair(len, &(*it)));
+ }
+ }
+ if (modifying) {
+ lst->resize(out - lst->begin());
+ }
+}
+
+} // namespace
+
+// static
+std::string XDG::config_home() {
+ return home_based_dir("XDG_CONFIG_HOME", ".config");
+}
+
+// static
+std::vector<std::string> XDG::config_dirs() {
+ std::vector<std::string> ret;
+ ret.push_back(config_home());
+ dirs("XDG_CONFIG_DIRS", SYSCONFDIR "/xdg", &ret);
+ remove_dupes(&ret);
+ return ret;
+}
+
+// static
+std::string XDG::data_home() {
+ return home_based_dir("XDG_DATA_HOME", ".local/share");
+}
+
+// static
+std::vector<std::string> XDG::data_dirs() {
+ std::vector<std::string> ret;
+ ret.push_back(data_home());
+ dirs("XDG_DATA_DIRS", "/usr/local/share/:/usr/share/", &ret);
+ remove_dupes(&ret);
+ return ret;
+}
+
+// static
+std::string XDG::cache_home() {
+ return home_based_dir("XDG_CACHE_HOME", ".cache");
+}
+
+// static
+std::string XDG::home() {
+ auto env = getenv("HOME");
+ if (!env || !*env) {
+ auto size = sysconf(_SC_GETPW_R_SIZE_MAX);
+ if (size == -1) size = 16384;
+ auto buf = std::unique_ptr<char[]>(new char[size]);
+ struct passwd pwd;
+ struct passwd* result;
+ auto ret = getpwuid_r(getuid(), &pwd, buf.get(), size, &result);
+ if (result && ret == 0) {
+ return Paths::cleanup(result->pw_dir);
+ }
+ return ".";
+ }
+ return Paths::cleanup(env);
+}
diff --git a/src/xdg.hh b/src/xdg.hh
new file mode 100644
index 0000000..99d1c24
--- /dev/null
+++ b/src/xdg.hh
@@ -0,0 +1,25 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#ifndef XDG_HH
+#define XDG_HH
+
+#include <string>
+#include <vector>
+
+class XDG {
+public:
+ static std::string config_home();
+ // Returned list is sorted in order of importance, directories may not exist
+ static std::vector<std::string> config_dirs();
+ static std::string data_home();
+ // Returned list is sorted in order of importance, directories may not exist
+ static std::vector<std::string> data_dirs();
+ static std::string cache_home();
+ static std::string home();
+
+private:
+ XDG() {}
+ ~XDG() {}
+};
+
+#endif // XDG_HH