/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "dns/DnsTlsTransport.h" #include #include #include #include #include #include #define LOG_TAG "DnsTlsTransport" #define DBG 0 #include "log/log.h" #include "Fwmark.h" #undef ADD // already defined in nameser.h #include "NetdConstants.h" #include "Permission.h" namespace android { namespace net { namespace { bool setNonBlocking(int fd, bool enabled) { int flags = fcntl(fd, F_GETFL); if (flags < 0) return false; if (enabled) { flags |= O_NONBLOCK; } else { flags &= ~O_NONBLOCK; } return (fcntl(fd, F_SETFL, flags) == 0); } int waitForReading(int fd) { fd_set fds; FD_ZERO(&fds); FD_SET(fd, &fds); const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr)); if (DBG && ret <= 0) { ALOGD("select"); } return ret; } int waitForWriting(int fd) { fd_set fds; FD_ZERO(&fds); FD_SET(fd, &fds); const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr)); if (DBG && ret <= 0) { ALOGD("select"); } return ret; } } // namespace android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const { android::base::unique_fd fd; int type = SOCK_NONBLOCK | SOCK_CLOEXEC; switch (mProtocol) { case IPPROTO_TCP: type |= SOCK_STREAM; break; default: errno = EPROTONOSUPPORT; return fd; } fd.reset(socket(mAddr.ss_family, type, mProtocol)); if (fd.get() == -1) { return fd; } const socklen_t len = sizeof(mMark); if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) { fd.reset(); } else if (connect(fd.get(), reinterpret_cast(&mAddr), sizeof(mAddr)) != 0 && errno != EINPROGRESS) { fd.reset(); } return fd; } bool getSPKIDigest(const X509* cert, std::vector* out) { int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL); unsigned char spki[spki_len]; unsigned char* temp = spki; if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) { ALOGW("SPKI length mismatch"); return false; } out->resize(SHA256_SIZE); unsigned int digest_len = 0; int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL); if (ret != 1) { ALOGW("Server cert digest extraction failed"); return false; } if (digest_len != out->size()) { ALOGW("Wrong digest length: %d", digest_len); return false; } return true; } SSL* DnsTlsTransport::sslConnect(int fd) { if (fd < 0) { ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno)); return nullptr; } // Set up TLS context. bssl::UniquePtr ssl_ctx(SSL_CTX_new(TLS_method())); if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) || !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) { ALOGD("failed to min/max TLS versions"); return nullptr; } bssl::UniquePtr ssl(SSL_new(ssl_ctx.get())); bssl::UniquePtr bio(BIO_new_socket(fd, BIO_CLOSE)); SSL_set_bio(ssl.get(), bio.get(), bio.get()); bio.release(); if (!setNonBlocking(fd, false)) { ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd"); return nullptr; } for (;;) { if (DBG) { ALOGD("%u Calling SSL_connect", mMark); } int ret = SSL_connect(ssl.get()); if (DBG) { ALOGD("%u SSL_connect returned %d", mMark, ret); } if (ret == 1) break; // SSL handshake complete; const int ssl_err = SSL_get_error(ssl.get(), ret); switch (ssl_err) { case SSL_ERROR_WANT_READ: if (waitForReading(fd) != 1) { ALOGW("SSL_connect read error"); return nullptr; } break; case SSL_ERROR_WANT_WRITE: if (waitForWriting(fd) != 1) { ALOGW("SSL_connect write error"); return nullptr; } break; default: ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno); return nullptr; } } if (!mFingerprints.empty()) { if (DBG) { ALOGD("Checking DNS over TLS fingerprint"); } // TODO: Follow the cert chain and check all the way up. bssl::UniquePtr cert(SSL_get_peer_certificate(ssl.get())); if (!cert) { ALOGW("Server has null certificate"); return nullptr; } std::vector digest; if (!getSPKIDigest(cert.get(), &digest)) { ALOGE("Digest computation failed"); return nullptr; } if (mFingerprints.count(digest) == 0) { ALOGW("No matching fingerprint"); return nullptr; } if (DBG) { ALOGD("DNS over TLS fingerprint is correct"); } } if (DBG) { ALOGD("%u handshake complete", mMark); } return ssl.release(); } bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) { if (DBG) { ALOGD("%u Writing %d bytes", mMark, len); } for (;;) { int ret = SSL_write(ssl, buffer, len); if (ret == len) break; // SSL write complete; if (ret < 1) { const int ssl_err = SSL_get_error(ssl, ret); switch (ssl_err) { case SSL_ERROR_WANT_WRITE: if (waitForWriting(fd) != 1) { if (DBG) { ALOGW("SSL_write error"); } return false; } continue; case 0: break; // SSL write complete; default: if (DBG) { ALOGW("SSL_write error %d", ssl_err); } return false; } } } if (DBG) { ALOGD("%u Wrote %d bytes", mMark, len); } return true; } // Read exactly len bytes into buffer or fail bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) { int remaining = len; while (remaining > 0) { int ret = SSL_read(ssl, buffer + (len - remaining), remaining); if (ret == 0) { ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len); return false; } if (ret < 0) { const int ssl_err = SSL_get_error(ssl, ret); if (ssl_err == SSL_ERROR_WANT_READ) { if (waitForReading(fd) != 1) { if (DBG) { ALOGW("SSL_read error"); } return false; } continue; } else { if (DBG) { ALOGW("SSL_read error %d", ssl_err); } return false; } } remaining -= ret; } return true; } DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen, uint8_t *response, size_t limit, int *resplen) { *resplen = 0; // Zero indicates an error. if (DBG) { ALOGD("%u connecting TCP socket", mMark); } android::base::unique_fd fd(makeConnectedSocket()); if (DBG) { ALOGD("%u connecting SSL", mMark); } bssl::UniquePtr ssl(sslConnect(fd)); if (ssl == nullptr) { if (DBG) { ALOGW("%u SSL connection failed", mMark); } return Response::network_error; } uint8_t queryHeader[2]; queryHeader[0] = qlen >> 8; queryHeader[1] = qlen; if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) { return Response::network_error; } if (!sslWrite(fd.get(), ssl.get(), query, qlen)) { return Response::network_error; } if (DBG) { ALOGD("%u SSL_write complete", mMark); } uint8_t responseHeader[2]; if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) { if (DBG) { ALOGW("%u Failed to read 2-byte length header", mMark); } return Response::network_error; } const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1]; if (DBG) { ALOGD("%u Expecting response of size %i", mMark, responseSize); } if (responseSize > limit) { ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize); return Response::limit_error; } if (!sslRead(fd.get(), ssl.get(), response, responseSize)) { if (DBG) { ALOGW("%u Failed to read %i bytes", mMark, responseSize); } return Response::network_error; } if (DBG) { ALOGD("%u SSL_read complete", mMark); } if (response[0] != query[0] || response[1] != query[1]) { ALOGE("reply query ID != query ID"); return Response::internal_error; } SSL_shutdown(ssl.get()); *resplen = responseSize; return Response::success; } bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss, const std::set>& fingerprints) { if (DBG) { ALOGD("Beginning validation on %u", netid); } // Generate "-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in // order to prove that it is actually a working DNS over TLS server. static const char kDnsSafeChars[] = "abcdefhijklmnopqrstuvwxyz" "ABCDEFHIJKLMNOPQRSTUVWXYZ" "0123456789"; const auto c = [](uint8_t rnd) -> uint8_t { return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))]; }; uint8_t rnd[8]; arc4random_buf(rnd, ARRAY_SIZE(rnd)); // We could try to use res_mkquery() here, but it's basically the same. uint8_t query[] = { rnd[6], rnd[7], // [0-1] query ID 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). 0, 1, // [4-5] QDCOUNT (number of queries) 0, 0, // [6-7] ANCOUNT (number of answers) 0, 0, // [8-9] NSCOUNT (number of name server records) 0, 0, // [10-11] ARCOUNT (number of additional records) 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's', 6, 'm', 'e', 't', 'r', 'i', 'c', 7, 'g', 's', 't', 'a', 't', 'i', 'c', 3, 'c', 'o', 'm', 0, // null terminator of FQDN (root TLD) 0, ns_t_aaaa, // QTYPE 0, ns_c_in // QCLASS }; const int qlen = ARRAY_SIZE(query); const int kRecvBufSize = 4 * 1024; uint8_t recvbuf[kRecvBufSize]; // At validation time, we only know the netId, so we have to guess/compute the // corresponding socket mark. Fwmark fwmark; fwmark.permission = PERMISSION_SYSTEM; fwmark.explicitlySelected = true; fwmark.protectedFromVpn = true; fwmark.netId = netid; unsigned mark = fwmark.intValue; DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints); int replylen = 0; xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen); if (replylen == 0) { if (DBG) { ALOGD("doQuery failed"); } return false; } if (replylen < NS_HFIXEDSZ) { if (DBG) { ALOGW("short response: %d", replylen); } return false; } const int qdcount = (recvbuf[4] << 8) | recvbuf[5]; if (qdcount != 1) { ALOGW("reply query count != 1: %d", qdcount); return false; } const int ancount = (recvbuf[6] << 8) | recvbuf[7]; if (DBG) { ALOGD("%u answer count: %d", netid, ancount); } // TODO: Further validate the response contents (check for valid AAAA record, ...). // Note that currently, integration tests rely on this function accepting a // response with zero records. #if 0 for (int i = 0; i < resplen; i++) { ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]); } #endif return true; } } // namespace net } // namespace android