summaryrefslogtreecommitdiff
path: root/src/http_protocol.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/http_protocol.cc')
-rw-r--r--src/http_protocol.cc625
1 files changed, 625 insertions, 0 deletions
diff --git a/src/http_protocol.cc b/src/http_protocol.cc
new file mode 100644
index 0000000..5915f77
--- /dev/null
+++ b/src/http_protocol.cc
@@ -0,0 +1,625 @@
+// -*- mode: c++; c-basic-offset: 2; -*-
+
+#include "common.hh"
+
+#include <deque>
+#include <memory>
+#include <string.h>
+
+#if HAVE_ZLIB
+#include <zlib.h>
+#endif
+#if HAVE_BZIP2
+#include <bzlib.h>
+#endif
+
+#include "chunked.hh"
+#include "gui_attrtext.hh"
+#include "gui_hexdump.hh"
+#include "http.hh"
+#include "http_protocol.hh"
+#include "utf.hh"
+
+namespace {
+
+class Filter {
+public:
+ virtual ~Filter() {}
+
+ virtual void set_output(Filter* filter) = 0;
+ virtual void write(void const* data, size_t size, bool last) = 0;
+ virtual void error() = 0;
+ virtual void incomplete() = 0;
+
+protected:
+ Filter() {}
+ Filter(Filter const&) = delete;
+ Filter& operator=(Filter const&) = delete;
+};
+
+class AbstractFilter : public Filter {
+public:
+ void set_output(Filter* filter) override {
+ output_ = filter;
+ }
+
+ void error() override {
+ if (output_) output_->error();
+ }
+
+ void incomplete() override {
+ if (output_) output_->incomplete();
+ }
+
+protected:
+ AbstractFilter()
+ : output_(nullptr) {
+ }
+
+ Filter* output_;
+};
+
+class ChunkedFilter : public AbstractFilter {
+public:
+ ChunkedFilter()
+ : chunked_(Chunked::create(std::bind(&ChunkedFilter::output, this,
+ std::placeholders::_1,
+ std::placeholders::_2))) {
+ }
+
+ void write(void const* data, size_t size, bool last) override {
+ if (!chunked_->good()) return;
+ auto ptr = reinterpret_cast<char const*>(data);
+ while (size > 0) {
+ size_t used;
+ if (buffer_.empty()) {
+ used = chunked_->add(ptr, size);
+ } else {
+ size_t old = buffer_.size();
+ buffer_.append(ptr, size);
+ used = chunked_->add(buffer_.data(), buffer_.size());
+ if (used < old) {
+ buffer_.erase(0, used);
+ size = 0;
+ break;
+ }
+ buffer_.clear();
+ used -= old;
+ }
+ if (used == 0) break;
+ ptr += used;
+ size -= used;
+ }
+ if (size > 0) {
+ buffer_.append(ptr, size);
+ }
+ if (!chunked_->good()) {
+ error();
+ } else if (last) {
+ if (chunked_->eof()) {
+ if (!buffer_.empty()) {
+ error();
+ }
+ } else {
+ incomplete();
+ }
+ }
+ }
+
+
+private:
+ void output(void const* data, size_t size) {
+ if (output_) output_->write(data, size, chunked_->eof());
+ }
+
+ std::string buffer_;
+ std::unique_ptr<Chunked> chunked_;
+};
+
+#if HAVE_ZLIB
+class DeflateFilter : public AbstractFilter {
+public:
+ DeflateFilter()
+ : error_(false), first_try_(true), eof_(false) {
+ memset(&stream_, 0, sizeof(stream_));
+ auto ret = inflateInit2(&stream_, 15 + 32);
+ assert(ret == Z_OK);
+ }
+
+ ~DeflateFilter() override {
+ inflateEnd(&stream_);
+ }
+
+ void write(void const* data, size_t size, bool last) override {
+ if (error_) return;
+ stream_.next_in = const_cast<Bytef*>(reinterpret_cast<Bytef const*>(data));
+ stream_.avail_in = static_cast<uInt>(size);
+ while (stream_.avail_in) {
+ auto status = inflate();
+ if (status == Z_BUF_ERROR) {
+ break;
+ } else if (status == Z_STREAM_END) {
+ eof_ = true;
+ break;
+ } else if (status == Z_NEED_DICT) {
+ error();
+ return;
+ } else if (status != Z_OK) {
+ if (first_try_) {
+ first_try_ = false;
+ if (inflateReset2(&stream_, -15) == Z_OK) {
+ stream_.next_in = const_cast<Bytef*>(
+ reinterpret_cast<Bytef const*>(data));
+ stream_.avail_in = static_cast<uInt>(size);
+ continue;
+ }
+ }
+ error();
+ return;
+ }
+ }
+
+ if ((last || eof_) && stream_.avail_in) {
+ error();
+ return;
+ }
+
+ if (last && !eof_) {
+ incomplete();
+ return;
+ }
+ }
+
+private:
+ int inflate() {
+ Bytef buf[8196];
+ while (true) {
+ stream_.next_out = buf;
+ stream_.avail_out = sizeof(buf);
+
+ auto status = ::inflate(&stream_, Z_NO_FLUSH);
+ auto used = sizeof(buf) - stream_.avail_out;
+ if (used) {
+ first_try_ = false;
+ if (output_) output_->write(buf, used, eof_ || status == Z_STREAM_END);
+ }
+ if (status != Z_OK) {
+ return status;
+ }
+ }
+ }
+
+ void error() override {
+ if (error_) return;
+ error_ = true;
+ AbstractFilter::error();
+ }
+
+ struct z_stream_s stream_;
+ bool error_;
+ bool first_try_;
+ bool eof_;
+};
+
+// gzip and deflate use the same compression
+// algo but different containers. But it is
+// common enough that clients/servers mix
+// them together (or do a Microsoft, a.k.a.
+// something completely against spec)
+// so instead we use one filter that tries
+// some different offsets given to inflateInit2
+class GZipFilter : public DeflateFilter {
+};
+#endif
+
+#if HAVE_BZIP2
+class Bzip2Filter : public AbstractFilter {
+public:
+ Bzip2Filter()
+ : error_(false), eof_(false) {
+ memset(&stream_, 0, sizeof(stream_));
+ auto ret = BZ2_bzDecompressInit(&stream_, 0, 0);
+ assert(ret == BZ_OK);
+ }
+
+ ~Bzip2Filter() override {
+ BZ2_bzDecompressEnd(&stream_);
+ }
+
+ void write(void const* data, size_t size, bool last) override {
+ if (error_) return;
+ if (eof_ && size) {
+ error();
+ return;
+ }
+
+ stream_.next_in = const_cast<char*>(reinterpret_cast<char const*>(data));
+ stream_.avail_in = static_cast<unsigned>(size);
+
+ while (stream_.avail_in) {
+ char buf[4096];
+ stream_.next_out = buf;
+ stream_.avail_out = sizeof(buf);
+
+ auto ret = BZ2_bzDecompress(&stream_);
+ if (output_) {
+ output_->write(buf, stream_.next_out - buf, ret == BZ_STREAM_END);
+ }
+ if (ret == BZ_DATA_ERROR || ret == BZ_DATA_ERROR_MAGIC) {
+ error();
+ return;
+ }
+ if (ret == BZ_STREAM_END) {
+ eof_ = true;
+ if (stream_.avail_in) error();
+ break;
+ }
+ if (ret != BZ_OK) {
+ assert(false);
+ error();
+ return;
+ }
+ }
+
+ if (last && stream_.avail_in) {
+ error();
+ return;
+ }
+
+ if (last && !eof_) {
+ incomplete();
+ return;
+ }
+ }
+
+private:
+ void error() override {
+ if (error_) return;
+ assert(false);
+ error_ = true;
+ AbstractFilter::error();
+ }
+
+ bz_stream stream_;
+ bool error_;
+ bool eof_;
+};
+#endif
+
+class OutputFilter : public Filter {
+public:
+ void set_output(Filter*) override {
+ assert(false);
+ }
+};
+
+class HexOutput : public OutputFilter {
+public:
+ HexOutput(AttributedText* text)
+ : text_(text) {
+ }
+
+ void write(void const* data, size_t size, bool last) override {
+ buffer_.append(reinterpret_cast<char const*>(data), size);
+ if (last) {
+ HexDump::write(text_, HexDump::ADDRESS | HexDump::CHARS, buffer_.data(),
+ 0, buffer_.size());
+ }
+ }
+
+ void error() override {
+ if (!buffer_.empty()) write("", 0, true); // Write out what we got
+ text_->append("\nDecoding failed, invalid data\n");
+ }
+
+ void incomplete() override {
+ if (!buffer_.empty()) write("", 0, true); // Write out what we got
+ text_->append("\nNeed more data...\n");
+ }
+
+private:
+ AttributedText* const text_;
+ std::string buffer_;
+};
+
+class TextOutput : public OutputFilter {
+public:
+ TextOutput(AttributedText* text)
+ : text_(text) {
+ }
+
+ void write(void const* data, size_t size, bool last) override {
+ if (hex_) {
+ hex_->write(data, size, last);
+ return;
+ }
+ auto d = reinterpret_cast<char const*>(data);
+ if (!buf_.empty()) {
+ buf_.append(d, size);
+ for (size_t i = 0; i < 4; ++i) {
+ if (i >= buf_.size()) break;
+ if (valid_utf8(buf_.data(), buf_.size() - i)) {
+ if (last && i > 0) break;
+ text_->append(buf_.data(), buf_.size() - i);
+ buf_.erase(0, buf_.size() - i);
+ return;
+ }
+ }
+ } else {
+ for (size_t i = 0; i < 4; ++i) {
+ if (i >= size) break;
+ if (valid_utf8(d, size - i)) {
+ if (last && i > 0) break;
+ text_->append(d, size - i);
+ if (i > 0) buf_.append(d + size - i, i);
+ return;
+ }
+ }
+ }
+
+ buf_.assign(text_->text());
+ buf_.append(d, size);
+ text_->clear();
+ hex_.reset(new HexOutput(text_));
+ hex_->write(buf_.data(), buf_.size(), last);
+ buf_.clear();
+ }
+
+ void error() override {
+ text_->append("\nDecoding failed, invalid data\n");
+ }
+
+ void incomplete() override {
+ text_->append("\nNeed more data...\n");
+ }
+
+private:
+ AttributedText* const text_;
+ std::string buf_;
+ std::unique_ptr<HexOutput> hex_;
+};
+
+class StreamOutput : public OutputFilter {
+public:
+ StreamOutput(std::ostream* out)
+ : out_(out) {
+ }
+
+ void write(void const* data, size_t size, bool last) override {
+ out_->write(reinterpret_cast<char const*>(data), size);
+ if (last) out_->flush();
+ }
+
+ void error() override {
+ *out_ << "\nDecoding failed, invalid data\n";
+ }
+
+ void incomplete() override {
+ *out_ << "\nNeed more data...\n";
+ }
+
+private:
+ std::ostream* const out_;
+};
+
+static Filter* match_compress_filter(HeaderTokenIterator* token) {
+#if HAVE_ZLIB
+ if (token->token_equal("deflate")) {
+ return new DeflateFilter();
+ }
+ if (token->token_equal("gzip") || token->token_equal("x-gzip")) {
+ return new GZipFilter();
+ }
+#endif
+#if HAVE_BZIP2
+ if (token->token_equal("bzip2")) {
+ return new Bzip2Filter();
+ }
+#endif
+ return nullptr;
+}
+
+class HttpMatch : public Protocol::Match {
+ static std::string const HTTP;
+
+protected:
+ HttpMatch(Http const* http)
+ : http_(http) {
+ }
+
+ std::string const& name() const override {
+ return HTTP;
+ }
+
+ void full(void const* data, size_t size, AttributedText* text) override {
+ auto iter = http_->header();
+ while (iter->valid()) {
+ text->append(iter->name());
+ text->append(": ");
+ text->append(iter->value());
+ text->append("\n");
+ iter->next();
+ }
+ text->append("\n");
+
+ print_content(data, size, [=](bool print_as_text) -> OutputFilter* {
+ if (print_as_text) return new TextOutput(text);
+ return new HexOutput(text);
+ });
+ }
+
+ bool content(void const* data, size_t size, std::ostream* out) override {
+ print_content(data, size, [=](bool) {
+ return new StreamOutput(out);
+ });
+ return true;
+ }
+
+protected:
+ void set_http(Http* http) {
+ http_ = http;
+ }
+
+private:
+ void print_content(void const* data, size_t size,
+ std::function<OutputFilter*(bool)> const& factory) const {
+ bool print_as_text;
+ auto token = http_->header_tokens("content-type");
+ if (token->valid() && token->token_equal("text")) {
+ print_as_text = true;
+ } else {
+ print_as_text = false;
+ }
+
+ std::deque<std::unique_ptr<Filter>> filters;
+ token = http_->header_tokens("content-encoding");
+ while (token->valid()) {
+ if (!token->token_equal("identity")) {
+ auto filter = match_compress_filter(token.get());
+ if (filter) {
+ filters.emplace_front(filter);
+ } else {
+ print_as_text = false;
+ // If there is a unknown content encoding then the next ones
+ // in the chain can't work, so reset the list
+ filters.clear();
+ }
+ }
+ token->next();
+ }
+
+ token = http_->header_tokens("transfer-encoding");
+ while (token->valid()) {
+ if (token->token_equal("chunked")) {
+ filters.emplace_front(new ChunkedFilter());
+ } else if (!token->token_equal("identity")) {
+ auto filter = match_compress_filter(token.get());
+ if (filter) {
+ filters.emplace_front(filter);
+ } else {
+ print_as_text = false;
+ // If there is a unknown transfer encoding then the next ones
+ // in the chain can't work, so reset the list
+ filters.clear();
+ }
+ }
+ token->next();
+ }
+
+ filters.emplace_back(factory(print_as_text));
+ auto it = filters.begin();
+ auto it2 = it + 1;
+ while (it2 != filters.end()) {
+ (*it)->set_output(it2->get());
+ it = it2++;
+ }
+ filters.front()->write(reinterpret_cast<char const*>(data) + http_->size(),
+ size - http_->size(), true);
+ }
+
+ Http const* http_;
+};
+
+// static
+std::string const HttpMatch::HTTP = "HTTP";
+
+class ResponseMatch : public HttpMatch {
+public:
+ ResponseMatch(std::unique_ptr<HttpResponse>&& resp)
+ : HttpMatch(resp.get()), resp_(std::move(resp)) {
+ }
+
+ void full(void const* data, size_t size, AttributedText* text) override {
+ check_ptr(data, size);
+ text->append(resp_->proto());
+ text->append("/");
+ char tmp[50];
+ snprintf(tmp, sizeof(tmp), "%u.%u %u ", resp_->proto_version().major,
+ resp_->proto_version().minor, resp_->status_code());
+ text->append(tmp);
+ text->append(resp_->status_message());
+ text->append("\n");
+ HttpMatch::full(data, size, text);
+ }
+
+ bool content(void const* data, size_t size, std::ostream* out) override {
+ check_ptr(data, size);
+ return HttpMatch::content(data, size, out);
+ }
+
+private:
+ void check_ptr(void const* data, size_t size) {
+ // This check allows us to use copy == false
+ if (data == resp_->data()) return;
+ resp_.reset(HttpResponse::parse(
+ reinterpret_cast<char const*>(data), size, false));
+ set_http(resp_.get());
+ }
+
+ std::unique_ptr<HttpResponse> resp_;
+};
+
+class RequestMatch : public HttpMatch {
+public:
+ RequestMatch(std::unique_ptr<HttpRequest>&& req)
+ : HttpMatch(req.get()), req_(std::move(req)) {
+ }
+
+ void full(void const* data, size_t size, AttributedText* text) override {
+ check_ptr(data, size);
+ text->append(req_->method());
+ text->append(" ");
+ text->append(req_->url());
+ text->append(" ");
+ text->append(req_->proto());
+ text->append("/");
+ char tmp[50];
+ snprintf(tmp, sizeof(tmp), "%u.%u\n", req_->proto_version().major,
+ req_->proto_version().minor);
+ text->append(tmp);
+ HttpMatch::full(data, size, text);
+ }
+
+ bool content(void const* data, size_t size, std::ostream* out) override {
+ check_ptr(data, size);
+ return HttpMatch::content(data, size, out);
+ }
+
+private:
+ void check_ptr(void const* data, size_t size) {
+ // This check allows us to use copy == false
+ if (data == req_->data()) return;
+ req_.reset(HttpRequest::parse(
+ reinterpret_cast<char const*>(data), size, false));
+ set_http(req_.get());
+ }
+
+ std::unique_ptr<HttpRequest> req_;
+};
+
+class HttpProtocolImpl : public HttpProtocol {
+public:
+ HttpProtocolImpl() {
+ }
+
+ Match* match(void const* data, size_t size) const override {
+ auto resp = std::unique_ptr<HttpResponse>(
+ HttpResponse::parse(reinterpret_cast<char const*>(data), size, false));
+ if (resp && resp->good()) {
+ return new ResponseMatch(std::move(resp));
+ }
+ auto req = std::unique_ptr<HttpRequest>(
+ HttpRequest::parse(reinterpret_cast<char const*>(data), size, false));
+ if (req && req->good()) {
+ return new RequestMatch(std::move(req));
+ }
+ return nullptr;
+ }
+};
+
+} // namespace
+
+// static
+HttpProtocol* HttpProtocol::create() {
+ return new HttpProtocolImpl();
+}