diff options
| -rw-r--r-- | configure.ac | 22 | ||||
| -rw-r--r-- | m4/ax_check_bzlib.m4 | 75 | ||||
| -rw-r--r-- | m4/ax_check_zlib.m4 | 142 | ||||
| -rw-r--r-- | src/.gitignore | 5 | ||||
| -rw-r--r-- | src/Makefile.am | 32 | ||||
| -rw-r--r-- | src/http.cc | 154 | ||||
| -rw-r--r-- | src/http.hh | 18 | ||||
| -rw-r--r-- | src/http_protocol.cc | 625 | ||||
| -rw-r--r-- | src/http_protocol.hh | 13 | ||||
| -rw-r--r-- | src/protocol-main.cc | 119 | ||||
| -rw-r--r-- | src/protocol.hh | 44 | ||||
| -rw-r--r-- | src/protocols.cc | 477 | ||||
| -rw-r--r-- | src/protocols.hh | 51 | ||||
| -rw-r--r-- | src/proxy.cc | 30 | ||||
| -rw-r--r-- | test/test-http.cc | 60 |
15 files changed, 1826 insertions, 41 deletions
diff --git a/configure.ac b/configure.ac index c49c2e3..e649ca1 100644 --- a/configure.ac +++ b/configure.ac @@ -179,6 +179,27 @@ AM_CONDITIONAL([HAVE_GTK],[test "x$have_gtk" = "x1"]) AM_CONDITIONAL([HAVE_GUI],[test "x$have_gtk" = "x1" -o "x$have_qt" = "x1"]) +comp_names='' +# zlib +AX_CHECK_ZLIB([ + comp_names="${comp_names}zlib " + AC_DEFINE([HAVE_ZLIB], [1], [define to 1 if you have zlib]) + ZLIB_CFLAGS="-I${ZLIB_HOME}/include" + ZLIB_LIBS="-L${ZLIB_HOME}/lib -lz" + ]) +AC_SUBST([ZLIB_CFLAGS]) +AC_SUBST([ZLIB_LIBS]) + +# bzip2 +AX_CHECK_BZLIB([ + comp_names="${comp_names}bzip2 " + AC_DEFINE([HAVE_BZIP2], [1], [define to 1 if you have bzip2]) + BZIP2_CFLAGS="-I${BZLIB_HOME}/include" + BZIP2_LIBS="-L${BZLIB_HOME}/lib -lbz2" + ]) +AC_SUBST([BZIP2_CFLAGS]) +AC_SUBST([BZIP2_LIBS]) + AC_ARG_ENABLE([update-mimedb], [AC_HELP_STRING([--disable-update-mimedb], [disable the update-mime-database after install])], [enable_update_mimedb=$enableval], @@ -196,4 +217,5 @@ AC_CONFIG_HEADERS([src/config.h]) AC_OUTPUT([Makefile src/Makefile data/Makefile test/Makefile]) AC_MSG_NOTICE([SSL library used: $ssl_name]) +AC_MSG_NOTICE([Compression libs: $comp_names]) AC_MSG_NOTICE([GUI toolkit: $gui_name]) diff --git a/m4/ax_check_bzlib.m4 b/m4/ax_check_bzlib.m4 new file mode 100644 index 0000000..8c79c68 --- /dev/null +++ b/m4/ax_check_bzlib.m4 @@ -0,0 +1,75 @@ +AC_DEFUN([AX_CHECK_BZLIB], +# +# Handle user hints +# +[AC_MSG_CHECKING(if bzip2 is wanted) +bzlib_places="/usr/local /usr /opt/local /sw" +AC_ARG_WITH([bzip2], +[ --with-bzip2=DIR root directory path of bzip2 installation @<:@defaults to + /usr/local or /usr if not found in /usr/local@:>@ + --without-bzip2 to disable bzip2 usage completely], +[if test "$withval" != no ; then + AC_MSG_RESULT(yes) + if test -d "$withval" + then + bzlib_places="$withval $bzlib_places" + else + AC_MSG_WARN([Sorry, $withval does not exist, checking usual places]) + fi +else + bzlib_places= + AC_MSG_RESULT(no) +fi], +[AC_MSG_RESULT(yes)]) + +# +# Locate bzip2, if wanted +# +if test -n "${bzlib_places}" +then + # check the user supplied or any other more or less 'standard' place: + # Most UNIX systems : /usr/local and /usr + # MacPorts / Fink on OSX : /opt/local respectively /sw + for bzlib_HOME in ${bzlib_places} ; do + if test -f "${bzlib_HOME}/include/bzlib.h"; then break; fi + bzlib_HOME="" + done + + bzlib_OLD_LDFLAGS=$LDFLAGS + bzlib_OLD_CPPFLAGS=$CPPFLAGS + if test -n "${bzlib_HOME}"; then + LDFLAGS="$LDFLAGS -L${bzlib_HOME}/lib" + CPPFLAGS="$CPPFLAGS -I${bzlib_HOME}/include" + fi + AC_LANG_SAVE + AC_LANG_C + AC_CHECK_LIB([bz2], [BZ2_bzDecompressInit], [bzlib_cv_libbz2=yes], [bzlib_cv_libbz2=no]) + AC_CHECK_HEADER([bzlib.h], [bzlib_cv_bzlib_h=yes], [bzlib_cv_bzlib_h=no]) + AC_LANG_RESTORE + if test "$bzlib_cv_libbz2" = "yes" && test "$bzlib_cv_bzlib_h" = "yes" + then + # + # If both library and header were found, action-if-found + # + m4_ifblank([$1],[ + CPPFLAGS="$CPPFLAGS -I${bzlib_HOME}/include" + LDFLAGS="$LDFLAGS -L${bzlib_HOME}/lib" + LIBS="-lbz2 $LIBS" + AC_DEFINE([HAVE_BZLIB], [1], + [Define to 1 if you have `bz2' library (-lbz2)]) + ],[ + # Restore variables + LDFLAGS="$bzlib_OLD_LDFLAGS" + CPPFLAGS="$bzlib_OLD_CPPFLAGS" + $1 + ]) + else + # + # If either header or library was not found, action-if-not-found + # + m4_default([$2],[ + AC_MSG_ERROR([either specify a valid bzip2 installation with --with-bzip2=DIR or disable bzip2 usage with --without-bzip2]) + ]) + fi +fi +]) diff --git a/m4/ax_check_zlib.m4 b/m4/ax_check_zlib.m4 new file mode 100644 index 0000000..7a8a800 --- /dev/null +++ b/m4/ax_check_zlib.m4 @@ -0,0 +1,142 @@ +# =========================================================================== +# https://www.gnu.org/software/autoconf-archive/ax_check_zlib.html +# =========================================================================== +# +# SYNOPSIS +# +# AX_CHECK_ZLIB([action-if-found], [action-if-not-found]) +# +# DESCRIPTION +# +# This macro searches for an installed zlib library. If nothing was +# specified when calling configure, it searches first in /usr/local and +# then in /usr, /opt/local and /sw. If the --with-zlib=DIR is specified, +# it will try to find it in DIR/include/zlib.h and DIR/lib/libz.a. If +# --without-zlib is specified, the library is not searched at all. +# +# If either the header file (zlib.h) or the library (libz) is not found, +# shell commands 'action-if-not-found' is run. If 'action-if-not-found' is +# not specified, the configuration exits on error, asking for a valid zlib +# installation directory or --without-zlib. +# +# If both header file and library are found, shell commands +# 'action-if-found' is run. If 'action-if-found' is not specified, the +# default action appends '-I${ZLIB_HOME}/include' to CPFLAGS, appends +# '-L$ZLIB_HOME}/lib' to LDFLAGS, prepends '-lz' to LIBS, and calls +# AC_DEFINE(HAVE_LIBZ). You should use autoheader to include a definition +# for this symbol in a config.h file. Sample usage in a C/C++ source is as +# follows: +# +# #ifdef HAVE_LIBZ +# #include <zlib.h> +# #endif /* HAVE_LIBZ */ +# +# LICENSE +# +# Copyright (c) 2008 Loic Dachary <loic@senga.org> +# Copyright (c) 2010 Bastien Chevreux <bach@chevreux.org> +# +# This program is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation; either version 2 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program. If not, see <https://www.gnu.org/licenses/>. +# +# As a special exception, the respective Autoconf Macro's copyright owner +# gives unlimited permission to copy, distribute and modify the configure +# scripts that are the output of Autoconf when processing the Macro. You +# need not follow the terms of the GNU General Public License when using +# or distributing such scripts, even though portions of the text of the +# Macro appear in them. The GNU General Public License (GPL) does govern +# all other use of the material that constitutes the Autoconf Macro. +# +# This special exception to the GPL applies to versions of the Autoconf +# Macro released by the Autoconf Archive. When you make and distribute a +# modified version of the Autoconf Macro, you may extend this special +# exception to the GPL to apply to your modified version as well. + +#serial 15 + +AU_ALIAS([CHECK_ZLIB], [AX_CHECK_ZLIB]) +AC_DEFUN([AX_CHECK_ZLIB], +# +# Handle user hints +# +[AC_MSG_CHECKING(if zlib is wanted) +zlib_places="/usr/local /usr /opt/local /sw" +AC_ARG_WITH([zlib], +[ --with-zlib=DIR root directory path of zlib installation @<:@defaults to + /usr/local or /usr if not found in /usr/local@:>@ + --without-zlib to disable zlib usage completely], +[if test "$withval" != no ; then + AC_MSG_RESULT(yes) + if test -d "$withval" + then + zlib_places="$withval $zlib_places" + else + AC_MSG_WARN([Sorry, $withval does not exist, checking usual places]) + fi +else + zlib_places= + AC_MSG_RESULT(no) +fi], +[AC_MSG_RESULT(yes)]) + +# +# Locate zlib, if wanted +# +if test -n "${zlib_places}" +then + # check the user supplied or any other more or less 'standard' place: + # Most UNIX systems : /usr/local and /usr + # MacPorts / Fink on OSX : /opt/local respectively /sw + for ZLIB_HOME in ${zlib_places} ; do + if test -f "${ZLIB_HOME}/include/zlib.h"; then break; fi + ZLIB_HOME="" + done + + ZLIB_OLD_LDFLAGS=$LDFLAGS + ZLIB_OLD_CPPFLAGS=$CPPFLAGS + if test -n "${ZLIB_HOME}"; then + LDFLAGS="$LDFLAGS -L${ZLIB_HOME}/lib" + CPPFLAGS="$CPPFLAGS -I${ZLIB_HOME}/include" + fi + AC_LANG_SAVE + AC_LANG_C + AC_CHECK_LIB([z], [inflateEnd], [zlib_cv_libz=yes], [zlib_cv_libz=no]) + AC_CHECK_HEADER([zlib.h], [zlib_cv_zlib_h=yes], [zlib_cv_zlib_h=no]) + AC_LANG_RESTORE + if test "$zlib_cv_libz" = "yes" && test "$zlib_cv_zlib_h" = "yes" + then + # + # If both library and header were found, action-if-found + # + m4_ifblank([$1],[ + CPPFLAGS="$CPPFLAGS -I${ZLIB_HOME}/include" + LDFLAGS="$LDFLAGS -L${ZLIB_HOME}/lib" + LIBS="-lz $LIBS" + AC_DEFINE([HAVE_ZLIB], [1], + [Define to 1 if you have `z' library (-lz)]) + ],[ + # Restore variables + LDFLAGS="$ZLIB_OLD_LDFLAGS" + CPPFLAGS="$ZLIB_OLD_CPPFLAGS" + $1 + ]) + else + # + # If either header or library was not found, action-if-not-found + # + m4_default([$2],[ + AC_MSG_ERROR([either specify a valid zlib installation with --with-zlib=DIR or disable zlib usage with --without-zlib]) + ]) + fi +fi +]) diff --git a/src/.gitignore b/src/.gitignore index 5c67110..4e68b1a 100644 --- a/src/.gitignore +++ b/src/.gitignore @@ -1,11 +1,14 @@ /config.h /config.h.in~ /libattrstr.a +/libhexdump.a +/libmitm.a /libmonitor.a /libmonitor_gui.a -/libmitm.a +/libprotocol.a /libproxy.a /libtp.a +/protocol /tp /tp-genca /tp-monitor diff --git a/src/Makefile.am b/src/Makefile.am index 9622262..3ace3ea 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -5,8 +5,9 @@ AM_CXXFLAGS = @DEFINES@ # Remove ar: `u' modifier ignored since `D' is the default (see `U') ARFLAGS = cr -bin_PROGRAMS = tp tp-monitor -noinst_LIBRARIES = libtp.a libproxy.a libmonitor.a libattrstr.a +bin_PROGRAMS = tp tp-monitor protocol +noinst_LIBRARIES = libtp.a libproxy.a libmonitor.a libattrstr.a libprotocol.a \ + libhexdump.a if HAVE_SSL bin_PROGRAMS += tp-genca noinst_LIBRARIES += libmitm.a @@ -55,15 +56,20 @@ tp_genca_SOURCES = genca.cc logger.cc logger.hh tp_genca_LDADD = libmitm.a libtp.a @SSL_LIBS@ tp_genca_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' -libmonitor_a_SOURCES = monitor.cc monitor.hh gui_hexdump.cc gui_hexdump.hh +libhexdump_a_SOURCES = gui_hexdump.cc gui_hexdump.hh + +libmonitor_a_SOURCES = monitor.cc monitor.hh libmonitor_a_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ -tp_monitor_SOURCES = monitor-cmd.cc ios_save.hh resolver.cc resolver.hh \ - gui_plainattrtext.cc gui_plainattrtext.hh \ - gui_attrtext.cc gui_attrtext.hh -tp_monitor_LDADD = libmonitor.a libtp.a @THREAD_LIBS@ +tp_monitor_SOURCES = monitor-cmd.cc ios_save.hh resolver.cc resolver.hh +tp_monitor_LDADD = libmonitor.a libhexdump.a libattrstr.a libtp.a @THREAD_LIBS@ tp_monitor_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ +libprotocol_a_SOURCES = protocols.hh protocols.cc protocol.hh \ + http_protocol.hh http_protocol.cc +libprotocol_a_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' \ + @THREAD_CFLAGS@ @ZLIB_CFLAGS@ @BZIP2_CFLAGS@ + libmonitor_gui_a_SOURCES = monitor-gui.cc gui_config.cc gui_config.hh \ gui_about.hh gui_file.hh gui_main.hh \ gui_window.hh gui_formapply.hh gui_menu.hh \ @@ -75,11 +81,18 @@ libmonitor_gui_a_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' \ libattrstr_a_SOURCES = gui_attrtext.cc gui_attrtext.hh \ gui_htmlattrtext.cc gui_htmlattrtext.hh \ + gui_plainattrtext.cc gui_plainattrtext.hh \ observers.hh +protocol_SOURCES = protocol-main.cc +protocol_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' @THREAD_CFLAGS@ +protocol_LDADD = libprotocol.a libhexdump.a libattrstr.a libtp.a @THREAD_LIBS@ \ + @ZLIB_LIBS@ @BZIP2_LIBS@ + tp_monitor_gtk_SOURCES = gui_gtk.cc tp_monitor_gtk_LDADD = libmonitor_gui.a libattrstr.a libproxy.a libmonitor.a \ - libtp.a @GTK_LIBS@ @THREAD_LIBS@ @PCAP_LIBS@ + libhexdump.a libtp.a @GTK_LIBS@ @THREAD_LIBS@ \ + @PCAP_LIBS@ if HAVE_SSL tp_monitor_gtk_LDADD += libmitm.a @SSL_LIBS@ endif @@ -88,7 +101,8 @@ tp_monitor_gtk_CXXFLAGS = $(AM_CXXFLAGS) -DVERSION='"@VERSION@"' \ tp_monitor_qt_SOURCES = gui_qt.cc tp_monitor_qt_LDADD = libmonitor_gui.a libattrstr.a libmonitor.a libproxy.a \ - libtp.a @QT_LIBS@ @THREAD_LIBS@ @PCAP_LIBS@ + libhexdump.a libtp.a @QT_LIBS@ @THREAD_LIBS@ \ + @PCAP_LIBS@ if HAVE_SSL tp_monitor_qt_LDADD += libmitm.a @SSL_LIBS@ endif diff --git a/src/http.cc b/src/http.cc index 26911cb..3b8a67b 100644 --- a/src/http.cc +++ b/src/http.cc @@ -34,6 +34,28 @@ inline char upper_ascii(char c) { return (c >= 'a' && c <= 'z') ? (c - 'a' + 'A') : c; } +inline bool is_lws(char c) { + return c == ' ' || c == '\t'; +} + +inline bool is_char(char c) { + return !(c & 0x80); +} + +inline bool is_ctl(char c) { + return c < ' ' || c == 0x7f; +} + +inline bool is_separator(char c) { + return is_lws(c) || c == '(' || c == ')' || c == '<' || c == '>' || c == '@' + || c == ',' || c == ';' || c == ':' || c == '\\' || c == '\"' || c == '/' + || c == '[' || c == ']' || c == '?' || c == '=' || c == '{' || c == '}'; +} + +inline bool is_token(char c) { + return is_char(c) && !is_ctl(c) && !is_separator(c); +} + bool lower_equal(char const* data, size_t start, size_t end, std::string const& str) { assert(start <= end); @@ -134,6 +156,122 @@ private: std::string const filter_; }; +class HeaderTokenIteratorImpl : public HeaderTokenIterator { +public: + HeaderTokenIteratorImpl(std::unique_ptr<HeaderIterator>&& header) + : header_(std::move(header)), start_(0), middle_(0), end_(0) { + check_token(); + } + + bool valid() const override { + return header_->valid(); + } + + std::string token() const override { + return make_string(header_->value().data(), start_, middle_); + } + + bool token_equal(std::string const& token) const override { + return lower_equal(header_->value().data(), start_, middle_, token); + } + + void next() override { + start_ = end_; + check_token(); + } + +private: + static size_t skip_lws(std::string const& str, size_t pos) { + while (pos < str.size() && is_lws(str[pos])) ++pos; + return pos; + } + + static size_t skip_token(std::string const& str, size_t pos) { + assert(is_token(str[pos])); + ++pos; + while (pos < str.size() && is_token(str[pos])) ++pos; + return pos; + } + + static size_t skip_quoted(std::string const& str, size_t pos) { + assert(str[pos] == '"'); + ++pos; + while (pos < str.size()) { + if (str[pos] == '\\') { + pos += 2; + } else if (str[pos] == '\"') { + ++pos; + break; + } else { + ++pos; + } + } + return pos; + } + + void check_token() { + while (true) { + if (!header_->valid()) return; + auto const& value = header_->value(); + start_ = skip_lws(value, start_); + if (start_ >= value.size()) { + header_->next(); + start_ = 0; + continue; + } + if (!is_token(value[start_])) { + if (value[start_] != ';') { + ++start_; + while (start_ < value.size() + && !(is_lws(value[start_]) || value[start_] == ',' + || value[start_] == ';')) { + ++start_; + } + if (start_ < value.size() && value[start_] != ';') { + continue; + } + } + // This will cause us to loop again after paramters + // are read + middle_ = start_; + } else { + middle_ = skip_token(value, start_); + } + end_ = middle_; + while (true) { + end_ = skip_lws(value, end_); + if (end_ == value.size() || value[end_] != ';') break; + end_ = skip_lws(value, end_ + 1); + if (!is_token(value[end_])) { + while (end_ < value.size() && !is_separator(value[end_])) ++end_; + continue; + } + end_ = skip_token(value, end_); + end_ = skip_lws(value, end_); + if (end_ == value.size() || value[end_] != '=') break; + end_ = skip_lws(value, end_ + 1); + if (end_ < value.size() && value[end_] == '"') { + end_ = skip_quoted(value, end_); + } else { + if (!is_token(value[end_])) { + while (end_ < value.size() && !is_separator(value[end_])) ++end_; + continue; + } + end_ = skip_token(value, end_); + } + } + if (end_ < value.size() && value[end_] == ',') ++end_; + if (start_ < middle_) return; + start_ = end_; + } + } + + std::unique_ptr<HeaderIterator> header_; + size_t start_; + size_t middle_; + size_t end_; +}; + class AbstractHttp : public virtual Http { public: AbstractHttp(char const* data, size_t size) @@ -166,6 +304,12 @@ public: new FilterHeaderIteratorImpl(data_, &headers_, name)); } + std::unique_ptr<HeaderTokenIterator> header_tokens(std::string const& name) + const override { + return std::unique_ptr<HeaderTokenIterator>( + new HeaderTokenIteratorImpl(header(name))); + } + size_t size() const override { return content_start_; } @@ -209,7 +353,7 @@ protected: data_ = data; } - char const* data() const { + char const* data() const override { return data_; } @@ -225,10 +369,6 @@ protected: return std::string::npos; } - static bool is_lws(char c) { - return c == ' ' || c == '\t'; - } - size_t skip_lws(size_t start, size_t end) const { assert(start <= end); while (start < end && is_lws(data_[start])) ++start; @@ -601,7 +741,7 @@ HttpResponse* HttpResponse::parse(char const* data, size_t len, bool copy) { auto ret = std::unique_ptr<UniqueHttpResponse>( new UniqueHttpResponse(data, len)); if (ret->parse() == INCOMPLETE) return nullptr; - if (copy) ret->copy(); + if (copy && ret->good()) ret->copy(); return ret.release(); } @@ -633,7 +773,7 @@ HttpRequest* HttpRequest::parse(char const* data, size_t len, bool copy) { auto ret = std::unique_ptr<UniqueHttpRequest>( new UniqueHttpRequest(data, len)); if (ret->parse() == INCOMPLETE) return nullptr; - if (copy) ret->copy(); + if (copy && ret->good()) ret->copy(); return ret.release(); } diff --git a/src/http.hh b/src/http.hh index 9a084d2..2a4353d 100644 --- a/src/http.hh +++ b/src/http.hh @@ -42,6 +42,20 @@ protected: HeaderIterator(HeaderIterator const&) = delete; }; +class HeaderTokenIterator { +public: + virtual ~HeaderTokenIterator() {} + + virtual bool valid() const = 0; + virtual std::string token() const = 0; + virtual bool token_equal(std::string const& name) const = 0; + virtual void next() = 0; + +protected: + HeaderTokenIterator() {} + HeaderTokenIterator(HeaderTokenIterator const&) = delete; +}; + class Http { public: virtual ~Http() {} @@ -55,7 +69,9 @@ public: virtual std::unique_ptr<HeaderIterator> header( std::string const& name) const = 0; std::string first_header(std::string const& name) const; - + virtual std::unique_ptr<HeaderTokenIterator> header_tokens( + std::string const& name) const = 0; + virtual char const* data() const = 0; virtual size_t size() const = 0; protected: diff --git a/src/http_protocol.cc b/src/http_protocol.cc new file mode 100644 index 0000000..5915f77 --- /dev/null +++ b/src/http_protocol.cc @@ -0,0 +1,625 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <deque> +#include <memory> +#include <string.h> + +#if HAVE_ZLIB +#include <zlib.h> +#endif +#if HAVE_BZIP2 +#include <bzlib.h> +#endif + +#include "chunked.hh" +#include "gui_attrtext.hh" +#include "gui_hexdump.hh" +#include "http.hh" +#include "http_protocol.hh" +#include "utf.hh" + +namespace { + +class Filter { +public: + virtual ~Filter() {} + + virtual void set_output(Filter* filter) = 0; + virtual void write(void const* data, size_t size, bool last) = 0; + virtual void error() = 0; + virtual void incomplete() = 0; + +protected: + Filter() {} + Filter(Filter const&) = delete; + Filter& operator=(Filter const&) = delete; +}; + +class AbstractFilter : public Filter { +public: + void set_output(Filter* filter) override { + output_ = filter; + } + + void error() override { + if (output_) output_->error(); + } + + void incomplete() override { + if (output_) output_->incomplete(); + } + +protected: + AbstractFilter() + : output_(nullptr) { + } + + Filter* output_; +}; + +class ChunkedFilter : public AbstractFilter { +public: + ChunkedFilter() + : chunked_(Chunked::create(std::bind(&ChunkedFilter::output, this, + std::placeholders::_1, + std::placeholders::_2))) { + } + + void write(void const* data, size_t size, bool last) override { + if (!chunked_->good()) return; + auto ptr = reinterpret_cast<char const*>(data); + while (size > 0) { + size_t used; + if (buffer_.empty()) { + used = chunked_->add(ptr, size); + } else { + size_t old = buffer_.size(); + buffer_.append(ptr, size); + used = chunked_->add(buffer_.data(), buffer_.size()); + if (used < old) { + buffer_.erase(0, used); + size = 0; + break; + } + buffer_.clear(); + used -= old; + } + if (used == 0) break; + ptr += used; + size -= used; + } + if (size > 0) { + buffer_.append(ptr, size); + } + if (!chunked_->good()) { + error(); + } else if (last) { + if (chunked_->eof()) { + if (!buffer_.empty()) { + error(); + } + } else { + incomplete(); + } + } + } + + +private: + void output(void const* data, size_t size) { + if (output_) output_->write(data, size, chunked_->eof()); + } + + std::string buffer_; + std::unique_ptr<Chunked> chunked_; +}; + +#if HAVE_ZLIB +class DeflateFilter : public AbstractFilter { +public: + DeflateFilter() + : error_(false), first_try_(true), eof_(false) { + memset(&stream_, 0, sizeof(stream_)); + auto ret = inflateInit2(&stream_, 15 + 32); + assert(ret == Z_OK); + } + + ~DeflateFilter() override { + inflateEnd(&stream_); + } + + void write(void const* data, size_t size, bool last) override { + if (error_) return; + stream_.next_in = const_cast<Bytef*>(reinterpret_cast<Bytef const*>(data)); + stream_.avail_in = static_cast<uInt>(size); + while (stream_.avail_in) { + auto status = inflate(); + if (status == Z_BUF_ERROR) { + break; + } else if (status == Z_STREAM_END) { + eof_ = true; + break; + } else if (status == Z_NEED_DICT) { + error(); + return; + } else if (status != Z_OK) { + if (first_try_) { + first_try_ = false; + if (inflateReset2(&stream_, -15) == Z_OK) { + stream_.next_in = const_cast<Bytef*>( + reinterpret_cast<Bytef const*>(data)); + stream_.avail_in = static_cast<uInt>(size); + continue; + } + } + error(); + return; + } + } + + if ((last || eof_) && stream_.avail_in) { + error(); + return; + } + + if (last && !eof_) { + incomplete(); + return; + } + } + +private: + int inflate() { + Bytef buf[8196]; + while (true) { + stream_.next_out = buf; + stream_.avail_out = sizeof(buf); + + auto status = ::inflate(&stream_, Z_NO_FLUSH); + auto used = sizeof(buf) - stream_.avail_out; + if (used) { + first_try_ = false; + if (output_) output_->write(buf, used, eof_ || status == Z_STREAM_END); + } + if (status != Z_OK) { + return status; + } + } + } + + void error() override { + if (error_) return; + error_ = true; + AbstractFilter::error(); + } + + struct z_stream_s stream_; + bool error_; + bool first_try_; + bool eof_; +}; + +// gzip and deflate use the same compression +// algo but different containers. But it is +// common enough that clients/servers mix +// them together (or do a Microsoft, a.k.a. +// something completely against spec) +// so instead we use one filter that tries +// some different offsets given to inflateInit2 +class GZipFilter : public DeflateFilter { +}; +#endif + +#if HAVE_BZIP2 +class Bzip2Filter : public AbstractFilter { +public: + Bzip2Filter() + : error_(false), eof_(false) { + memset(&stream_, 0, sizeof(stream_)); + auto ret = BZ2_bzDecompressInit(&stream_, 0, 0); + assert(ret == BZ_OK); + } + + ~Bzip2Filter() override { + BZ2_bzDecompressEnd(&stream_); + } + + void write(void const* data, size_t size, bool last) override { + if (error_) return; + if (eof_ && size) { + error(); + return; + } + + stream_.next_in = const_cast<char*>(reinterpret_cast<char const*>(data)); + stream_.avail_in = static_cast<unsigned>(size); + + while (stream_.avail_in) { + char buf[4096]; + stream_.next_out = buf; + stream_.avail_out = sizeof(buf); + + auto ret = BZ2_bzDecompress(&stream_); + if (output_) { + output_->write(buf, stream_.next_out - buf, ret == BZ_STREAM_END); + } + if (ret == BZ_DATA_ERROR || ret == BZ_DATA_ERROR_MAGIC) { + error(); + return; + } + if (ret == BZ_STREAM_END) { + eof_ = true; + if (stream_.avail_in) error(); + break; + } + if (ret != BZ_OK) { + assert(false); + error(); + return; + } + } + + if (last && stream_.avail_in) { + error(); + return; + } + + if (last && !eof_) { + incomplete(); + return; + } + } + +private: + void error() override { + if (error_) return; + assert(false); + error_ = true; + AbstractFilter::error(); + } + + bz_stream stream_; + bool error_; + bool eof_; +}; +#endif + +class OutputFilter : public Filter { +public: + void set_output(Filter*) override { + assert(false); + } +}; + +class HexOutput : public OutputFilter { +public: + HexOutput(AttributedText* text) + : text_(text) { + } + + void write(void const* data, size_t size, bool last) override { + buffer_.append(reinterpret_cast<char const*>(data), size); + if (last) { + HexDump::write(text_, HexDump::ADDRESS | HexDump::CHARS, buffer_.data(), + 0, buffer_.size()); + } + } + + void error() override { + if (!buffer_.empty()) write("", 0, true); // Write out what we got + text_->append("\nDecoding failed, invalid data\n"); + } + + void incomplete() override { + if (!buffer_.empty()) write("", 0, true); // Write out what we got + text_->append("\nNeed more data...\n"); + } + +private: + AttributedText* const text_; + std::string buffer_; +}; + +class TextOutput : public OutputFilter { +public: + TextOutput(AttributedText* text) + : text_(text) { + } + + void write(void const* data, size_t size, bool last) override { + if (hex_) { + hex_->write(data, size, last); + return; + } + auto d = reinterpret_cast<char const*>(data); + if (!buf_.empty()) { + buf_.append(d, size); + for (size_t i = 0; i < 4; ++i) { + if (i >= buf_.size()) break; + if (valid_utf8(buf_.data(), buf_.size() - i)) { + if (last && i > 0) break; + text_->append(buf_.data(), buf_.size() - i); + buf_.erase(0, buf_.size() - i); + return; + } + } + } else { + for (size_t i = 0; i < 4; ++i) { + if (i >= size) break; + if (valid_utf8(d, size - i)) { + if (last && i > 0) break; + text_->append(d, size - i); + if (i > 0) buf_.append(d + size - i, i); + return; + } + } + } + + buf_.assign(text_->text()); + buf_.append(d, size); + text_->clear(); + hex_.reset(new HexOutput(text_)); + hex_->write(buf_.data(), buf_.size(), last); + buf_.clear(); + } + + void error() override { + text_->append("\nDecoding failed, invalid data\n"); + } + + void incomplete() override { + text_->append("\nNeed more data...\n"); + } + +private: + AttributedText* const text_; + std::string buf_; + std::unique_ptr<HexOutput> hex_; +}; + +class StreamOutput : public OutputFilter { +public: + StreamOutput(std::ostream* out) + : out_(out) { + } + + void write(void const* data, size_t size, bool last) override { + out_->write(reinterpret_cast<char const*>(data), size); + if (last) out_->flush(); + } + + void error() override { + *out_ << "\nDecoding failed, invalid data\n"; + } + + void incomplete() override { + *out_ << "\nNeed more data...\n"; + } + +private: + std::ostream* const out_; +}; + +static Filter* match_compress_filter(HeaderTokenIterator* token) { +#if HAVE_ZLIB + if (token->token_equal("deflate")) { + return new DeflateFilter(); + } + if (token->token_equal("gzip") || token->token_equal("x-gzip")) { + return new GZipFilter(); + } +#endif +#if HAVE_BZIP2 + if (token->token_equal("bzip2")) { + return new Bzip2Filter(); + } +#endif + return nullptr; +} + +class HttpMatch : public Protocol::Match { + static std::string const HTTP; + +protected: + HttpMatch(Http const* http) + : http_(http) { + } + + std::string const& name() const override { + return HTTP; + } + + void full(void const* data, size_t size, AttributedText* text) override { + auto iter = http_->header(); + while (iter->valid()) { + text->append(iter->name()); + text->append(": "); + text->append(iter->value()); + text->append("\n"); + iter->next(); + } + text->append("\n"); + + print_content(data, size, [=](bool print_as_text) -> OutputFilter* { + if (print_as_text) return new TextOutput(text); + return new HexOutput(text); + }); + } + + bool content(void const* data, size_t size, std::ostream* out) override { + print_content(data, size, [=](bool) { + return new StreamOutput(out); + }); + return true; + } + +protected: + void set_http(Http* http) { + http_ = http; + } + +private: + void print_content(void const* data, size_t size, + std::function<OutputFilter*(bool)> const& factory) const { + bool print_as_text; + auto token = http_->header_tokens("content-type"); + if (token->valid() && token->token_equal("text")) { + print_as_text = true; + } else { + print_as_text = false; + } + + std::deque<std::unique_ptr<Filter>> filters; + token = http_->header_tokens("content-encoding"); + while (token->valid()) { + if (!token->token_equal("identity")) { + auto filter = match_compress_filter(token.get()); + if (filter) { + filters.emplace_front(filter); + } else { + print_as_text = false; + // If there is a unknown content encoding then the next ones + // in the chain can't work, so reset the list + filters.clear(); + } + } + token->next(); + } + + token = http_->header_tokens("transfer-encoding"); + while (token->valid()) { + if (token->token_equal("chunked")) { + filters.emplace_front(new ChunkedFilter()); + } else if (!token->token_equal("identity")) { + auto filter = match_compress_filter(token.get()); + if (filter) { + filters.emplace_front(filter); + } else { + print_as_text = false; + // If there is a unknown transfer encoding then the next ones + // in the chain can't work, so reset the list + filters.clear(); + } + } + token->next(); + } + + filters.emplace_back(factory(print_as_text)); + auto it = filters.begin(); + auto it2 = it + 1; + while (it2 != filters.end()) { + (*it)->set_output(it2->get()); + it = it2++; + } + filters.front()->write(reinterpret_cast<char const*>(data) + http_->size(), + size - http_->size(), true); + } + + Http const* http_; +}; + +// static +std::string const HttpMatch::HTTP = "HTTP"; + +class ResponseMatch : public HttpMatch { +public: + ResponseMatch(std::unique_ptr<HttpResponse>&& resp) + : HttpMatch(resp.get()), resp_(std::move(resp)) { + } + + void full(void const* data, size_t size, AttributedText* text) override { + check_ptr(data, size); + text->append(resp_->proto()); + text->append("/"); + char tmp[50]; + snprintf(tmp, sizeof(tmp), "%u.%u %u ", resp_->proto_version().major, + resp_->proto_version().minor, resp_->status_code()); + text->append(tmp); + text->append(resp_->status_message()); + text->append("\n"); + HttpMatch::full(data, size, text); + } + + bool content(void const* data, size_t size, std::ostream* out) override { + check_ptr(data, size); + return HttpMatch::content(data, size, out); + } + +private: + void check_ptr(void const* data, size_t size) { + // This check allows us to use copy == false + if (data == resp_->data()) return; + resp_.reset(HttpResponse::parse( + reinterpret_cast<char const*>(data), size, false)); + set_http(resp_.get()); + } + + std::unique_ptr<HttpResponse> resp_; +}; + +class RequestMatch : public HttpMatch { +public: + RequestMatch(std::unique_ptr<HttpRequest>&& req) + : HttpMatch(req.get()), req_(std::move(req)) { + } + + void full(void const* data, size_t size, AttributedText* text) override { + check_ptr(data, size); + text->append(req_->method()); + text->append(" "); + text->append(req_->url()); + text->append(" "); + text->append(req_->proto()); + text->append("/"); + char tmp[50]; + snprintf(tmp, sizeof(tmp), "%u.%u\n", req_->proto_version().major, + req_->proto_version().minor); + text->append(tmp); + HttpMatch::full(data, size, text); + } + + bool content(void const* data, size_t size, std::ostream* out) override { + check_ptr(data, size); + return HttpMatch::content(data, size, out); + } + +private: + void check_ptr(void const* data, size_t size) { + // This check allows us to use copy == false + if (data == req_->data()) return; + req_.reset(HttpRequest::parse( + reinterpret_cast<char const*>(data), size, false)); + set_http(req_.get()); + } + + std::unique_ptr<HttpRequest> req_; +}; + +class HttpProtocolImpl : public HttpProtocol { +public: + HttpProtocolImpl() { + } + + Match* match(void const* data, size_t size) const override { + auto resp = std::unique_ptr<HttpResponse>( + HttpResponse::parse(reinterpret_cast<char const*>(data), size, false)); + if (resp && resp->good()) { + return new ResponseMatch(std::move(resp)); + } + auto req = std::unique_ptr<HttpRequest>( + HttpRequest::parse(reinterpret_cast<char const*>(data), size, false)); + if (req && req->good()) { + return new RequestMatch(std::move(req)); + } + return nullptr; + } +}; + +} // namespace + +// static +HttpProtocol* HttpProtocol::create() { + return new HttpProtocolImpl(); +} diff --git a/src/http_protocol.hh b/src/http_protocol.hh new file mode 100644 index 0000000..f4a2b11 --- /dev/null +++ b/src/http_protocol.hh @@ -0,0 +1,13 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef HTTP_PROTOCOL_HH +#define HTTP_PROTOCOL_HH + +#include "protocol.hh" + +class HttpProtocol : public virtual Protocol { +public: + static HttpProtocol* create(); +}; + +#endif // HTTP_PROTOCOL_HH diff --git a/src/protocol-main.cc b/src/protocol-main.cc new file mode 100644 index 0000000..5f22f86 --- /dev/null +++ b/src/protocol-main.cc @@ -0,0 +1,119 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <fstream> +#include <iostream> +#include <unistd.h> + +#include "args.hh" +#include "gui_attrtext.hh" +#include "gui_htmlattrtext.hh" +#include "gui_plainattrtext.hh" +#include "looper.hh" +#include "protocols.hh" + +namespace { + +std::ostream* g_out; +bool g_use_html; + +class Receiver : public Protocols::Listener { +public: + Receiver(Looper* looper) + : looper_(looper) { + } + + void text(Protocols*, size_t, std::string const&, + std::unique_ptr<AttributedText>&& text) { + if (g_out) { + if (g_use_html) { + *g_out << "<html><body>" + << static_cast<HtmlAttributedText*>(text.get())->html() + << "</body></html>" << std::endl; + } else { + *g_out << text->text() << std::endl; + } + } + looper_->quit(); + } + + void content(Protocols*, size_t, std::string const&, std::ostream*) { + looper_->quit(); + } + +private: + Looper* looper_; +}; + +bool run(std::istream& in, std::ostream& out, bool html, bool content) { + g_use_html = html; + std::unique_ptr<Looper> looper(Looper::create()); + std::unique_ptr<Receiver> receiver(new Receiver(looper.get())); + std::unique_ptr<Protocols> protocols( + Protocols::create(1, 65536, 1, looper.get(), receiver.get())); + char buffer[8192]; + std::string data; + in.read(buffer, 8192); + data.append(buffer, in.gcount()); + protocols->add(42, data.data(), data.size()); + while (in.good()) { + in.read(buffer, 8192); + if (in.gcount() == 0) break; + data.append(buffer, in.gcount()); + protocols->update(42, data.data(), data.size()); + } + if (content) { + protocols->content(42, &out); + } else { + g_out = &out; + protocols->text(42); + } + looper->run(); + protocols.reset(); + g_out = nullptr; + return true; +} + +} // namespace + +// static +AttributedText* AttributedText::create() { + if (g_use_html) return HtmlAttributedText::create(); + return PlainAttributedText::create(); +} + +int main(int argc, char** argv) { + std::unique_ptr<Args> args(Args::create()); + args->add('H', "html", "generate HTML output"); + args->add('C', "content", "print content if possible"); + args->add('h', "help", "display this text and exit."); + if (!args->run(argc, argv)) { + std::cerr << "Try `protocol --help` for usage." << std::endl; + return EXIT_FAILURE; + } + if (args->is_set('h')) { + std::cout << "Usage: `protocol [OPTIONS...] [INPUT]`\n" + << "Runs protocols on INPUT or STDIN and prints result.\n" + << '\n'; + args->print_help(); + return EXIT_SUCCESS; + } + switch (args->arguments().size()) { + case 0: + return run(std::cin, std::cout, args->is_set('H'), args->is_set('C')); + case 1: { + std::ifstream in(args->arguments().front()); + if (!in.good()) { + std::cerr << "Unable to open " << args->arguments().front() + << " for reading." << std::endl; + return EXIT_FAILURE; + } + return run(in, std::cout, args->is_set('H'), args->is_set('C')); + } + default: + std::cerr << "Too many arguments.\n" + << "Try `protocol --help` for usage." << std::endl; + return EXIT_FAILURE; + } +} diff --git a/src/protocol.hh b/src/protocol.hh new file mode 100644 index 0000000..e010092 --- /dev/null +++ b/src/protocol.hh @@ -0,0 +1,44 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef PROTOCOL_HH +#define PROTOCOL_HH + +#include <stddef.h> +#include <ostream> + +class AttributedText; + +class Protocol { +public: + class Match { + public: + virtual ~Match() {} + + virtual std::string const& name() const = 0; + virtual void full(void const* data, size_t size, AttributedText* text) = 0; + virtual void append(void const* data, size_t /* offset */, size_t size, + AttributedText* text) { + full(data, size, text); + } + virtual bool content(void const* /* data */, size_t /* size */, + std::ostream* /* out */) { + return false; + } + + protected: + Match() {} + Match(Match const&) = delete; + Match& operator=(Match const&) = delete; + }; + + virtual ~Protocol() {} + + virtual Match* match(void const* data, size_t size) const = 0; + +protected: + Protocol() {} + Protocol(Protocol const&) = delete; + Protocol& operator=(Protocol const&) = delete; +}; + +#endif // PROTOCOL_HH diff --git a/src/protocols.cc b/src/protocols.cc new file mode 100644 index 0000000..a1eaa58 --- /dev/null +++ b/src/protocols.cc @@ -0,0 +1,477 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#include "common.hh" + +#include <algorithm> +#include <chrono> +#include <condition_variable> +#include <deque> +#include <mutex> +#include <string.h> +#include <thread> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "gui_attrtext.hh" +#include "gui_plainattrtext.hh" +#include "http_protocol.hh" +#include "io.hh" +#include "looper.hh" +#include "protocol.hh" +#include "protocols.hh" + +namespace { + +class ProtocolsImpl : public Protocols { +public: + ProtocolsImpl(size_t workers, size_t buffer, size_t cache, Looper* looper, + Listener* listener) + : listener_(listener), workers_size_( + std::max(static_cast<size_t>(1), workers)), buffer_size_(buffer), + cache_size_(cache), looper_(looper), + wanted_(std::string::npos), quit_(false) { + protocols_.emplace_back(HttpProtocol::create()); + wanted_poke_.open(); + if (wanted_poke_) { + looper_->add(wanted_poke_.read(), Looper::EVENT_READ, + std::bind(&ProtocolsImpl::notify, this, + std::placeholders::_1, + std::placeholders::_2)); + + } + } + + ~ProtocolsImpl() override { + { + std::unique_lock<std::mutex> lock(mutex_); + quit_ = true; + cond_.notify_all(); + } + for (auto& worker : workers_) { + worker.join(); + } + if (wanted_poke_) { + looper_->remove(wanted_poke_.read()); + } + } + + void add(size_t id, void const* data, size_t size) override { + std::unique_lock<std::mutex> lock(mutex_); + auto& entry = entries_[id]; + if (entry.data_) { + assert(false); + return; + } + entry.data_ = data; + entry.size_ = size; + + reschedule_with_lock(); + } + + void update(size_t id, void const* data, size_t size) override { + std::unique_lock<std::mutex> lock(mutex_); + auto& entry = entries_[id]; + if (!entry.data_) { + assert(false); + entries_.erase(id); + return; + } + entry.data_ = data; + if (entry.size_ == size) return; + auto old = entry.size_; + entry.size_ = size; + if (old >= buffer_size_) return; + + for (auto it = cache_.begin(); it != cache_.end(); ++it) { + if (it->id_ == id) { + queue_.emplace_back(id, it->size_, + std::move(it->match_), std::move(it->text_)); + cache_.erase(it); + if (active_.size() < workers_size_) cond_.notify_one(); + return; + } + } + } + + void remove(size_t id) override { + auto it = entries_.find(id); + if (it == entries_.end()) { + assert(false); + return; + } + + std::unique_lock<std::mutex> lock(mutex_); + for (auto it = cache_.begin(); it != cache_.end(); ++it) { + if (it->id_ == id) { + cache_.erase(it); + reschedule_with_lock(); + return; + } + } + + if (wanted_ == id) wanted_ = std::string::npos; + + for (auto it = active_.begin(); it != active_.end(); ++it) { + if (*it == id) { + active_.erase(it); + reschedule_with_lock(); + return; + } + } + + for (auto it = queue_.begin(); it != queue_.end(); ++it) { + if (it->id_ == id) { + queue_.erase(it); + reschedule_with_lock(); + return; + } + } + } + + void clear() override { + std::unique_lock<std::mutex> lock(mutex_); + cache_.clear(); + wanted_ = std::string::npos; + active_.clear(); + queue_.clear(); + } + + void text(size_t id) override { + auto it = entries_.find(id); + if (it == entries_.end()) { + assert(false); + return; + } + + std::unique_lock<std::mutex> lock(mutex_); + for (auto& entry : cache_) { + if (entry.id_ == id) { + entry.last_ = std::chrono::steady_clock::now(); + if (!entry.text_) { + // Already an active entry + assert(false); + return; + } + std::unique_ptr<AttributedText> text(std::move(entry.text_)); + std::string name(entry.match_ ? entry.match_->name() : ""); + lock.unlock(); + listener_->text(this, id, name, std::move(text)); + return; + } + } + + wanted_ = id; + + for (auto& active : active_) { + if (active == id) { + // Already working on it + return; + } + } + + for (auto it = queue_.begin(); it != queue_.end(); ++it) { + if (it->id_ == id) { + // Already in queue, remove so we can add it first + if (it == queue_.begin()) { + return; + } + auto offset = it->offset_; + std::unique_ptr<Protocol::Match> match(std::move(it->match_)); + std::unique_ptr<AttributedText> text(std::move(it->text_)); + queue_.erase(it); + queue_.emplace_front(id, offset, std::move(match), std::move(text)); + return; + } + } + + queue_.emplace_front(id); + if (active_.size() < workers_size_) cond_.notify_one(); + } + + void free(size_t id, std::unique_ptr<AttributedText>&& text) override { + std::unique_lock<std::mutex> lock(mutex_); + for (auto& entry : cache_) { + if (entry.id_ == id) { + if (entry.text_) { + // Entry updated while shown + return; + } + entry.text_.swap(text); + break; + } + } + } + + void content(size_t id, std::ostream* out) override { + auto it = entries_.find(id); + if (it == entries_.end()) { + assert(false); + return; + } + + std::unique_lock<std::mutex> lock(mutex_); + content_queue_.emplace_back(id, out); + if (active_.size() < workers_size_) cond_.notify_one(); + } + +private: + struct Entry { + void const* data_; + size_t size_; + + Entry() + : data_(nullptr), size_(0) { + } + }; + + struct CacheEntry { + std::chrono::steady_clock::time_point last_; + size_t id_; + size_t size_; + std::unique_ptr<Protocol::Match> match_; + std::unique_ptr<AttributedText> text_; + + CacheEntry(size_t id, size_t size, std::unique_ptr<Protocol::Match>&& match, + std::unique_ptr<AttributedText>&& text) + : last_(std::chrono::steady_clock::now()), id_(id), size_(size), + match_(std::move(match)), text_(std::move(text)) { + } + }; + + struct QueueEntry { + size_t id_; + size_t offset_; + std::unique_ptr<Protocol::Match> match_; + std::unique_ptr<AttributedText> text_; + + QueueEntry(size_t id) + : id_(id), offset_(0) { + } + + QueueEntry(size_t id, size_t offset, + std::unique_ptr<Protocol::Match>&& match, + std::unique_ptr<AttributedText>&& text) + : id_(id), offset_(offset), match_(std::move(match)), + text_(std::move(text)) { + } + }; + + struct ContentEntry { + size_t id_; + std::string name_; + std::ostream* out_; + + ContentEntry(size_t id, std::ostream* out) + : id_(id), out_(out) { + } + + ContentEntry(size_t id, std::string const& name, std::ostream* out) + : id_(id), name_(name), out_(out) { + } + }; + + void reschedule_with_lock() { + if (cache_.size() < cache_size_) { + size_t want = cache_.size() - cache_size_; + size_t queued = active_.size() + queue_.size(); + if (want < queued) { + want = 0; + } else { + want -= queued; + } + if (want) { + std::unordered_set<size_t> taken; + for (auto const& cache : cache_) taken.insert(cache.id_); + for (auto const& active : active_) taken.insert(active); + for (auto const& queue : queue_) taken.insert(queue.id_); + for (auto const& pair : entries_) { + if (pair.second.size_ == 0) continue; // No point caching 0 bytes + if (taken.count(pair.first)) continue; // Already in cache or queues + queue_.emplace_back(pair.first); + if (!--want) break; + } + } + } + + if (queue_.empty()) return; + if (active_.size() == workers_size_) return; + while (workers_.size() < workers_size_) { + workers_.emplace_back(&ProtocolsImpl::worker, this, buffer_size_); + } + cond_.notify_all(); + } + + void worker(size_t buffer_size) { + std::unique_ptr<AttributedText> text; + std::unique_ptr<Protocol::Match> match; + std::ostream* out = nullptr; + size_t id = 0; + size_t offset = 0; + size_t size; + std::unique_ptr<char[]> buf(new char[buffer_size]); + char const* ptr; + size_t fill; + while (true) { + { + std::unique_lock<std::mutex> lock(mutex_); + if (out) { + content_done_with_lock(id, match ? match->name() : "", out); + out = nullptr; + id = 0; + match.reset(); + } else if (text) { + cache_entry_with_lock(id, size, std::move(match), std::move(text)); + assert(!match); + assert(!text); + id = 0; + offset = 0; + } + while (true) { + if (quit_) return; + if (!content_queue_.empty()) { + id = content_queue_.front().id_; + out = content_queue_.front().out_; + content_queue_.pop_front(); + break; + } + if (!queue_.empty()) { + id = queue_.front().id_; + offset = queue_.front().offset_; + match.swap(queue_.front().match_); + text.swap(queue_.front().text_); + queue_.pop_front(); + active_.emplace_back(id); + break; + } + cond_.wait(lock); + } + + auto const& entry = entries_[id]; + size = entry.size_; + if (out) { + fill = size; + ptr = reinterpret_cast<char const*>(entry.data_); + } else { + fill = std::min(size, buffer_size); + memcpy(buf.get(), entry.data_, fill); + ptr = buf.get(); + } + } + + if (!out && !text) text.reset(AttributedText::create()); + if (!match) { + for (auto const& protocol : protocols_) { + match.reset(protocol->match(ptr, fill)); + if (match) break; + } + offset = 0; + } + if (match) { + if (out) { + if (!match->content(ptr, fill, out)) { + out->write(ptr, fill); + } + } else { + if (offset == 0) { + match->full(ptr, fill, text.get()); + } else { + match->append(ptr, offset, fill, text.get()); + } + } + } + } + } + + void cache_entry_with_lock(size_t id, size_t size, + std::unique_ptr<Protocol::Match>&& match, + std::unique_ptr<AttributedText>&& text) { + active_.erase(std::find(active_.begin(), active_.end(), id)); + while (cache_.size() >= cache_size_) { + auto oldest = cache_.end(); + for (auto it = cache_.begin(); it != cache_.end(); ++it) { + if (!it->text_) continue; // Do not remove active + if (oldest == cache_.end() || it->last_ < oldest->last_) { + oldest = it; + } + } + if (oldest == cache_.end()) break; + cache_.erase(oldest); + } + cache_.emplace_back(id, size, std::move(match), std::move(text)); + if (wanted_ == id && wanted_poke_) { + io::write(wanted_poke_.write(), "a", 1); + } + } + + void content_done_with_lock(size_t id, std::string const& name, + std::ostream* out) { + content_done_.emplace_back(id, name, out); + if (wanted_poke_) { + io::write(wanted_poke_.write(), "b", 1); + } + } + + void notify(int fd, uint8_t) { + char tmp[10]; + io::read(fd, tmp, sizeof(tmp)); + std::unique_lock<std::mutex> lock(mutex_); + while (!content_done_.empty()) { + ContentEntry entry(content_done_.front()); + content_done_.pop_front(); + lock.unlock(); + listener_->content(this, entry.id_, entry.name_, entry.out_); + lock.lock(); + } + if (wanted_ == std::string::npos) return; + for (auto& entry : cache_) { + if (entry.id_ == wanted_) { + wanted_ = std::string::npos; + entry.last_ = std::chrono::steady_clock::now(); + if (!entry.text_) { + // Already an active entry + assert(false); + return; + } + auto id = entry.id_; + std::unique_ptr<AttributedText> text(std::move(entry.text_)); + std::string name(entry.match_ ? entry.match_->name() : ""); + lock.unlock(); + listener_->text(this, id, name, std::move(text)); + return; + } + } + } + + Listener* const listener_; + size_t const workers_size_; + size_t const buffer_size_; + size_t const cache_size_; + Looper* const looper_; + std::unordered_map<size_t, Entry> entries_; + + std::vector<std::unique_ptr<Protocol>> protocols_; + std::vector<std::thread> workers_; + + io::auto_pipe wanted_poke_; + + std::mutex mutex_; + std::condition_variable cond_; + size_t wanted_; + bool quit_; + std::vector<CacheEntry> cache_; + std::vector<size_t> active_; + std::deque<QueueEntry> queue_; + + std::deque<ContentEntry> content_done_; + std::deque<ContentEntry> content_queue_; +}; + +} // namespace + +// static +Protocols* Protocols::create(size_t workers, size_t buffer, size_t cache, + Looper* looper, Listener* listener) { + assert(listener); + return new ProtocolsImpl(workers, buffer, cache, looper, listener); +} diff --git a/src/protocols.hh b/src/protocols.hh new file mode 100644 index 0000000..8f3da0e --- /dev/null +++ b/src/protocols.hh @@ -0,0 +1,51 @@ +// -*- mode: c++; c-basic-offset: 2; -*- + +#ifndef PROTOCOLS_HH +#define PROTOCOLS_HH + +#include <memory> +#include <ostream> +#include <stddef.h> + +class AttributedText; +class Looper; + +class Protocols { +public: + class Listener { + public: + virtual ~Listener() {} + + virtual void text(Protocols* protocols, + size_t id, std::string const& protocol, + std::unique_ptr<AttributedText>&& text) = 0; + virtual void content(Protocols* protocols, size_t id, + std::string const& protocol, std::ostream* out) = 0; + + protected: + Listener() {} + }; + + virtual ~Protocols() {} + + static Protocols* create(size_t workers, size_t buffer, size_t cache, + Looper* looper, Listener* listener); + + virtual void clear() = 0; + + virtual void add(size_t id, void const* data, size_t size) = 0; + virtual void update(size_t id, void const* data, size_t size) = 0; + virtual void remove(size_t id) = 0; + + virtual void text(size_t id) = 0; + virtual void free(size_t id, std::unique_ptr<AttributedText>&& text) = 0; + + virtual void content(size_t id, std::ostream* out) = 0; + +protected: + Protocols() {} + Protocols(Protocols const&) = delete; + Protocols& operator=(Protocols const&) = delete; +}; + +#endif // PROTOCOLS_HH diff --git a/src/proxy.cc b/src/proxy.cc index 3909541..5eca4fb 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -822,27 +822,6 @@ bool ProxyImpl::base_event(BaseClient* client, uint8_t events, return true; } -inline char lower_ascii(char c) { - return (c >= 'A' && c <= 'Z') ? (c - 'A' + 'a') : c; -} - -bool lower_equal(char const* data, size_t start, size_t end, - std::string const& str) { - assert(start <= end); - if (str.size() != end - start) return false; - for (auto i = str.begin(); start < end; ++start, ++i) { - if (lower_ascii(*i) != lower_ascii(data[start])) return false; - } - return true; -} - -bool header_token_eq(std::string const& value, std::string const& token) { - if (value.empty()) return false; - auto pos = value.find(';'); - if (pos == std::string::npos) pos = value.size(); - return lower_equal(value.data(), 0, pos, token); -} - void ProxyImpl::client_remote_error(size_t index, uint16_t error) { auto& client = clients_[index]; if (client.remote_state > CONNECTED) { @@ -1014,8 +993,13 @@ void ProxyImpl::client_event(size_t index, int fd, uint8_t events) { bool setup_content(Http const* http, Content* content) { assert(content->type == CONTENT_NONE); - std::string te = http->first_header("transfer-encoding"); - if (te.empty() || header_token_eq(te, "identity")) { + auto iter = http->header_tokens("transfer-encoding"); + bool chunked = false; + while (iter->valid()) { + chunked = iter->token_equal("chunked"); + iter->next(); + } + if (!chunked) { std::string len = http->first_header("content-length"); if (len.empty()) { content->type = CONTENT_CLOSE; diff --git a/test/test-http.cc b/test/test-http.cc index 4b1b62f..73f5261 100644 --- a/test/test-http.cc +++ b/test/test-http.cc @@ -326,6 +326,38 @@ bool resp(std::string const& name, std::string const& out, return true; } +bool tokens(std::string const& in, char const* header, ...) { + std::unique_ptr<HttpResponse> resp(HttpResponse::parse(in)); + if (!resp) { + std::cerr << "tokens:" << header << ": Expected valid http" << std::endl; + return false; + } + auto iter = resp->header_tokens(header); + va_list tokens; + va_start(tokens, header); + while (true) { + auto token = va_arg(tokens, char const*); + if (!token) break; + if (!iter->valid()) { + std::cerr << "tokens:" << header << ": Expected " << token << " got " + << "no more tokens" << std::endl; + return false; + } + if (iter->token().compare(token)) { + std::cerr << "tokens:" << header << ": Expected " << token << " got " + << iter->token() << std::endl; + return false; + } + iter->next(); + } + if (iter->valid()) { + std::cerr << "tokens:" << header << ": Expected no more tokens got " + << iter->token() << std::endl; + return false; + } + return true; +} + } // namespace int main() { @@ -473,5 +505,33 @@ int main() { "connection", "close", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "\r\n", "Transfer-Encoding", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n", "Transfer-Encoding", "chunked", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked, stuff; param=foo\r\n" + "\r\n", "Transfer-Encoding", "chunked", "stuff", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked, stuff; param=\"foo\"\r\n" + "\r\n", "Transfer-Encoding", "chunked", "stuff", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked;param=\"\\\"\",stuff\r\n" + "\r\n", "Transfer-Encoding", "chunked", "stuff", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked;p1;p2=p3;=p4, stuff\r\n" + "\r\n", "Transfer-Encoding", "chunked", "stuff", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked,stuff\r\n" + "\r\n", "Transfer-Encoding", "chunked", "stuff", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "Transfer-Encoding: stuff\r\n" + "\r\n", "Transfer-Encoding", "chunked", "stuff", nullptr)); + RUN(tokens("HTTP/1.1 200 OK\r\n" + "Content-Type: text/html; charset=utf-8\r\n" + "\r\n", "Content-Type", "text", nullptr)); + AFTER; } |
