Skip to content

Commit

Permalink
HttpTask support auth when has userinfo (sogou#1554)
Browse files Browse the repository at this point in the history
* HttpTask support auth when has userinfo

* clear auth when redirect to other host

* HttpProxyTask support auth when has userinfo

* remove one StringUtil::url_decode as requested in sogou#1554
  • Loading branch information
kedixa authored May 28, 2024
1 parent 16686ba commit d54f06e
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 59 deletions.
175 changes: 134 additions & 41 deletions src/factory/HttpTaskImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ using namespace protocol;

/**********Client**********/

static int __encode_auth(const char *p, std::string& auth)
{
size_t len = strlen(p);
size_t base64_len = (len + 2) / 3 * 4;
char *base64 = (char *)malloc(base64_len + 1);

if (!base64)
return -1;

EVP_EncodeBlock((unsigned char *)base64, (const unsigned char *)p, len);
auth.append("Basic ");
auth.append(base64, base64_len);

free(base64);
return 0;
}

class ComplexHttpTask : public WFComplexClientTask<HttpRequest, HttpResponse>
{
public:
Expand All @@ -64,8 +81,9 @@ class ComplexHttpTask : public WFComplexClientTask<HttpRequest, HttpResponse>
virtual bool finish_once();

protected:
bool need_redirect(ParsedURI& uri);
bool redirect_url(HttpResponse *client_resp, ParsedURI& uri);
bool need_redirect(const ParsedURI& uri, ParsedURI& new_uri);
bool redirect_url(HttpResponse *client_resp,
const ParsedURI& uri, ParsedURI& new_uri);
void set_empty_request();
void check_response();

Expand Down Expand Up @@ -184,6 +202,10 @@ void ComplexHttpTask::set_empty_request()

client_req->set_request_uri("/");
cursor.find_and_erase(&header);

header.name = "Authorization";
header.name_len = strlen("Authorization");
cursor.find_and_erase(&header);
}

void ComplexHttpTask::init_failed()
Expand All @@ -206,7 +228,6 @@ bool ComplexHttpTask::init_success()
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
this->set_empty_request();
return false;
}

Expand Down Expand Up @@ -253,10 +274,29 @@ bool ComplexHttpTask::init_success()
this->WFComplexClientTask::set_transport_type(is_ssl ? TT_TCP_SSL : TT_TCP);
client_req->set_request_uri(request_uri.c_str());
client_req->set_header_pair("Host", header_host.c_str());

if (uri_.userinfo && uri_.userinfo[0])
{
std::string userinfo(uri_.userinfo);
std::string http_auth;

StringUtil::url_decode(userinfo);

if (__encode_auth(userinfo.c_str(), http_auth) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
return false;
}

client_req->set_header_pair("Authorization", http_auth.c_str());
}

return true;
}

bool ComplexHttpTask::redirect_url(HttpResponse *client_resp, ParsedURI& uri)
bool ComplexHttpTask::redirect_url(HttpResponse *client_resp,
const ParsedURI& uri, ParsedURI& new_uri)
{
if (redirect_count_ < redirect_max_)
{
Expand Down Expand Up @@ -284,14 +324,14 @@ bool ComplexHttpTask::redirect_url(HttpResponse *client_resp, ParsedURI& uri)
url = uri.scheme + (':' + url);
}

URIParser::parse(url, uri);
URIParser::parse(url, new_uri);
return true;
}

return false;
}

