upload android base code part6

This commit is contained in:
August 2018-08-08 17:48:24 +08:00
parent 421e214c7d
commit 4e516ec6ed
35396 changed files with 9188716 additions and 0 deletions

View file

@ -0,0 +1,45 @@
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
LOCAL_PATH := $(call my-dir)
# TODO describe library here
include $(CLEAR_VARS)
LOCAL_MODULE := libnetd_test_dnsresponder
LOCAL_CFLAGS := -Wall -Werror -Wunused-parameter
# Bug: http://b/29823425 Disable -Wvarargs for Clang update to r271374
LOCAL_CFLAGS += -Wno-varargs
EXTRA_LDLIBS := -lpthread
LOCAL_SHARED_LIBRARIES += libbase libbinder libcrypto liblog libnetd_client libssl
LOCAL_STATIC_LIBRARIES += libutils
LOCAL_AIDL_INCLUDES += system/netd/server/binder
LOCAL_C_INCLUDES += system/netd/include \
system/netd/server \
system/netd/server/binder \
system/netd/tests/dns_responder \
bionic/libc/dns/include
LOCAL_SRC_FILES := dns_responder.cpp \
dns_responder_client.cpp \
dns_tls_frontend.cpp \
../../server/binder/android/net/INetd.aidl \
../../server/binder/android/net/UidRange.cpp
LOCAL_MODULE_TAGS := eng tests
include $(BUILD_STATIC_LIBRARY)

View file

