summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJoel Klinghed <the_jk@yahoo.com>2017-03-01 22:54:09 +0100
committerJoel Klinghed <the_jk@yahoo.com>2017-03-01 22:54:09 +0100
commit719d90a40e83e870be19f8d46cc55caed618aa35 (patch)
tree10bc226c44aff6cfa3c53e5d837a32720c9bf836 /src
parent537ed164ae1875a8d38e06dbc214ef4f91bc4642 (diff)
Add support for CONNECT
Diffstat (limited to 'src')
-rw-r--r--src/proxy.cc205
-rw-r--r--src/url.cc44
-rw-r--r--src/url.hh2
3 files changed, 163 insertions, 88 deletions
diff --git a/src/proxy.cc b/src/proxy.cc
index 2d84208..dcc52df 100644
--- a/src/proxy.cc
+++ b/src/proxy.cc
@@ -212,6 +212,11 @@ struct Content {
std::unique_ptr<Chunked> chunked;
};
+struct Connect {
+ std::string host;
+ uint16_t port;
+};
+
enum RemoteState {
CLOSED,
RESOLVING,
@@ -222,6 +227,8 @@ enum RemoteState {
struct RemoteClient : public BaseClient {
Content content;
+ std::string host;
+ uint16_t port;
};
struct Client : public BaseClient {
@@ -229,6 +236,7 @@ struct Client : public BaseClient {
: resolve(nullptr) {
}
std::unique_ptr<HttpRequest> request;
+ std::unique_ptr<Connect> connect;
std::unique_ptr<Url> url;
Content content;
RemoteState remote_state;
@@ -455,6 +463,7 @@ void ProxyImpl::close_client(size_t index) {
auto& client = clients_[index];
client.request.reset();
client.url.reset();
+ client.connect.reset();
client.content.type = CONTENT_NONE;
client.content.chunked.reset();
client.remote_state = CLOSED;
@@ -717,6 +726,7 @@ void ProxyImpl::client_error(size_t index, uint16_t status_code,
client.read_flag = 0;
looper_->modify(client.fd.get(), client.write_flag);
client.url.reset();
+ client.connect.reset();
std::string proto;
Version version;
@@ -754,36 +764,52 @@ bool ProxyImpl::client_request(size_t index) {
client_error(index, 505, "HTTP Version Not Supported");
return false;
}
- if (client.url->scheme() != "http") {
- client_error(index, 501, "Not Implemented");
- return false;
- }
- if (client.url->userinfo()) {
- client_error(index, 400, "Bad request");
- return false;
+ std::string host;
+ uint16_t port;
+ if (client.connect) {
+ host = client.connect->host;
+ port = client.connect->port;
+ } else {
+ if (client.url->scheme() != "http") {
+ client_error(index, 501, "Not Implemented");
+ return false;
+ }
+ if (client.url->userinfo()) {
+ client_error(index, 400, "Bad request");
+ return false;
+ }
+ host = client.url->host();
+ port = client.url->port();
}
if (client.remote_state == WAITING) {
- client.remote_state = CONNECTING;
- client.remote.read_flag = Looper::EVENT_READ;
- client.remote.write_flag = 0;
- looper_->modify(client.remote.fd.get(),
- client.remote.read_flag | client.remote.write_flag);
- client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE);
- } else {
- assert(client.remote_state == CLOSED);
- client.remote.last = looper_->now();
- client.remote.new_connection = true;
- client.remote_state = RESOLVING;
-
- auto port = client.url->port();
- if (port == 0) port = 80;
- client.resolve = resolver_->request(
- client.url->host(), port,
- std::bind(&ProxyImpl::client_remote_resolved, this, index,
- std::placeholders::_1,
- std::placeholders::_2,
- std::placeholders::_3));
+ if (client.connect || host != client.remote.host
+ || port != client.remote.port) {
+ client.remote_state = CLOSED;
+ close_base(&client.remote);
+ } else {
+ client.remote_state = CONNECTING;
+ client.remote.read_flag = Looper::EVENT_READ;
+ client.remote.write_flag = 0;
+ looper_->modify(client.remote.fd.get(),
+ client.remote.read_flag | client.remote.write_flag);
+ client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE);
+ return true;
+ }
}
+ assert(client.remote_state == CLOSED);
+ client.remote.last = looper_->now();
+ client.remote.new_connection = true;
+ client.remote_state = RESOLVING;
+ client.remote.host = host;
+ client.remote.port = port;
+
+ if (port == 0) port = 80;
+ client.resolve = resolver_->request(
+ host, port,
+ std::bind(&ProxyImpl::client_remote_resolved, this, index,
+ std::placeholders::_1,
+ std::placeholders::_2,
+ std::placeholders::_3));
return true;
}
@@ -882,6 +908,13 @@ void ProxyImpl::client_empty_input(size_t index) {
assert(false);
// falltrough
case CONTENT_NONE: {
+ if (client.connect && client.remote_state == CONNECTED) {
+ if (!client_send(index, ptr, avail)) {
+ return;
+ }
+ client.in->consume(avail);
+ break;
+ }
if (client.remote_state != CLOSED && client.remote_state != WAITING) {
// Still working on the last request, wait
return;
@@ -902,21 +935,26 @@ void ProxyImpl::client_empty_input(size_t index) {
return;
}
if (client.request->method_equal("CONNECT")) {
- client_error(index, 501, "Not Implemented");
- return;
- }
- client.url.reset(Url::parse(client.request->url()));
- if (!client.url) {
- client_error(index, 400, "Bad request");
- return;
- }
- if (!setup_content(client.request.get(), &client.content)) {
- logger_->out(Logger::INFO, "%zu: Client bad content-length", index);
- client_error(index, 400, "Bad request");
- return;
- }
- if (client.content.type == CONTENT_CLOSE) {
- client.content.type = CONTENT_NONE;
+ client.connect.reset(new Connect());
+ if (!Url::parse_authority(client.request->url(), &client.connect->host,
+ &client.connect->port)) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ } else {
+ client.url.reset(Url::parse(client.request->url()));
+ if (!client.url) {
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (!setup_content(client.request.get(), &client.content)) {
+ logger_->out(Logger::INFO, "%zu: Client bad content-length", index);
+ client_error(index, 400, "Bad request");
+ return;
+ }
+ if (client.content.type == CONTENT_CLOSE) {
+ client.content.type = CONTENT_NONE;
+ }
}
if (!client_request(index)) {
client.content.type = CONTENT_NONE;
@@ -997,43 +1035,60 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) {
}
if (client.remote_state == CONNECTING) {
if (events & Looper::EVENT_WRITE) {
- std::string url(client.url->path_escaped());
- if (url.empty()) url.push_back('/');
- auto query = client.url->full_query_escaped();
- if (query) {
- url.push_back('?');
- url.append(query);
- }
- auto req = std::unique_ptr<HttpRequestBuilder>(
- HttpRequestBuilder::create(
- client.request->method(),
- url,
- client.request->proto(),
- client.request->proto_version()));
- auto iter = client.request->header();
- bool have_host = false;
- for (; iter->valid(); iter->next()) {
- if (!have_host && iter->name_equal("host")) have_host = true;
- if (iter->name_equal("proxy-connection") ||
- iter->name_equal("proxy-authenticate") ||
- iter->name_equal("proxy-authorization")) {
- continue;
+ if (client.connect) {
+ auto req = std::unique_ptr<HttpResponseBuilder>(
+ HttpResponseBuilder::create(
+ client.request->proto(),
+ client.request->proto_version(),
+ 200, "OK"));
+ auto data = req->build();
+ client.in->consume(client.request->size());
+ client.request.reset();
+ client.remote.content.type = CONTENT_CLOSE;
+ if (!base_send(&client, data.data(), data.size(), index, "Client")) {
+ return;
}
- req->add_header(iter->name(), iter->value());
- }
- if (!have_host &&
- (client.request->proto_version().major == 1 &&
- client.request->proto_version().minor == 1)) {
- req->add_header("host", client.url->host());
+ events &= ~Looper::EVENT_WRITE;
+ } else {
+ std::string url(client.url->path_escaped());
+ if (url.empty()) url.push_back('/');
+ auto query = client.url->full_query_escaped();
+ if (query) {
+ url.push_back('?');
+ url.append(query);
+ }
+ auto req = std::unique_ptr<HttpRequestBuilder>(
+ HttpRequestBuilder::create(
+ client.request->method(),
+ url,
+ client.request->proto(),
+ client.request->proto_version()));
+ auto iter = client.request->header();
+ bool have_host = false;
+ for (; iter->valid(); iter->next()) {
+ if (!have_host && iter->name_equal("host")) have_host = true;
+ if (iter->name_equal("proxy-connection") ||
+ iter->name_equal("proxy-authenticate") ||
+ iter->name_equal("proxy-authorization")) {
+ continue;
+ }
+ req->add_header(iter->name(), iter->value());
+ }
+ if (!have_host &&
+ (client.request->proto_version().major == 1 &&
+ client.request->proto_version().minor == 1)) {
+ req->add_header("host", client.url->host());
+ }
+ auto data = req->build();
+ client.in->consume(client.request->size());
+ client.request.reset();
+ client.url.reset();
+ client.remote.out->write(data.data(), data.size());
}
- auto data = req->build();
- client.in->consume(client.request->size());
- client.request.reset();
- client.url.reset();
- client.remote.out->write(data.data(), data.size());
client.remote_state = CONNECTED;
client.remote.read_flag = Looper::EVENT_READ;
- client.remote.write_flag = Looper::EVENT_WRITE;
+ client.remote.write_flag =
+ client.remote.out->empty() ? 0 : Looper::EVENT_WRITE;
looper_->modify(client.remote.fd.get(),
client.remote.read_flag | client.remote.write_flag);
client_empty_input(index);
diff --git a/src/url.cc b/src/url.cc
index 63419ba..2c160d2 100644
--- a/src/url.cc
+++ b/src/url.cc
@@ -364,7 +364,8 @@ bool UrlImpl::relative(std::string const& url, Url const* base) {
return true;
}
-char const* UrlImpl::parse_authority(char const* pos) {
+char const* parse_authority(char const* pos, char** userinfo,
+ std::string* host, uint16_t* port) {
/* authority = [ userinfo "@" ] host [ ":" port ]
userinfo = *( unreserved / pct-encoded / sub-delims / ":" )
host = IP-literal / IPv4address / reg-name
@@ -417,9 +418,10 @@ char const* UrlImpl::parse_authority(char const* pos) {
// userinfo?
if (at) {
+ if (!userinfo) return nullptr;
host_start = at + 1;
- userinfo_ = dup(start, at);
- if (strchr(userinfo_, '[') || strchr(userinfo_, ']')) {
+ *userinfo = dup(start, at);
+ if (strchr(*userinfo, '[') || strchr(*userinfo, ']')) {
return nullptr;
}
} else {
@@ -448,7 +450,7 @@ char const* UrlImpl::parse_authority(char const* pos) {
tmp++;
}
if (host_end == colon) {
- port_ = v;
+ *port = v;
}
} else {
host_end = pos;
@@ -467,17 +469,17 @@ char const* UrlImpl::parse_authority(char const* pos) {
if (!is_hex(host_start[1]) || host_start[2] != '.') {
return nullptr;
}
- host_.assign(host_start, host_end - host_start);
- lower(host_);
+ host->assign(host_start, host_end - host_start);
+ lower(*host);
} else {
if (!is_ipv6(host_start, host_end)) {
return nullptr;
}
- host_.assign(host_start, host_end - host_start);
- lower(host_);
+ host->assign(host_start, host_end - host_start);
+ lower(*host);
}
- if (host_.find('[') != std::string::npos ||
- host_.find(']') != std::string::npos) {
+ if (host->find('[') != std::string::npos ||
+ host->find(']') != std::string::npos) {
return nullptr;
}
} else {
@@ -490,14 +492,18 @@ char const* UrlImpl::parse_authority(char const* pos) {
}
tmp = unescape(const_cast<char*>(host_start), const_cast<char*>(host_end),
false);
- host_ = tmp;
- lower(host_);
+ host->assign(tmp);
+ lower(*host);
delete[] const_cast<char*>(tmp);
}
- if (host_.empty()) return nullptr;
+ if (host->empty()) return nullptr;
return pos;
}
+char const* UrlImpl::parse_authority(char const* pos) {
+ return ::parse_authority(pos, &userinfo_, &host_, &port_);
+}
+
char const* UrlImpl::parse_query(char const* pos) {
// query = *( pchar / "/" / "?" )
char const* start = ++pos;
@@ -892,6 +898,18 @@ Url* Url::parse(std::string const& url, Url const* base) {
return nullptr;
}
+// static
+bool Url::parse_authority(std::string const& str,
+ std::string* host, uint16_t* port) {
+ std::string tmp_host;
+ uint16_t tmp_port;
+ auto ret = ::parse_authority(str.c_str(), nullptr, &tmp_host, &tmp_port);
+ if (!ret || *ret) return false;
+ if (host) host->assign(tmp_host);
+ if (port) *port = tmp_port;
+ return true;
+}
+
bool Url::operator==(Url const& url) const {
if (scheme() != url.scheme()) return false;
if (host() != url.host()) return false;
diff --git a/src/url.hh b/src/url.hh
index d3b69b7..77027d2 100644
--- a/src/url.hh
+++ b/src/url.hh
@@ -12,6 +12,8 @@ public:
virtual ~Url() {}
static Url* parse(std::string const& url, Url const* base = nullptr);
+ static bool parse_authority(std::string const& str,
+ std::string* host, uint16_t* port);
virtual Url* copy() const = 0;