bool ComplexHttpTask::need_redirect(ParsedURI& uri)
bool ComplexHttpTask::need_redirect(const ParsedURI& uri, ParsedURI& new_uri)
{
HttpRequest *client_req = this->get_req();
HttpResponse *client_resp = this->get_resp();
Expand All @@ -308,7 +348,7 @@ bool ComplexHttpTask::need_redirect(ParsedURI& uri)
case 301:
case 302:
case 303:
if (redirect_url(client_resp, uri))
if (redirect_url(client_resp, uri, new_uri))
{
if (strcasecmp(method, HttpMethodGet) != 0 &&
strcasecmp(method, HttpMethodHead) != 0)
Expand All @@ -323,7 +363,7 @@ bool ComplexHttpTask::need_redirect(ParsedURI& uri)

case 307:
case 308:
if (redirect_url(client_resp, uri))
if (redirect_url(client_resp, uri, new_uri))
return true;
else
break;
Expand Down Expand Up @@ -359,8 +399,31 @@ bool ComplexHttpTask::finish_once()

if (this->state == WFT_STATE_SUCCESS)
{
if (this->need_redirect(uri_))
this->set_redirect(uri_);
ParsedURI new_uri;
if (this->need_redirect(uri_, new_uri))
{
if (uri_.userinfo && strcasecmp(uri_.host, new_uri.host) == 0)
{
if (!new_uri.userinfo)
{
new_uri.userinfo = uri_.userinfo;
uri_.userinfo = NULL;
}
}
else if (uri_.userinfo)
{
HttpRequest *client_req = this->get_req();
HttpHeaderCursor cursor(client_req);
struct HttpMessageHeader header = {
.name = "Authorization",
.name_len = strlen("Authorization")
};

cursor.find_and_erase(&header);
}

this->set_redirect(new_uri);
}
else if (this->state != WFT_STATE_SUCCESS)
this->disable_retry();
}
Expand All @@ -370,23 +433,6 @@ bool ComplexHttpTask::finish_once()

/*******Proxy Client*******/

static int __encode_auth(const char *p, std::string& auth)
{
size_t len = strlen(p);
size_t base64_len = (len + 2) / 3 * 4;
char *base64 = (char *)malloc(base64_len + 1);

if (!base64)
return -1;

EVP_EncodeBlock((unsigned char *)base64, (const unsigned char *)p, len);
auth.append("Basic ");
auth.append(base64, base64_len);

free(base64);
return 0;
}

static SSL *__create_ssl(SSL_CTX *ssl_ctx)
{
BIO *wbio;
Expand Down Expand Up @@ -635,7 +681,6 @@ bool ComplexHttpProxyTask::init_success()
{
this->state = WFT_STATE_TASK_ERROR;
this->error = WFT_ERR_URI_SCHEME_INVALID;
this->set_empty_request();
return false;
}

Expand All @@ -653,27 +698,32 @@ bool ComplexHttpProxyTask::init_success()
else
user_port = is_ssl_ ? 443 : 80;

std::string info("http-proxy|remote:");
info += is_ssl_ ? "https://" : "http://";
info += user_uri_.host;
info += ":";
if (user_uri_.port)
info += user_uri_.port;
else
info += is_ssl_ ? "443" : "80";

if (uri_.userinfo && uri_.userinfo[0])
{
std::string userinfo(uri_.userinfo);

StringUtil::url_decode(userinfo);
proxy_auth_.clear();
if (__encode_auth(uri_.userinfo, proxy_auth_) < 0)

if (__encode_auth(userinfo.c_str(), proxy_auth_) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
return false;
}
}

std::string info("http-proxy|remote:");
info += is_ssl_ ? "https://" : "http://";
info += user_uri_.host;
info += ":";
if (user_uri_.port)
info += user_uri_.port;
else
info += is_ssl_ ? "443" : "80";
info += "|auth:";
info += proxy_auth_;
info += "|auth:";
info += proxy_auth_;
}

this->WFComplexClientTask::set_info(info);

Expand Down Expand Up @@ -704,6 +754,24 @@ bool ComplexHttpProxyTask::init_success()
client_req->set_request_uri(request_uri.c_str());
client_req->set_header_pair("Host", header_host.c_str());
this->WFComplexClientTask::set_transport_type(TT_TCP);

if (user_uri_.userinfo && user_uri_.userinfo[0])
{
std::string userinfo(user_uri_.userinfo);
std::string http_auth;

StringUtil::url_decode(userinfo);

if (__encode_auth(userinfo.c_str(), http_auth) < 0)
{
this->state = WFT_STATE_SYS_ERROR;
this->error = errno;
return false;
}

client_req->set_header_pair("Authorization", http_auth.c_str());
}

return true;
}

Expand Down Expand Up @@ -732,8 +800,33 @@ bool ComplexHttpProxyTask::finish_once()

if (this->state == WFT_STATE_SUCCESS)
{
if (this->need_redirect(user_uri_))
ParsedURI new_uri;
if (this->need_redirect(user_uri_, new_uri))
{
if (user_uri_.userinfo &&
strcasecmp(user_uri_.host, new_uri.host) == 0)
{
if (!new_uri.userinfo)
{
new_uri.userinfo = user_uri_.userinfo;
user_uri_.userinfo = NULL;
}
}
else if (user_uri_.userinfo)
{
HttpRequest *client_req = this->get_req();
HttpHeaderCursor cursor(client_req);
struct HttpMessageHeader header = {
.name = "Authorization",
.name_len = strlen("Authorization")
};

cursor.find_and_erase(&header);
}

user_uri_ = std::move(new_uri);
this->set_redirect(uri_);
}
else if (this->state != WFT_STATE_SUCCESS)
this->disable_retry();
}
Expand Down
31 changes: 14 additions & 17 deletions src/util/StringUtil.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,17 @@ static inline char __itoh(int n)
return n + '0';
}