@ -0,0 +1,842 @@
/*
* Copyright (C) 2016 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dns_responder.h"
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <iostream>
#include <vector>
#define LOG_TAG "DNSResponder"
#include <log/log.h>
namespace test {
std::string errno2str() {
char error_msg[512] = { 0 };
if (strerror_r(errno, error_msg, sizeof(error_msg)))
return std::string();
return std::string(error_msg);
}
#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
std::string str2hex(const char* buffer, size_t len) {
std::string str(len*2, '\0');
for (size_t i = 0 ; i < len ; ++i) {
static const char* hex = "0123456789ABCDEF";
uint8_t c = buffer[i];
str[i*2] = hex[c >> 4];
str[i*2 + 1] = hex[c & 0x0F];
}
return str;
}
std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
char host_str[NI_MAXHOST] = { 0 };
int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
NI_NUMERICHOST);
if (rv == 0) return std::string(host_str);
return std::string();
}
/* DNS struct helpers */
const char* dnstype2str(unsigned dnstype) {
static std::unordered_map<unsigned, const char*> kTypeStrs = {
{ ns_type::ns_t_a, "A" },
{ ns_type::ns_t_ns, "NS" },
{ ns_type::ns_t_md, "MD" },
{ ns_type::ns_t_mf, "MF" },
{ ns_type::ns_t_cname, "CNAME" },
{ ns_type::ns_t_soa, "SOA" },
{ ns_type::ns_t_mb, "MB" },
{ ns_type::ns_t_mb, "MG" },
{ ns_type::ns_t_mr, "MR" },
{ ns_type::ns_t_null, "NULL" },
{ ns_type::ns_t_wks, "WKS" },
{ ns_type::ns_t_ptr, "PTR" },
{ ns_type::ns_t_hinfo, "HINFO" },
{ ns_type::ns_t_minfo, "MINFO" },
{ ns_type::ns_t_mx, "MX" },
{ ns_type::ns_t_txt, "TXT" },
{ ns_type::ns_t_rp, "RP" },
{ ns_type::ns_t_afsdb, "AFSDB" },
{ ns_type::ns_t_x25, "X25" },
{ ns_type::ns_t_isdn, "ISDN" },
{ ns_type::ns_t_rt, "RT" },
{ ns_type::ns_t_nsap, "NSAP" },
{ ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
{ ns_type::ns_t_sig, "SIG" },
{ ns_type::ns_t_key, "KEY" },
{ ns_type::ns_t_px, "PX" },
{ ns_type::ns_t_gpos, "GPOS" },
{ ns_type::ns_t_aaaa, "AAAA" },
{ ns_type::ns_t_loc, "LOC" },
{ ns_type::ns_t_nxt, "NXT" },
{ ns_type::ns_t_eid, "EID" },
{ ns_type::ns_t_nimloc, "NIMLOC" },
{ ns_type::ns_t_srv, "SRV" },
{ ns_type::ns_t_naptr, "NAPTR" },
{ ns_type::ns_t_kx, "KX" },
{ ns_type::ns_t_cert, "CERT" },
{ ns_type::ns_t_a6, "A6" },
{ ns_type::ns_t_dname, "DNAME" },
{ ns_type::ns_t_sink, "SINK" },
{ ns_type::ns_t_opt, "OPT" },
{ ns_type::ns_t_apl, "APL" },
{ ns_type::ns_t_tkey, "TKEY" },
{ ns_type::ns_t_tsig, "TSIG" },
{ ns_type::ns_t_ixfr, "IXFR" },
{ ns_type::ns_t_axfr, "AXFR" },
{ ns_type::ns_t_mailb, "MAILB" },
{ ns_type::ns_t_maila, "MAILA" },
{ ns_type::ns_t_any, "ANY" },
{ ns_type::ns_t_zxfr, "ZXFR" },
};
auto it = kTypeStrs.find(dnstype);
static const char* kUnknownStr{ "UNKNOWN" };
if (it == kTypeStrs.end()) return kUnknownStr;
return it->second;
}
const char* dnsclass2str(unsigned dnsclass) {
static std::unordered_map<unsigned, const char*> kClassStrs = {
{ ns_class::ns_c_in , "Internet" },
{ 2, "CSNet" },
{ ns_class::ns_c_chaos, "ChaosNet" },
{ ns_class::ns_c_hs, "Hesiod" },
{ ns_class::ns_c_none, "none" },
{ ns_class::ns_c_any, "any" }
};
auto it = kClassStrs.find(dnsclass);
static const char* kUnknownStr{ "UNKNOWN" };
if (it == kClassStrs.end()) return kUnknownStr;
return it->second;
return "unknown";
}
struct DNSName {
std::string name;
const char* read(const char* buffer, const char* buffer_end);
char* write(char* buffer, const char* buffer_end) const;
const char* toString() const;
private:
const char* parseField(const char* buffer, const char* buffer_end,
bool* last);
};
const char* DNSName::toString() const {
return name.c_str();
}
const char* DNSName::read(const char* buffer, const char* buffer_end) {
const char* cur = buffer;
bool last = false;
do {
cur = parseField(cur, buffer_end, &last);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
} while (!last);
return cur;
}
char* DNSName::write(char* buffer, const char* buffer_end) const {
char* buffer_cur = buffer;
for (size_t pos = 0 ; pos < name.size() ; ) {
size_t dot_pos = name.find('.', pos);
if (dot_pos == std::string::npos) {
// Sanity check, should never happen unless parseField is broken.
ALOGI("logic error: all names are expected to end with a '.'");
return nullptr;
}
size_t len = dot_pos - pos;
if (len >= 256) {
ALOGI("name component '%s' is %zu long, but max is 255",
name.substr(pos, dot_pos - pos).c_str(), len);
return nullptr;
}
if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
ALOGI("buffer overflow at line %d", __LINE__);
return nullptr;
}
*buffer_cur++ = len;
buffer_cur = std::copy(std::next(name.begin(), pos),
std::next(name.begin(), dot_pos),
buffer_cur);
pos = dot_pos + 1;
}
// Write final zero.
*buffer_cur++ = 0;
return buffer_cur;
}
const char* DNSName::parseField(const char* buffer, const char* buffer_end,
bool* last) {
if (buffer + sizeof(uint8_t) > buffer_end) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
unsigned field_type = *buffer >> 6;
unsigned ofs = *buffer & 0x3F;
const char* cur = buffer + sizeof(uint8_t);
if (field_type == 0) {
// length + name component
if (ofs == 0) {
*last = true;
return cur;
}
if (cur + ofs > buffer_end) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
name.append(cur, ofs);
name.push_back('.');
return cur + ofs;
} else if (field_type == 3) {
ALOGI("name compression not implemented");
return nullptr;
}
ALOGI("invalid name field type");
return nullptr;
}
struct DNSQuestion {
DNSName qname;
unsigned qtype;
unsigned qclass;
const char* read(const char* buffer, const char* buffer_end);
char* write(char* buffer, const char* buffer_end) const;
std::string toString() const;
};
const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
const char* cur = qname.read(buffer, buffer_end);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
if (cur + 2*sizeof(uint16_t) > buffer_end) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
return cur + 2*sizeof(uint16_t);
}
char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
char* buffer_cur = qname.write(buffer, buffer_end);
if (buffer_cur == nullptr) return nullptr;
if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
ALOGI("buffer overflow on line %d", __LINE__);
return nullptr;
}
*reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
*reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
htons(qclass);
return buffer_cur + 2*sizeof(uint16_t);
}
std::string DNSQuestion::toString() const {
char buffer[4096];
int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
dnstype2str(qtype), dnsclass2str(qclass));
return std::string(buffer, len);
}
struct DNSRecord {
DNSName name;
unsigned rtype;
unsigned rclass;
unsigned ttl;
std::vector<char> rdata;
const char* read(const char* buffer, const char* buffer_end);
char* write(char* buffer, const char* buffer_end) const;
std::string toString() const;
private:
struct IntFields {
uint16_t rtype;
uint16_t rclass;
uint32_t ttl;
uint16_t rdlen;
} __attribute__((__packed__));
const char* readIntFields(const char* buffer, const char* buffer_end,
unsigned* rdlen);
char* writeIntFields(unsigned rdlen, char* buffer,
const char* buffer_end) const;
};
const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
const char* cur = name.read(buffer, buffer_end);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
unsigned rdlen = 0;
cur = readIntFields(cur, buffer_end, &rdlen);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
if (cur + rdlen > buffer_end) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
rdata.assign(cur, cur + rdlen);
return cur + rdlen;
}
char* DNSRecord::write(char* buffer, const char* buffer_end) const {
char* buffer_cur = name.write(buffer, buffer_end);
if (buffer_cur == nullptr) return nullptr;
buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
if (buffer_cur == nullptr) return nullptr;
if (buffer_cur + rdata.size() > buffer_end) {
ALOGI("buffer overflow on line %d", __LINE__);
return nullptr;
}
return std::copy(rdata.begin(), rdata.end(), buffer_cur);
}
std::string DNSRecord::toString() const {
char buffer[4096];
int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
dnstype2str(rtype), dnsclass2str(rclass));
return std::string(buffer, len);
}
const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
unsigned* rdlen) {
if (buffer + sizeof(IntFields) > buffer_end ) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
rtype = ntohs(intfields.rtype);
rclass = ntohs(intfields.rclass);
ttl = ntohl(intfields.ttl);
*rdlen = ntohs(intfields.rdlen);
return buffer + sizeof(IntFields);
}
char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
const char* buffer_end) const {
if (buffer + sizeof(IntFields) > buffer_end ) {
ALOGI("buffer overflow on line %d", __LINE__);
return nullptr;
}
auto& intfields = *reinterpret_cast<IntFields*>(buffer);
intfields.rtype = htons(rtype);
intfields.rclass = htons(rclass);
intfields.ttl = htonl(ttl);
intfields.rdlen = htons(rdlen);
return buffer + sizeof(IntFields);
}
struct DNSHeader {
unsigned id;
bool ra;
uint8_t rcode;
bool qr;
uint8_t opcode;
bool aa;
bool tr;
bool rd;
std::vector<DNSQuestion> questions;
std::vector<DNSRecord> answers;
std::vector<DNSRecord> authorities;
std::vector<DNSRecord> additionals;
const char* read(const char* buffer, const char* buffer_end);
char* write(char* buffer, const char* buffer_end) const;
std::string toString() const;
private:
struct Header {
uint16_t id;
uint8_t flags0;
uint8_t flags1;
uint16_t qdcount;
uint16_t ancount;
uint16_t nscount;
uint16_t arcount;
} __attribute__((__packed__));
const char* readHeader(const char* buffer, const char* buffer_end,
unsigned* qdcount, unsigned* ancount,
unsigned* nscount, unsigned* arcount);
};
const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
unsigned qdcount;
unsigned ancount;
unsigned nscount;
unsigned arcount;
const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
&nscount, &arcount);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
if (qdcount) {
questions.resize(qdcount);
for (unsigned i = 0 ; i < qdcount ; ++i) {
cur = questions[i].read(cur, buffer_end);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
}
}
if (ancount) {
answers.resize(ancount);
for (unsigned i = 0 ; i < ancount ; ++i) {
cur = answers[i].read(cur, buffer_end);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
}
}
if (nscount) {
authorities.resize(nscount);
for (unsigned i = 0 ; i < nscount ; ++i) {
cur = authorities[i].read(cur, buffer_end);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
}
}
if (arcount) {
additionals.resize(arcount);
for (unsigned i = 0 ; i < arcount ; ++i) {
cur = additionals[i].read(cur, buffer_end);
if (cur == nullptr) {
ALOGI("parsing failed at line %d", __LINE__);
return nullptr;
}
}
}
return cur;
}
char* DNSHeader::write(char* buffer, const char* buffer_end) const {
if (buffer + sizeof(Header) > buffer_end) {
ALOGI("buffer overflow on line %d", __LINE__);
return nullptr;
}
Header& header = *reinterpret_cast<Header*>(buffer);
// bytes 0-1
header.id = htons(id);
// byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
// byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
header.flags1 = rcode;
// rest of header
header.qdcount = htons(questions.size());
header.ancount = htons(answers.size());
header.nscount = htons(authorities.size());
header.arcount = htons(additionals.size());
char* buffer_cur = buffer + sizeof(Header);
for (const DNSQuestion& question : questions) {
buffer_cur = question.write(buffer_cur, buffer_end);
if (buffer_cur == nullptr) return nullptr;
}
for (const DNSRecord& answer : answers) {
buffer_cur = answer.write(buffer_cur, buffer_end);
if (buffer_cur == nullptr) return nullptr;
}
for (const DNSRecord& authority : authorities) {
buffer_cur = authority.write(buffer_cur, buffer_end);
if (buffer_cur == nullptr) return nullptr;
}
for (const DNSRecord& additional : additionals) {
buffer_cur = additional.write(buffer_cur, buffer_end);
if (buffer_cur == nullptr) return nullptr;
}
return buffer_cur;
}
std::string DNSHeader::toString() const {
// TODO
return std::string();
}
const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
unsigned* qdcount, unsigned* ancount,
unsigned* nscount, unsigned* arcount) {
if (buffer + sizeof(Header) > buffer_end)
return 0;
const auto& header = *reinterpret_cast<const Header*>(buffer);
// bytes 0-1
id = ntohs(header.id);
// byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
qr = header.flags0 >> 7;
opcode = (header.flags0 >> 3) & 0x0F;
aa = (header.flags0 >> 2) & 1;
tr = (header.flags0 >> 1) & 1;
rd = header.flags0 & 1;
// byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
ra = header.flags1 >> 7;
rcode = header.flags1 & 0xF;
// rest of header
*qdcount = ntohs(header.qdcount);
*ancount = ntohs(header.ancount);
*nscount = ntohs(header.nscount);
*arcount = ntohs(header.arcount);
return buffer + sizeof(Header);
}
/* DNS responder */
DNSResponder::DNSResponder(std::string listen_address,
std::string listen_service, int poll_timeout_ms,
uint16_t error_rcode, double response_probability) :
listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
response_probability_(response_probability),
socket_(-1), epoll_fd_(-1), terminate_(false) { }
DNSResponder::~DNSResponder() {
stopServer();
}
void DNSResponder::addMapping(const char* name, ns_type type,
const char* addr) {
std::lock_guard<std::mutex> lock(mappings_mutex_);
auto it = mappings_.find(QueryKey(name, type));
if (it != mappings_.end()) {
ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
"address %s", name, dnstype2str(type), it->second.c_str(),
addr);
it->second = addr;
return;
}
mappings_.emplace(std::piecewise_construct,
std::forward_as_tuple(name, type),
std::forward_as_tuple(addr));
}
void DNSResponder::removeMapping(const char* name, ns_type type) {
std::lock_guard<std::mutex> lock(mappings_mutex_);
auto it = mappings_.find(QueryKey(name, type));
if (it != mappings_.end()) {
ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
dnstype2str(type));
return;
}
mappings_.erase(it);
}
void DNSResponder::setResponseProbability(double response_probability) {
response_probability_ = response_probability;
}
bool DNSResponder::running() const {
return socket_ != -1;
}
bool DNSResponder::startServer() {
if (running()) {
ALOGI("server already running");
return false;
}
addrinfo ai_hints{
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_DGRAM,
.ai_flags = AI_PASSIVE
};
addrinfo* ai_res;
int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
&ai_hints, &ai_res);
if (rv) {
ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
listen_service_.c_str(), gai_strerror(rv));
return false;
}
int s = -1;
for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
if (s < 0) continue;
const int one = 1;
setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
APLOGI("bind failed for socket %d", s);
close(s);
s = -1;
continue;
}
std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
break;
}
freeaddrinfo(ai_res);
if (s < 0) {
ALOGI("bind() failed");
return false;
}
int flags = fcntl(s, F_GETFL, 0);
if (flags < 0) flags = 0;
if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
APLOGI("fcntl(F_SETFL) failed for socket %d", s);
close(s);
return false;
}
int ep_fd = epoll_create(1);
if (ep_fd < 0) {
char error_msg[512] = { 0 };
if (strerror_r(errno, error_msg, sizeof(error_msg)))
strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
APLOGI("epoll_create() failed: %s", error_msg);
close(s);
return false;
}
epoll_event ev;
ev.events = EPOLLIN;
ev.data.fd = s;
if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
APLOGI("epoll_ctl() failed for socket %d", s);
close(ep_fd);
close(s);
return false;
}
epoll_fd_ = ep_fd;
socket_ = s;
{
std::lock_guard<std::mutex> lock(update_mutex_);
handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
}
ALOGI("server started successfully");
return true;
}
bool DNSResponder::stopServer() {
std::lock_guard<std::mutex> lock(update_mutex_);
if (!running()) {
ALOGI("server not running");
return false;
}
if (terminate_) {
ALOGI("LOGIC ERROR");
return false;
}
ALOGI("stopping server");
terminate_ = true;
handler_thread_.join();
close(epoll_fd_);
close(socket_);
terminate_ = false;
socket_ = -1;
ALOGI("server stopped successfully");
return true;
}
std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
std::lock_guard<std::mutex> lock(queries_mutex_);
return queries_;
}
void DNSResponder::clearQueries() {
std::lock_guard<std::mutex> lock(queries_mutex_);
queries_.clear();
}
void DNSResponder::requestHandler() {
epoll_event evs[1];
while (!terminate_) {
int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
if (n == 0) continue;
if (n < 0) {
ALOGI("epoll_wait() failed");
// TODO(imaipi): terminate on error.
return;
}
char buffer[4096];
sockaddr_storage sa;
socklen_t sa_len = sizeof(sa);
ssize_t len;
do {
len = recvfrom(socket_, buffer, sizeof(buffer), 0,
(sockaddr*) &sa, &sa_len);
} while (len < 0 && (errno == EAGAIN || errno == EINTR));
if (len <= 0) {
ALOGI("recvfrom() failed");
continue;
}
ALOGI("read %zd bytes", len);
char response[4096];
size_t response_len = sizeof(response);
if (handleDNSRequest(buffer, len, response, &response_len) &&
response_len > 0) {
len = sendto(socket_, response, response_len, 0,
reinterpret_cast<const sockaddr*>(&sa), sa_len);
std::string host_str =
addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
if (len > 0) {
ALOGI("sent %zu bytes to %s", len, host_str.c_str());
} else {
APLOGI("sendto() failed for %s", host_str.c_str());
}
// Test that the response is actually a correct DNS message.
const char* response_end = response + len;
DNSHeader header;
const char* cur = header.read(response, response_end);
if (cur == nullptr) ALOGI("response is flawed");
} else {
ALOGI("not responding");
}
}
}
bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
char* response, size_t* response_len)
const {
ALOGI("request: '%s'", str2hex(buffer, len).c_str());
const char* buffer_end = buffer + len;
DNSHeader header;
const char* cur = header.read(buffer, buffer_end);
// TODO(imaipi): for now, unparsable messages are silently dropped, fix.
if (cur == nullptr) {
ALOGI("failed to parse query");
return false;
}
if (header.qr) {
ALOGI("response received instead of a query");
return false;
}
if (header.opcode != ns_opcode::ns_o_query) {
ALOGI("unsupported request opcode received");
return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
response_len);
}
if (header.questions.empty()) {
ALOGI("no questions present");
return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
response_len);
}
if (!header.answers.empty()) {
ALOGI("already %zu answers present in query", header.answers.size());
return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
response_len);
}
{
std::lock_guard<std::mutex> lock(queries_mutex_);
for (const DNSQuestion& question : header.questions) {
queries_.push_back(make_pair(question.qname.name,
ns_type(question.qtype)));
}
}
// Ignore requests with the preset probability.
auto constexpr bound = std::numeric_limits<unsigned>::max();
if (arc4random_uniform(bound) > bound*response_probability_) {
ALOGI("returning SRVFAIL in accordance with probability distribution");
return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
response_len);
}
for (const DNSQuestion& question : header.questions) {
if (question.qclass != ns_class::ns_c_in &&
question.qclass != ns_class::ns_c_any) {
ALOGI("unsupported question class %u", question.qclass);
return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
response_len);
}
if (!addAnswerRecords(question, &header.answers)) {
return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
response_len);
}
}
header.qr = true;
char* response_cur = header.write(response, response + *response_len);
if (response_cur == nullptr) {
return false;
}
*response_len = response_cur - response;
return true;
}
bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
std::vector<DNSRecord>* answers) const {
auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
if (it == mappings_.end()) {
// TODO(imaipi): handle correctly
ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
question.qname.name.c_str(), dnstype2str(question.qtype));
return true;
}
ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
dnstype2str(question.qtype), it->second.c_str());
DNSRecord record;
record.name = question.qname;
record.rtype = question.qtype;
record.rclass = ns_class::ns_c_in;
record.ttl = 5; // seconds
if (question.qtype == ns_type::ns_t_a) {
record.rdata.resize(4);
if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
return false;
}
} else if (question.qtype == ns_type::ns_t_aaaa) {
record.rdata.resize(16);
if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
return false;
}
} else {
ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
return false;
}
answers->push_back(std::move(record));
return true;
}
bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
char* response, size_t* response_len)
const {
header->answers.clear();
header->authorities.clear();
header->additionals.clear();
header->rcode = rcode;
header->qr = true;
char* response_cur = header->write(response, response + *response_len);
if (response_cur == nullptr) return false;
*response_len = response_cur - response;
return true;
}
} // namespace test

