diff options
Diffstat (limited to 'src/proxy.cc')
| -rw-r--r-- | src/proxy.cc | 205 |
1 files changed, 130 insertions, 75 deletions
diff --git a/src/proxy.cc b/src/proxy.cc index 2d84208..dcc52df 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -212,6 +212,11 @@ struct Content { std::unique_ptr<Chunked> chunked; }; +struct Connect { + std::string host; + uint16_t port; +}; + enum RemoteState { CLOSED, RESOLVING, @@ -222,6 +227,8 @@ enum RemoteState { struct RemoteClient : public BaseClient { Content content; + std::string host; + uint16_t port; }; struct Client : public BaseClient { @@ -229,6 +236,7 @@ struct Client : public BaseClient { : resolve(nullptr) { } std::unique_ptr<HttpRequest> request; + std::unique_ptr<Connect> connect; std::unique_ptr<Url> url; Content content; RemoteState remote_state; @@ -455,6 +463,7 @@ void ProxyImpl::close_client(size_t index) { auto& client = clients_[index]; client.request.reset(); client.url.reset(); + client.connect.reset(); client.content.type = CONTENT_NONE; client.content.chunked.reset(); client.remote_state = CLOSED; @@ -717,6 +726,7 @@ void ProxyImpl::client_error(size_t index, uint16_t status_code, client.read_flag = 0; looper_->modify(client.fd.get(), client.write_flag); client.url.reset(); + client.connect.reset(); std::string proto; Version version; @@ -754,36 +764,52 @@ bool ProxyImpl::client_request(size_t index) { client_error(index, 505, "HTTP Version Not Supported"); return false; } - if (client.url->scheme() != "http") { - client_error(index, 501, "Not Implemented"); - return false; - } - if (client.url->userinfo()) { - client_error(index, 400, "Bad request"); - return false; + std::string host; + uint16_t port; + if (client.connect) { + host = client.connect->host; + port = client.connect->port; + } else { + if (client.url->scheme() != "http") { + client_error(index, 501, "Not Implemented"); + return false; + } + if (client.url->userinfo()) { + client_error(index, 400, "Bad request"); + return false; + } + host = client.url->host(); + port = client.url->port(); } if (client.remote_state == WAITING) { - client.remote_state = CONNECTING; - client.remote.read_flag = Looper::EVENT_READ; - client.remote.write_flag = 0; - looper_->modify(client.remote.fd.get(), - client.remote.read_flag | client.remote.write_flag); - client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE); - } else { - assert(client.remote_state == CLOSED); - client.remote.last = looper_->now(); - client.remote.new_connection = true; - client.remote_state = RESOLVING; - - auto port = client.url->port(); - if (port == 0) port = 80; - client.resolve = resolver_->request( - client.url->host(), port, - std::bind(&ProxyImpl::client_remote_resolved, this, index, - std::placeholders::_1, - std::placeholders::_2, - std::placeholders::_3)); + if (client.connect || host != client.remote.host + || port != client.remote.port) { + client.remote_state = CLOSED; + close_base(&client.remote); + } else { + client.remote_state = CONNECTING; + client.remote.read_flag = Looper::EVENT_READ; + client.remote.write_flag = 0; + looper_->modify(client.remote.fd.get(), + client.remote.read_flag | client.remote.write_flag); + client_remote_event(index, client.remote.fd.get(), Looper::EVENT_WRITE); + return true; + } } + assert(client.remote_state == CLOSED); + client.remote.last = looper_->now(); + client.remote.new_connection = true; + client.remote_state = RESOLVING; + client.remote.host = host; + client.remote.port = port; + + if (port == 0) port = 80; + client.resolve = resolver_->request( + host, port, + std::bind(&ProxyImpl::client_remote_resolved, this, index, + std::placeholders::_1, + std::placeholders::_2, + std::placeholders::_3)); return true; } @@ -882,6 +908,13 @@ void ProxyImpl::client_empty_input(size_t index) { assert(false); // falltrough case CONTENT_NONE: { + if (client.connect && client.remote_state == CONNECTED) { + if (!client_send(index, ptr, avail)) { + return; + } + client.in->consume(avail); + break; + } if (client.remote_state != CLOSED && client.remote_state != WAITING) { // Still working on the last request, wait return; @@ -902,21 +935,26 @@ void ProxyImpl::client_empty_input(size_t index) { return; } if (client.request->method_equal("CONNECT")) { - client_error(index, 501, "Not Implemented"); - return; - } - client.url.reset(Url::parse(client.request->url())); - if (!client.url) { - client_error(index, 400, "Bad request"); - return; - } - if (!setup_content(client.request.get(), &client.content)) { - logger_->out(Logger::INFO, "%zu: Client bad content-length", index); - client_error(index, 400, "Bad request"); - return; - } - if (client.content.type == CONTENT_CLOSE) { - client.content.type = CONTENT_NONE; + client.connect.reset(new Connect()); + if (!Url::parse_authority(client.request->url(), &client.connect->host, + &client.connect->port)) { + client_error(index, 400, "Bad request"); + return; + } + } else { + client.url.reset(Url::parse(client.request->url())); + if (!client.url) { + client_error(index, 400, "Bad request"); + return; + } + if (!setup_content(client.request.get(), &client.content)) { + logger_->out(Logger::INFO, "%zu: Client bad content-length", index); + client_error(index, 400, "Bad request"); + return; + } + if (client.content.type == CONTENT_CLOSE) { + client.content.type = CONTENT_NONE; + } } if (!client_request(index)) { client.content.type = CONTENT_NONE; @@ -997,43 +1035,60 @@ void ProxyImpl::client_remote_event(size_t index, int fd, uint8_t events) { } if (client.remote_state == CONNECTING) { if (events & Looper::EVENT_WRITE) { - std::string url(client.url->path_escaped()); - if (url.empty()) url.push_back('/'); - auto query = client.url->full_query_escaped(); - if (query) { - url.push_back('?'); - url.append(query); - } - auto req = std::unique_ptr<HttpRequestBuilder>( - HttpRequestBuilder::create( - client.request->method(), - url, - client.request->proto(), - client.request->proto_version())); - auto iter = client.request->header(); - bool have_host = false; - for (; iter->valid(); iter->next()) { - if (!have_host && iter->name_equal("host")) have_host = true; - if (iter->name_equal("proxy-connection") || - iter->name_equal("proxy-authenticate") || - iter->name_equal("proxy-authorization")) { - continue; + if (client.connect) { + auto req = std::unique_ptr<HttpResponseBuilder>( + HttpResponseBuilder::create( + client.request->proto(), + client.request->proto_version(), + 200, "OK")); + auto data = req->build(); + client.in->consume(client.request->size()); + client.request.reset(); + client.remote.content.type = CONTENT_CLOSE; + if (!base_send(&client, data.data(), data.size(), index, "Client")) { + return; } - req->add_header(iter->name(), iter->value()); - } - if (!have_host && - (client.request->proto_version().major == 1 && - client.request->proto_version().minor == 1)) { - req->add_header("host", client.url->host()); + events &= ~Looper::EVENT_WRITE; + } else { + std::string url(client.url->path_escaped()); + if (url.empty()) url.push_back('/'); + auto query = client.url->full_query_escaped(); + if (query) { + url.push_back('?'); + url.append(query); + } + auto req = std::unique_ptr<HttpRequestBuilder>( + HttpRequestBuilder::create( + client.request->method(), + url, + client.request->proto(), + client.request->proto_version())); + auto iter = client.request->header(); + bool have_host = false; + for (; iter->valid(); iter->next()) { + if (!have_host && iter->name_equal("host")) have_host = true; + if (iter->name_equal("proxy-connection") || + iter->name_equal("proxy-authenticate") || + iter->name_equal("proxy-authorization")) { + continue; + } + req->add_header(iter->name(), iter->value()); + } + if (!have_host && + (client.request->proto_version().major == 1 && + client.request->proto_version().minor == 1)) { + req->add_header("host", client.url->host()); + } + auto data = req->build(); + client.in->consume(client.request->size()); + client.request.reset(); + client.url.reset(); + client.remote.out->write(data.data(), data.size()); } - auto data = req->build(); - client.in->consume(client.request->size()); - client.request.reset(); - client.url.reset(); - client.remote.out->write(data.data(), data.size()); client.remote_state = CONNECTED; client.remote.read_flag = Looper::EVENT_READ; - client.remote.write_flag = Looper::EVENT_WRITE; + client.remote.write_flag = + client.remote.out->empty() ? 0 : Looper::EVENT_WRITE; looper_->modify(client.remote.fd.get(), client.remote.read_flag | client.remote.write_flag); client_empty_input(index); |
