diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/.gitignore | 4 | ||||
| -rw-r--r-- | src/Makefile.am | 18 | ||||
| -rw-r--r-- | src/args.cc | 281 | ||||
| -rw-r--r-- | src/args.hh | 60 | ||||
| -rw-r--r-- | src/buffer.cc | 119 | ||||
| -rw-r--r-- | src/buffer.hh | 30 | ||||
| -rw-r--r-- | src/character.cc | 31 | ||||
| -rw-r--r-- | src/character.hh | 18 | ||||
| -rw-r--r-- | src/chunked.cc | 106 | ||||
| -rw-r--r-- | src/chunked.hh | 23 | ||||
| -rw-r--r-- | src/common.hh | 18 | ||||
| -rw-r--r-- | src/config.cc | 176 | ||||
| -rw-r--r-- | src/config.hh | 33 | ||||
| -rw-r--r-- | src/http.cc | 657 | ||||
| -rw-r--r-- | src/http.hh | 141 | ||||
| -rw-r--r-- | src/io.cc | 90 | ||||
| -rw-r--r-- | src/io.hh | 124 | ||||
| -rw-r--r-- | src/logger.cc | 130 | ||||
| -rw-r--r-- | src/logger.hh | 33 | ||||
| -rw-r--r-- | src/looper.cc | 287 | ||||
| -rw-r--r-- | src/looper.hh | 41 | ||||
| -rw-r--r-- | src/main.cc | 187 | ||||
| -rw-r--r-- | src/paths.cc | 98 | ||||
| -rw-r--r-- | src/paths.hh | 18 | ||||
| -rw-r--r-- | src/proxy.cc | 1373 | ||||
| -rw-r--r-- | src/proxy.hh | 37 | ||||
| -rw-r--r-- | src/resolver.cc | 207 | ||||
| -rw-r--r-- | src/resolver.hh | 28 | ||||
| -rw-r--r-- | src/strings.cc | 94 | ||||
| -rw-r--r-- | src/strings.hh | 24 | ||||
| -rw-r--r-- | src/terminal.cc | 20 | ||||
| -rw-r--r-- | src/terminal.hh | 22 | ||||
| -rw-r--r-- | src/url.cc | 901 | ||||
| -rw-r--r-- | src/url.hh | 62 | ||||
| -rw-r--r-- | src/xdg.cc | 135 | ||||
| -rw-r--r-- | src/xdg.hh | 25 |
36 files changed, 5651 insertions, 0 deletions
diff --git a/src/.gitignore b/src/.gitignore new file mode 100644 index 0000000..7066278 --- /dev/null +++ b/src/.gitignore @@ -0,0 +1,4 @@ +/config.h +/config.h.in~ +/libtp.a +/tp diff --git a/src/Makefile.am b/src/Makefile.am new file mode 100644 index 0000000..502c82d --- /dev/null +++ b/src/Makefile.am @@ -0,0 +1,18 @@ +MAINTAINERCLEANFILES = Makefile.in config.h.in + +AM_CXXFLAGS = @DEFINES@ + +# Remove ar: `u' modifier ignored since `D' is the default (see `U') +ARFLAGS = cr + +bin_PROGRAMS = tp +noinst_LIBRARIES = libtp.a + +tp_SOURCES = main.cc proxy.cc logger.cc resolver.cc +tp_LDADD = libtp.a @THREAD_LIBS@ +tp_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ + +libtp_a_SOURCES = args.cc xdg.cc terminal.cc http.cc url.cc paths.cc \ + character.cc config.cc strings.cc io.cc looper.cc \ + buffer.cc chunked.cc +libtp_a_CXXFLAGS = $(AM_CXXFLAGS) -DSYSCONFDIR='"@SYSCONFDIR@"' diff --git a/src/args.cc b/src/args.cc new file mode 100644 index 0000000..cde77b0 --- /dev/null +++ b/src/args.cc @@ -0,0 +1,281 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <cstring> +#include <string> +#include <vector> +#include <unordered_map> + +#include "args.hh" +#include "character.hh" +#include "terminal.hh" + +namespace { + +class ArgsImpl : public Args { +public: + ArgsImpl() + : good_(true) { + } + void add(char short_opt, std::string const& long_opt, + std::string const& argument, std::string const& help) override { + assert(short_opt == '\0' || short_opts_.count(short_opt) == 0); + assert(long_opt.empty() || long_opts_.count(long_opt) == 0); + assert(long_opt.find('=') == std::string::npos); + auto const index = opts_.size(); + opts_.push_back(Option(short_opt, long_opt, argument, help)); + if (short_opt != '\0') { + short_opts_.insert(std::make_pair(short_opt, index)); + } + if (!long_opt.empty()) { + long_opts_.insert(std::make_pair(long_opt, index)); + } + } + + bool run(int argc, char** argv, std::ostream& out) override { + if (argc == 0) { + assert(false); + good_ = false; + return false; + } + auto start = strrchr(argv[0], '/'); + if (!start) { + start = argv[0]; + } else { + start++; + } + return run(start, argc, argv, out); + } + + bool run(std::string const& prg, int argc, char** argv, std::ostream& out) + override { + reset(); + + std::string opt; + for (int a = 1; a < argc; ++a) { + if (argv[a][0] == '-') { + if (argv[a][1] == '-') { + if (argv[a][2] == '\0') { + for (++a; a < argc; ++a) { + args_.push_back(argv[a]); + } + return good_; + } + size_t len = 2; + while (argv[a][len] && argv[a][len] != '=') ++len; + opt.assign(argv[a] + 2, len - 2); + auto i = long_opts_.find(opt); + if (i == long_opts_.end()) { + out << prg << ": unrecognized option '--" << opt << "'\n"; + good_ = false; + continue; + } + if (argv[a][len] == '=') { + if (opts_[i->second].argument.empty()) { + out << prg << ": option '--" << opt << "'" + << " doesn't allow an argument\n"; + good_ = false; + continue; + } else { + opts_[i->second].value = argv[a] + len + 1; + opts_[i->second].is_set = true; + } + } else { + if (opts_[i->second].argument.empty()) { + opts_[i->second].is_set = true; + } else if (a + 1 == argc) { + out << prg << ": option '--" << opt << "'" + << " requires an argument\n"; + good_ = false; + continue; + } else { + opts_[i->second].value = argv[++a]; + opts_[i->second].is_set = true; + } + } + } else { + for (auto opt = argv[a] + 1; *opt; ++opt) { + auto i = short_opts_.find(*opt); + if (i == short_opts_.end()) { + out << prg << ": invalid option -- '" << *opt << "'\n"; + good_ = false; + continue; + } + if (opts_[i->second].argument.empty()) { + opts_[i->second].is_set = true; + } else if (a + 1 == argc) { + out << prg << ": option requires an argument " + << " -- '" << *opt << "'\n"; + good_ = false; + continue; + } else { + opts_[i->second].value = argv[++a]; + opts_[i->second].is_set = true; + } + } + } + } else { + args_.push_back(argv[a]); + } + } + + return good_; + } + bool good() const override { + return good_; + } + + bool is_set(char short_opt) const override { + auto i = short_opts_.find(short_opt); + if (i == short_opts_.end()) return false; + return opts_[i->second].is_set; + } + bool is_set(std::string const& long_opt) const override { + auto i = long_opts_.find(long_opt); + if (i == long_opts_.end()) return false; + return opts_[i->second].is_set; + } + char const* arg(char short_opt, char const* fallback) const override { + auto i = short_opts_.find(short_opt); + if (i == short_opts_.end()) return fallback; + if (!opts_[i->second].is_set) return fallback; + if (opts_[i->second].argument.empty()) return fallback; + return opts_[i->second].value.c_str(); + } + char const* arg(std::string const& long_opt, + char const* fallback) const override { + auto i = long_opts_.find(long_opt); + if (i == long_opts_.end()) return fallback; + if (!opts_[i->second].is_set) return fallback; + if (opts_[i->second].argument.empty()) return fallback; + return opts_[i->second].value.c_str(); + } + + std::vector<std::string> const& arguments() const override { + return args_; + } + + void print_help(std::ostream& out) const override { + print_help(out, Terminal::size().width); + } + + void print_help(std::ostream& out, size_t width) const override { + size_t left = 0; + for (auto const& opt : opts_) { + size_t l = 0; + if (!opt.long_opt.empty()) { + l += 6 + opt.long_opt.size(); + } else if (opt.short_opt != '\0') { + l += 2; + } else { + continue; + } + if (!opt.argument.empty()) { + l += 1 + opt.argument.size(); + } + if (l > left) left = l; + } + + size_t const need = 2 + 2 + left; + if (need + 10 > width) { + width = need + 10; + } + size_t const right = width - need; + + for (auto const& opt : opts_) { + size_t i = 0; + if (!opt.long_opt.empty()) { + if (opt.short_opt != '\0') { + out << " -" << opt.short_opt << ", "; + } else { + out << " "; + } + out << "--" << opt.long_opt; + i += 8 + opt.long_opt.size(); + } else if (opt.short_opt != '\0') { + out << " -" << opt.short_opt; + i += 4; + } else { + continue; + } + if (!opt.argument.empty()) { + out << '=' << opt.argument; + i += 1 + opt.argument.size(); + } + pad(out, need - i); + if (opt.help.size() < right) { + out << opt.help << '\n'; + } else { + i = right; + while (i > 0 && !Character::isseparator(opt.help, i)) --i; + if (i == 0) i = right; + out << opt.help.substr(0, i) << '\n'; + while (true) { + while (i < opt.help.size() && Character::isspace(opt.help, i)) ++i; + if (i == opt.help.size()) break; + size_t j = right - 2; + pad(out, width - j); + if (i + j >= opt.help.size()) { + out << opt.help.substr(i) << '\n'; + break; + } + while (j > 0 && !Character::isseparator(opt.help, i + j)) --j; + if (j == 0) j = right - 2; + out << opt.help.substr(i, j) << '\n'; + i += j; + } + } + } + } + +private: + struct Option { + char const short_opt; + std::string const long_opt; + std::string const argument; + std::string const help; + + bool is_set; + std::string value; + + Option(char short_opt, std::string const& long_opt, + std::string const& argument, std::string const& help) + : short_opt(short_opt), long_opt(long_opt), argument(argument), + help(help), is_set(false) { + } + }; + bool good_; + std::unordered_map<char, size_t> short_opts_; + std::unordered_map<std::string, size_t> long_opts_; + std::vector<Option> opts_; + std::vector<std::string> args_; + + void reset() { + good_ = true; + args_.clear(); + for (auto& opt : opts_) { + opt.is_set = false; + opt.value.clear(); + } + } + + static void pad(std::ostream& out, size_t count) { + while (count > 4) { + out << " "; + count -= 4; + } + while (count) { + out << ' '; + --count; + } + } +}; + +} // namespace + +// static +Args* Args::create() { + return new ArgsImpl(); +} + diff --git a/src/args.hh b/src/args.hh new file mode 100644 index 0000000..10c56b3 --- /dev/null +++ b/src/args.hh @@ -0,0 +1,60 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef ARGS_HH +#define ARGS_HH + +#include <iostream> +#include <string> +#include <vector> + +class Args { +public: + virtual ~Args() {} + + static Args* create(); + + virtual void add(char short_opt, std::string const& long_opt, + std::string const& argument, std::string const& help) = 0; + void add(char short_opt, std::string const& long_opt, + std::string const& help) { + add(short_opt, long_opt, "", help); + } + void add(std::string const& long_opt, std::string const& help) { + add('\0', long_opt, "", help); + } + void add(std::string const& long_opt, std::string const& argument, + std::string const& help) { + add('\0', long_opt, argument, help); + } + + bool run(int argc, char** argv) { + return run(argc, argv, std::cerr); + } + virtual bool run(int argc, char** argv, std::ostream& out) = 0; + bool run(std::string const& prg, int argc, char** argv) { + return run(prg, argc, argv, std::cerr); + } + virtual bool run( + std::string const& prg, int argc, char** argv, std::ostream& out) = 0; + virtual bool good() const = 0; + + virtual bool is_set(char short_opt) const = 0; + virtual bool is_set(std::string const& long_opt) const = 0; + virtual char const* arg(char short_opt, char const* fallback) const = 0; + virtual char const* arg(std::string const& long_opt, + char const* fallback) const = 0; + + virtual std::vector<std::string> const& arguments() const = 0; + + void print_help() const { + print_help(std::cout); + } + virtual void print_help(std::ostream& out) const = 0; + virtual void print_help(std::ostream& out, size_t width) const = 0; + +protected: + Args() {} + Args(Args const&) = delete; +}; + +#endif // ARGS_HH diff --git a/src/buffer.cc b/src/buffer.cc new file mode 100644 index 0000000..d0c5fbb --- /dev/null +++ b/src/buffer.cc @@ -0,0 +1,119 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <algorithm> +#include <cstring> + +#include "buffer.hh" + +namespace { + +class BufferImpl : public Buffer { +public: + BufferImpl(size_t capacity, size_t min_avail) + : capacity_(capacity), min_avail_(min_avail) { + data_ = capacity_ > 0 ? + reinterpret_cast<char*>(malloc(capacity_)) : nullptr; + end_ = data_; + rptr_ = data_; + wptr_ = data_; + } + ~BufferImpl() { + free(data_); + } + + void const* read_ptr(size_t* avail) const override { + if (avail) *avail = wptr_ - rptr_; + return rptr_; + } + + void consume(size_t bytes) override { + if (bytes == 0) return; + assert(rptr_ + bytes <= wptr_); + rptr_ += bytes; + if (rptr_ == wptr_) { + rptr_ = wptr_ = data_; + } + } + + void* write_ptr(size_t* avail) override { + if (wptr_ + min_avail_ > end_) { + if (rptr_ > data_) { + memmove(data_, rptr_, wptr_ - rptr_); + wptr_ -= rptr_ - data_; + rptr_ = data_; + } + if (wptr_ + min_avail_ > end_) { + auto new_size = (end_ - data_) + std::max(capacity_, min_avail_); + auto tmp = reinterpret_cast<char*>(realloc(data_, new_size)); + if (tmp) { + end_ = tmp + new_size; + rptr_ = tmp + (rptr_ - data_); + wptr_ = tmp + (wptr_ - data_); + data_ = tmp; + } + } + } + if (avail) *avail = end_ - wptr_; + return wptr_; + } + + void commit(size_t bytes) override { + if (bytes == 0) return; + assert(wptr_ + bytes <= end_); + wptr_ += bytes; + } + +private: + size_t capacity_; + size_t min_avail_; + char* data_; + char* end_; + char* rptr_; + char* wptr_; +}; + +} // namespace + +// static +Buffer* Buffer::create(size_t size, size_t min_avail) { + return new BufferImpl(std::max(size, min_avail), min_avail); +} + +bool Buffer::empty() const { + size_t avail; + read_ptr(&avail); + return avail == 0; +} + +size_t Buffer::read(void* data, size_t max) { + if (max == 0) return 0; + size_t avail; + auto ptr = read_ptr(&avail); + if (avail == 0) return 0; + avail = std::min(avail, max); + memcpy(data, ptr, avail); + commit(avail); + return avail; +} + +void Buffer::write(void const* data, size_t size) { + if (size == 0) return; + auto d = reinterpret_cast<char const*>(data); + size_t pos = 0; + while (true) { + size_t avail; + auto ptr = write_ptr(&avail); + if (pos + avail < size) { + memcpy(ptr, d + pos, avail); + pos += avail; + commit(avail); + } else { + memcpy(ptr, d + pos, size - pos); + commit(size - pos); + return; + } + } +} + diff --git a/src/buffer.hh b/src/buffer.hh new file mode 100644 index 0000000..92a7566 --- /dev/null +++ b/src/buffer.hh @@ -0,0 +1,30 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef BUFFER_HH +#define BUFFER_HH + +#include <cstddef> + +class Buffer { +public: + virtual ~Buffer() {} + + static Buffer* create(size_t capacity, size_t min_avail); + + bool empty() const; + + virtual void const* read_ptr(size_t* avail) const = 0; + virtual void consume(size_t bytes) = 0; + + virtual void* write_ptr(size_t* avail) = 0; + virtual void commit(size_t bytes) = 0; + + size_t read(void* data, size_t max); + void write(void const* data, size_t size); + +protected: + Buffer() {} + Buffer(Buffer const&) = delete; +}; + +#endif // BUFFER_HH diff --git a/src/character.cc b/src/character.cc new file mode 100644 index 0000000..98166ee --- /dev/null +++ b/src/character.cc @@ -0,0 +1,31 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include "character.hh" + +// static +bool Character::isspace(std::string const& str, size_t pos) { + switch (str[pos]) { + case ' ': + case '\t': + case '\r': + case '\n': + return true; + } + return false; +} + +// static +bool Character::isseparator(std::string const& str, size_t pos) { + if (isspace(str, pos)) return true; + switch (str[pos]) { + case '.': + case ':': + case '-': + case ',': + case ';': + return true; + } + return false; +} diff --git a/src/character.hh b/src/character.hh new file mode 100644 index 0000000..18f98d8 --- /dev/null +++ b/src/character.hh @@ -0,0 +1,18 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef CHARACTER_HH +#define CHARACTER_HH + +#include <string> + +class Character { +public: + static bool isspace(std::string const& str, size_t pos); + static bool isseparator(std::string const& str, size_t pos); + +private: + ~Character() {} + Character() {} +}; + +#endif // CHARACTER_HH diff --git a/src/chunked.cc b/src/chunked.cc new file mode 100644 index 0000000..99aea0d --- /dev/null +++ b/src/chunked.cc @@ -0,0 +1,106 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <cerrno> +#include <cstdint> +#include <cstdlib> + +#include "chunked.hh" + +namespace { + +enum State { + CHUNK, + IN_CHUNK, + TRAILER, + DONE, +}; + +class ChunkedImpl : public Chunked { +public: + ChunkedImpl() + : state_(CHUNK), good_(true) { + } + + size_t add(void const* data, size_t avail) override { + if (!good_) return 0; + auto const start = reinterpret_cast<char const*>(data); + auto const end = start + avail; + auto d = start; + while (true) { + if (d == end) return avail; + switch (state_) { + case CHUNK: { + auto p = find_crlf(d, end); + if (!p) return d - start; + char* x = nullptr; + errno = 0; + auto tmp = strtoull(d, &x, 16); + if (errno || (x != p && (!x || *x != ';'))) { + good_ = false; + return d - start; + } + size_ = tmp; + d = p + 2; + if (size_ == 0) { + // Last chunk + state_ = TRAILER; + } else { + state_ = IN_CHUNK; + } + break; + } + case IN_CHUNK: + if (static_cast<uint64_t>(end - d) < size_) { + return avail; + } + d += size_; + state_ = CHUNK; + break; + case TRAILER: { + auto p = find_crlf(d, end); + if (!p) return d - start; + if (p == d) { + state_ = DONE; + } + d = p + 2; + break; + } + case DONE: + return d - start; + } + } + } + + bool good() const override { + return good_; + } + + bool eof() const override { + return state_ == DONE; + } + +private: + char const* find_crlf(char const* start, char const* end) { + for (; start != end; ++start) { + if (*start == '\r') { + if (start + 1 == end) break; + if (start[1] == '\n') return start; + } + } + return nullptr; + } + + State state_; + bool good_; + uint64_t size_; +}; + +} // namespace + +// static +Chunked* Chunked::create() { + return new ChunkedImpl(); +} + diff --git a/src/chunked.hh b/src/chunked.hh new file mode 100644 index 0000000..66d3ae7 --- /dev/null +++ b/src/chunked.hh @@ -0,0 +1,23 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef CHUNKED_HH +#define CHUNKED_HH + +#include <cstddef> + +class Chunked { +public: + virtual ~Chunked() { } + + static Chunked* create(); + + virtual size_t add(void const* data, size_t avail) = 0; + virtual bool good() const = 0; + virtual bool eof() const = 0; + +protected: + Chunked() {} + Chunked(Chunked const&) = delete; +}; + +#endif // CHUNKED_HH diff --git a/src/common.hh b/src/common.hh new file mode 100644 index 0000000..67c8fa5 --- /dev/null +++ b/src/common.hh @@ -0,0 +1,18 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef COMMON_HH +#define COMMON_HH + +#ifdef HAVE_CONFIG_H +# include "config.h" +#endif + +#if HAVE_VAR_ATTRIBUTE_UNUSED +# define UNUSED(x) __attribute__((unused)) x ## _unused +#else +# define UNUSED(x) x ## _unused +#endif + +#include <cassert> + +#endif // COMMON_HH diff --git a/src/config.cc b/src/config.cc new file mode 100644 index 0000000..9824044 --- /dev/null +++ b/src/config.cc @@ -0,0 +1,176 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include "config.hh" +#include "paths.hh" +#include "strings.hh" +#include "xdg.hh" + +#include <cstring> +#include <fstream> +#include <sstream> +#include <unordered_map> + +namespace { + +class ConfigImpl : public Config { +public: + ConfigImpl() + : good_(true) { + } + + bool load_name(std::string const& name) override { + name_ = name; + + auto dirs = XDG::config_dirs(); + bool good = true; + std::string error; + std::unordered_map<std::string,std::string> data; + for (auto& dir : dirs) { + std::string tmp_error; + std::unordered_map<std::string,std::string> tmp_data; + if (load_file(Paths::join(dir, name_), &tmp_data, &tmp_error)) { + data.insert(tmp_data.begin(), tmp_data.end()); + } else { + if (good) { + // We want the most import error + error = tmp_error; + good = false; + } + } + } + return update(good, error, data); + } + + bool load_file(std::string const& filename) override { + std::string error; + std::unordered_map<std::string,std::string> data; + bool good = load_file(filename, &data, &error); + return update(good, error, data); + } + + bool good() const override { + return good_; + } + + std::string const& last_error() const override { + return last_error_; + } + + std::string const& get(std::string const& key, + std::string const& fallback) override { + auto i = override_.find(key); + if (i != override_.end()) return i->second; + i = data_.find(key); + if (i != data_.end()) return i->second; + return fallback; + } + char const* get(std::string const& key, char const* fallback) override { + auto i = override_.find(key); + if (i != override_.end()) return i->second.c_str(); + i = data_.find(key); + if (i != data_.end()) return i->second.c_str(); + return fallback; + } + bool is_set(std::string const& key) override { + return override_.count(key) > 0 || data_.count(key) > 0; + } + + bool get(std::string const& key, bool fallback) override { + auto ret = get(key, nullptr); + if (!ret) return fallback; + return strcmp(ret, "true") == 0; + } + + void set(std::string const& key, std::string const& value) override { + auto ret = override_.insert(std::make_pair(key, value)); + if (!ret.second) { + ret.first->second = value; + } + } + +private: + bool update(bool good, std::string const& last_error, + std::unordered_map<std::string, std::string>& data) { + good_ = good; + if (!good_) { + last_error_ = last_error; + } else { + data_.clear(); + data_.swap(data); + } + return good_; + } + + static bool load_file(std::string const& filename, + std::unordered_map<std::string, std::string>* data, + std::string* error) { + bool good = true; + data->clear(); + error->clear(); + + std::ifstream in(filename); + // Non existent file is not considered an error + if (in) { + std::string line; + uint32_t count = 0; + std::string key, value; + while (std::getline(in, line)) { + count++; + if (line.empty() || line[0] == '#') continue; + auto idx = line.find('='); + if (idx == 0 || idx == std::string::npos) { + std::stringstream ss; + if (idx == 0) { + ss << filename << ':' << count << ": Invalid line, starts with '='"; + } else { + ss << filename << ':' << count << ": Invalid line, no '=' found"; + } + *error = ss.str(); + good = false; + break; + } + size_t start = 0, end = idx; + key.assign(Strings::trim(line, start, end)); + if (data->count(key) > 0) { + std::stringstream ss; + ss << filename << ':' << count << ": '" << key << "' is already set"; + *error = ss.str(); + good = false; + break; + } + start = idx + 1; + end = line.size(); + Strings::trim(line, &start, &end); + if (line[start] == '"' && line[end - 1] == '"') { + value.assign(Strings::unquote(line, start, end)); + } else { + value.assign(line.substr(start, end - start)); + } + (*data)[key] = value; + } + if (good && in.bad()) { + std::stringstream ss; + ss << filename << ": I/O error: " << strerror(errno); + *error = ss.str(); + good = false; + } + if (!good) data->clear(); + } + return good; + } + + bool good_; + std::string name_; + std::string last_error_; + std::unordered_map<std::string, std::string> data_; + std::unordered_map<std::string, std::string> override_; +}; + +} // namespace + +// static +Config* Config::create() { + return new ConfigImpl(); +} diff --git a/src/config.hh b/src/config.hh new file mode 100644 index 0000000..5262e91 --- /dev/null +++ b/src/config.hh @@ -0,0 +1,33 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef CONFIG_HH +#define CONFIG_HH + +#include <string> + +class Config { +public: + virtual ~Config() { } + + static Config* create(); + + virtual bool load_name(std::string const& name) = 0; + virtual bool load_file(std::string const& filename) = 0; + + virtual bool good() const = 0; + virtual std::string const& last_error() const = 0; + + virtual std::string const& get(std::string const& key, + std::string const& fallback) = 0; + virtual char const* get(std::string const& key, char const* fallback) = 0; + virtual bool get(std::string const& key, bool fallback) = 0; + virtual bool is_set(std::string const& key) = 0; + + virtual void set(std::string const& key, std::string const& value) = 0; + +protected: + Config() { } + Config(Config const&) = delete; +}; + +#endif // CONFIG_HH diff --git a/src/http.cc b/src/http.cc new file mode 100644 index 0000000..c043c87 --- /dev/null +++ b/src/http.cc @@ -0,0 +1,657 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <cstring> +#include <memory> +#include <sstream> +#include <vector> + +#include "http.hh" + +namespace { + +std::string make_string(char const* data, size_t start, size_t end) { + assert(start <= end); + return std::string(data + start, end - start); +} + +uint16_t number(char const* data, size_t start, size_t end) { + uint16_t ret = 0; + assert(start < end); + for (; start < end; ++start) { + ret *= 10; + ret += data[start] - '0'; + } + return ret; +} + +inline char lower_ascii(char c) { + return (c >= 'A' && c <= 'Z') ? (c - 'A' + 'a') : c; +} + +inline char upper_ascii(char c) { + return (c >= 'a' && c <= 'z') ? (c - 'a' + 'A') : c; +} + +bool lower_equal(char const* data, size_t start, size_t end, + std::string const& str) { + assert(start <= end); + if (str.size() != end - start) return false; + for (auto i = str.begin(); start < end; ++start, ++i) { + if (lower_ascii(*i) != lower_ascii(data[start])) return false; + } + return true; +} + +bool lower_equal2(char const* data, size_t start, size_t end, + char const* str, size_t size) { + assert(start <= end); + if (size != end - start) return false; + for (auto i = str; start < end; ++start, ++i) { + if (*i != lower_ascii(data[start])) return false; + } + return true; +} + +bool allow_header_append(char const* data, size_t start, size_t end) { + // These headers doesn't handle being merged with ',' even if the standard + // say they must + return !lower_equal2(data, start, end, "set-cookie", 10) && + !lower_equal2(data, start, end, "set-cookie2", 11); +} + +enum ParseResult { + GOOD, + BAD, + INCOMPLETE, +}; + +class HeaderIteratorImpl : public HeaderIterator { +public: + HeaderIteratorImpl(char const* data, std::vector<size_t> const* headers) + : data_(data), headers_(headers), iter_(headers_->begin()) { + } + + bool valid() const override { + return iter_ != headers_->end(); + } + std::string name() const override { + return make_string(data_, iter_[0], iter_[1]); + } + bool name_equal(std::string const& name) const override { + return lower_equal(data_, iter_[0], iter_[1], name); + } + std::string value() const override { + std::string ret = make_string(data_, iter_[2], iter_[3]); + if (allow_header_append(data_, iter_[0], iter_[1])) { + auto i = iter_ + 4; + while (i != headers_->end()) { + if (i[0] != i[1]) break; + ret.push_back(','); + ret.append(data_ + i[2], i[3] - i[2]); + i += 4; + } + } + return ret; + } + void next() override { + if (iter_ != headers_->end()) { + while (true) { + iter_ += 4; + if (iter_ == headers_->end() || iter_[0] != iter_[1]) break; + } + } + } + +private: + char const* const data_; + std::vector<size_t> const* const headers_; + std::vector<size_t>::const_iterator iter_; +}; + +class FilterHeaderIteratorImpl : public HeaderIteratorImpl { +public: + FilterHeaderIteratorImpl(char const* data, std::vector<size_t> const* headers, + std::string const& filter) + : HeaderIteratorImpl(data, headers), filter_(filter) { + check_filter(); + } + + void next() override { + HeaderIteratorImpl::next(); + check_filter(); + } + +private: + void check_filter() { + while (true) { + if (!valid() || name_equal(filter_)) return; + next(); + } + } + + std::string const filter_; +}; + +class AbstractHttp : public virtual Http { +public: + AbstractHttp(char const* data, size_t size) + : data_(data), size_(size) { + } + + std::string proto() const override { + return make_string(data_, proto_start_, proto_slash_); + } + + bool proto_equal(std::string const& proto) const override { + return proto.compare(0, proto.size(), + data_ + proto_start_, proto_slash_ - proto_start_) == 0; + } + + Version proto_version() const override { + Version ret; + ret.major = number(data_, proto_slash_ + 1, proto_dot_); + ret.minor = number(data_, proto_dot_ + 1, proto_end_); + return ret; + } + + std::unique_ptr<HeaderIterator> header() const override { + return std::unique_ptr<HeaderIterator>( + new HeaderIteratorImpl(data_, &headers_)); + } + std::unique_ptr<HeaderIterator> header( + std::string const& name) const override { + return std::unique_ptr<HeaderIterator>( + new FilterHeaderIteratorImpl(data_, &headers_, name)); + } + + size_t size() const override { + return content_start_; + } + +protected: + ParseResult parse_headers() { + headers_.clear(); + while (true) { + auto start = content_start_; + auto end = find_newline(start, &content_start_); + if (end == std::string::npos) return INCOMPLETE; + if (end == start) { + // The final newline can only be a alone '\r' if the one in front of + // it is also '\r', otherwise we expect a missing '\n' + if (data_[start - 1] == '\n' && data_[content_start_ - 1] == '\r') { + return INCOMPLETE; + } + break; + } + if (is_lws(data_[start])) { + if (headers_.empty()) return BAD; + headers_.push_back(start); + headers_.push_back(start); + headers_.push_back(start + 1); + headers_.push_back(end); + } else { + auto colon = find(start, ':', end); + if (colon == std::string::npos) return BAD; + auto value_start = skip_lws(colon + 1, end); + while (colon > start && is_lws(data_[colon - 1])) --colon; + headers_.push_back(start); + headers_.push_back(colon); + headers_.push_back(value_start); + headers_.push_back(end); + } + } + return GOOD; + } + + void new_data(char const* data) { + data_ = data; + } + + char const* data() const { + return data_; + } + + size_t data_size() const { + return size_; + } + + size_t find(size_t start, char c, size_t end) const { + assert(start <= end); + for (; start < end; ++start) { + if (data_[start] == c) return start; + } + return std::string::npos; + } + + static bool is_lws(char c) { + return c == ' ' || c == '\t'; + } + + size_t skip_lws(size_t start, size_t end) const { + assert(start <= end); + while (start < end && is_lws(data_[start])) ++start; + return start; + } + + size_t find_newline(size_t start, size_t* next) const { + assert(start <= size_); + for (; start < size_; ++start) { + if (data_[start] == '\r') { + if (start + 1 < size_ && data_[start + 1] == '\n') { + if (next) *next = start + 2; + } else { + if (next) *next = start + 1; + } + return start; + } else if (data_[start] == '\n') { + if (next) *next = start + 1; + return start; + } + } + return std::string::npos; + } + + size_t valid_number(size_t start, size_t end) const { + assert(start <= end); + if (start == end) return std::string::npos; + if (data_[start] == '0') { + return start + 1; + } + if (data_[start] < '0' || data_[start] > '9') return std::string::npos; + for (++start; start < end; ++start) { + if (data_[start] < '0' || data_[start] > '9') break; + } + return start; + } + + char const* data_; + size_t const size_; + size_t proto_start_; + size_t proto_slash_; + size_t proto_dot_; + size_t proto_end_; + std::vector<size_t> headers_; + size_t content_start_; +}; + +class AbstractHttpResponse : public HttpResponse, protected AbstractHttp { +public: + AbstractHttpResponse(char const* data, size_t size) + : AbstractHttp(data, size), good_(false) { + } + + bool good() const override { + return good_; + } + + uint16_t status_code() const override { + return number(data_, status_start_, status_end_); + } + + std::string status_message() const override { + return make_string(data_, status_msg_start_, status_msg_end_); + } + + ParseResult parse() { + good_ = false; + status_msg_end_ = find_newline(0, &content_start_); + if (status_msg_end_ == std::string::npos) return INCOMPLETE; + proto_start_ = 0; + proto_slash_ = find(0, '/', status_msg_end_); + if (proto_slash_ == std::string::npos) return BAD; + proto_dot_ = valid_number(proto_slash_ + 1, status_msg_end_); + if (proto_dot_ == std::string::npos || data_[proto_dot_] != '.') { + return BAD; + } + proto_end_ = valid_number(proto_dot_ + 1, status_msg_end_); + if (proto_end_ == std::string::npos || !is_lws(data_[proto_end_])) { + return BAD; + } + status_start_ = skip_lws(proto_end_ + 1, status_msg_end_); + status_end_ = valid_number(status_start_, status_msg_end_); + if (status_end_ == std::string::npos) return BAD; + if (is_lws(data_[status_end_])) { + status_msg_start_ = skip_lws(status_end_ + 1, status_msg_end_); + } else { + status_msg_start_ = status_end_; + if (status_msg_start_ != status_msg_end_) return BAD; + } + + auto ret = parse_headers(); + if (ret == GOOD) good_ = true; + return ret; + } + +protected: + bool good_; + size_t status_start_; + size_t status_end_; + size_t status_msg_start_; + size_t status_msg_end_; +}; + +class UniqueHttpResponse : public AbstractHttpResponse { +public: + UniqueHttpResponse(char const* data, size_t size) + : AbstractHttpResponse(data, size) { + } + + void copy() { + assert(!data_); + auto tmp = new char[data_size()]; + memcpy(tmp, data(), data_size()); + new_data(tmp); + data_.reset(tmp); + } + +private: + std::unique_ptr<char[]> data_; +}; + +class SharedHttpResponse : public AbstractHttpResponse { +public: + SharedHttpResponse(std::shared_ptr<char> data, size_t offset, size_t len) + : AbstractHttpResponse(data.get() + offset, len), data_(data) { + } + +private: + std::shared_ptr<char> data_; +}; + +class AbstractHttpRequest : public HttpRequest, protected AbstractHttp { +public: + AbstractHttpRequest(char const* data, size_t size) + : AbstractHttp(data, size), good_(false) { + } + + bool good() const override { + return good_; + } + + std::string method() const override { + return make_string(data_, 0, method_end_); + } + + bool method_equal(std::string const& method) const override { + return method.compare(0, method.size(), data_, method_end_) == 0; + } + std::string url() const override { + return make_string(data_, url_start_, url_end_); + } + + ParseResult parse() { + good_ = false; + proto_end_ = find_newline(0, &content_start_); + if (proto_end_ == std::string::npos) return INCOMPLETE; + method_end_ = 0; + while (method_end_ < proto_end_ && !is_lws(data_[method_end_])) { + ++method_end_; + } + if (method_end_ == 0 || method_end_ == proto_end_) return BAD; + url_start_ = skip_lws(method_end_ + 1, proto_end_); + url_end_ = url_start_; + while (url_end_ < proto_end_ && !is_lws(data_[url_end_])) { + ++url_end_; + } + if (url_end_ == url_start_ || url_end_ == proto_end_) return BAD; + proto_start_ = skip_lws(url_end_ + 1, proto_end_); + proto_slash_ = find(proto_start_, '/', proto_end_); + if (proto_slash_ == std::string::npos) return BAD; + proto_dot_ = valid_number(proto_slash_ + 1, proto_end_); + if (proto_dot_ == std::string::npos || data_[proto_dot_] != '.') { + return BAD; + } + auto tmp = valid_number(proto_dot_ + 1, proto_end_); + if (tmp != proto_end_) return BAD; + + auto ret = parse_headers(); + if (ret == GOOD) good_ = true; + return ret; + } + +protected: + bool good_; + size_t method_end_; + size_t url_start_; + size_t url_end_; +}; + +class UniqueHttpRequest : public AbstractHttpRequest { +public: + UniqueHttpRequest(char const* data, size_t size) + : AbstractHttpRequest(data, size) { + } + + void copy() { + assert(!data_); + auto tmp = new char[data_size()]; + memcpy(tmp, data(), data_size()); + new_data(tmp); + data_.reset(tmp); + } + +private: + std::unique_ptr<char[]> data_; +}; + +class SharedHttpRequest : public AbstractHttpRequest { +public: + SharedHttpRequest(std::shared_ptr<char> data, size_t offset, size_t len) + : AbstractHttpRequest(data.get() + offset, len), data_(data) { + } + +private: + std::shared_ptr<char> data_; +}; + +class AbstractHttpBuilder { +public: + AbstractHttpBuilder() + : set_content_length_(false), set_content_(false) { + } + + void add_header(std::string const& name, + std::string const& value) { + if (name.empty()) { + data_.push_back(' '); + data_.append(value); + data_.append("\r\n", 2); + return; + } + size_t pos = 0; + auto c = upper_ascii(name[pos]); + if (c != name[pos]) { + data_.push_back(c); + ++pos; + } + auto last = pos; + for (; pos < name.size(); ++pos) { + if (name[pos] == '-') { + ++pos; + if (pos == name.size()) break; + data_.append(name.data() + last, pos - last); + auto c = upper_ascii(name[pos]); + if (c != name[pos]) { + data_.push_back(c); + ++pos; + } + last = pos; + } + } + if (!set_content_length_ + && lower_equal2(name.data(), 0, name.size(), "content-length", 14)) { + set_content_length_ = true; + } + if (last < pos) { + data_.append(name.data() + last, pos - last); + } + data_.append(": ", 2); + data_.append(value); + data_.append("\r\n", 2); + } + + void set_content(std::string const& content) { + set_content_ = true; + content_ = content; + } + + std::string build() const { + std::string ret(data_); + if (!set_content_length_ && set_content_) { + char tmp[50]; + auto len = snprintf(tmp, sizeof(tmp), + "Content-Length: %zu\r\n", content_.size()); + ret.append(tmp, len); + } + ret.append("\r\n", 2); + if (set_content_) { + ret.append(content_); + } + return ret; + } + +protected: + std::string data_; + std::string content_; + bool set_content_length_; + bool set_content_; +}; + +class HttpRequestBuilderImpl : public HttpRequestBuilder, AbstractHttpBuilder { +public: + HttpRequestBuilderImpl(std::string const& method, std::string const& url, + std::string const& proto, Version version) { + data_.append(method); + data_.push_back(' '); + data_.append(url); + data_.push_back(' '); + data_.append(proto); + data_.push_back('/'); + char tmp[10]; + data_.append(tmp, snprintf( + tmp, 10, "%u", static_cast<unsigned int>(version.major))); + data_.push_back('.'); + data_.append(tmp, snprintf( + tmp, 10, "%u", static_cast<unsigned int>(version.minor))); + data_.append("\r\n", 2); + } + + void add_header(std::string const& name, + std::string const& value) override { + AbstractHttpBuilder::add_header(name, value); + } + void set_content(std::string const& content) override { + AbstractHttpBuilder::set_content(content); + } + std::string build() const override { + return AbstractHttpBuilder::build(); + } +}; + +class HttpResponseBuilderImpl : public HttpResponseBuilder, + AbstractHttpBuilder { +public: + HttpResponseBuilderImpl(std::string const& proto, Version version, + uint16_t status_code, std::string const& status) { + data_.append(proto); + data_.push_back('/'); + char tmp[10]; + data_.append(tmp, snprintf( + tmp, 10, "%u", static_cast<unsigned int>(version.major))); + data_.push_back('.'); + data_.append(tmp, snprintf( + tmp, 10, "%u", static_cast<unsigned int>(version.minor))); + data_.push_back(' '); + data_.append(tmp, snprintf( + tmp, 10, "%u", static_cast<unsigned int>(status_code))); + data_.push_back(' '); + data_.append(status); + data_.append("\r\n", 2); + } + + void add_header(std::string const& name, + std::string const& value) override { + AbstractHttpBuilder::add_header(name, value); + } + void set_content(std::string const& content) override { + AbstractHttpBuilder::set_content(content); + } + std::string build() const override { + return AbstractHttpBuilder::build(); + } +}; + +} // namespace + +// static +HttpResponse* HttpResponse::parse(std::string const& data) { + return parse(data.data(), data.size(), true); +} + +// static +HttpResponse* HttpResponse::parse(char const* data, size_t len, bool copy) { + auto ret = std::unique_ptr<UniqueHttpResponse>( + new UniqueHttpResponse(data, len)); + if (ret->parse() == INCOMPLETE) return nullptr; + if (copy) ret->copy(); + return ret.release(); +} + +// static +HttpResponse* HttpResponse::parse(std::shared_ptr<char> data, + size_t offset, size_t len) { + auto ret = std::unique_ptr<SharedHttpResponse>( + new SharedHttpResponse(data, offset, len)); + if (ret->parse() == INCOMPLETE) return nullptr; + return ret.release(); +} + +std::string Http::first_header(std::string const& name) const { + static std::string empty_str; + auto iter = header(name); + if (iter->valid()) { + return iter->value(); + } + return empty_str; +} + +// static +HttpRequest* HttpRequest::parse(std::string const& data) { + return parse(data.data(), data.size(), true); +} + +// static +HttpRequest* HttpRequest::parse(char const* data, size_t len, bool copy) { + auto ret = std::unique_ptr<UniqueHttpRequest>( + new UniqueHttpRequest(data, len)); + if (ret->parse() == INCOMPLETE) return nullptr; + if (copy) ret->copy(); + return ret.release(); +} + +// static +HttpRequest* HttpRequest::parse(std::shared_ptr<char> data, + size_t offset, size_t len) { + auto ret = std::unique_ptr<SharedHttpRequest>( + new SharedHttpRequest(data, offset, len)); + if (ret->parse() == INCOMPLETE) return nullptr; + return ret.release(); +} + +// static +HttpRequestBuilder* HttpRequestBuilder::create(std::string const& method, + std::string const& url, + std::string const& proto, + Version version) { + return new HttpRequestBuilderImpl(method, url, proto, version); +} + +// static +HttpResponseBuilder* HttpResponseBuilder::create(std::string const& proto, + Version version, + uint16_t status_code, + std::string const& status) { + return new HttpResponseBuilderImpl(proto, version, status_code, status); +} diff --git a/src/http.hh b/src/http.hh new file mode 100644 index 0000000..091d3d4 --- /dev/null +++ b/src/http.hh @@ -0,0 +1,141 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef HTTP_HH +#define HTTP_HH + +#include <cstdint> +#include <memory> +#include <string> + +// glibc defines these even tho it shouldn't if you ask posix +#ifdef major +# undef major +#endif +#ifdef minor +# undef minor +#endif + +struct Version { + uint16_t major; + uint16_t minor; + + Version() + : major(0), minor(0) { + } + Version(uint16_t major, uint16_t minor) + : major(major), minor(minor) { + } +}; + +class HeaderIterator { +public: + virtual ~HeaderIterator() {} + + virtual bool valid() const = 0; + virtual std::string name() const = 0; + virtual bool name_equal(std::string const& name) const = 0; + virtual std::string value() const = 0; + virtual void next() = 0; + +protected: + HeaderIterator() {} + HeaderIterator(HeaderIterator const&) = delete; +}; + +class Http { +public: + virtual ~Http() {} + + virtual bool good() const = 0; + + virtual std::string proto() const = 0; + virtual bool proto_equal(std::string const& proto) const = 0; + virtual Version proto_version() const = 0; + virtual std::unique_ptr<HeaderIterator> header() const = 0; + virtual std::unique_ptr<HeaderIterator> header( + std::string const& name) const = 0; + std::string first_header(std::string const& name) const; + + virtual size_t size() const = 0; + +protected: + Http() {} + Http(Http const&) = delete; +}; + +class HttpResponse : public virtual Http { +public: + virtual ~HttpResponse() {} + + static HttpResponse* parse(std::string const& data); + static HttpResponse* parse(char const* data, size_t len, bool copy = true); + static HttpResponse* parse(std::shared_ptr<char> data, + size_t offset, size_t len); + + virtual uint16_t status_code() const = 0; + virtual std::string status_message() const = 0; + +protected: + HttpResponse() {} + HttpResponse(HttpResponse const&) = delete; +}; + +class HttpRequest : public virtual Http { +public: + virtual ~HttpRequest() {} + + static HttpRequest* parse(std::string const& data); + static HttpRequest* parse(char const* data, size_t len, bool copy = true); + static HttpRequest* parse(std::shared_ptr<char> data, + size_t offset, size_t len); + + virtual std::string method() const = 0; + virtual bool method_equal(std::string const& method) const = 0; + virtual std::string url() const = 0; + +protected: + HttpRequest() {} + HttpRequest(HttpRequest const&) = delete; +}; + +class HttpResponseBuilder { +public: + virtual ~HttpResponseBuilder() {} + + static HttpResponseBuilder* create( + std::string const& proto, Version version, + uint16_t status_code, std::string const& status_message); + + virtual void add_header(std::string const& name, + std::string const& value) = 0; + // This will add a content-length header unless there already is one + virtual void set_content(std::string const& content) = 0; + + virtual std::string build() const = 0; + +protected: + HttpResponseBuilder() {} + HttpResponseBuilder(HttpResponseBuilder const&) = delete; +}; + +class HttpRequestBuilder { +public: + virtual ~HttpRequestBuilder() {} + + static HttpRequestBuilder* create( + std::string const& method, std::string const& url, + std::string const& proto, Version version); + + virtual void add_header(std::string const& name, + std::string const& value) = 0; + // This will add a content-length header unless there already is one + virtual void set_content(std::string const& content) = 0; + + virtual std::string build() const = 0; + +protected: + HttpRequestBuilder() {} + HttpRequestBuilder(HttpRequestBuilder const&) = delete; +}; + +#endif // HTTP_HH diff --git a/src/io.cc b/src/io.cc new file mode 100644 index 0000000..5c33ea2 --- /dev/null +++ b/src/io.cc @@ -0,0 +1,90 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <algorithm> +#include <cerrno> +#include <limits> +#include <unistd.h> + +#include "io.hh" + +namespace io { + +ssize_t read(int fd, void* buf, size_t max) { + while (true) { + ssize_t ret = ::read(fd, buf, max); + if (ret != -1 || errno != EINTR) return ret; + } +} + +ssize_t write(int fd, void const* buf, size_t max) { + while (true) { + ssize_t ret = ::write(fd, buf, max); + if (ret != -1 || errno != EINTR) return ret; + } +} + +bool read_all(int fd, void* buf, size_t size) { + auto data = reinterpret_cast<char*>(buf); + auto end = data + size; + while (data < end) { + ssize_t ret = ::read(fd, data, std::min( + std::numeric_limits<ssize_t>::max(), end - data)); + if (ret < 0) { + if (errno == EINTR) continue; + return false; + } + if (ret == 0) return false; + data += ret; + } + return true; +} + +bool write_all(int fd, void const* buf, size_t size) { + auto data = reinterpret_cast<char const*>(buf); + auto end = data + size; + while (data < end) { + ssize_t ret = ::write(fd, data, std::min( + std::numeric_limits<ssize_t>::max(), end - data)); + if (ret < 0) { + if (errno == EINTR) continue; + return false; + } + if (ret == 0) return false; + data += ret; + } + return true; +} + +bool auto_fd::reset() { + if (fd_ >= 0) { + auto tmp = fd_; + fd_ = -1; + return close(tmp) != 0; + } + return true; +} + +void auto_pipe::reset() { + if (read_ >= 0) { + assert(write_ >= 0); + close(read_); + close(write_); + read_ = -1; + write_ = -1; + } +} + +bool auto_pipe::open() { + reset(); + int fd[2]; + if (pipe(fd) != 0) { + return false; + } + read_ = fd[0]; + write_ = fd[1]; + return true; +} + +} // namespace io diff --git a/src/io.hh b/src/io.hh new file mode 100644 index 0000000..3bfccec --- /dev/null +++ b/src/io.hh @@ -0,0 +1,124 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef IO_HH +#define IO_HH + +#include <sys/types.h> + +namespace io { + +ssize_t read(int fd, void* buf, size_t max); +ssize_t write(int fd, void const* buf, size_t max); + +bool read_all(int fd, void* buf, size_t size); +bool write_all(int fd, void const* buf, size_t size); + +class auto_fd { +public: + auto_fd() + : fd_(-1) { + } + explicit auto_fd(int fd) + : fd_(fd) { + } + auto_fd(auto_fd&& fd) + : fd_(fd.fd_) { + fd.fd_ = -1; + } + auto_fd(auto_fd const&) = delete; + ~auto_fd() { + reset(); + } + + auto_fd& operator=(auto_fd&& fd) { + reset(fd.release()); + return *this; + } + auto_fd& operator=(auto_fd&) = delete; + + int get() const { + return fd_; + } + + explicit operator bool() const { + return fd_ >= 0; + } + + int release() { + auto ret = fd_; + fd_ = -1; + return ret; + } + + bool reset(); + + bool reset(int fd) { + auto ret = reset(); + fd_ = fd; + return ret; + } + + void swap(auto_fd& fd) { + auto tmp = fd_; + fd_ = fd.fd_; + fd.fd_ = tmp; + } +private: + int fd_; +}; + +class auto_pipe { +public: + auto_pipe() + : read_(-1), write_(-1) { + } + auto_pipe(auto_pipe&& pipe) + : read_(pipe.read_), write_(pipe.write_) { + pipe.read_ = -1; + pipe.write_ = -1; + } + auto_pipe(auto_pipe const&) = delete; + ~auto_pipe() { + reset(); + } + + auto_pipe& operator=(auto_pipe&& pipe) { + reset(); + read_ = pipe.read_; + write_ = pipe.write_; + pipe.read_ = -1; + pipe.write_ = -1; + return *this; + } + auto_pipe& operator=(auto_pipe&) = delete; + + int read() const { + return read_; + } + int write() const { + return write_; + } + + explicit operator bool() const { + return read_ >= 0 /* && write_ >= 0 */; + } + + bool open(); + + void reset(); + + void swap(auto_pipe& pipe) { + auto r = read_; + auto w = write_; + read_ = pipe.read_; + write_ = pipe.write_; + pipe.read_ = r; + pipe.write_ = w; + } +private: + int read_, write_; +}; + +} // namespace io + +#endif // IO_HH diff --git a/src/logger.cc b/src/logger.cc new file mode 100644 index 0000000..29dcb38 --- /dev/null +++ b/src/logger.cc @@ -0,0 +1,130 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#ifndef _BSD_SOURCE +#define _BSD_SOURCE +#endif +#ifndef _DEFAULT_SOURCE +#define _DEFAULT_SOURCE +#endif + +#include <cstdarg> +#include <cstdio> +#include <iostream> +#include <memory> +#include <syslog.h> + +#include "logger.hh" + +namespace { + +class LoggerStdErr : public Logger { +public: + void out(Level UNUSED(lvl), char const* format, ...) override { + char* tmp; + va_list args; + va_start(args, format); + auto ret = vasprintf(&tmp, format, args); + va_end(args); + if (ret == -1) return; + std::cerr << tmp << std::endl; + free(tmp); + } +}; + +class LoggerSyslog : public Logger { +public: + LoggerSyslog(std::string const& name) { + openlog(name.c_str(), LOG_PID, LOG_DAEMON); + } + + ~LoggerSyslog() override { + closelog(); + } + + void out(Level lvl, char const* format, ...) override { + va_list args; + va_start(args, format); + vsyslog(lvl2prio(lvl), format, args); + va_end(args); + } + +private: + static int lvl2prio(Level lvl) { + switch (lvl) { + case ERR: + return LOG_ERR; + case WARN: + return LOG_WARNING; + case INFO: + return LOG_INFO; + } + assert(false); + return LOG_INFO; + } +}; + +class LoggerFile : public Logger { +public: + LoggerFile() + : fh_(nullptr) { + } + + bool open(std::string const& path) { + if (fh_) fclose(fh_); + fh_ = fopen(path.c_str(), "a"); + return fh_ != NULL; + } + + ~LoggerFile() override { + if (fh_) fclose(fh_); + } + + void out(Level lvl, char const* format, ...) override { + fputs(lvl2str(lvl), fh_); + fwrite(": ", 1, 2, fh_); + va_list args; + va_start(args, format); + vfprintf(fh_, format, args); + va_end(args); + fputc('\n', fh_); + } + +private: + static char const* lvl2str(Level lvl) { + switch (lvl) { + case ERR: + return "Error"; + case WARN: + return "Warning"; + case INFO: + return "Info"; + } + assert(false); + return "Info"; + } + + FILE* fh_; +}; + +} // namespace + +// static +Logger* Logger::create_stderr() { + return new LoggerStdErr(); +} + +// static +Logger* Logger::create_syslog(std::string const& name) { + return new LoggerSyslog(name); +} + +// static +Logger* Logger::create_file(std::string const& path) { + std::unique_ptr<LoggerFile> ret(new LoggerFile()); + if (ret->open(path)) { + return ret.release(); + } + return nullptr; +} diff --git a/src/logger.hh b/src/logger.hh new file mode 100644 index 0000000..8b0db05 --- /dev/null +++ b/src/logger.hh @@ -0,0 +1,33 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef LOGGER_HH +#define LOGGER_HH + +#include <string> + +class Logger { +public: + virtual ~Logger() {} + + enum Level { + ERR, + WARN, + INFO + }; + + static Logger* create_stderr(); + static Logger* create_syslog(std::string const& name); + static Logger* create_file(std::string const& path); + + virtual void out(Level lvl, char const* format, ...) +#ifdef HAVE_FUNC_ATTRIBUTE_FORMAT + __attribute__ ((format (printf, 3, 4))) +#endif + = 0; + +protected: + Logger() {} + Logger(Logger const&) = delete; +}; + +#endif // LOGGER_HH diff --git a/src/looper.cc b/src/looper.cc new file mode 100644 index 0000000..0da851b --- /dev/null +++ b/src/looper.cc @@ -0,0 +1,287 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <list> +#include <poll.h> +#include <vector> + +#include "looper.hh" + +namespace { + +int const read_events = POLLIN | POLLPRI +#ifdef POLLRDHUP + | POLLRDHUP +#endif + ; +int const write_events = POLLOUT; +int const error_events = POLLERR | POLLNVAL; +int const hup_events = POLLHUP; + +class LooperImpl : public Looper { +public: + LooperImpl() + : fds_to_remove_(0), fds_protected_(0), quit_(false) { + } + + ~LooperImpl() override { + for (auto& timeout : timeouts_) { + delete timeout; + } + } + + void add(int fd, uint8_t events, FdCallback const& callback) override { + if (fd < 0) { + assert(false); + return; + } + for (auto it = fds_.begin(); it != fds_.end(); ++it) { + if (it->fd == fd) { + size_t index = it - fds_.begin(); + auto& entry = fdentries_[index]; + if (index >= fds_protected_) { + entry.callback = callback; + it->events = pollevents(events); + return; + } else { + // Don't call new callback this run, so add it at the end but + // remove the old callback as it would be replaced + if (!entry.removed) { + entry.removed = true; + fds_to_remove_++; + } + } + } + } + + fds_.emplace_back((struct pollfd) { fd, pollevents(events), 0 }); + fdentries_.emplace_back(callback); + } + + void modify(int fd, uint8_t events) override { + if (fd < 0) { + assert(false); + return; + } + for (auto it = fds_.begin(); it != fds_.end(); ++it) { + if (it->fd == fd) { + size_t index = it - fds_.begin(); + if (index < fds_protected_) { + auto entry = fdentries_.begin() + index; + // If entry is removed we need to write to the later one + if (entry->removed) { + continue; + } + } + it->events = pollevents(events); + return; + } + } + assert(false); + } + + void remove(int fd) override { + if (fd < 0) return; + for (auto it = fds_.begin(); it != fds_.end(); ++it) { + if (it->fd == fd) { + size_t index = it - fds_.begin(); + auto entry = fdentries_.begin() + index; + if (index < fds_protected_) { + if (!entry->removed) { + entry->removed = true; + fds_to_remove_++; + } + } else { + fds_.erase(it); + fdentries_.erase(entry); + } + return; + } + } + assert(false); + } + + void* schedule(float delay_s, ScheduleCallback const& callback) override { + clock::time_point target = clock::now() + + std::chrono::duration_cast<clock::duration>( + std::chrono::duration<float>(delay_s)); + auto timeout = new Timeout(target, callback); + for (auto it = timeouts_.begin(); it != timeouts_.end(); ++it) { + if (target < (*it)->target) { + timeouts_.insert(it, timeout); + return timeout; + } + } + timeouts_.push_back(timeout); + return timeout; + } + + void cancel(void* handle) override { + auto timeout = reinterpret_cast<Timeout*>(handle); + if (!timeout) { + assert(false); + return; + } + if (timeout->expired) return; + for (auto it = timeouts_.begin(); it != timeouts_.end(); ++it) { + if (*it == timeout) { + timeouts_.erase(it); + delete timeout; + return; + } + } + assert(false); + } + + void quit() override { + quit_ = true; + } + + bool run() override { + std::vector<Timeout*> expired; + + while (!quit_) { + int timeout = -1; + if (!timeouts_.empty()) { + auto dur = std::chrono::duration_cast<std::chrono::milliseconds>( + timeouts_.front()->target - clock::now()); + if (dur.count() <= 0) { + timeout = 0; + } else if (dur.count() < std::numeric_limits<int>::max()) { + timeout = dur.count(); + } else { + timeout = std::numeric_limits<int>::max(); + } + } + auto ret = poll(fds_.data(), fds_.size(), timeout); + if (ret < 0) { + if (errno == EINTR) continue; + return false; + } + now_ = clock::now(); + fds_protected_ = fds_.size(); + + if (!timeouts_.empty()) { + while (timeouts_.front()->target <= now_) { + auto timeout = timeouts_.front(); + timeouts_.pop_front(); + timeout->expired = true; + expired.push_back(timeout); + if (timeouts_.empty()) break; + } + + for (auto& timeout : expired) { + timeout->callback(timeout); + } + for (auto& timeout : expired) { + delete timeout; + } + expired.clear(); + } + + // Not using iterators here as that would be unsafe with + // add() and remove() modifying the vector outside protected range + // while callbacks are called + size_t i; + for (i = 0; ret > 0 && i < fds_protected_; ++i) { + if (fds_[i].revents) { + --ret; + if (!fdentries_[i].removed) { + fdentries_[i].callback(fds_[i].fd, unpollevents(fds_[i].revents)); + } + } + } + assert(ret == 0); + assert(fds_.size() >= fds_protected_); + assert(fdentries_.size() >= fds_protected_); + for (i = fds_protected_; fds_to_remove_ > 0 && i > 0; --i) { + if (fdentries_[i - 1].removed) { + --fds_to_remove_; + fds_.erase(fds_.begin() + i - 1); + fdentries_.erase(fdentries_.begin() + i - 1); + } + } + assert(fds_to_remove_ == 0); + fds_protected_ = 0; + } + return true; + } + + clock::time_point now() const override { + return now_; + } + +private: + struct FdEntry { + FdCallback callback; + bool removed; + + FdEntry(FdCallback const& callback) + : callback(callback), removed(false) { + } + }; + + struct Timeout { + clock::time_point target; + ScheduleCallback callback; + bool expired; + + Timeout(clock::time_point target, ScheduleCallback const& callback) + : target(target), callback(callback), expired(false) { + } + }; + + static uint8_t unpollevents(short events) { + uint8_t ret = 0; + if (events & read_events) { + ret |= EVENT_READ; + } + if (events & write_events) { + ret |= EVENT_WRITE; + } + if (events & error_events) { + ret |= EVENT_ERROR; + } + if (events & hup_events) { + ret |= EVENT_HUP; + } + return ret; + } + + static short pollevents(uint8_t events) { + int ret = 0; + if (events & EVENT_READ) { + ret |= read_events; + } + if (events & EVENT_WRITE) { + ret |= write_events; + } + return ret; + } + + std::vector<struct pollfd> fds_; + std::vector<FdEntry> fdentries_; + size_t fds_to_remove_; + size_t fds_protected_; + std::list<Timeout*> timeouts_; + bool quit_; + clock::time_point now_; +}; + +} // namespace + +// static +Looper* Looper::create() { + return new LooperImpl(); +} + +// static +const uint8_t Looper::EVENT_READ = 1 << 0; +// static +const uint8_t Looper::EVENT_WRITE = 1 << 1; +// static +const uint8_t Looper::EVENT_ERROR = 1 << 2; +// static +const uint8_t Looper::EVENT_HUP = 1 << 3; + diff --git a/src/looper.hh b/src/looper.hh new file mode 100644 index 0000000..7315220 --- /dev/null +++ b/src/looper.hh @@ -0,0 +1,41 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef LOOPER_HH +#define LOOPER_HH + +#include <chrono> +#include <functional> + +class Looper { +public: + typedef std::chrono::steady_clock clock; + typedef std::function<void(int, uint8_t)> FdCallback; + typedef std::function<void(void*)> ScheduleCallback; + virtual ~Looper() { } + + static const uint8_t EVENT_READ; + static const uint8_t EVENT_WRITE; + // Always reported, need not be set + static const uint8_t EVENT_HUP; + static const uint8_t EVENT_ERROR; + + static Looper* create(); + + virtual void add(int fd, uint8_t events, FdCallback const& callback) = 0; + virtual void modify(int fd, uint8_t events) = 0; + virtual void remove(int fd) = 0; + + virtual void* schedule(float delay_s, ScheduleCallback const& callback) = 0; + virtual void cancel(void* handle) = 0; + + virtual bool run() = 0; + virtual void quit() = 0; + + virtual clock::time_point now() const = 0; + +protected: + Looper() { } + Looper(Looper const&) = delete; +}; + +#endif // LOOPER_HH diff --git a/src/main.cc b/src/main.cc new file mode 100644 index 0000000..f9e1af6 --- /dev/null +++ b/src/main.cc @@ -0,0 +1,187 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <cstring> +#include <memory> +#include <signal.h> +#include <unistd.h> + +#include "args.hh" +#include "config.hh" +#include "io.hh" +#include "logger.hh" +#include "proxy.hh" + +namespace { + +Proxy* g_proxy; + +void proxy_quit(int UNUSED(sig)) { + g_proxy->quit(); +} + +void proxy_reload(int UNUSED(sig)) { + g_proxy->reload(); +} + +void setup_signals(Proxy* proxy) { + g_proxy = proxy; + struct sigaction action; + memset(&action, 0, sizeof(action)); + action.sa_handler = proxy_quit; + action.sa_flags = SA_RESTART; + sigaction(SIGINT, &action, nullptr); + sigaction(SIGQUIT, &action, nullptr); + action.sa_handler = proxy_reload; + sigaction(SIGHUP, &action, nullptr); +} + +std::string get_cwd() { + std::string ret; + char buf[128]; + auto ptr = getcwd(buf, 128); + if (ptr) { + ret.assign(ptr, strlen(ptr)); + return ret; + } + if (errno != ERANGE) return ret; + ret.resize(1024); + while (true) { + ptr = getcwd(&ret[0], ret.size()); + if (ptr) { + ret.resize(strlen(ret.data())); + break; + } + if (errno != ERANGE) { + ret.resize(0); + break; + } + ret.resize(ret.size() * 2); + } + return ret; +} + +} // namespace + +int main(int argc, char** argv) { + std::unique_ptr<Args> args(Args::create()); + args->add('C', "config", "FILE", "load config from FILE instead of default"); + args->add('F', "foreground", "do not fork into background and log to stdout"); + args->add('m', "monitor", "HOST:PORT", "accept monitors on HOST:PORT"); + args->add('h', "help", "display this text and exit."); + args->add('V', "version", "display version and exit."); + if (!args->run(argc, argv)) { + std::cerr << "Try `tp --help` for usage." << std::endl; + return EXIT_FAILURE; + } + if (args->is_set('h')) { + std::cout << "Usage: `tp [OPTIONS...] [BIND][:PORT]`\n" + << "Transparent proxy.\n" + << '\n'; + args->print_help(); + return EXIT_SUCCESS; + } + if (args->is_set('V')) { + std::cout << "TransparentProxy version " VERSION + << " written by Joel Klinghed <the_jk@yahoo.com>" << std::endl; + return EXIT_SUCCESS; + } + std::unique_ptr<Config> config(Config::create()); + switch (args->arguments().size()) { + case 0: + break; + case 1: { + auto arg = args->arguments().front(); + auto colon = arg.find(':'); + if (colon == std::string::npos) { + config->set("proxy_bind", arg); + } else { + if (colon > 0) { + config->set("proxy_bind", arg.substr(0, colon)); + } + config->set("proxy_port", arg.substr(colon + 1)); + } + break; + } + default: + std::cerr << "Too many arguments.\n" + << "Try `tp --help` for usage." << std::endl; + return EXIT_FAILURE; + } + if (args->is_set('F')) { + config->set("foreground", "true"); + } + auto monitor = args->arg('m', nullptr); + if (monitor) { + auto str = std::string(monitor); + auto colon = str.find(':'); + if (colon == 0 || colon == std::string::npos) { + std::cerr << "Invalid argument to monitor, expected HOST:PORT not: " + << str << std::endl; + return EXIT_FAILURE; + } + config->set("monitor", "true"); + config->set("monitor_bind", str.substr(0, colon)); + config->set("monitor_port", str.substr(colon + 1)); + } + auto configfile = args->arg('C', nullptr); + if (configfile) { + config->load_file(configfile); + } else { + config->load_name("tp"); + } + if (!config->good()) { + std::cerr << "Error loading config\n" + << config->last_error() << std::endl; + return EXIT_FAILURE; + } + std::unique_ptr<Logger> logger(Logger::create_stderr()); + std::unique_ptr<Logger> file_logger; + auto logfile = args->arg('l', nullptr); + bool logfile_from_argument; + if (!logfile) { + logfile = config->get("logfile", nullptr); + logfile_from_argument = false; + } else { + logfile_from_argument = true; + } + if (logfile) { + if (logfile[0] != '/') { + logger->out(Logger::ERR, "Logfile need to be an absolute path, not: %s", + logfile); + return EXIT_FAILURE; + } + file_logger.reset(Logger::create_file(logfile)); + if (!file_logger) { + logger->out(Logger::ERR, "Unable to open %s for logging: %s", + logfile, strerror(errno)); + return EXIT_FAILURE; + } + } + io::auto_fd accept_socket(Proxy::setup_accept_socket(config.get(), + logger.get())); + if (!accept_socket) return EXIT_FAILURE; + io::auto_fd monitor_socket; + if (config->get("monitor", false)) { + monitor_socket.reset(Proxy::setup_monitor_socket(config.get(), + logger.get())); + } + auto foreground = config->get("foreground", false); + auto cwd = get_cwd(); + std::unique_ptr<Proxy> proxy( + Proxy::create(config.get(), cwd, configfile, + foreground ? "bogus" : + (logfile_from_argument ? logfile : nullptr), + foreground ? logger.get() : file_logger.get(), + accept_socket.release(), + monitor_socket.release())); + if (!foreground) { + if (daemon(0, 0)) { + logger->out(Logger::ERR, "Failed to fork: %s", strerror(errno)); + return EXIT_FAILURE; + } + } + setup_signals(proxy.get()); + return proxy->run() ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/src/paths.cc b/src/paths.cc new file mode 100644 index 0000000..095076f --- /dev/null +++ b/src/paths.cc @@ -0,0 +1,98 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include "paths.hh" + +// static +std::string Paths::join(std::string const& base, std::string const& path) { + if (path.empty()) { + return cleanup(base); + } else if (path.front() == '/') { + return cleanup(path); + } + std::string ret(cleanup(base)); + if (ret == ".") { + return cleanup(path); + } + if (ret.empty() || ret.back() != '/') { + ret.push_back('/'); + } + ret.append(cleanup(path)); + return ret; +} + +namespace { + +std::string do_cleanup(std::string const& path) { + size_t i = 0, j = 0; + std::string ret; + bool add_slash = false; + if (path.front() == '/') { + ret.push_back('/'); + i = j = 1; + } + while (true) { + if (j == path.size() || path[j] == '/') { + auto len = j - i; + if (len > 0) { + if (len == 1 && path[i] == '.') { + i = ++j; + continue; + } + if (add_slash && len == 2 && path[i] == '.' && path[i + 1] == '.') { + auto x = ret.find_last_of('/'); + if (x == std::string::npos) x = 0; + ret.erase(x); + add_slash = !(ret.empty() || ret == "/"); + i = ++j; + continue; + } + if (add_slash) { + ret.push_back('/'); + } else { + add_slash = true; + } + ret.append(path.substr(i, len)); + } + if (j == path.size()) { + return ret; + } + i = ++j; + } else { + ++j; + } + } +} + +} // namespace + +// static +std::string Paths::cleanup(std::string const& path) { + if (path.empty()) return "."; + if (path[0] == '.') { + if (path.size() == 1) return path; + if (path[1] == '/') return do_cleanup(path); // found ./ at beginning + } + for (size_t pos = 0; pos < path.size(); ++pos) { + if (path[pos] == '/') { + if (pos > 0 && pos + 1 == path.size()) { + // found slash at end + return path.substr(0, path.size() - 1); + } + if (path[pos + 1] == '/') return do_cleanup(path); // found double slash + if (path[pos + 1] == '.') { + if (pos + 2 == path.size()) return do_cleanup(path); // found /. at end + if (path[pos + 2] == '/') return do_cleanup(path); // found /./ + if (path[pos + 2] == '.') { + if (pos + 3 == path.size()) { + // found /.. at end + return do_cleanup(path); + } + if (path[pos + 3] == '/') do_cleanup(path); // found /../ + } + } + } + } + return path; +} diff --git a/src/paths.hh b/src/paths.hh new file mode 100644 index 0000000..adec361 --- /dev/null +++ b/src/paths.hh @@ -0,0 +1,18 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef PATHS_HH +#define PATHS_HH + +#include <string> + +class Paths { +public: + static std::string join(std::string const& base, std::string const& path); + static std::string cleanup(std::string const& path); + +private: + Paths() {} + ~Paths() {} +}; + +#endif // PATHS_HH diff --git a/src/proxy.cc b/src/proxy.cc new file mode 100644 index 0000000..fcd02f6 --- /dev/null +++ b/src/proxy.cc @@ -0,0 +1,1373 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include <chrono> +#include <cstring> +#include <fcntl.h> +#include <memory> +#include <netdb.h> +#include <signal.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <vector> + +#include "buffer.hh" +#include "chunked.hh" +#include "config.hh" +#include "http.hh" +#include "io.hh" +#include "logger.hh" +#include "looper.hh" +#include "resolver.hh" +#include "paths.hh" +#include "proxy.hh" +#include "url.hh" + +namespace { + +auto const NEW_CONNECTION_TIMEOUT = std::chrono::duration<float>(5.0f); +auto const CONNECTION_TIMEOUT = std::chrono::duration<float>(30.0f); + +template<typename T> +class clients { +public: + class iterator { + public: + iterator(clients* clients, size_t active) + : clients_(clients), index_(0), active_(active) { + if (active_) { + while (clients_->is_free(index_)) { + ++index_; + } + } + } + iterator(iterator const& i) + : clients_(i.clients_), index_(i.index_), active_(i.active_) { + } + + iterator& operator=(iterator const& i) { + clients_ = i.clients_; + index_ = i.index_; + active_ = i.active_; + return *this; + } + + bool operator==(iterator const& i) const { + return active_ == i.active_; + } + bool operator!=(iterator const& i) const { + return !(*this == i); + } + bool operator<(iterator const& i) const { + return active_ < i.active_; + } + bool operator<=(iterator const& i) const { + return (*this < i) || (*this == i); + } + bool operator>=(iterator const& i) const { + return !(*this < i); + } + bool operator>(iterator const& i) const { + return !(*this <= i); + } + + iterator& operator++() { + next(); + return *this; + } + + iterator operator++(int UNUSED(dummy)) { + iterator ret(*this); + ++(*this); + return ret; + } + + T& operator*() { + return (*clients_)[index_]; + } + T* operator->() { + return &(*clients_)[index_]; + } + + size_t index() const { + return index_; + } + + private: + void next() { + if (active_ == 0) return; + --active_; + if (active_ == 0) return; + ++index_; + while (clients_->is_free(index_)) { + ++index_; + } + } + + clients* clients_; + size_t index_; + size_t active_; + }; + + clients() + : active_(0), max_(0) { + } + + void resize(size_t max) { + max_ = max; + if (max > data_.size()) { + data_.resize(max); + } + } + + bool full() const { + return active_ >= max_; + } + + iterator begin() { + return iterator(this, active_); + } + + iterator end() { + return iterator(this, 0); + } + + T& operator[] (size_t index) { + return data_[index]; + } + + size_t new_client() { + size_t index; + if (active_ == data_.size()) { + index = data_.size(); + data_.emplace_back(); + } else { + index = rand() % data_.size(); + while (!is_free(index)) { + if (++index == data_.size()) { + index = 0; + } + } + } + ++active_; + return index; + } + + void erase(size_t index) { + assert(active_ > 0); + --active_; + assert(is_free(index)); + if (data_.size() > max_ && index >= max_) { + size_t size = data_.size(); + while (size > max_ && is_free(size - 1)) --size; + data_.resize(size); + } + } + + void clear() { + if (active_ > 0) { + data_.clear(); + active_ = 0; + } + data_.resize(max_); + } + +private: + bool is_free(size_t index) const { + return !data_[index].fd; + } + + size_t active_; + size_t max_; + std::vector<T> data_; +}; + +struct BaseClient { + io::auto_fd fd; + bool new_connection; + Looper::clock::time_point last; + std::unique_ptr<Buffer> in; + std::unique_ptr<Buffer> out; + uint8_t read_flag; + uint8_t write_flag; +}; + +enum ContentType { + CONTENT_NONE, + CONTENT_LEN, + CONTENT_CHUNKED, + CONTENT_CLOSE +}; + +struct Content { + ContentType type; + uint64_t len; + std::unique_ptr<Chunked> chunked; +}; + +enum RemoteState { + CLOSED, + RESOLVING, + CONNECTING, + CONNECTED, + WAITING, +}; + +struct RemoteClient : public BaseClient { + Content content; +}; + +struct Client : public BaseClient { + Client() + : resolve(nullptr) { + } + std::unique_ptr<HttpRequest> request; + std::unique_ptr<Url> url; + Content content; + RemoteState remote_state; + void* resolve; + RemoteClient remote; +}; + +struct Monitor : public BaseClient { +}; + +class ProxyImpl : public Proxy { +public: + ProxyImpl(Config* config, std::string const& cwd, char const* configfile, + char const* logfile, Logger* logger, int accept_fd, int monitor_fd) + : config_(config), cwd_(cwd), configfile_(configfile), logfile_(logfile), + logger_(logger), accept_socket_(accept_fd), monitor_socket_(monitor_fd), + looper_(Looper::create()), resolver_(Resolver::create(looper_.get())), + new_timeout_(nullptr), timeout_(nullptr) { + setup(); + } + ~ProxyImpl() override { + if (accept_socket_) { + looper_->remove(accept_socket_.get()); + } + if (monitor_socket_) { + looper_->remove(monitor_socket_.get()); + } + if (signal_pipe_) { + looper_->remove(signal_pipe_.read()); + } + } + + void quit() override { + char cmd = 'q'; + io::write_all(signal_pipe_.write(), &cmd, 1); + } + + void reload() override { + char cmd = 'r'; + io::write_all(signal_pipe_.write(), &cmd, 1); + } + + bool run() override; + +private: + void setup(); + bool reload_config(); + void fatal_error(); + void new_client(int fd, uint8_t events); + void new_monitor(int fd, uint8_t events); + void new_base(BaseClient* client, int fd); + void signal_event(int fd, uint8_t events); + bool base_event(BaseClient* client, uint8_t events, + size_t index, char const* name); + void client_event(size_t index, int fd, uint8_t events); + void client_remote_event(size_t index, int fd, uint8_t events); + void client_empty_input(size_t index); + void monitor_event(size_t index, int fd, uint8_t events); + void close_client(size_t index); + void close_monitor(size_t index); + void close_base(BaseClient* client); + void new_timeout(); + void timeout(); + float handle_timeout(bool new_conn, + std::chrono::duration<float> const& timeout); + bool base_send(BaseClient* client, void const* data, size_t size, + size_t index, char const* name); + void client_error(size_t index, + uint16_t status_code, std::string const& status); + bool client_request(size_t index); + bool client_send(size_t index, void const* ptr, size_t avail); + void client_remote_error(size_t index, uint16_t error); + void close_client_when_done(size_t index); + void client_remote_resolved(size_t index, int fd, bool connected, + char const* error); + + Config* const config_; + std::string cwd_; + char const* const configfile_; + char const* const logfile_; + Logger* logger_; + std::unique_ptr<Logger> priv_logger_; + io::auto_fd accept_socket_; + io::auto_fd monitor_socket_; + io::auto_pipe signal_pipe_; + std::unique_ptr<Looper> looper_; + std::unique_ptr<Resolver> resolver_; + bool good_; + void* new_timeout_; + void* timeout_; + + clients<Client> clients_; + clients<Monitor> monitors_; +}; + +size_t get_size(Config* config, Logger* logger, std::string const& name, + size_t fallback) { + auto value = config->get(name, nullptr); + if (!value) return fallback; + char* end = nullptr; + errno = 0; + auto tmp = strtoul(value, &end, 10); + if (errno || !end || *end) { + logger->out(Logger::WARN, + "Invalid value given for %s: %s, using fallback %zu instead", + name.c_str(), value, fallback); + return fallback; + } + return static_cast<size_t>(tmp); +} + +void ProxyImpl::setup() { + clients_.resize(get_size(config_, logger_, "max_clients", 1024)); + monitors_.resize(get_size(config_, logger_, "max_monitors", 2)); + looper_->add(accept_socket_.get(), + clients_.full() ? 0 : Looper::EVENT_READ, + std::bind(&ProxyImpl::new_client, this, + std::placeholders::_1, + std::placeholders::_2)); + if (monitor_socket_) { + looper_->add(monitor_socket_.get(), + monitors_.full() ? 0 :Looper::EVENT_READ, + std::bind(&ProxyImpl::new_monitor, this, + std::placeholders::_1, + std::placeholders::_2)); + } else { + monitors_.clear(); + } +} + +bool ProxyImpl::reload_config() { + if (configfile_) { + auto file = Paths::join(cwd_, configfile_); + logger_->out(Logger::INFO, "Reloading config file %s", file.c_str()); + config_->load_file(file); + } else { + logger_->out(Logger::INFO, "Reloading config"); + config_->load_name("tp"); + } + if (!config_->good()) { + logger_->out(Logger::WARN, "New config invalid, ignored."); + return true; + } + if (!logfile_) { + auto logfile = config_->get("logfile", nullptr); + if (logfile) { + if (logfile[0] != '/') { + logger_->out(Logger::WARN, + "Logfile need to be an absolute path, not: %s", + logfile); + } else { + std::unique_ptr<Logger> tmp(Logger::create_file(logfile_)); + if (tmp) { + priv_logger_.swap(tmp); + logger_ = priv_logger_.get(); + } else { + logger_->out(Logger::WARN, + "Unable to open %s for logging", logfile); + } + } + } else { + priv_logger_.reset(Logger::create_syslog("tp")); + logger_ = priv_logger_.get(); + } + } + auto const old_accept = accept_socket_.get(); + auto const old_monitor = monitor_socket_.get(); + looper_->remove(old_accept); + if (monitor_socket_) looper_->remove(old_monitor); + accept_socket_.reset(setup_accept_socket(config_, logger_)); + monitor_socket_.reset(); + if (!accept_socket_) { + logger_->out(Logger::ERR, "Unable to bind to new configuration, abort"); + return false; + } + if (config_->get("monitor", false)) { + monitor_socket_.reset(setup_monitor_socket(config_, logger_)); + if (!monitor_socket_) { + logger_->out(Logger::ERR, "Unable to bind to new configuration, abort"); + return false; + } + } + setup(); + return true; +} + +void ProxyImpl::signal_event(int fd, uint8_t events) { + assert(fd == signal_pipe_.read()); + if (events == Looper::EVENT_READ) { + char cmd; + auto ret = io::read(fd, &cmd, 1); + if (ret == 1) { + switch (cmd) { + case 'q': + logger_->out(Logger::INFO, "Exiting ..."); + looper_->quit(); + return; + case 'r': + if (!reload_config()) { + fatal_error(); + return; + } + break; + } + } + } else { + logger_->out(Logger::WARN, "Signal pipes have crashed"); + signal_pipe_.reset(); + looper_->remove(fd); + } +} + +void ProxyImpl::close_base(BaseClient* client) { + if (client->fd) { + looper_->remove(client->fd.get()); + client->fd.reset(); + } + client->in.reset(); + client->out.reset(); +} + +void ProxyImpl::close_client(size_t index) { + bool was_full = clients_.full(); + auto& client = clients_[index]; + client.request.reset(); + client.url.reset(); + client.content.type = CONTENT_NONE; + client.content.chunked.reset(); + client.remote_state = CLOSED; + if (client.resolve) { + resolver_->cancel(client.resolve); + client.resolve = nullptr; + } + close_base(&client.remote); + close_base(&client); + clients_.erase(index); + if (was_full && !clients_.full() && accept_socket_) { + looper_->modify(accept_socket_.get(), Looper::EVENT_READ); + } +} + +void ProxyImpl::close_monitor(size_t index) { + bool was_full = monitors_.full(); + auto& monitor = monitors_[index]; + close_base(&monitor); + monitors_.erase(index); + if (was_full && !monitors_.full() && monitor_socket_) { + looper_->modify(monitor_socket_.get(), Looper::EVENT_READ); + } +} + +float ProxyImpl::handle_timeout(bool new_conn, + std::chrono::duration<float> const& timeout) { + auto now = looper_->now(); + std::vector<size_t> close; + std::vector<bool> remote; + float next = -1.0f; + for (auto i = clients_.begin(); i != clients_.end(); ++i) { + if (i->new_connection != new_conn) continue; + auto diff = std::chrono::duration_cast<std::chrono::duration<float>>( + ((i->last + timeout) - now)).count(); + if (diff < 0.0f) { + close.push_back(i.index()); + remote.push_back(false); + } else { + if (!new_conn && i->remote_state > CLOSED) { + auto diff2 = std::chrono::duration_cast<std::chrono::duration<float>>( + ((i->remote.last + timeout) - now)).count(); + if (diff2 < 0.0f) { + close.push_back(i.index()); + remote.push_back(true); + } else if (diff2 < diff) { + diff = diff2; + } + } + if (next < 0.0f || diff < next) { + next = diff; + } + } + } + assert(close.size() == remote.size()); + auto j = remote.rbegin(); + for (auto i = close.rbegin(); i != close.rend(); ++i, ++j) { + if (*j) { + client_remote_error(*i, 504); + } else { + close_client(*i); + } + } + close.clear(); + for (auto i = monitors_.begin(); i != monitors_.end(); ++i) { + if (i->new_connection != new_conn) continue; + auto diff = std::chrono::duration_cast<std::chrono::duration<float>>( + ((i->last + timeout) - now)).count(); + if (diff < 0.0f) { + close.push_back(i.index()); + } else if (next < 0.0f || diff < next) { + next = diff; + } + } + for (auto i = close.rbegin(); i != close.rend(); ++i) { + close_monitor(*i); + } + return next; +} + +void ProxyImpl::timeout() { + assert(timeout_); + timeout_ = nullptr; + float next = handle_timeout(false, CONNECTION_TIMEOUT); + if (next < 0.0f) return; + timeout_ = looper_->schedule(next, std::bind(&ProxyImpl::timeout, this)); +} + +void ProxyImpl::new_timeout() { + assert(new_timeout_); + new_timeout_ = nullptr; + float next = handle_timeout(true, NEW_CONNECTION_TIMEOUT); + if (next < 0.0f) return; + new_timeout_ = + looper_->schedule(next, std::bind(&ProxyImpl::new_timeout, this)); +} + +bool ProxyImpl::base_send(BaseClient* client, void const* data, size_t size, + size_t index, char const* name) { + if (size == 0) return true; + if (!client->out->empty()) { + // Already waiting for write event + client->out->write(data, size); + return true; + } + auto ret = io::write(client->fd.get(), data, size); + if (ret == -1) { + if (errno != EAGAIN && errno != EWOULDBLOCK) { + logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name, + strerror(errno)); + return false; + } + client->out->write(data, size); + } else { + if (static_cast<size_t>(ret) == size) { + if (!client->in) { + // If input is closed, close after sending all data + return false; + } + return true; + } + client->out->write(reinterpret_cast<char const*>(data) + ret, size - ret); + } + client->write_flag = Looper::EVENT_WRITE; + looper_->modify(client->fd.get(), client->read_flag | client->write_flag); + return true; +} + +bool ProxyImpl::base_event(BaseClient* client, uint8_t events, + size_t index, char const* name) { + if (events & Looper::EVENT_READ) { + if (client->new_connection) { + char tmp[1]; + auto ret = io::read(client->fd.get(), &tmp, 1); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return true; + logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name, + strerror(errno)); + return false; + } + if (ret == 0) { + return false; + } + client->last = looper_->now(); + client->new_connection = false; + client->in.reset(Buffer::create(8192, 1024)); + client->out.reset(Buffer::create(8192, 1024)); + client->in->write(tmp, 1); + if (!timeout_) { + timeout_ = looper_->schedule(CONNECTION_TIMEOUT.count(), + std::bind(&ProxyImpl::timeout, this)); + } + } + if (!client->in) { + assert(false); + return false; + } + size_t avail; + auto ptr = client->in->write_ptr(&avail); + assert(avail > 0); + auto ret = io::read(client->fd.get(), ptr, avail); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return true; + logger_->out(Logger::INFO, "%zu: %s read error: %s", index, name, + strerror(errno)); + return false; + } + if (ret == 0) { + return false; + } + client->last = looper_->now(); + client->in->commit(ret); + } + if (events & Looper::EVENT_WRITE) { + if (client->new_connection) { + assert(false); + return true; + } + size_t avail; + auto ptr = client->out->read_ptr(&avail); + if (avail == 0) { + assert(false); + if (!client->in) return false; + client->write_flag = 0; + looper_->modify(client->fd.get(), client->read_flag); + return true; + } + auto ret = io::write(client->fd.get(), ptr, avail); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return true; + logger_->out(Logger::INFO, "%zu: %s write error: %s", index, name, + strerror(errno)); + return false; + } + if (ret == 0) { + return false; + } + client->last = looper_->now(); + client->out->consume(ret); + if (client->out->empty()) { + if (!client->in) return false; + client->write_flag = 0; + looper_->modify(client->fd.get(), client->read_flag); + } + } + return true; +} + +inline char lower_ascii(char c) { + return (c >= 'A' && c <= 'Z') ? (c - 'A' + 'a') : c; +} + +bool lower_equal(char const* data, size_t start, size_t end, + std::string const& str) { + assert(start <= end); + if (str.size() != end - start) return false; + for (auto i = str.begin(); start < end; ++start, ++i) { + if (lower_ascii(*i) != lower_ascii(data[start])) return false; + } + return true; +} + +bool header_token_eq(std::string const& value, std::string const& token) { + if (value.empty()) return false; + auto pos = value.find(';'); + if (pos == std::string::npos) pos = value.size(); + return lower_equal(value.data(), 0, pos, token); +} + +void ProxyImpl::client_remote_error(size_t index, uint16_t error) { + assert(false); + auto& client = clients_[index]; + if (client.remote_state > CONNECTED) { + // Already started to return response, too late to do anything + close_client(index); + return; + } + char const* status; + switch (error) { + case 502: + status = "Bad Gateway"; + break; + case 504: + status = "Gateway Timeout"; + break; + default: + assert(false); + error = 500; + status = "Internal Server Error"; + break; + } + + client_error(index, error, status); +} + +void ProxyImpl::client_error(size_t index, uint16_t status_code, + std::string const& status_message) { + auto& client = clients_[index]; + // No more input + client.read_flag = 0; + looper_->modify(client.fd.get(), client.write_flag); + client.url.reset(); + + std::string proto; + Version version; + + if (!client.request) { + proto = "HTTP"; + version.major = 1; + version.minor = 1; + } else { + proto = client.request->proto(); + version = client.request->proto_version(); + } + + auto resp = std::unique_ptr<HttpResponseBuilder>( + HttpResponseBuilder::create( + proto, version, status_code, status_message)); + resp->add_header("Content-Length", "0"); + resp->add_header("Connection", "close"); + + client.in.reset(); + client.request.reset(); + + auto data = resp->build(); + + if (!base_send(&client, data.data(), data.size(), index, "Client")) { + close_client(index); + return; + } +} + +bool ProxyImpl::client_request(size_t index) { + auto& client = clients_[index]; + auto version = client.request->proto_version(); + if (version.major != 1 || version.minor > 1) { + client_error(index, 505, "HTTP Version Not Supported"); + return false; + } + if (client.url->scheme() != "http") { + client_error(index, 501, "Not Implemented"); + return false; + } + if (client.url->userinfo()) { + client_error(index, 400, "Bad request"); + return false; + } + if (client.remote_state == WAITING) { + client.remote_state = CONNECTING; + client.remote.read_flag = Looper::EVENT_READ; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), + client.remote.read_flag | client.remote.write_flag); + client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE); + } else { + assert(client.remote_state == CLOSED); + client.remote.last = looper_->now(); + client.remote.new_connection = true; + client.remote_state = RESOLVING; + + auto port = client.url->port(); + if (port == 0) port = 80; + client.resolve = resolver_->request( + client.url->host(), port, + std::bind(&ProxyImpl::client_remote_resolved, this, index, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3)); + } + return true; +} + +void ProxyImpl::client_remote_resolved(size_t index, int fd, bool connected, + char const* error) { + auto& client = clients_[index]; + assert(client.resolve); + client.resolve = nullptr; + assert(client.remote_state == RESOLVING); + if (fd < 0) { + logger_->out(Logger::INFO, "%zu: Client unable to resolve remote: %s", + index, error); + client_remote_error(index, 502); + return; + } + client.remote_state = CONNECTING; + client.remote.fd.reset(fd); + client.remote.last = looper_->now(); + client.remote.new_connection = false; + client.remote.in.reset(Buffer::create(8192, 1024)); + client.remote.out.reset(Buffer::create(8192, 1024)); + if (connected) { + client.remote.read_flag = Looper::EVENT_READ; + client.remote.write_flag = 0; + } else { + client.remote.read_flag = 0; + client.remote.write_flag = Looper::EVENT_WRITE; + } + looper_->add(client.remote.fd.get(), + client.remote.read_flag | client.remote.write_flag, + std::bind(&ProxyImpl::client_remote_event, this, index, + std::placeholders::_1, + std::placeholders::_2)); + assert(timeout_); + if (connected) { + client_remote_event(index, fd, Looper::EVENT_WRITE); + } +} + +void ProxyImpl::client_event(size_t index, int fd, uint8_t events) { + auto& client = clients_[index]; + assert(client.fd.get() == fd); + if (events & Looper::EVENT_HUP) { + close_client(index); + return; + } + if (events & Looper::EVENT_ERROR) { + logger_->out(Logger::INFO, "%zu: Client connection error", index); + close_client(index); + return; + } + if (!base_event(&client, events, index, "Client")) { + close_client(index); + return; + } + if (client.new_connection) return; + client_empty_input(index); +} + +bool setup_content(Http const* http, Content* content) { + assert(content->type == CONTENT_NONE); + std::string te = http->first_header("transfer-encoding"); + if (te.empty() || header_token_eq(te, "identity")) { + std::string len = http->first_header("content-length"); + if (len.empty()) { + content->type = CONTENT_CLOSE; + return true; + } + char* end = nullptr; + errno = 0; + auto tmp = strtoull(len.c_str(), &end, 10); + if (errno || !end || *end) { + return false; + } + if (tmp == 0) { + content->type = CONTENT_NONE; + return true; + } + content->len = tmp; + content->type = CONTENT_LEN; + } else { + content->chunked.reset(Chunked::create()); + content->type = CONTENT_CHUNKED; + } + return true; +} + +void ProxyImpl::client_empty_input(size_t index) { + auto& client = clients_[index]; + while (true) { + size_t avail; + auto ptr = client.in->read_ptr(&avail); + if (avail == 0) return; + switch (client.content.type) { + case CONTENT_CLOSE: + assert(false); + // falltrough + case CONTENT_NONE: { + if (client.remote_state != CLOSED && client.remote_state != WAITING) { + // Still working on the last request, wait + return; + } + client.request.reset( + HttpRequest::parse( + reinterpret_cast<char const*>(ptr), avail, false)); + if (!client.request) { + if (avail >= 1024 * 1024) { + logger_->out(Logger::INFO, "%zu: Client too large request %zu", + index, avail); + close_client(index); + } + return; + } + if (!client.request->good()) { + client_error(index, 400, "Bad request"); + return; + } + if (client.request->method_equal("CONNECT")) { + client_error(index, 501, "Not Implemented"); + return; + } + client.url.reset(Url::parse(client.request->url())); + if (!client.url) { + client_error(index, 400, "Bad request"); + return; + } + if (!setup_content(client.request.get(), &client.content)) { + logger_->out(Logger::INFO, "%zu: Client bad content-length", index); + client_error(index, 400, "Bad request"); + return; + } + if (client.content.type == CONTENT_CLOSE) { + client.content.type = CONTENT_NONE; + } + if (!client_request(index)) { + client.content.type = CONTENT_NONE; + return; + } + break; + } + case CONTENT_LEN: + if (client.remote_state < CONNECTED) { + // Request hasn't been sent yet, still collecting data + return; + } + if (avail < client.content.len) { + if (!client_send(index, ptr, avail)) { + return; + } + client.in->consume(avail); + client.content.len -= avail; + return; + } + if (!client_send(index, ptr, client.content.len)) { + return; + } + client.in->consume(client.content.len); + client.content.len = 0; + client.content.type = CONTENT_NONE; + break; + case CONTENT_CHUNKED: + if (client.remote_state < CONNECTED) { + // Request hasn't been sent yet, still collecting data + return; + } + auto used = client.content.chunked->add(ptr, avail); + if (!client.content.chunked->good()) { + client_error(index, 400, "Bad request"); + return; + } + if (!client_send(index, ptr, used)) { + return; + } + client.in->consume(used); + if (client.content.chunked->eof()) { + client.content.chunked.reset(); + client.content.type = CONTENT_NONE; + } + break; + } + } +} + +void ProxyImpl::close_client_when_done(size_t index) { + auto& client = clients_[index]; + if (client.out->empty()) { + close_client(index); + return; + } + client.remote_state = CLOSED; + close_base(&client.remote); + client.in.reset(); + client.read_flag = 0; + client.request.reset(); + client.content.type = CONTENT_NONE; + looper_->modify(client.fd.get(), client.write_flag); +} + +void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { + auto& client = clients_[index]; + assert(client.remote.fd.get() == fd); + if (events & Looper::EVENT_HUP) { + logger_->out(Logger::INFO, "%zu: Client remote connection closed", index); + client_remote_error(index, 502); + return; + } + if (events & Looper::EVENT_ERROR) { + logger_->out(Logger::INFO, "%zu: Client remote connection error", index); + client_remote_error(index, 502); + return; + } + if (client.remote_state == CONNECTING) { + if (events & Looper::EVENT_WRITE) { + std::string url(client.url->path_escaped()); + if (url.empty()) url.push_back('/'); + auto query = client.url->full_query_escaped(); + if (query) { + url.push_back('?'); + url.append(query); + } + auto req = std::unique_ptr<HttpRequestBuilder>( + HttpRequestBuilder::create( + client.request->method(), + url, + client.request->proto(), + client.request->proto_version())); + auto iter = client.request->header(); + bool have_host = false; + for (; iter->valid(); iter->next()) { + if (!have_host && iter->name_equal("host")) have_host = true; + if (iter->name_equal("proxy-connection") || + iter->name_equal("proxy-authenticate") || + iter->name_equal("proxy-authorization")) { + continue; + } + req->add_header(iter->name(), iter->value()); + } + if (!have_host && + (client.request->proto_version().major == 1 && + client.request->proto_version().minor == 1)) { + req->add_header("host", client.url->host()); + } + auto data = req->build(); + client.in->consume(client.request->size()); + client.request.reset(); + client.url.reset(); + client.remote.out->write(data.data(), data.size()); + client.remote_state = CONNECTED; + client.remote.read_flag = Looper::EVENT_READ; + client.remote.write_flag = Looper::EVENT_WRITE; + looper_->modify(client.remote.fd.get(), + client.remote.read_flag | client.remote.write_flag); + client_empty_input(index); + } else { + return; + } + } + if (!base_event(&client.remote, events, index, "Client remote")) { + switch (client.remote_state) { + case CONNECTED: + switch (client.remote.content.type) { + case CONTENT_CLOSE: + close_client_when_done(index); + break; + case CONTENT_CHUNKED: + case CONTENT_LEN: + case CONTENT_NONE: + client_remote_error(index, 502); + break; + } + break; + case CONNECTING: + case RESOLVING: + case CLOSED: + assert(false); + client_remote_error(index, 502); + break; + case WAITING: + close_base(&client.remote); + client.remote_state = CLOSED; + break; + } + return; + } + while (true) { + size_t avail; + auto ptr = client.remote.in->read_ptr(&avail); + if (avail == 0) return; + switch (client.remote_state) { + case CONNECTED: + switch (client.remote.content.type) { + case CONTENT_NONE: { + auto response = std::unique_ptr<HttpResponse>( + HttpResponse::parse( + reinterpret_cast<char const*>(ptr), avail, false)); + if (!response) { + if (avail >= 1024 * 1024) { + logger_->out(Logger::INFO, + "%zu: Client remote too large request %zu", + index, avail); + client_remote_error(index, 502); + } + return; + } + if (!response->good()) { + client_remote_error(index, 502); + return; + } + if (!setup_content(response.get(), &client.remote.content)) { + logger_->out(Logger::INFO, "%zu: Client remote bad content-length", + index); + client.remote.content.type = CONTENT_CLOSE; + } else { + if (client.remote.content.type == CONTENT_NONE) { + client.remote_state = WAITING; + client.remote.read_flag = 0; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), 0); + } + } + if (!base_send(&client, ptr, response->size(), index, "Client")) { + return; + } + client.remote.in->consume(response->size()); + break; + } + case CONTENT_LEN: + if (avail < client.remote.content.len) { + if (!base_send(&client, ptr, avail, index, "Client")) { + return; + } + client.remote.in->consume(avail); + client.remote.content.len -= avail; + return; + } + if (!base_send(&client, ptr, client.remote.content.len, + index, "Client")) { + return; + } + client.remote.in->consume(client.remote.content.len); + client.remote.content.len = 0; + client.remote.content.type = CONTENT_NONE; + client.remote_state = WAITING; + client.remote.read_flag = 0; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), 0); + return; + case CONTENT_CHUNKED: { + auto used = client.remote.content.chunked->add(ptr, avail); + if (!client.remote.content.chunked->good()) { + logger_->out(Logger::INFO, "%zu: Client remote bad chunked", + index); + client.remote.content.type = CONTENT_CLOSE; + break; + } + if (!base_send(&client, ptr, used, index, "Client")) { + return; + } + client.remote.in->consume(used); + if (client.remote.content.chunked->eof()) { + client.remote.content.chunked.reset(); + client.remote.content.type = CONTENT_NONE; + logger_->out(Logger::INFO, "%zu: chunked -> waiting", index); + client.remote_state = WAITING; + client.remote.read_flag = 0; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), 0); + } + break; + } + case CONTENT_CLOSE: + if (!base_send(&client, ptr, avail, index, "Client")) { + return; + } + client.remote.in->consume(avail); + break; + } + break; + case CONNECTING: + case RESOLVING: + case CLOSED: + assert(false); + return; + case WAITING: + return; + } + } +} + +bool ProxyImpl::client_send(size_t index, void const* ptr, size_t size) { + auto& client = clients_[index]; + assert(!client.request); + assert(client.remote_state >= CONNECTED); + + return base_send(&client.remote, ptr, size, index, "Client remote"); +} + +void ProxyImpl::monitor_event(size_t index, int fd, uint8_t events) { + auto& monitor = monitors_[index]; + assert(monitor.fd.get() == fd); + if (events & Looper::EVENT_HUP) { + close_monitor(index); + return; + } + if (events & Looper::EVENT_ERROR) { + logger_->out(Logger::INFO, "%zu: Monitor connection error", index); + close_monitor(index); + return; + } + if (!base_event(&monitor, events, index, "Monitor")) { + close_monitor(index); + return; + } +} + +void ProxyImpl::new_base(BaseClient* client, int fd) { + client->fd.reset(fd); + client->new_connection = true; + client->read_flag = Looper::EVENT_READ; + client->write_flag = 0; + client->last = looper_->now(); + if (!new_timeout_) { + new_timeout_ = looper_->schedule( + NEW_CONNECTION_TIMEOUT.count(), + std::bind(&ProxyImpl::new_timeout, this)); + } +} + +void ProxyImpl::new_client(int fd, uint8_t events) { + assert(fd == accept_socket_.get()); + if (events == Looper::EVENT_READ) { + assert(!clients_.full()); + while (true) { + int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return; + if (errno == EINTR) continue; + logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno)); + return; + } + auto index = clients_.new_client(); + new_base(&clients_[index], ret); + clients_[index].content.type = CONTENT_NONE; + clients_[index].remote_state = CLOSED; + clients_[index].remote.content.type = CONTENT_NONE; + looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag, + std::bind(&ProxyImpl::client_event, this, + index, + std::placeholders::_1, + std::placeholders::_2)); + break; + } + if (clients_.full()) looper_->modify(fd, 0); + } else { + logger_->out(Logger::ERR, "Accept socket died"); + fatal_error(); + } +} + +void ProxyImpl::new_monitor(int fd, uint8_t events) { + assert(fd == monitor_socket_.get()); + if (events == Looper::EVENT_READ) { + assert(!monitors_.full()); + while (true) { + int ret = accept4(fd, nullptr, nullptr, SOCK_NONBLOCK); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return; + if (errno == EINTR) continue; + logger_->out(Logger::WARN, "Accept failed: %s", strerror(errno)); + return; + } + auto index = monitors_.new_client(); + new_base(&monitors_[index], ret); + looper_->add(ret, clients_[index].read_flag | clients_[index].write_flag, + std::bind(&ProxyImpl::monitor_event, this, + index, + std::placeholders::_1, + std::placeholders::_2)); + break; + } + if (monitors_.full()) { + looper_->modify(fd, 0); + } + } else { + logger_->out(Logger::ERR, "Monitor socket died"); + fatal_error(); + } +} + +void ProxyImpl::fatal_error() { + looper_->quit(); + good_ = false; +} + +bool ProxyImpl::run() { + good_ = true; + if (!logger_) { + priv_logger_.reset(Logger::create_syslog("tp")); + logger_ = priv_logger_.get(); + } + { + struct sigaction action; + memset(&action, 0, sizeof(action)); + action.sa_handler = SIG_IGN; + action.sa_flags = SA_RESTART; + sigaction(SIGPIPE, &action, nullptr); + } + if (!signal_pipe_.open()) { + logger_->out(Logger::WARN, + "Failed to create pipes, signals wont work: %s", + strerror(errno)); + } + if (signal_pipe_) { + looper_->add(signal_pipe_.read(), Looper::EVENT_READ, + std::bind(&ProxyImpl::signal_event, this, + std::placeholders::_1, + std::placeholders::_2)); + } + if (!looper_->run()) { + logger_->out(Logger::ERR, "poll() failed: %s", strerror(errno)); + return false; + } + return good_; +} + +int setup_socket(char const* host, std::string const& port, Logger* logger) { + io::auto_fd ret; + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG | AI_PASSIVE; + struct addrinfo* result; + auto retval = getaddrinfo(host, port.c_str(), &hints, &result); + if (retval) { + logger->out(Logger::ERR, "getaddrinfo failed for %s:%s: %s", + host ? host : "*", port.c_str(), gai_strerror(retval)); + return -1; + } + auto retp = result; + for (; retp; retp = retp->ai_next) { + ret.reset(socket(retp->ai_family, retp->ai_socktype, + retp->ai_protocol)); + if (!ret) continue; + if (bind(ret.get(), retp->ai_addr, retp->ai_addrlen) == 0) + break; + } + freeaddrinfo(result); + if (!retp) { + logger->out(Logger::ERR, "Failed to bind %s:%s: %s", + host ? host : "*", port.c_str(), strerror(errno)); + return -1; + } + if (listen(ret.get(), SOMAXCONN)) { + logger->out(Logger::ERR, "Failed to listen: %s", strerror(errno)); + return -1; + } + if (fcntl(ret.get(), F_SETFL, O_NONBLOCK)) { + logger->out(Logger::ERR, "fcntl(O_NONBLOCK) failed: %s", strerror(errno)); + return -1; + } + return ret.release(); +} + +} // namespace + +// static +Proxy* Proxy::create(Config* config, std::string const& cwd, + char const* configfile, + char const* logfile, + Logger* logger, + int accept_fd, + int monitor_fd) { + return new ProxyImpl(config, cwd, configfile, logfile, logger, + accept_fd, monitor_fd); +} + +// static +int Proxy::setup_accept_socket(Config* config, Logger* logger) { + return setup_socket(config->get("proxy_bind", nullptr), + config->get("proxy_port", "8080"), logger); +} + +// static +int Proxy::setup_monitor_socket(Config* config, Logger* logger) { + assert(config->get("monitor", false)); + return setup_socket(config->get("monitor_bind", "localhost"), + config->get("monitor_port", "9000"), logger); +} diff --git a/src/proxy.hh b/src/proxy.hh new file mode 100644 index 0000000..6545152 --- /dev/null +++ b/src/proxy.hh @@ -0,0 +1,37 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef PROXY_HH +#define PROXY_HH + +#include <string> + +class Config; +class Logger; + +class Proxy { +public: + virtual ~Proxy() {} + + static Proxy* create(Config* config, std::string const& cwd, + char const* configfile, + char const* logfile, + Logger* logger, + int accept_fd, int monitor_fd); + static int setup_accept_socket(Config* config, Logger* logger); + static int setup_monitor_socket(Config* config, Logger* logger); + + // Called from signal handler, ask the proxy to exit as soon as possible + virtual void quit() = 0; + + // Called from signal handler, ask the proxy to reload config as soon + // as possible + virtual void reload() = 0; + + virtual bool run() = 0; + +protected: + Proxy() {} + Proxy(Proxy const&) = delete; +}; + +#endif // PROXY_HH diff --git a/src/resolver.cc b/src/resolver.cc new file mode 100644 index 0000000..2623089 --- /dev/null +++ b/src/resolver.cc @@ -0,0 +1,207 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <condition_variable> +#include <cstring> +#include <fcntl.h> +#include <mutex> +#include <netdb.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <thread> +#include <vector> + +#include "io.hh" +#include "looper.hh" +#include "resolver.hh" + +namespace { + +size_t const WORKERS = 4; + +class ResolverImpl : public Resolver { +public: + ResolverImpl(Looper* looper) + : looper_(looper), request_(nullptr), buf_(new char[sizeof(Request*)]), + fill_(0), quit_(false) { + while (threads_.size() < WORKERS) { + threads_.emplace_back(std::bind(&ResolverImpl::worker, this)); + } + if (pipe_.open() && fcntl(pipe_.read(), F_SETFL, O_NONBLOCK) == 0) { + looper_->add(pipe_.read(), + Looper::EVENT_READ, + std::bind(&ResolverImpl::event, this, + std::placeholders::_1, std::placeholders::_2)); + } else { + assert(false); + } + } + + ~ResolverImpl() override { + quit_ = true; + cond_.notify_all(); + for (auto& thread : threads_) { + thread.join(); + } + } + + void* request(std::string const& host, uint16_t port, + Callback const& callback) override { + auto req = new Request(); + req->host = host; + req->port = port; + req->callback = callback; + req->canceled = false; + std::unique_lock<std::mutex> lock(mutex_); + req->next = request_; + request_ = req; + cond_.notify_one(); + return req; + } + + void cancel(void* ptr) override { + auto req = reinterpret_cast<Request*>(ptr); + req->canceled = true; + std::unique_lock<std::mutex> lock(mutex_); + if (request_ == req) { + request_ = req->next; + delete req; + } else { + for (auto r = request_; r->next; r = r->next) { + if (r->next == req) { + r->next = req->next; + delete req; + return; + } + } + } + } + +protected: + struct Request { + Request* next; + std::string host; + uint16_t port; + Callback callback; + bool canceled; + io::auto_fd fd; + bool connected; + std::string error; + }; + + void event(int fd, uint8_t event) { + assert(fd == pipe_.read()); + if (event & Looper::EVENT_READ) { + while (true) { + auto ret = io::read(fd, buf_.get() + fill_, sizeof(Request*) - fill_); + if (ret == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) return; + assert(false); + return; + } else if (ret == 0) { + assert(false); + return; + } + fill_ += ret; + if (fill_ == sizeof(Request*)) { + fill_ = 0; + auto req = *reinterpret_cast<Request**>(buf_.get()); + if (!req->canceled) { + auto err = req->fd ? nullptr : req->error.c_str(); + req->callback(req->fd.release(), req->connected, err); + } + delete req; + } else { + break; + } + } + } else { + assert(false); + } + } + + void report(Request* req, int fd, bool connected, char const* errmsg) { + req->fd.reset(fd); + req->connected = connected; + if (errmsg) req->error = errmsg; + io::write_all(pipe_.write(), &req, sizeof(Request*)); + } + + void worker() { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG | AI_NUMERICSERV; + char tmp[10]; + while (true) { + Request* req; + { + std::unique_lock<std::mutex> lock(mutex_); + while (!quit_ && !request_) { + cond_.wait(lock); + } + if (quit_) return; + auto pr = &request_; + while ((*pr)->next) { + pr = &((*pr)->next); + } + req = *pr; + *pr = nullptr; + } + snprintf(tmp, sizeof(tmp), "%u", static_cast<unsigned int>(req->port)); + struct addrinfo* result; + auto ret = getaddrinfo(req->host.c_str(), tmp, &hints, &result); + if (ret != 0) { + report(req, -1, false, gai_strerror(ret)); + continue; + } + auto rp = result; + for (; rp; rp = rp->ai_next) { + io::auto_fd fd(socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)); + if (!fd) continue; + + fcntl(fd.get(), F_SETFL, O_NONBLOCK); + + while (true) { + ret = connect(fd.get(), rp->ai_addr, rp->ai_addrlen); + if (ret == 0 || errno != EINTR) break; + } + if (ret == 0) { + report(req, fd.release(), true, nullptr); + break; + } + if (errno == EINPROGRESS) { + report(req, fd.release(), false, nullptr); + break; + } + } + + if (!rp) { + freeaddrinfo(result); + report(req, -1, false, strerror(errno)); + continue; + } + + freeaddrinfo(result); + } + } + + Looper* const looper_; + Request* request_; + io::auto_pipe pipe_; + std::mutex mutex_; + std::condition_variable cond_; + std::unique_ptr<char[]> buf_; + size_t fill_; + bool quit_; + std::vector<std::thread> threads_; +}; + +} // namespace + +// static +Resolver* Resolver::create(Looper* looper) { + return new ResolverImpl(looper); +} diff --git a/src/resolver.hh b/src/resolver.hh new file mode 100644 index 0000000..d98458f --- /dev/null +++ b/src/resolver.hh @@ -0,0 +1,28 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef RESOLVER_HH +#define RESOLVER_HH + +#include <functional> + +class Looper; + +class Resolver { +public: + typedef std::function<void(int fd, bool connected, + char const* error)> Callback; + + virtual ~Resolver() {} + + static Resolver* create(Looper* looper); + + virtual void* request(std::string const& host, uint16_t port, + Callback const& callback) = 0; + virtual void cancel(void* ptr) = 0; + +protected: + Resolver() {} + Resolver(Resolver const&) = delete; +}; + +#endif // RESOLVER_HH diff --git a/src/strings.cc b/src/strings.cc new file mode 100644 index 0000000..1d74dee --- /dev/null +++ b/src/strings.cc @@ -0,0 +1,94 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include "character.hh" +#include "strings.hh" + +// static +void Strings::trim(std::string const& str, size_t* start, size_t* end) { + assert(start && end); + assert(*start <= *end); + assert(*end <= str.size()); + while (*start < *end && Character::isspace(str, *start)) ++(*start); + while (*end > *start && Character::isspace(str, *end - 1)) --(*end); +} + +// static +std::string Strings::trim(std::string const& str) { + return trim(str, 0, str.size()); +} + +// static +std::string Strings::trim(std::string const& str, size_t start, size_t end) { + auto s = start, e = end; + trim(str, &s, &e); + return str.substr(s, e - s); +} + +// static +std::string Strings::quote(std::string const& str) { + return quote(str, 0, str.size()); +} + +// static +std::string Strings::quote(std::string const& str, size_t start, size_t end) { + assert(start <= end); + assert(end <= str.size()); + auto i = start; + while (i < end) { + auto c = str[i]; + if (c == '"' || c == '\\') break; + ++i; + } + std::string ret("\""); + if (i == end) { + ret.append(str.substr(start, end - start)); + } else { + ret.append(str.substr(start, i - start)); + while (true) { + ret.push_back('\\'); + auto j = i++; + while (i < end) { + auto c = str[i]; + if (c == '"' || c == '\\') break; + ++i; + } + ret.append(str.substr(j, i - j)); + if (i == end) break; + } + } + ret.push_back('"'); + return ret; +} + +// static +std::string Strings::unquote(std::string const& str) { + return unquote(str, 0, str.size()); +} + +// static +std::string Strings::unquote(std::string const& str, size_t start, size_t end) { + assert(start < end - 1); + assert(end <= str.size()); + assert(str[start] == '"' && str[end - 1] == '"'); + ++start; + --end; + auto i = start; + while (i < end && str[i] != '\\') ++i; + if (i == end) return str.substr(start, end - start); + std::string ret(str.substr(start, i - start)); + while (true) { + ++i; // skip backslash + if (i == end) { + assert(false); + ret.push_back('\\'); + break; + } + auto j = i++; + while (i < end && str[i] != '\\') ++i; + ret.append(str.substr(j, i - j)); + if (i == end) break; + } + return ret; +} diff --git a/src/strings.hh b/src/strings.hh new file mode 100644 index 0000000..10e3edf --- /dev/null +++ b/src/strings.hh @@ -0,0 +1,24 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef STRINGS_HH +#define STRINGS_HH + +#include <string> + +class Strings { +public: + static void trim(std::string const& str, size_t* start, size_t* end); + static std::string trim(std::string const& str); + static std::string trim(std::string const& str, size_t start, size_t end); + + static std::string quote(std::string const& str); + static std::string quote(std::string const& str, size_t start, size_t end); + static std::string unquote(std::string const& str); + static std::string unquote(std::string const& str, size_t start, size_t end); + +private: + ~Strings() {} + Strings() {} +}; + +#endif // STRINGS_HH diff --git a/src/terminal.cc b/src/terminal.cc new file mode 100644 index 0000000..fd7bf98 --- /dev/null +++ b/src/terminal.cc @@ -0,0 +1,20 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <cstring> +#include <sys/ioctl.h> +#include <unistd.h> + +#include "terminal.hh" + +// static +Terminal::Size Terminal::size() { + struct winsize size; + memset(&size, 0, sizeof(size)); + ioctl(STDOUT_FILENO, TIOCGWINSZ, &size); + if (size.ws_col == 0 || size.ws_row == 0) { + return { .width = 80, .height = 25 }; + } + return { .width = size.ws_col, .height = size.ws_row }; +} diff --git a/src/terminal.hh b/src/terminal.hh new file mode 100644 index 0000000..414ee4b --- /dev/null +++ b/src/terminal.hh @@ -0,0 +1,22 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef TERMINAL_HH +#define TERMINAL_HH + +#include <cstdint> + +class Terminal { +public: + struct Size { + uint32_t width; + uint32_t height; + }; + + static Size size(); + +private: + Terminal() { } + ~Terminal() { } +}; + +#endif // TERMINAL_HH diff --git a/src/url.cc b/src/url.cc new file mode 100644 index 0000000..b4e6c0f --- /dev/null +++ b/src/url.cc @@ -0,0 +1,901 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <algorithm> +#include <cstring> +#include <iostream> +#include <memory> + +#include "paths.hh" +#include "url.hh" + +namespace { + +class UrlImpl : public Url { +public: + UrlImpl() + : userinfo_(nullptr), userinfo_unescaped_(nullptr), port_(0), + path_(nullptr), path_unescaped_(nullptr), query_(nullptr), + query_unescaped_(nullptr), fragment_(nullptr) { + } + ~UrlImpl() override { + delete[] userinfo_; + if (userinfo_unescaped_ != userinfo_) delete[] userinfo_unescaped_; + delete[] path_; + if (path_unescaped_ != path_) delete[] path_unescaped_; + delete[] query_; + if (query_unescaped_ != query_) delete[] query_unescaped_; + delete[] fragment_; + } + UrlImpl(UrlImpl const& url); + + bool parse(std::string const& url, Url const* base); + + Url* copy() const override { + return new UrlImpl(*this); + } + + std::string const& scheme() const override { + return scheme_; + } + + char const* userinfo() const override; + + char const* userinfo_escaped() const override { + return userinfo_; + } + + std::string const& host() const override { + return host_; + } + + uint16_t port() const override { + return port_; + } + + std::string path() const override; + + std::string path_escaped() const override { + return path_; + } + + bool query(std::string const& name, std::string* value) const override; + + char const* full_query() const override; + + char const* full_query_escaped() const override { + return query_; + } + + char const* fragment() const override { + return fragment_; + } + + void print(std::ostream& out, bool path = true, bool query = true, + bool fragment = true) const override; + +private: + char const* parse_authority(char const* pos); + char const* parse_query(char const* pos); + char const* parse_fragment(char const* pos); + bool relative(std::string const& url, Url const* base); + + std::string scheme_; + char* userinfo_; + mutable char* userinfo_unescaped_; + std::string host_; + uint16_t port_; + char* path_; + mutable char* path_unescaped_; + char* query_; + mutable char* query_unescaped_; + char* fragment_; +}; + +char* unescape(char* start, char* end, bool query); +char* escape(char const* str, char const* safe); +char* dup(char const* start, char const* end); +void lower(std::string& str); + +bool is_ipv4(char const* start, char const* end); +bool is_ipv6(char const* start, char const* end); + +char const subdelims[] = "!$&'()*+,;="; + +UrlImpl::UrlImpl(UrlImpl const& url) + : Url(), scheme_(url.scheme_), host_(url.host_), port_(url.port_) { + userinfo_ = dup(url.userinfo_, nullptr); + userinfo_unescaped_ = dup(url.userinfo_unescaped_, nullptr); + path_ = dup(url.path_, nullptr); + path_unescaped_ = dup(url.path_unescaped_, nullptr); + query_ = dup(url.query_, nullptr); + query_unescaped_ = dup(url.query_unescaped_, nullptr); + fragment_ = dup(url.fragment_, nullptr); +} + +inline bool is_alpha(char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); +} + +inline bool is_digit(char c) { + return c >= '0' && c <= '9'; +} + +inline bool is_unreserved(char c) { + return is_alpha(c) || is_digit(c) || c == '-' || c == '.' || c == '_' || + c == '~'; +} + +inline bool is_hex(char c) { + return is_digit(c) || (c >= 'A' && c <= 'F') || (c >= 'a' && c <= 'f'); +} + +char* emptystr() { + char* ret = new char[1]; + ret[0] = '\0'; + return ret; +} + +bool UrlImpl::parse(std::string const& url, Url const* base) { + char const* pos = url.c_str(), *start; + + // URI = scheme ":" hier-part [ "?" query ] [ "#" fragment ] + + // scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) + start = pos; + if (!is_alpha(*pos)) { + return relative(url, base); + } + pos++; + while (is_alpha(*pos) || is_digit(*pos) || *pos == '+' || + *pos == '-' || *pos == '.') { + pos++; + } + if (*pos != ':') { + return relative(url, base); + } + scheme_.assign(start, pos); + lower(scheme_); + pos++; + + /* + hier-part = "//" authority path-abempty + / path-absolute + / path-rootless + / path-empty + */ + if (memcmp(pos, "//", 2)) { + if (pos[0] == '/' && pos[1] != '/') { + // path-absolute = "/" [ segment-nz *( "/" segment ) ] + start = pos++; + } else if (pos[0] != '/') { + /* + path-rootless = segment-nz *( "/" segment ) + path-empty = 0<pchar> + */ + start = pos; + } + while (true) { + if (is_unreserved(*pos)) { + pos++; + } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) { + pos += 3; + } else if (*pos && strchr(subdelims, *pos)) { + pos++; + } else if (*pos == ':' || *pos == '@' || *pos == '/') { + pos++; + } else { + break; + } + } + + path_ = dup(start, pos); + } else { + pos += 2; + pos = parse_authority(pos); + if (!pos) { + return relative(url, base); + } + + if (*pos == '/') { + /* + path-abempty = *( "/" segment ) + segment = *pchar + pchar = unreserved / pct-encoded / sub-delims / ":" / "@" + */ + start = pos++; + while (true) { + if (is_unreserved(*pos)) { + pos++; + } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) { + pos += 3; + } else if (*pos && strchr(subdelims, *pos)) { + pos++; + } else if (*pos == ':' || *pos == '@' || *pos == '/') { + pos++; + } else { + break; + } + } + + path_ = dup(start, pos); + } else { + path_ = emptystr(); + } + } + + if (*pos == '?') { + pos = parse_query(pos); + if (!pos) { + return relative(url, base); + } + } + + if (*pos == '#') { + pos = parse_fragment(pos); + if (!pos) { + return relative(url, base); + } + } + + if (*pos != '\0') { + return relative(url, base); + } + + return true; +} + +bool UrlImpl::relative(std::string const& url, Url const* base) { + char const* pos = url.c_str(); + char const* start; + if (!base) return false; + + scheme_.clear(); + if (userinfo_unescaped_ != userinfo_) { + delete[] userinfo_unescaped_; userinfo_unescaped_ = nullptr; + } + delete[] userinfo_; userinfo_ = nullptr; + host_.clear(); + port_ = 0; + if (path_unescaped_ != path_) { + delete[] path_unescaped_; path_unescaped_ = nullptr; + } + delete[] path_; path_ = nullptr; + if (query_unescaped_ != query_) { + delete[] query_unescaped_; query_unescaped_ = nullptr; + } + delete[] query_; query_ = nullptr; + delete[] fragment_; fragment_ = nullptr; + + /* + relative-part = "//" authority path-abempty + / path-absolute + / path-noscheme + / path-empty + */ + if (memcmp(pos, "//", 2)) { + if (pos[0] == '/' && pos[1] != '/') { + // path-absolute = "/" [ segment-nz *( "/" segment ) ] + start = pos++; + } else if (pos[0] != '/') { + /* + path-noscheme = segment-nz-nc *( "/" segment ) + path-empty = 0<pchar> + */ + start = pos; + } + while (true) { + if (is_unreserved(*pos)) { + pos++; + } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) { + pos += 3; + } else if (*pos && strchr(subdelims, *pos)) { + pos++; + } else if (*pos == ':' || *pos == '@' || *pos == '/') { + pos++; + } else { + break; + } + } + + path_ = dup(start, pos); + } else { + pos += 2; + pos = parse_authority(pos); + if (!pos) return false; + + if (*pos == '/') { + /* + path-abempty = *( "/" segment ) + segment = *pchar + pchar = unreserved / pct-encoded / sub-delims / ":" / "@" + */ + start = pos++; + while (true) { + if (is_unreserved(*pos)) { + pos++; + } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) { + pos += 3; + } else if (*pos && strchr(subdelims, *pos)) { + pos++; + } else if (*pos == ':' || *pos == '@' || *pos == '/') { + pos++; + } else { + break; + } + } + + path_ = dup(start, pos); + } else { + path_ = emptystr(); + } + } + + if (*pos == '?') { + pos = parse_query(pos); + if (!pos) return false; + } + + if (*pos == '#') { + pos = parse_fragment(pos); + if (!pos) return false; + } + + if (*pos != '\0') return false; + + scheme_ = base->scheme(); + if (host_.empty()) { + auto userinfo = base->userinfo(); + if (userinfo) userinfo_ = dup(userinfo, userinfo + strlen(userinfo)); + host_ = base->host(); + port_ = base->port(); + } + if (path_[0] != '/') { + std::string tmp( + Paths::join(!base->path().empty() ? base->path() : "/", path_)); + delete[] path_; + path_ = dup(tmp.data(), tmp.data() + tmp.size()); + } + return true; +} + +char const* UrlImpl::parse_authority(char const* pos) { + /* authority = [ userinfo "@" ] host [ ":" port ] + userinfo = *( unreserved / pct-encoded / sub-delims / ":" ) + host = IP-literal / IPv4address / reg-name + IP-literal = "[" ( IPv6address / IPvFuture ) "]" + IPvFuture = "v" 1*HEXDIG "." 1*( unreserved / sub-delims / ":" ) + IPv6address = 6( h16 ":" ) ls32 + / "::" 5( h16 ":" ) ls32 + / [ h16 ] "::" 4( h16 ":" ) ls32 + / [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32 + / [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32 + / [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32 + / [ *4( h16 ":" ) h16 ] "::" ls32 + / [ *5( h16 ":" ) h16 ] "::" h16 + / [ *6( h16 ":" ) h16 ] "::" + ls32 = ( h16 ":" h16 ) / IPv4address + h16 = 1*4HEXDIG + IPv4address = dec-octet "." dec-octet "." dec-octet "." dec-octet + dec-octet = DIGIT ; 0-9 + / %x31-39 DIGIT ; 10-99 + / "1" 2DIGIT ; 100-199 + / "2" %x30-34 DIGIT ; 200-249 + / "25" %x30-35 ; 250-255 + reg-name = *( unreserved / pct-encoded / sub-delims ) + port = *DIGIT + */ + char const* start = pos; + char const* at = nullptr, *colon = nullptr; + char const* host_start, *host_end, *tmp; + while (true) { + if (is_unreserved(*pos)) { + pos++; + } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) { + pos += 3; + } else if (*pos && strchr(subdelims, *pos)) { + pos++; + } else if (*pos == '[' || *pos == ']') { + pos++; + } else if (*pos == ':') { + colon = pos++; + } else if (*pos == '@') { + if (at) { + return nullptr; + } + colon = nullptr; + at = pos++; + } else { + break; + } + } + + // userinfo? + if (at) { + host_start = at + 1; + userinfo_ = dup(start, at); + if (strchr(userinfo_, '[') || strchr(userinfo_, ']')) { + return nullptr; + } + } else { + host_start = start; + } + + // port? + if (colon) { + tmp = colon + 1; + if (tmp < pos && is_digit(*tmp)) { + uint16_t v = *tmp - '0'; + tmp++; + host_end = colon; + while (tmp < pos) { + uint16_t x; + if (!is_digit(*tmp)) { + host_end = pos; + break; + } + x = v; + v *= 10; + if (v < x) { + return nullptr; + } + v += *tmp - '0'; + tmp++; + } + if (host_end == colon) { + port_ = v; + } + } else { + host_end = pos; + } + } else { + host_end = pos; + } + + if (*host_start == '[') { + if (host_end[-1] != ']' || host_start + 1 >= host_end - 1) { + return nullptr; + } + host_start++; + host_end--; + if (*host_start == 'v') { + if (!is_hex(host_start[1]) || host_start[2] != '.') { + return nullptr; + } + host_.assign(host_start, host_end - host_start); + lower(host_); + } else { + if (!is_ipv6(host_start, host_end)) { + return nullptr; + } + host_.assign(host_start, host_end - host_start); + lower(host_); + } + if (host_.find('[') != std::string::npos || + host_.find(']') != std::string::npos) { + return nullptr; + } + } else { + tmp = host_start; + while (tmp < host_end) { + if (*tmp == '[' || *tmp == ']') { + return nullptr; + } + tmp++; + } + tmp = unescape(const_cast<char*>(host_start), const_cast<char*>(host_end), + false); + host_ = tmp; + lower(host_); + delete[] const_cast<char*>(tmp); + } + if (host_.empty()) return nullptr; + return pos; +} + +char const* UrlImpl::parse_query(char const* pos) { + // query = *( pchar / "/" / "?" ) + char const* start = ++pos; + while (true) { + if (is_unreserved(*pos)) { + pos++; + } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) { + pos += 3; + } else if (*pos && strchr(subdelims, *pos)) { + pos++; + } else if (*pos == ':' || *pos == '@' || *pos == '/' || + *pos == '?') { + pos++; + } else { + break; + } + } + + query_ = dup(start, pos); + return pos; +} + +char const* UrlImpl::parse_fragment(char const* pos) { + // fragment = *( pchar / "/" / "?" ) + char const* start = ++pos; + while (true) { + if (is_unreserved(*pos)) { + pos++; + } else if (*pos == '%' && is_hex(pos[1]) && is_hex(pos[2])) { + pos += 3; + } else if (*pos && strchr(subdelims, *pos)) { + pos++; + } else if (*pos == ':' || *pos == '@' || *pos == '/' || + *pos == '?') { + pos++; + } else { + break; + } + } + + fragment_ = unescape((char*) start, (char*) pos, false); + return pos; +} + +char const* UrlImpl::userinfo() const { + if (userinfo_ && !userinfo_unescaped_) { + userinfo_unescaped_ = unescape(userinfo_, nullptr, false); + } + return userinfo_unescaped_; +} + +std::string UrlImpl::path() const { + if (!path_unescaped_) { + path_unescaped_ = unescape(path_, nullptr, false); + } + return path_unescaped_; +} + +bool UrlImpl::query(std::string const& name, std::string* value) const { + char* pos; + if (!query_ || !*query_) { + return false; + } + pos = query_; + while (true) { + char* next = pos, *eq; + char* tmp; + while (*next && *next != '=' && *next != '&') { + next++; + } + if (*next == '=') { + eq = next++; + while (*next && *next != '&') { + next++; + } + } else { + eq = next; + } + tmp = unescape(pos, eq, true); + if (name.compare(tmp) == 0) { + delete[] tmp; + if (eq != next) { + std::unique_ptr<char[]> tmp(unescape(eq + 1, next, true)); + if (value) value->assign(tmp.get()); + return true; + } else { + if (value) value->clear(); + return true; + } + } + delete[] tmp; + if (!*next) { + return false; + } + pos = next + 1; + } +} + +char const* UrlImpl::full_query() const { + if (query_ && !query_unescaped_) { + query_unescaped_ = unescape(query_, nullptr, true); + } + return query_unescaped_; +} + +bool unhex(char c, uint8_t* ret) { + if (is_digit(c)) { + *ret = c - '0'; + return true; + } else if (c >= 'A' && c <= 'F') { + *ret = c - 'A' + 10; + return true; + } else if (c >= 'a' && c <= 'f') { + *ret = c - 'a' + 10; + return true; + } + return false; +} + +char* unescape(char* start, char* end, bool query) { + char* pos = start; + char* ret; + size_t o; + if (!start) { + return nullptr; + } + if (end) { + for (; pos < end; pos++) { + if (*pos == '%' || (query && *pos == '+')) { + break; + } + } + if (pos == end) { + return dup(start, end); + } + } else { + for (; *pos; pos++) { + if (*pos == '%' || (query && *pos == '+')) { + break; + } + } + if (!*pos) { + return start; + } + end = pos + strlen(pos); + } + ret = new char[end - start + 1]; + o = 0; + while (true) { + uint8_t h, l; + memcpy(ret + o, start, pos - start); + o += pos - start; + if (query && *pos == '+') { + ret[o++] = ' '; + pos++; + } else if (pos + 3 <= end && unhex(pos[1], &h) && unhex(pos[2], &l)) { + ret[o++] = h << 4 | l; + pos += 3; + } else { + ret[o++] = *(pos++); + } + start = pos; + while (pos < end && *pos != '%' && !(query && *pos == '+')) { + pos++; + } + if (pos == end) { + memcpy(ret + o, start, pos - start); + o += pos - start; + break; + } + } + ret[o] = '\0'; + return ret; +} + +bool unsafe(char c, char const* safe) { + if (is_unreserved(c) || (c && strchr(safe, c))) { + return false; + } + return true; +} + +char hex(uint8_t c) { + return c < 10 ? '0' + c : 'A' + c - 10; +} + +char* escape(char const* str, char const* safe) { + char const* pos = str; + size_t len; + char* ret; + size_t o; + while (*pos && !unsafe(*pos, safe)) { + pos++; + } + if (!*pos) { + return (char*) str; + } + len = strlen(pos); + ret = new char[(pos - str) + len * 3 + 1]; + o = 0; + while (true) { + memcpy(ret + o, str, pos - str); + o += pos - str; + ret[o++] = '%'; + ret[o++] = hex(*((const uint8_t*)pos) >> 4); + ret[o++] = hex(*((const uint8_t*)pos) & 0xf); + str = ++pos; + while (*pos && !unsafe(*pos, safe)) { + pos++; + } + if (!*pos) { + memcpy(ret + o, str, pos - str); + o += pos - str; + break; + } + } + ret[o] = '\0'; + return ret; +} + +char* dup(char const* start, char const* end) { + if (!start) return nullptr; + if (!end) end = start + strlen(start); + size_t len = end - start; + char* ret; + assert(start <= end); + ret = new char[len + 1]; + memcpy(ret, start, len); + ret[len] = '\0'; + return ret; +} + +inline char ascii_tolower(char c) { + return (c >= 'A' && c <= 'Z') ? ('a' + c - 'A') : c; +} + +void lower(std::string& str) { + std::transform(str.begin(), str.end(), str.begin(), ascii_tolower); +} + +bool is_ipv4(char const* start, char const* end) { + char const* pos = start; + size_t i = 0; + while (true) { + if (pos[0] == '2' && pos + 2 <= end) { + if (pos[1] == '5') { + if (pos[2] < '0' || pos[2] > '5') { + break; + } + } else { + if (pos[1] < '0' || pos[1] > '4' || !is_digit(pos[2])) { + break; + } + } + pos += 3; + } else if (pos[0] == '1' && pos + 2 <= end && + is_digit(pos[1]) && is_digit(pos[2])) { + pos += 3; + } else if ((pos[0] >= '1' && pos[0] <= '9') && pos + 1 <= end && + is_digit(pos[1])) { + pos += 2; + } else if (is_digit(pos[0])) { + pos++; + } else { + break; + } + + i++; + if (pos == end || *pos != '.') { + break; + } + pos++; + if (pos == end) { + return false; + } + } + return i == 4 && pos == end; +} + +size_t walk_hex(char const* start, char const* end, size_t max) { + size_t i = 0; + while (max-- && start < end) { + if (!is_hex(*start)) break; + start++; + i++; + } + return i; +} + +bool is_ipv6(char const* start, char const* end) { + size_t i = 0, j = 0, x; + char const* pos = start; + bool empty = false; + while (pos < end) { + if (*pos == ':') { + pos++; + if (*pos == ':') { + empty = true; + pos++; + break; + } + if (i == 0) { + return false; + } + } + x = walk_hex(pos, end, 4); + if (x == 0) { + return false; + } + pos += x; + i++; + } + + if (pos < end) { + while (true) { + x = walk_hex(pos, end, 4); + if (x == 0) { + return false; + } + pos += x; + if (pos == end) { + j++; + break; + } + if (*pos != ':') { + pos -= x; + break; + } + pos++; + j++; + } + } + + if (pos != end) { + if (!is_ipv4(pos, end)) { + return false; + } + j += 2; + } + + if (!empty) { + return i == 8; + } + + if (i + j > 7) { + return false; + } + return true; +} + +void UrlImpl::print(std::ostream& out, bool path, bool query, bool fragment) + const { + out << scheme_; + if (!host_.empty()) { + out << "://"; + if (userinfo_) { + out << userinfo_ << '@'; + } + out << host_; + if (port_) { + out << ':' << port_; + } + } else { + out << ':'; + } + if (path) { + out << path_; + } + if (query && query_ && *query_) { + out << '?' << query_; + } + if (fragment && fragment_ && *fragment_) { + out << '#'; + char* tmp = escape(fragment_, "/?"); + out << tmp; + if (tmp != fragment_) delete[] tmp; + } +} + +bool null_eq(char const* a1, char const* a2) { + if (a1 == a2) return true; + return a1 && a2 && strcmp(a1, a2) == 0; +} + +} // namespace + +// static +Url* Url::parse(std::string const& url, Url const* base) { + UrlImpl* ret = new UrlImpl(); + if (ret->parse(url, base)) return ret; + delete ret; + return nullptr; +} + +bool Url::operator==(Url const& url) const { + if (scheme() != url.scheme()) return false; + if (host() != url.host()) return false; + if (port() != url.port()) return false; + if (path_escaped() != url.path_escaped()) return false; + if (!null_eq(userinfo_escaped(), url.userinfo_escaped())) return false; + if (!null_eq(full_query_escaped(), url.full_query_escaped())) return false; + if (!null_eq(fragment(), url.fragment())) return false; + return true; +} + diff --git a/src/url.hh b/src/url.hh new file mode 100644 index 0000000..d3b69b7 --- /dev/null +++ b/src/url.hh @@ -0,0 +1,62 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef URL_HH +#define URL_HH + +#include <cstdint> +#include <string> + +// scheme://[userinfo@]host[:port]/path[?query][#fragment] +class Url { +public: + virtual ~Url() {} + + static Url* parse(std::string const& url, Url const* base = nullptr); + + virtual Url* copy() const = 0; + + bool operator==(Url const& url) const; + + // never empty, unescaped + virtual std::string const& scheme() const = 0; + + // may be null, not empty, unescaped + virtual char const* userinfo() const = 0; + + // may be null, not empty + virtual char const* userinfo_escaped() const = 0; + + // may be empty, unescaped + virtual std::string const& host() const = 0; + + // 0 is used for no port set + virtual uint16_t port() const = 0; + + // may be empty, unescaped + virtual std::string path() const = 0; + + // may be empty + virtual std::string path_escaped() const = 0; + + // return true if name was found and value updated, + // value may be empty, unescaped + virtual bool query(std::string const& name, std::string* value) const = 0; + + // may be null or empty, unescaped + virtual char const* full_query() const = 0; + + // may be null or empty + virtual char const* full_query_escaped() const = 0; + + // may be null or empty, unescaped + virtual char const* fragment() const = 0; + + virtual void print(std::ostream& out, bool path = true, bool query = true, + bool fragment = true) const = 0; + +protected: + Url() {} + Url(Url const&) = delete; +}; + +#endif // URL_HH diff --git a/src/xdg.cc b/src/xdg.cc new file mode 100644 index 0000000..d230e07 --- /dev/null +++ b/src/xdg.cc @@ -0,0 +1,135 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include "paths.hh" +#include "xdg.hh" + +#include <map> +#include <memory> +#include <pwd.h> +#include <sys/types.h> +#include <unistd.h> + +namespace { + +std::string home_based_dir(char const* env_name, char const* fallback) { + auto env = getenv(env_name); + if (!env || !*env) { + return Paths::join(XDG::home(), fallback); + } + return Paths::cleanup(env); +} + +void dirs(char const* dirs, std::vector<std::string>* ret) { + auto start = dirs; + auto i = dirs; + while (true) { + if (*i == ':' || !*i) { + if (i > start) { + auto tmp = Paths::cleanup(std::string(start, i - start)); + if (tmp[0] == '/') { + ret->push_back(tmp); + } + } + if (!*i) break; + start = ++i; + } else { + ++i; + } + } +} + +void dirs(char const* env_name, char const* fallback, + std::vector<std::string>* ret) { + auto env = getenv(env_name); + if (!env || !*env) { + dirs(fallback, ret); + } else { + dirs(env, ret); + } +} + +void remove_dupes(std::vector<std::string>* lst) { + std::multimap<size_t, std::string*> mem; + bool modifying = false; + auto out = lst->begin(); + for (auto it = lst->begin(); it != lst->end(); ++it) { + auto len = it->size(); + auto pair = mem.equal_range(len); + bool remove = false; + for (auto j = pair.first; j != pair.second; ++j) { + if (*(j->second) == *it) { + remove = true; + break; + } + } + if (remove) { + if (!modifying) { + modifying = true; + out = it; + } + } else { + if (modifying) { + *(out++) = *it; + } + mem.insert(pair.first, std::make_pair(len, &(*it))); + } + } + if (modifying) { + lst->resize(out - lst->begin()); + } +} + +} // namespace + +// static +std::string XDG::config_home() { + return home_based_dir("XDG_CONFIG_HOME", ".config"); +} + +// static +std::vector<std::string> XDG::config_dirs() { + std::vector<std::string> ret; + ret.push_back(config_home()); + dirs("XDG_CONFIG_DIRS", SYSCONFDIR "/xdg", &ret); + remove_dupes(&ret); + return ret; +} + +// static +std::string XDG::data_home() { + return home_based_dir("XDG_DATA_HOME", ".local/share"); +} + +// static +std::vector<std::string> XDG::data_dirs() { + std::vector<std::string> ret; + ret.push_back(data_home()); + dirs("XDG_DATA_DIRS", "/usr/local/share/:/usr/share/", &ret); + remove_dupes(&ret); + return ret; +} + +// static +std::string XDG::cache_home() { + return home_based_dir("XDG_CACHE_HOME", ".cache"); +} + +// static +std::string XDG::home() { + auto env = getenv("HOME"); + if (!env || !*env) { + auto size = sysconf(_SC_GETPW_R_SIZE_MAX); + if (size == -1) size = 16384; + auto buf = std::unique_ptr<char[]>(new char[size]); + struct passwd pwd; + struct passwd* result; + auto ret = getpwuid_r(getuid(), &pwd, buf.get(), size, &result); + if (result && ret == 0) { + return Paths::cleanup(result->pw_dir); + } + return "."; + } + return Paths::cleanup(env); +} diff --git a/src/xdg.hh b/src/xdg.hh new file mode 100644 index 0000000..99d1c24 --- /dev/null +++ b/src/xdg.hh @@ -0,0 +1,25 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef XDG_HH +#define XDG_HH + +#include <string> +#include <vector> + +class XDG { +public: + static std::string config_home(); + // Returned list is sorted in order of importance, directories may not exist + static std::vector<std::string> config_dirs(); + static std::string data_home(); + // Returned list is sorted in order of importance, directories may not exist + static std::vector<std::string> data_dirs(); + static std::string cache_home(); + static std::string home(); + +private: + XDG() {} + ~XDG() {} +}; + +#endif // XDG_HH |