View file

@ -0,0 +1,138 @@
/*
* Copyright (C) 2016 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless requied by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#ifndef DNS_RESPONDER_H
#define DNS_RESPONDER_H
#include <arpa/nameser.h>
#include <atomic>
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
#include <android-base/thread_annotations.h>
namespace test {
struct DNSHeader;
struct DNSQuestion;
struct DNSRecord;
/*
* Simple DNS responder, which replies to queries with the registered response
* for that type. Class is assumed to be IN. If no response is registered, the
* default error response code is returned.
*/
class DNSResponder {
public:
DNSResponder(std::string listen_address, std::string listen_service,
int poll_timeout_ms, uint16_t error_rcode,
double response_probability);
~DNSResponder();
void addMapping(const char* name, ns_type type, const char* addr);
void removeMapping(const char* name, ns_type type);
void setResponseProbability(double response_probability);
bool running() const;
bool startServer();
bool stopServer();
const std::string& listen_address() const {
return listen_address_;
}
const std::string& listen_service() const {
return listen_service_;
}
std::vector<std::pair<std::string, ns_type>> queries() const;
void clearQueries();
private:
// Key used for accessing mappings.
struct QueryKey {
std::string name;
unsigned type;
QueryKey(std::string n, unsigned t) : name(n), type(t) {}
bool operator == (const QueryKey& o) const {
return name == o.name && type == o.type;
}
bool operator < (const QueryKey& o) const {
if (name < o.name) return true;
if (name > o.name) return false;
return type < o.type;
}
};
struct QueryKeyHash {
size_t operator() (const QueryKey& key) const {
return std::hash<std::string>()(key.name) +
static_cast<size_t>(key.type);
}
};
// DNS request handler.
void requestHandler();
// Parses and generates a response message for incoming DNS requests.
// Returns false on parsing errors.
bool handleDNSRequest(const char* buffer, ssize_t buffer_len,
char* response, size_t* response_len) const;
bool addAnswerRecords(const DNSQuestion& question,
std::vector<DNSRecord>* answers) const;
bool generateErrorResponse(DNSHeader* header, ns_rcode rcode,
char* response, size_t* response_len) const;
bool makeErrorResponse(DNSHeader* header, ns_rcode rcode, char* response,
size_t* response_len) const;
// Address and service to listen on, currently limited to UDP.
const std::string listen_address_;
const std::string listen_service_;
// epoll_wait() timeout in ms.
const int poll_timeout_ms_;
// Error code to return for requests for an unknown name.
const uint16_t error_rcode_;
// Probability that a valid response is being sent instead of being sent
// instead of returning error_rcode_.
std::atomic<double> response_probability_;
// Mappings from (name, type) to registered response and the
// mutex protecting them.
std::unordered_map<QueryKey, std::string, QueryKeyHash> mappings_
GUARDED_BY(mappings_mutex_);
// TODO(imaipi): enable GUARDED_BY(mappings_mutex_);
std::mutex mappings_mutex_;
// Query names received so far and the corresponding mutex.
mutable std::vector<std::pair<std::string, ns_type>> queries_
GUARDED_BY(queries_mutex_);
mutable std::mutex queries_mutex_;
// Socket on which the server is listening.
int socket_;
// File descriptor for epoll.
int epoll_fd_;
// Signal for request handler termination.
std::atomic<bool> terminate_ GUARDED_BY(update_mutex_);
// Thread for handling incoming threads.
std::thread handler_thread_ GUARDED_BY(update_mutex_);
std::mutex update_mutex_;
};
} // namespace test
#endif // DNS_RESPONDER_H

