207 lines
6.3 KiB
C++
207 lines
6.3 KiB
C++
// Copyright 2015 The Weave Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "examples/provider/ssl_stream.h"
|
|
|
|
#include <openssl/err.h>
|
|
|
|
#include <base/bind.h>
|
|
#include <base/bind_helpers.h>
|
|
#include <weave/provider/task_runner.h>
|
|
|
|
namespace weave {
|
|
namespace examples {
|
|
|
|
namespace {
|
|
|
|
void AddSslError(ErrorPtr* error,
|
|
const tracked_objects::Location& location,
|
|
const std::string& error_code,
|
|
unsigned long ssl_error_code) {
|
|
ERR_load_BIO_strings();
|
|
SSL_load_error_strings();
|
|
Error::AddToPrintf(error, location, error_code, "%s: %s",
|
|
ERR_lib_error_string(ssl_error_code),
|
|
ERR_reason_error_string(ssl_error_code));
|
|
}
|
|
|
|
void RetryAsyncTask(provider::TaskRunner* task_runner,
|
|
const tracked_objects::Location& location,
|
|
const base::Closure& task) {
|
|
task_runner->PostDelayedTask(FROM_HERE, task,
|
|
base::TimeDelta::FromMilliseconds(100));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void SSLStream::SslDeleter::operator()(BIO* bio) const {
|
|
BIO_free(bio);
|
|
}
|
|
|
|
void SSLStream::SslDeleter::operator()(SSL* ssl) const {
|
|
SSL_free(ssl);
|
|
}
|
|
|
|
void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const {
|
|
SSL_CTX_free(ctx);
|
|
}
|
|
|
|
SSLStream::SSLStream(provider::TaskRunner* task_runner,
|
|
std::unique_ptr<BIO, SslDeleter> stream_bio)
|
|
: task_runner_{task_runner} {
|
|
ctx_.reset(SSL_CTX_new(TLSv1_2_client_method()));
|
|
CHECK(ctx_);
|
|
ssl_.reset(SSL_new(ctx_.get()));
|
|
|
|
SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get());
|
|
stream_bio.release(); // Owned by ssl now.
|
|
SSL_set_connect_state(ssl_.get());
|
|
}
|
|
|
|
SSLStream::~SSLStream() {
|
|
CancelPendingOperations();
|
|
}
|
|
|
|
void SSLStream::RunTask(const base::Closure& task) {
|
|
task.Run();
|
|
}
|
|
|
|
void SSLStream::Read(void* buffer,
|
|
size_t size_to_read,
|
|
const ReadCallback& callback) {
|
|
int res = SSL_read(ssl_.get(), buffer, size_to_read);
|
|
if (res > 0) {
|
|
task_runner_->PostDelayedTask(
|
|
FROM_HERE,
|
|
base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
|
|
base::Bind(callback, res, nullptr)),
|
|
{});
|
|
return;
|
|
}
|
|
|
|
int err = SSL_get_error(ssl_.get(), res);
|
|
|
|
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
|
return RetryAsyncTask(
|
|
task_runner_, FROM_HERE,
|
|
base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer,
|
|
size_to_read, callback));
|
|
}
|
|
|
|
ErrorPtr weave_error;
|
|
AddSslError(&weave_error, FROM_HERE, "read_failed", err);
|
|
return task_runner_->PostDelayedTask(
|
|
FROM_HERE,
|
|
base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
|
|
base::Bind(callback, 0, base::Passed(&weave_error))),
|
|
{});
|
|
}
|
|
|
|
void SSLStream::Write(const void* buffer,
|
|
size_t size_to_write,
|
|
const WriteCallback& callback) {
|
|
int res = SSL_write(ssl_.get(), buffer, size_to_write);
|
|
if (res > 0) {
|
|
buffer = static_cast<const char*>(buffer) + res;
|
|
size_to_write -= res;
|
|
if (size_to_write == 0) {
|
|
return task_runner_->PostDelayedTask(
|
|
FROM_HERE,
|
|
base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
|
|
base::Bind(callback, nullptr)),
|
|
{});
|
|
}
|
|
|
|
return RetryAsyncTask(
|
|
task_runner_, FROM_HERE,
|
|
base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
|
|
size_to_write, callback));
|
|
}
|
|
|
|
int err = SSL_get_error(ssl_.get(), res);
|
|
|
|
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
|
return RetryAsyncTask(
|
|
task_runner_, FROM_HERE,
|
|
base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer,
|
|
size_to_write, callback));
|
|
}
|
|
|
|
ErrorPtr weave_error;
|
|
AddSslError(&weave_error, FROM_HERE, "write_failed", err);
|
|
task_runner_->PostDelayedTask(
|
|
FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(),
|
|
base::Bind(callback, base::Passed(&weave_error))),
|
|
{});
|
|
}
|
|
|
|
void SSLStream::CancelPendingOperations() {
|
|
weak_ptr_factory_.InvalidateWeakPtrs();
|
|
}
|
|
|
|
void SSLStream::Connect(
|
|
provider::TaskRunner* task_runner,
|
|
const std::string& host,
|
|
uint16_t port,
|
|
const provider::Network::OpenSslSocketCallback& callback) {
|
|
SSL_library_init();
|
|
|
|
char end_point[255];
|
|
snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port);
|
|
|
|
std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point));
|
|
CHECK(stream_bio);
|
|
BIO_set_nbio(stream_bio.get(), 1);
|
|
|
|
std::unique_ptr<SSLStream> stream{
|
|
new SSLStream{task_runner, std::move(stream_bio)}};
|
|
ConnectBio(std::move(stream), callback);
|
|
}
|
|
|
|
void SSLStream::ConnectBio(
|
|
std::unique_ptr<SSLStream> stream,
|
|
const provider::Network::OpenSslSocketCallback& callback) {
|
|
BIO* bio = SSL_get_rbio(stream->ssl_.get());
|
|
if (BIO_do_connect(bio) == 1)
|
|
return DoHandshake(std::move(stream), callback);
|
|
|
|
auto task_runner = stream->task_runner_;
|
|
if (BIO_should_retry(bio)) {
|
|
return RetryAsyncTask(
|
|
task_runner, FROM_HERE,
|
|
base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback));
|
|
}
|
|
|
|
ErrorPtr error;
|
|
AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error());
|
|
task_runner->PostDelayedTask(
|
|
FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
|
|
}
|
|
|
|
void SSLStream::DoHandshake(
|
|
std::unique_ptr<SSLStream> stream,
|
|
const provider::Network::OpenSslSocketCallback& callback) {
|
|
int res = SSL_do_handshake(stream->ssl_.get());
|
|
auto task_runner = stream->task_runner_;
|
|
if (res == 1) {
|
|
return task_runner->PostDelayedTask(
|
|
FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {});
|
|
}
|
|
|
|
res = SSL_get_error(stream->ssl_.get(), res);
|
|
|
|
if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) {
|
|
return RetryAsyncTask(
|
|
task_runner, FROM_HERE,
|
|
base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback));
|
|
}
|
|
|
|
ErrorPtr error;
|
|
AddSslError(&error, FROM_HERE, "handshake_failed", res);
|
|
task_runner->PostDelayedTask(
|
|
FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {});
|
|
}
|
|
|
|
} // namespace examples
|
|
} // namespace weave
|