size_t StringUtil::url_decode(char *str, size_t len)
static size_t __url_decode(char *str)
{
char *dest = str;
char *data = str;

while (len--)
while (*data)
{
if (*data == '%' && len >= 2
&& isxdigit(*(data + 1))
&& isxdigit(*(data + 2)))
if (*data == '%' && isxdigit(data[1]) && isxdigit(data[2]))
{
*dest = __htoi((unsigned char *)data + 1);
data += 2;
len -= 2;
}
else if (*data == '+')
*dest = ' ';
Expand All @@ -82,25 +79,25 @@ void StringUtil::url_decode(std::string& str)
if (str.empty())
return;

size_t sz = url_decode(const_cast<char *>(str.c_str()), str.size());
size_t sz = __url_decode(const_cast<char *>(str.c_str()));

str.resize(sz);
}

std::string StringUtil::url_encode(const std::string& str)
{
std::string res;
const char *cur = str.c_str();
const char *ed = cur + str.size();
std::string res;

while (cur < ed)
{
if (*cur == ' ')
res += '+';
else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.'
|| *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\''
|| *cur == '(' || *cur == ')' || *cur == ':' || *cur == '/'
|| *cur == '@' || *cur == '?' || *cur == '#' || *cur == '&')
else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' ||
*cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' ||
*cur == '(' || *cur == ')' || *cur == ':' || *cur == '/' ||
*cur == '@' || *cur == '?' || *cur == '#' || *cur == '&')
res += *cur;
else
{
Expand All @@ -117,17 +114,17 @@ std::string StringUtil::url_encode(const std::string& str)

std::string StringUtil::url_encode_component(const std::string& str)
{
std::string res;
const char *cur = str.c_str();
const char *ed = cur + str.size();
std::string res;

while (cur < ed)
{
if (*cur == ' ')
res += '+';
else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.'
|| *cur == '!' || *cur == '~' || *cur == '*' || *cur == '\''
|| *cur == '(' || *cur == ')')
else if (isalnum(*cur) || *cur == '-' || *cur == '_' || *cur == '.' ||
*cur == '!' || *cur == '~' || *cur == '*' || *cur == '\'' ||
*cur == '(' || *cur == ')')
res += *cur;
else
{
Expand All @@ -144,10 +141,10 @@ std::string StringUtil::url_encode_component(const std::string& str)

std::vector<std::string> StringUtil::split(const std::string& str, char sep)
{
std::vector<std::string> res;
std::string::const_iterator start = str.begin();
std::string::const_iterator end = str.end();
std::string::const_iterator next = find(start, end, sep);
std::vector<std::string> res;

while (next != end)
{
Expand Down
1 change: 0 additions & 1 deletion src/util/StringUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
class StringUtil
{
public:
static size_t url_decode(char *str, size_t len);
static void url_decode(std::string& str);
static std::string url_encode(const std::string& str);
static std::string url_encode_component(const std::string& str);
Expand Down

0 comments on commit d54f06e

Please sign in to comment.