View file

@ -0,0 +1,188 @@
/*
* Copyright (C) 2016 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dns_responder_client.h"
#include <android-base/stringprintf.h>
// TODO: make this dynamic and stop depending on implementation details.
#define TEST_OEM_NETWORK "oem29"
#define TEST_NETID 30
// TODO: move this somewhere shared.
static const char* ANDROID_DNS_MODE = "ANDROID_DNS_MODE";
// The only response code used in this class. See
// frameworks/base/services/java/com/android/server/NetworkManagementService.java
// for others.
static constexpr int ResponseCodeOK = 200;
using android::base::StringPrintf;
static int netdCommand(const char* sockname, const char* command) {
int sock = socket_local_client(sockname,
ANDROID_SOCKET_NAMESPACE_RESERVED,
SOCK_STREAM);
if (sock < 0) {
perror("Error connecting");
return -1;
}
// FrameworkListener expects the whole command in one read.
char buffer[256];
int nwritten = snprintf(buffer, sizeof(buffer), "0 %s", command);
if (write(sock, buffer, nwritten + 1) < 0) {
perror("Error sending netd command");
close(sock);
return -1;
}
int nread = read(sock, buffer, sizeof(buffer));
if (nread < 0) {
perror("Error reading response");
close(sock);
return -1;
}
close(sock);
return atoi(buffer);
}
static bool expectNetdResult(int expected, const char* sockname, const char* format, ...) {
char command[256];
va_list args;
va_start(args, format);
vsnprintf(command, sizeof(command), format, args);
va_end(args);
int result = netdCommand(sockname, command);
if (expected != result) {
return false;
}
return (200 <= expected && expected < 300);
}
void DnsResponderClient::SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
std::vector<Mapping>* mappings) {
mappings->resize(num_hosts * domains.size());
auto mappings_it = mappings->begin();
for (unsigned i = 0 ; i < num_hosts ; ++i) {
for (const auto& domain : domains) {
mappings_it->host = StringPrintf("host%u", i);
mappings_it->entry = StringPrintf("%s.%s.", mappings_it->host.c_str(),
domain.c_str());
mappings_it->ip4 = StringPrintf("192.0.2.%u", i%253 + 1);
mappings_it->ip6 = StringPrintf("2001:db8::%x", i%65534 + 1);
++mappings_it;
}
}
}
bool DnsResponderClient::SetResolversForNetwork(const std::vector<std::string>& servers,
const std::vector<std::string>& domains, const std::vector<int>& params) {
auto rv = mNetdSrv->setResolverConfiguration(TEST_NETID, servers, domains, params);
return rv.isOk();
}
bool DnsResponderClient::SetResolversForNetwork(const std::vector<std::string>& searchDomains,
const std::vector<std::string>& servers, const std::string& params) {
std::string cmd = StringPrintf("resolver setnetdns %d \"", mOemNetId);
if (!searchDomains.empty()) {
cmd += searchDomains[0].c_str();
for (size_t i = 1 ; i < searchDomains.size() ; ++i) {
cmd += " ";
cmd += searchDomains[i];
}
}
cmd += "\"";
for (const auto& str : servers) {
cmd += " ";
cmd += str;
}
if (!params.empty()) {
cmd += " --params \"";
cmd += params;
cmd += "\"";
}
int rv = netdCommand("netd", cmd.c_str());
if (rv != ResponseCodeOK) {
return false;
}
return true;
}
void DnsResponderClient::SetupDNSServers(unsigned num_servers, const std::vector<Mapping>& mappings,
std::vector<std::unique_ptr<test::DNSResponder>>* dns,
std::vector<std::string>* servers) {
const char* listen_srv = "53";
dns->resize(num_servers);
servers->resize(num_servers);
for (unsigned i = 0 ; i < num_servers ; ++i) {
auto& server = (*servers)[i];
auto& d = (*dns)[i];
server = StringPrintf("127.0.0.%u", i + 100);
d = std::make_unique<test::DNSResponder>(server, listen_srv, 250,
ns_rcode::ns_r_servfail, 1.0);
for (const auto& mapping : mappings) {
d->addMapping(mapping.entry.c_str(), ns_type::ns_t_a, mapping.ip4.c_str());
d->addMapping(mapping.entry.c_str(), ns_type::ns_t_aaaa, mapping.ip6.c_str());
}
d->startServer();
}
}
void DnsResponderClient::ShutdownDNSServers(std::vector<std::unique_ptr<test::DNSResponder>>* dns) {
for (const auto& d : *dns) {
d->stopServer();
}
dns->clear();
}
int DnsResponderClient::SetupOemNetwork() {
netdCommand("netd", "network destroy " TEST_OEM_NETWORK);
if (!expectNetdResult(ResponseCodeOK, "netd",
"network create %s", TEST_OEM_NETWORK)) {
return -1;
}
int oemNetId = TEST_NETID;
setNetworkForProcess(oemNetId);
if ((unsigned) oemNetId != getNetworkForProcess()) {
return -1;
}
return oemNetId;
}
void DnsResponderClient::TearDownOemNetwork(int oemNetId) {
if (oemNetId != -1) {
expectNetdResult(ResponseCodeOK, "netd",
"network destroy %s", TEST_OEM_NETWORK);
}
}
void DnsResponderClient::SetUp() {
// Ensure resolutions go via proxy.
setenv(ANDROID_DNS_MODE, "", 1);
mOemNetId = SetupOemNetwork();
// binder setup
auto binder = android::defaultServiceManager()->getService(android::String16("netd"));
mNetdSrv = android::interface_cast<android::net::INetd>(binder);
}
void DnsResponderClient::TearDown() {
TearDownOemNetwork(mOemNetId);
}

View file

@ -0,0 +1,54 @@
#ifndef DNS_RESPONDER_CLIENT_H
#define DNS_RESPONDER_CLIENT_H
#include <cutils/sockets.h>
#include <private/android_filesystem_config.h>
#include <utils/StrongPointer.h>
#include "android/net/INetd.h"
#include "binder/IServiceManager.h"
#include "NetdClient.h"
#include "dns_responder.h"
#include "resolv_params.h"
class DnsResponderClient {
public:
struct Mapping {
std::string host;
std::string entry;
std::string ip4;
std::string ip6;
};
virtual ~DnsResponderClient() = default;
void SetupMappings(unsigned num_hosts, const std::vector<std::string>& domains,
std::vector<Mapping>* mappings);
bool SetResolversForNetwork(const std::vector<std::string>& servers,
const std::vector<std::string>& domains, const std::vector<int>& params);
bool SetResolversForNetwork(const std::vector<std::string>& searchDomains,
const std::vector<std::string>& servers, const std::string& params);
static void SetupDNSServers(unsigned num_servers, const std::vector<Mapping>& mappings,
std::vector<std::unique_ptr<test::DNSResponder>>* dns,
std::vector<std::string>* servers);
static void ShutdownDNSServers(std::vector<std::unique_ptr<test::DNSResponder>>* dns);
static int SetupOemNetwork();
static void TearDownOemNetwork(int oemNetId);
virtual void SetUp();
virtual void TearDown();
public:
android::sp<android::net::INetd> mNetdSrv = nullptr;
int mOemNetId = -1;
};
#endif // DNS_RESPONDER_CLIENT_H

View file

@ -0,0 +1,374 @@
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dns_tls_frontend.h"
#include <netdb.h>
#include <stdio.h>
#include <unistd.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <arpa/inet.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
#define LOG_TAG "DnsTlsFrontend"
#include <log/log.h>
#include <unistd.h>
namespace {
const int SHA256_SIZE = 32;
// Copied from DnsTlsTransport.
bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
unsigned char spki[spki_len];
unsigned char* temp = spki;
if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
ALOGE("SPKI length mismatch");
return false;
}
out->resize(SHA256_SIZE);
unsigned int digest_len = 0;
int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
if (ret != 1) {
ALOGE("Server cert digest extraction failed");
return false;
}
if (digest_len != out->size()) {
ALOGE("Wrong digest length: %d", digest_len);
return false;
}
return true;
}
std::string errno2str() {
char error_msg[512] = { 0 };
if (strerror_r(errno, error_msg, sizeof(error_msg)))
return std::string();
return std::string(error_msg);
}
#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
char host_str[NI_MAXHOST] = { 0 };
int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
NI_NUMERICHOST);
if (rv == 0) return std::string(host_str);
return std::string();
}
bssl::UniquePtr<EVP_PKEY> make_private_key() {
bssl::UniquePtr<BIGNUM> e(BN_new());
if (!e) {
ALOGE("BN_new failed");
return nullptr;
}
if (!BN_set_word(e.get(), RSA_F4)) {
ALOGE("BN_set_word failed");
return nullptr;
}
bssl::UniquePtr<RSA> rsa(RSA_new());
if (!rsa) {
ALOGE("RSA_new failed");
return nullptr;
}
if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), NULL)) {
ALOGE("RSA_generate_key_ex failed");
return nullptr;
}
bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
if (!privkey) {
ALOGE("EVP_PKEY_new failed");
return nullptr;
}
if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
ALOGE("EVP_PKEY_assign_RSA failed");
return nullptr;
}
// |rsa| is now owned by |privkey|, so no need to free it.
rsa.release();
return privkey;
}
bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey) {
bssl::UniquePtr<X509> cert(X509_new());
if (!cert) {
ALOGE("X509_new failed");
return nullptr;
}
ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
// Set one hour expiration.
X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
X509_set_pubkey(cert.get(), privkey);
if (!X509_sign(cert.get(), privkey, EVP_sha256())) {
ALOGE("X509_sign failed");
return nullptr;
}
return cert;
}
}
namespace test {
bool DnsTlsFrontend::startServer() {
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
ctx_.reset(SSL_CTX_new(TLS_server_method()));
if (!ctx_) {
ALOGE("SSL context creation failed");
return false;
}
SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
bssl::UniquePtr<EVP_PKEY> key(make_private_key());
bssl::UniquePtr<X509> cert(make_cert(key.get()));
if (SSL_CTX_use_certificate(ctx_.get(), cert.get()) <= 0) {
ALOGE("SSL_CTX_use_certificate failed");
return false;
}
if (!getSPKIDigest(cert.get(), &fingerprint_)) {
ALOGE("getSPKIDigest failed");
return false;
}
if (SSL_CTX_use_PrivateKey(ctx_.get(), key.get()) <= 0 ) {
ALOGE("SSL_CTX_use_PrivateKey failed");
return false;
}
// Set up TCP server socket for clients.
addrinfo frontend_ai_hints{
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM,
.ai_flags = AI_PASSIVE
};
addrinfo* frontend_ai_res;
int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
&frontend_ai_hints, &frontend_ai_res);
if (rv) {
ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
listen_service_.c_str(), gai_strerror(rv));
return false;
}
int s = -1;
for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
if (s < 0) continue;
const int one = 1;
setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
APLOGI("bind failed for socket %d", s);
close(s);
s = -1;
continue;
}
std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
break;
}
freeaddrinfo(frontend_ai_res);
if (s < 0) {
ALOGE("server socket creation failed");
return false;
}
if (listen(s, 1) < 0) {
ALOGE("listen failed");
return false;
}
socket_ = s;
// Set up UDP client socket to backend.
addrinfo backend_ai_hints{
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_DGRAM
};
addrinfo* backend_ai_res;
rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
&backend_ai_hints, &backend_ai_res);
if (rv) {
ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
listen_service_.c_str(), gai_strerror(rv));
return false;
}
backend_socket_ = socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
backend_ai_res->ai_protocol);
if (backend_socket_ < 0) {
ALOGE("backend socket creation failed");
return false;
}
connect(backend_socket_, backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
freeaddrinfo(backend_ai_res);
{
std::lock_guard<std::mutex> lock(update_mutex_);
handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
}
ALOGI("server started successfully");
return true;
}
void DnsTlsFrontend::requestHandler() {
ALOGD("Request handler started");
struct pollfd fds[1] = {{ .fd = socket_, .events = POLLIN }};
while (!terminate_) {
int poll_code = poll(fds, 1, 10 /* ms */);
if (poll_code == 0) {
// Timeout. Poll again.
continue;
} else if (poll_code < 0) {
ALOGW("Poll failed with error %d", poll_code);
// Error.
break;
}
sockaddr_storage addr;
socklen_t len = sizeof(addr);
ALOGD("Trying to accept a client");
int client = accept(socket_, reinterpret_cast<sockaddr*>(&addr), &len);
ALOGD("Got client socket %d", client);
if (client < 0) {
// Stop
break;
}
bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
SSL_set_fd(ssl.get(), client);
ALOGD("Doing SSL handshake");
bool success = false;
if (SSL_accept(ssl.get()) <= 0) {
ALOGI("SSL negotiation failure");
} else {
ALOGD("SSL handshake complete");
success = handleOneRequest(ssl.get());
}
close(client);
if (success) {
// Increment queries_ as late as possible, because it represents
// a query that is fully processed, and the response returned to the
// client, including cleanup actions.
++queries_;
}
}
ALOGD("Request handler terminating");
}
bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
uint8_t queryHeader[2];
if (SSL_read(ssl, &queryHeader, 2) != 2) {
ALOGI("Not enough header bytes");
return false;
}
const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
uint8_t query[qlen];
if (SSL_read(ssl, &query, qlen) != qlen) {
ALOGI("Not enough query bytes");
return false;
}
int sent = send(backend_socket_, query, qlen, 0);
if (sent != qlen) {
ALOGI("Failed to send query");
return false;
}
const int max_size = 4096;
uint8_t recv_buffer[max_size];
int rlen = recv(backend_socket_, recv_buffer, max_size, 0);
if (rlen <= 0) {
ALOGI("Failed to receive response");
return false;
}
uint8_t responseHeader[2];
responseHeader[0] = rlen >> 8;
responseHeader[1] = rlen;
if (SSL_write(ssl, responseHeader, 2) != 2) {
ALOGI("Failed to write response header");
return false;
}
if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
ALOGI("Failed to write response body");
return false;
}
return true;
}
bool DnsTlsFrontend::stopServer() {
std::lock_guard<std::mutex> lock(update_mutex_);
if (!running()) {
ALOGI("server not running");
return false;
}
if (terminate_) {
ALOGI("LOGIC ERROR");
return false;
}
ALOGI("stopping frontend");
terminate_ = true;
handler_thread_.join();
close(socket_);
close(backend_socket_);
terminate_ = false;
socket_ = -1;
backend_socket_ = -1;
ctx_.reset();
fingerprint_.clear();
ALOGI("frontend stopped successfully");
return true;
}
bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
constexpr int intervalMs = 20;
int limit = timeoutMs / intervalMs;
for (int count = 0; count <= limit; ++count) {
bool done = queries_ >= number;
// Always sleep at least one more interval after we are done, to wait for
// any immediate post-query actions that the client may take (such as
// marking this server as reachable during validation).
usleep(intervalMs * 1000);
if (done) {
// For ensuring that calls have sufficient headroom for slow machines
ALOGD("Query arrived in %d/%d of allotted time", count, limit);
return true;
}
}
return false;
}
} // namespace test

