435 lines
13 KiB
C++
435 lines
13 KiB
C++
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "dns/DnsTlsTransport.h"
|
|
|
|
#include <arpa/inet.h>
|
|
#include <arpa/nameser.h>
|
|
#include <errno.h>
|
|
#include <openssl/err.h>
|
|
#include <openssl/ssl.h>
|
|
#include <stdlib.h>
|
|
|
|
#define LOG_TAG "DnsTlsTransport"
|
|
#define DBG 0
|
|
|
|
#include "log/log.h"
|
|
#include "Fwmark.h"
|
|
#undef ADD // already defined in nameser.h
|
|
#include "NetdConstants.h"
|
|
#include "Permission.h"
|
|
|
|
|
|
namespace android {
|
|
namespace net {
|
|
|
|
namespace {
|
|
|
|
bool setNonBlocking(int fd, bool enabled) {
|
|
int flags = fcntl(fd, F_GETFL);
|
|
if (flags < 0) return false;
|
|
|
|
if (enabled) {
|
|
flags |= O_NONBLOCK;
|
|
} else {
|
|
flags &= ~O_NONBLOCK;
|
|
}
|
|
return (fcntl(fd, F_SETFL, flags) == 0);
|
|
}
|
|
|
|
int waitForReading(int fd) {
|
|
fd_set fds;
|
|
FD_ZERO(&fds);
|
|
FD_SET(fd, &fds);
|
|
const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
|
|
if (DBG && ret <= 0) {
|
|
ALOGD("select");
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
int waitForWriting(int fd) {
|
|
fd_set fds;
|
|
FD_ZERO(&fds);
|
|
FD_SET(fd, &fds);
|
|
const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
|
|
if (DBG && ret <= 0) {
|
|
ALOGD("select");
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
|
|
android::base::unique_fd fd;
|
|
int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
|
|
switch (mProtocol) {
|
|
case IPPROTO_TCP:
|
|
type |= SOCK_STREAM;
|
|
break;
|
|
default:
|
|
errno = EPROTONOSUPPORT;
|
|
return fd;
|
|
}
|
|
|
|
fd.reset(socket(mAddr.ss_family, type, mProtocol));
|
|
if (fd.get() == -1) {
|
|
return fd;
|
|
}
|
|
|
|
const socklen_t len = sizeof(mMark);
|
|
if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
|
|
fd.reset();
|
|
} else if (connect(fd.get(),
|
|
reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
|
|
&& errno != EINPROGRESS) {
|
|
fd.reset();
|
|
}
|
|
|
|
return fd;
|
|
}
|
|
|
|
bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
|
|
int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
|
|
unsigned char spki[spki_len];
|
|
unsigned char* temp = spki;
|
|
if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
|
|
ALOGW("SPKI length mismatch");
|
|
return false;
|
|
}
|
|
out->resize(SHA256_SIZE);
|
|
unsigned int digest_len = 0;
|
|
int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
|
|
if (ret != 1) {
|
|
ALOGW("Server cert digest extraction failed");
|
|
return false;
|
|
}
|
|
if (digest_len != out->size()) {
|
|
ALOGW("Wrong digest length: %d", digest_len);
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
SSL* DnsTlsTransport::sslConnect(int fd) {
|
|
if (fd < 0) {
|
|
ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
|
|
return nullptr;
|
|
}
|
|
|
|
// Set up TLS context.
|
|
bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
|
|
if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
|
|
!SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
|
|
ALOGD("failed to min/max TLS versions");
|
|
return nullptr;
|
|
}
|
|
|
|
bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
|
|
bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
|
|
SSL_set_bio(ssl.get(), bio.get(), bio.get());
|
|
bio.release();
|
|
|
|
if (!setNonBlocking(fd, false)) {
|
|
ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
|
|
return nullptr;
|
|
}
|
|
|
|
for (;;) {
|
|
if (DBG) {
|
|
ALOGD("%u Calling SSL_connect", mMark);
|
|
}
|
|
int ret = SSL_connect(ssl.get());
|
|
if (DBG) {
|
|
ALOGD("%u SSL_connect returned %d", mMark, ret);
|
|
}
|
|
if (ret == 1) break; // SSL handshake complete;
|
|
|
|
const int ssl_err = SSL_get_error(ssl.get(), ret);
|
|
switch (ssl_err) {
|
|
case SSL_ERROR_WANT_READ:
|
|
if (waitForReading(fd) != 1) {
|
|
ALOGW("SSL_connect read error");
|
|
return nullptr;
|
|
}
|
|
break;
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if (waitForWriting(fd) != 1) {
|
|
ALOGW("SSL_connect write error");
|
|
return nullptr;
|
|
}
|
|
break;
|
|
default:
|
|
ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
if (!mFingerprints.empty()) {
|
|
if (DBG) {
|
|
ALOGD("Checking DNS over TLS fingerprint");
|
|
}
|
|
// TODO: Follow the cert chain and check all the way up.
|
|
bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get()));
|
|
if (!cert) {
|
|
ALOGW("Server has null certificate");
|
|
return nullptr;
|
|
}
|
|
std::vector<uint8_t> digest;
|
|
if (!getSPKIDigest(cert.get(), &digest)) {
|
|
ALOGE("Digest computation failed");
|
|
return nullptr;
|
|
}
|
|
|
|
if (mFingerprints.count(digest) == 0) {
|
|
ALOGW("No matching fingerprint");
|
|
return nullptr;
|
|
}
|
|
if (DBG) {
|
|
ALOGD("DNS over TLS fingerprint is correct");
|
|
}
|
|
}
|
|
|
|
if (DBG) {
|
|
ALOGD("%u handshake complete", mMark);
|
|
}
|
|
return ssl.release();
|
|
}
|
|
|
|
bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
|
|
if (DBG) {
|
|
ALOGD("%u Writing %d bytes", mMark, len);
|
|
}
|
|
for (;;) {
|
|
int ret = SSL_write(ssl, buffer, len);
|
|
if (ret == len) break; // SSL write complete;
|
|
|
|
if (ret < 1) {
|
|
const int ssl_err = SSL_get_error(ssl, ret);
|
|
switch (ssl_err) {
|
|
case SSL_ERROR_WANT_WRITE:
|
|
if (waitForWriting(fd) != 1) {
|
|
if (DBG) {
|
|
ALOGW("SSL_write error");
|
|
}
|
|
return false;
|
|
}
|
|
continue;
|
|
case 0:
|
|
break; // SSL write complete;
|
|
default:
|
|
if (DBG) {
|
|
ALOGW("SSL_write error %d", ssl_err);
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
if (DBG) {
|
|
ALOGD("%u Wrote %d bytes", mMark, len);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Read exactly len bytes into buffer or fail
|
|
bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
|
|
int remaining = len;
|
|
while (remaining > 0) {
|
|
int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
|
|
if (ret == 0) {
|
|
ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
|
|
return false;
|
|
}
|
|
|
|
if (ret < 0) {
|
|
const int ssl_err = SSL_get_error(ssl, ret);
|
|
if (ssl_err == SSL_ERROR_WANT_READ) {
|
|
if (waitForReading(fd) != 1) {
|
|
if (DBG) {
|
|
ALOGW("SSL_read error");
|
|
}
|
|
return false;
|
|
}
|
|
continue;
|
|
} else {
|
|
if (DBG) {
|
|
ALOGW("SSL_read error %d", ssl_err);
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
|
|
remaining -= ret;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
|
|
uint8_t *response, size_t limit, int *resplen) {
|
|
*resplen = 0; // Zero indicates an error.
|
|
|
|
if (DBG) {
|
|
ALOGD("%u connecting TCP socket", mMark);
|
|
}
|
|
android::base::unique_fd fd(makeConnectedSocket());
|
|
if (DBG) {
|
|
ALOGD("%u connecting SSL", mMark);
|
|
}
|
|
bssl::UniquePtr<SSL> ssl(sslConnect(fd));
|
|
if (ssl == nullptr) {
|
|
if (DBG) {
|
|
ALOGW("%u SSL connection failed", mMark);
|
|
}
|
|
return Response::network_error;
|
|
}
|
|
|
|
uint8_t queryHeader[2];
|
|
queryHeader[0] = qlen >> 8;
|
|
queryHeader[1] = qlen;
|
|
if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
|
|
return Response::network_error;
|
|
}
|
|
if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
|
|
return Response::network_error;
|
|
}
|
|
if (DBG) {
|
|
ALOGD("%u SSL_write complete", mMark);
|
|
}
|
|
|
|
uint8_t responseHeader[2];
|
|
if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
|
|
if (DBG) {
|
|
ALOGW("%u Failed to read 2-byte length header", mMark);
|
|
}
|
|
return Response::network_error;
|
|
}
|
|
const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
|
|
if (DBG) {
|
|
ALOGD("%u Expecting response of size %i", mMark, responseSize);
|
|
}
|
|
if (responseSize > limit) {
|
|
ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
|
|
return Response::limit_error;
|
|
}
|
|
if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
|
|
if (DBG) {
|
|
ALOGW("%u Failed to read %i bytes", mMark, responseSize);
|
|
}
|
|
return Response::network_error;
|
|
}
|
|
if (DBG) {
|
|
ALOGD("%u SSL_read complete", mMark);
|
|
}
|
|
|
|
if (response[0] != query[0] || response[1] != query[1]) {
|
|
ALOGE("reply query ID != query ID");
|
|
return Response::internal_error;
|
|
}
|
|
|
|
SSL_shutdown(ssl.get());
|
|
|
|
*resplen = responseSize;
|
|
return Response::success;
|
|
}
|
|
|
|
bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
|
|
const std::set<std::vector<uint8_t>>& fingerprints) {
|
|
if (DBG) {
|
|
ALOGD("Beginning validation on %u", netid);
|
|
}
|
|
// Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
|
|
// order to prove that it is actually a working DNS over TLS server.
|
|
static const char kDnsSafeChars[] =
|
|
"abcdefhijklmnopqrstuvwxyz"
|
|
"ABCDEFHIJKLMNOPQRSTUVWXYZ"
|
|
"0123456789";
|
|
const auto c = [](uint8_t rnd) -> uint8_t {
|
|
return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
|
|
};
|
|
uint8_t rnd[8];
|
|
arc4random_buf(rnd, ARRAY_SIZE(rnd));
|
|
// We could try to use res_mkquery() here, but it's basically the same.
|
|
uint8_t query[] = {
|
|
rnd[6], rnd[7], // [0-1] query ID
|
|
1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
|
|
0, 1, // [4-5] QDCOUNT (number of queries)
|
|
0, 0, // [6-7] ANCOUNT (number of answers)
|
|
0, 0, // [8-9] NSCOUNT (number of name server records)
|
|
0, 0, // [10-11] ARCOUNT (number of additional records)
|
|
17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
|
|
'-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
|
|
6, 'm', 'e', 't', 'r', 'i', 'c',
|
|
7, 'g', 's', 't', 'a', 't', 'i', 'c',
|
|
3, 'c', 'o', 'm',
|
|
0, // null terminator of FQDN (root TLD)
|
|
0, ns_t_aaaa, // QTYPE
|
|
0, ns_c_in // QCLASS
|
|
};
|
|
const int qlen = ARRAY_SIZE(query);
|
|
|
|
const int kRecvBufSize = 4 * 1024;
|
|
uint8_t recvbuf[kRecvBufSize];
|
|
|
|
// At validation time, we only know the netId, so we have to guess/compute the
|
|
// corresponding socket mark.
|
|
Fwmark fwmark;
|
|
fwmark.permission = PERMISSION_SYSTEM;
|
|
fwmark.explicitlySelected = true;
|
|
fwmark.protectedFromVpn = true;
|
|
fwmark.netId = netid;
|
|
unsigned mark = fwmark.intValue;
|
|
DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
|
|
int replylen = 0;
|
|
xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
|
|
if (replylen == 0) {
|
|
if (DBG) {
|
|
ALOGD("doQuery failed");
|
|
}
|
|
return false;
|
|
}
|
|
|
|
if (replylen < NS_HFIXEDSZ) {
|
|
if (DBG) {
|
|
ALOGW("short response: %d", replylen);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
|
|
if (qdcount != 1) {
|
|
ALOGW("reply query count != 1: %d", qdcount);
|
|
return false;
|
|
}
|
|
|
|
const int ancount = (recvbuf[6] << 8) | recvbuf[7];
|
|
if (DBG) {
|
|
ALOGD("%u answer count: %d", netid, ancount);
|
|
}
|
|
|
|
// TODO: Further validate the response contents (check for valid AAAA record, ...).
|
|
// Note that currently, integration tests rely on this function accepting a
|
|
// response with zero records.
|
|
#if 0
|
|
for (int i = 0; i < resplen; i++) {
|
|
ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
|
|
}
|
|
#endif
|
|
return true;
|
|
}
|
|
|
|
} // namespace net
|
|
} // namespace android
|