View file

@ -0,0 +1,84 @@
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless requied by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#ifndef DNS_TLS_FRONTEND_H
#define DNS_TLS_FRONTEND_H
#include <arpa/nameser.h>
#include <atomic>
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
#include <android-base/thread_annotations.h>
#include <openssl/ssl.h>
namespace test {
/*
* Simple DNS over TLS reverse proxy that forwards to a UDP backend.
* Only handles a single request at a time.
*/
class DnsTlsFrontend {
public:
DnsTlsFrontend(const std::string& listen_address, const std::string& listen_service,
const std::string& backend_address, const std::string& backend_service) :
listen_address_(listen_address), listen_service_(listen_service),
backend_address_(backend_address), backend_service_(backend_service),
queries_(0), terminate_(false) { }
~DnsTlsFrontend() {
stopServer();
}
const std::string& listen_address() const {
return listen_address_;
}
const std::string& listen_service() const {
return listen_service_;
}
bool running() const {
return socket_ != -1;
}
bool startServer();
bool stopServer();
int queries() const { return queries_; }
bool waitForQueries(int number, int timeoutMs) const;
const std::vector<uint8_t>& fingerprint() const { return fingerprint_; }
private:
void requestHandler();
bool handleOneRequest(SSL* ssl);
std::string listen_address_;
std::string listen_service_;
std::string backend_address_;
std::string backend_service_;
bssl::UniquePtr<SSL_CTX> ctx_;
int socket_ = -1;
int backend_socket_ = -1;
std::atomic<int> queries_;
std::atomic<bool> terminate_ GUARDED_BY(update_mutex_);
std::thread handler_thread_ GUARDED_BY(update_mutex_);
std::mutex update_mutex_;
std::vector<uint8_t> fingerprint_;
};
} // namespace test
#endif // DNS_TLS_FRONTEND_H