1323 lines
51 KiB
Python
1323 lines
51 KiB
Python
#pylint: disable-msg=C0111
|
|
"""\
|
|
Utility functions for rpc_interface.py. We keep them in a separate file so that
|
|
only RPC interface functions go into that file.
|
|
"""
|
|
|
|
__author__ = 'showard@google.com (Steve Howard)'
|
|
|
|
import datetime
|
|
from functools import wraps
|
|
import inspect
|
|
import os
|
|
import sys
|
|
import django.db.utils
|
|
import django.http
|
|
|
|
from autotest_lib.frontend import thread_local
|
|
from autotest_lib.frontend.afe import models, model_logic
|
|
from autotest_lib.client.common_lib import control_data, error
|
|
from autotest_lib.client.common_lib import global_config, priorities
|
|
from autotest_lib.client.common_lib import time_utils
|
|
from autotest_lib.client.common_lib.cros.graphite import autotest_stats
|
|
from autotest_lib.server import utils as server_utils
|
|
from autotest_lib.server.cros import provision
|
|
from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
|
|
|
|
NULL_DATETIME = datetime.datetime.max
|
|
NULL_DATE = datetime.date.max
|
|
DUPLICATE_KEY_MSG = 'Duplicate entry'
|
|
|
|
def prepare_for_serialization(objects):
|
|
"""
|
|
Prepare Python objects to be returned via RPC.
|
|
@param objects: objects to be prepared.
|
|
"""
|
|
if (isinstance(objects, list) and len(objects) and
|
|
isinstance(objects[0], dict) and 'id' in objects[0]):
|
|
objects = gather_unique_dicts(objects)
|
|
return _prepare_data(objects)
|
|
|
|
|
|
def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
|
|
"""
|
|
Prepare a Django query to be returned via RPC as a sequence of nested
|
|
dictionaries.
|
|
|
|
@param query - A Django model query object with a select_related() method.
|
|
@param nested_dict_column_names - A list of column/attribute names for the
|
|
rows returned by query to expand into nested dictionaries using
|
|
their get_object_dict() method when not None.
|
|
|
|
@returns An list suitable to returned in an RPC.
|
|
"""
|
|
all_dicts = []
|
|
for row in query.select_related():
|
|
row_dict = row.get_object_dict()
|
|
for column in nested_dict_column_names:
|
|
if row_dict[column] is not None:
|
|
row_dict[column] = getattr(row, column).get_object_dict()
|
|
all_dicts.append(row_dict)
|
|
return prepare_for_serialization(all_dicts)
|
|
|
|
|
|
def _prepare_data(data):
|
|
"""
|
|
Recursively process data structures, performing necessary type
|
|
conversions to values in data to allow for RPC serialization:
|
|
-convert datetimes to strings
|
|
-convert tuples and sets to lists
|
|
"""
|
|
if isinstance(data, dict):
|
|
new_data = {}
|
|
for key, value in data.iteritems():
|
|
new_data[key] = _prepare_data(value)
|
|
return new_data
|
|
elif (isinstance(data, list) or isinstance(data, tuple) or
|
|
isinstance(data, set)):
|
|
return [_prepare_data(item) for item in data]
|
|
elif isinstance(data, datetime.date):
|
|
if data is NULL_DATETIME or data is NULL_DATE:
|
|
return None
|
|
return str(data)
|
|
else:
|
|
return data
|
|
|
|
|
|
def fetchall_as_list_of_dicts(cursor):
|
|
"""
|
|
Converts each row in the cursor to a dictionary so that values can be read
|
|
by using the column name.
|
|
@param cursor: The database cursor to read from.
|
|
@returns: A list of each row in the cursor as a dictionary.
|
|
"""
|
|
desc = cursor.description
|
|
return [ dict(zip([col[0] for col in desc], row))
|
|
for row in cursor.fetchall() ]
|
|
|
|
|
|
def raw_http_response(response_data, content_type=None):
|
|
response = django.http.HttpResponse(response_data, mimetype=content_type)
|
|
response['Content-length'] = str(len(response.content))
|
|
return response
|
|
|
|
|
|
def gather_unique_dicts(dict_iterable):
|
|
"""\
|
|
Pick out unique objects (by ID) from an iterable of object dicts.
|
|
"""
|
|
id_set = set()
|
|
result = []
|
|
for obj in dict_iterable:
|
|
if obj['id'] not in id_set:
|
|
id_set.add(obj['id'])
|
|
result.append(obj)
|
|
return result
|
|
|
|
|
|
def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
|
|
"""\
|
|
Generate a SQL WHERE clause for job status filtering, and return it in
|
|
a dict of keyword args to pass to query.extra().
|
|
* not_yet_run: all HQEs are Queued
|
|
* finished: all HQEs are complete
|
|
* running: everything else
|
|
"""
|
|
if not (not_yet_run or running or finished):
|
|
return {}
|
|
not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
|
|
'WHERE status != "%s")'
|
|
% models.HostQueueEntry.Status.QUEUED)
|
|
not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
|
|
'WHERE not complete)')
|
|
|
|
where = []
|
|
if not_yet_run:
|
|
where.append('id NOT IN ' + not_queued)
|
|
if running:
|
|
where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
|
|
if finished:
|
|
where.append('id NOT IN ' + not_finished)
|
|
return {'where': [' OR '.join(['(%s)' % x for x in where])]}
|
|
|
|
|
|
def extra_job_type_filters(extra_args, suite=False,
|
|
sub=False, standalone=False):
|
|
"""\
|
|
Generate a SQL WHERE clause for job status filtering, and return it in
|
|
a dict of keyword args to pass to query.extra().
|
|
|
|
param extra_args: a dict of existing extra_args.
|
|
|
|
No more than one of the parameters should be passed as True:
|
|
* suite: job which is parent of other jobs
|
|
* sub: job with a parent job
|
|
* standalone: job with no child or parent jobs
|
|
"""
|
|
assert not ((suite and sub) or
|
|
(suite and standalone) or
|
|
(sub and standalone)), ('Cannot specify more than one '
|
|
'filter to this function')
|
|
|
|
where = extra_args.get('where', [])
|
|
parent_job_id = ('DISTINCT parent_job_id')
|
|
child_job_id = ('id')
|
|
filter_common = ('(SELECT %s FROM afe_jobs '
|
|
'WHERE parent_job_id IS NOT NULL)')
|
|
|
|
if suite:
|
|
where.append('id IN ' + filter_common % parent_job_id)
|
|
elif sub:
|
|
where.append('id IN ' + filter_common % child_job_id)
|
|
elif standalone:
|
|
where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
|
|
'WHERE parent_job_id IS NOT NULL'
|
|
' AND (sub_query.parent_job_id=afe_jobs.id'
|
|
' OR sub_query.id=afe_jobs.id))')
|
|
else:
|
|
return extra_args
|
|
|
|
extra_args['where'] = where
|
|
return extra_args
|
|
|
|
|
|
|
|
def extra_host_filters(multiple_labels=()):
|
|
"""\
|
|
Generate SQL WHERE clauses for matching hosts in an intersection of
|
|
labels.
|
|
"""
|
|
extra_args = {}
|
|
where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels '
|
|
'where label_id=%s)')
|
|
extra_args['where'] = [where_str] * len(multiple_labels)
|
|
extra_args['params'] = [models.Label.smart_get(label).id
|
|
for label in multiple_labels]
|
|
return extra_args
|
|
|
|
|
|
def get_host_query(multiple_labels, exclude_only_if_needed_labels,
|
|
exclude_atomic_group_hosts, valid_only, filter_data):
|
|
if valid_only:
|
|
query = models.Host.valid_objects.all()
|
|
else:
|
|
query = models.Host.objects.all()
|
|
|
|
if exclude_only_if_needed_labels:
|
|
only_if_needed_labels = models.Label.valid_objects.filter(
|
|
only_if_needed=True)
|
|
if only_if_needed_labels.count() > 0:
|
|
only_if_needed_ids = ','.join(
|
|
str(label['id'])
|
|
for label in only_if_needed_labels.values('id'))
|
|
query = models.Host.objects.add_join(
|
|
query, 'afe_hosts_labels', join_key='host_id',
|
|
join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)'
|
|
% only_if_needed_ids),
|
|
suffix='_exclude_OIN', exclude=True)
|
|
|
|
if exclude_atomic_group_hosts:
|
|
atomic_group_labels = models.Label.valid_objects.filter(
|
|
atomic_group__isnull=False)
|
|
if atomic_group_labels.count() > 0:
|
|
atomic_group_label_ids = ','.join(
|
|
str(atomic_group['id'])
|
|
for atomic_group in atomic_group_labels.values('id'))
|
|
query = models.Host.objects.add_join(
|
|
query, 'afe_hosts_labels', join_key='host_id',
|
|
join_condition=(
|
|
'afe_hosts_labels_exclude_AG.label_id IN (%s)'
|
|
% atomic_group_label_ids),
|
|
suffix='_exclude_AG', exclude=True)
|
|
try:
|
|
assert 'extra_args' not in filter_data
|
|
filter_data['extra_args'] = extra_host_filters(multiple_labels)
|
|
return models.Host.query_objects(filter_data, initial_query=query)
|
|
except models.Label.DoesNotExist as e:
|
|
return models.Host.objects.none()
|
|
|
|
|
|
class InconsistencyException(Exception):
|
|
'Raised when a list of objects does not have a consistent value'
|
|
|
|
|
|
def get_consistent_value(objects, field):
|
|
if not objects:
|
|
# well a list of nothing is consistent
|
|
return None
|
|
|
|
value = getattr(objects[0], field)
|
|
for obj in objects:
|
|
this_value = getattr(obj, field)
|
|
if this_value != value:
|
|
raise InconsistencyException(objects[0], obj)
|
|
return value
|
|
|
|
|
|
def afe_test_dict_to_test_object(test_dict):
|
|
if not isinstance(test_dict, dict):
|
|
return test_dict
|
|
|
|
numerized_dict = {}
|
|
for key, value in test_dict.iteritems():
|
|
try:
|
|
numerized_dict[key] = int(value)
|
|
except (ValueError, TypeError):
|
|
numerized_dict[key] = value
|
|
|
|
return type('TestObject', (object,), numerized_dict)
|
|
|
|
|
|
def prepare_generate_control_file(tests, kernel, label, profilers,
|
|
db_tests=True):
|
|
if db_tests:
|
|
test_objects = [models.Test.smart_get(test) for test in tests]
|
|
else:
|
|
test_objects = [afe_test_dict_to_test_object(test) for test in tests]
|
|
|
|
profiler_objects = [models.Profiler.smart_get(profiler)
|
|
for profiler in profilers]
|
|
# ensure tests are all the same type
|
|
try:
|
|
test_type = get_consistent_value(test_objects, 'test_type')
|
|
except InconsistencyException, exc:
|
|
test1, test2 = exc.args
|
|
raise model_logic.ValidationError(
|
|
{'tests' : 'You cannot run both test_suites and server-side '
|
|
'tests together (tests %s and %s differ' % (
|
|
test1.name, test2.name)})
|
|
|
|
is_server = (test_type == control_data.CONTROL_TYPE.SERVER)
|
|
if test_objects:
|
|
synch_count = max(test.sync_count for test in test_objects)
|
|
else:
|
|
synch_count = 1
|
|
if label:
|
|
label = models.Label.smart_get(label)
|
|
|
|
if db_tests:
|
|
dependencies = set(label.name for label
|
|
in models.Label.objects.filter(test__in=test_objects))
|
|
else:
|
|
dependencies = reduce(
|
|
set.union, [set(test.dependencies) for test in test_objects])
|
|
|
|
cf_info = dict(is_server=is_server, synch_count=synch_count,
|
|
dependencies=list(dependencies))
|
|
return cf_info, test_objects, profiler_objects, label
|
|
|
|
|
|
def check_job_dependencies(host_objects, job_dependencies):
|
|
"""
|
|
Check that a set of machines satisfies a job's dependencies.
|
|
host_objects: list of models.Host objects
|
|
job_dependencies: list of names of labels
|
|
"""
|
|
# check that hosts satisfy dependencies
|
|
host_ids = [host.id for host in host_objects]
|
|
hosts_in_job = models.Host.objects.filter(id__in=host_ids)
|
|
ok_hosts = hosts_in_job
|
|
for index, dependency in enumerate(job_dependencies):
|
|
if not provision.is_for_special_action(dependency):
|
|
ok_hosts = ok_hosts.filter(labels__name=dependency)
|
|
failing_hosts = (set(host.hostname for host in host_objects) -
|
|
set(host.hostname for host in ok_hosts))
|
|
if failing_hosts:
|
|
raise model_logic.ValidationError(
|
|
{'hosts' : 'Host(s) failed to meet job dependencies (' +
|
|
(', '.join(job_dependencies)) + '): ' +
|
|
(', '.join(failing_hosts))})
|
|
|
|
|
|
def check_job_metahost_dependencies(metahost_objects, job_dependencies):
|
|
"""
|
|
Check that at least one machine within the metahost spec satisfies the job's
|
|
dependencies.
|
|
|
|
@param metahost_objects A list of label objects representing the metahosts.
|
|
@param job_dependencies A list of strings of the required label names.
|
|
@raises NoEligibleHostException If a metahost cannot run the job.
|
|
"""
|
|
for metahost in metahost_objects:
|
|
hosts = models.Host.objects.filter(labels=metahost)
|
|
for label_name in job_dependencies:
|
|
if not provision.is_for_special_action(label_name):
|
|
hosts = hosts.filter(labels__name=label_name)
|
|
if not any(hosts):
|
|
raise error.NoEligibleHostException("No hosts within %s satisfy %s."
|
|
% (metahost.name, ', '.join(job_dependencies)))
|
|
|
|
|
|
def _execution_key_for(host_queue_entry):
|
|
return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
|
|
|
|
|
|
def check_abort_synchronous_jobs(host_queue_entries):
|
|
# ensure user isn't aborting part of a synchronous autoserv execution
|
|
count_per_execution = {}
|
|
for queue_entry in host_queue_entries:
|
|
key = _execution_key_for(queue_entry)
|
|
count_per_execution.setdefault(key, 0)
|
|
count_per_execution[key] += 1
|
|
|
|
for queue_entry in host_queue_entries:
|
|
if not queue_entry.execution_subdir:
|
|
continue
|
|
execution_count = count_per_execution[_execution_key_for(queue_entry)]
|
|
if execution_count < queue_entry.job.synch_count:
|
|
raise model_logic.ValidationError(
|
|
{'' : 'You cannot abort part of a synchronous job execution '
|
|
'(%d/%s), %d included, %d expected'
|
|
% (queue_entry.job.id, queue_entry.execution_subdir,
|
|
execution_count, queue_entry.job.synch_count)})
|
|
|
|
|
|
def check_atomic_group_create_job(synch_count, host_objects, metahost_objects,
|
|
dependencies, atomic_group):
|
|
"""
|
|
Attempt to reject create_job requests with an atomic group that
|
|
will be impossible to schedule. The checks are not perfect but
|
|
should catch the most obvious issues.
|
|
|
|
@param synch_count - The job's minimum synch count.
|
|
@param host_objects - A list of models.Host instances.
|
|
@param metahost_objects - A list of models.Label instances.
|
|
@param dependencies - A list of job dependency label names.
|
|
@param labels_by_name - A dictionary mapping label names to models.Label
|
|
instance. Used to look up instances for dependencies.
|
|
|
|
@raises model_logic.ValidationError - When an issue is found.
|
|
"""
|
|
# If specific host objects were supplied with an atomic group, verify
|
|
# that there are enough to satisfy the synch_count.
|
|
minimum_required = synch_count or 1
|
|
if (host_objects and not metahost_objects and
|
|
len(host_objects) < minimum_required):
|
|
raise model_logic.ValidationError(
|
|
{'hosts':
|
|
'only %d hosts provided for job with synch_count = %d' %
|
|
(len(host_objects), synch_count)})
|
|
|
|
# Check that the atomic group has a hope of running this job
|
|
# given any supplied metahosts and dependancies that may limit.
|
|
|
|
# Get a set of hostnames in the atomic group.
|
|
possible_hosts = set()
|
|
for label in atomic_group.label_set.all():
|
|
possible_hosts.update(h.hostname for h in label.host_set.all())
|
|
|
|
# Filter out hosts that don't match all of the job dependency labels.
|
|
for label in models.Label.objects.filter(name__in=dependencies):
|
|
hosts_in_label = (h.hostname for h in label.host_set.all())
|
|
possible_hosts.intersection_update(hosts_in_label)
|
|
|
|
if not host_objects and not metahost_objects:
|
|
# No hosts or metahosts are required to queue an atomic group Job.
|
|
# However, if they are given, we respect them below.
|
|
host_set = possible_hosts
|
|
else:
|
|
host_set = set(host.hostname for host in host_objects)
|
|
unusable_host_set = host_set.difference(possible_hosts)
|
|
if unusable_host_set:
|
|
raise model_logic.ValidationError(
|
|
{'hosts': 'Hosts "%s" are not in Atomic Group "%s"' %
|
|
(', '.join(sorted(unusable_host_set)), atomic_group.name)})
|
|
|
|
# Lookup hosts provided by each meta host and merge them into the
|
|
# host_set for final counting.
|
|
for meta_host in metahost_objects:
|
|
meta_possible = possible_hosts.copy()
|
|
hosts_in_meta_host = (h.hostname for h in meta_host.host_set.all())
|
|
meta_possible.intersection_update(hosts_in_meta_host)
|
|
|
|
# Count all hosts that this meta_host will provide.
|
|
host_set.update(meta_possible)
|
|
|
|
if len(host_set) < minimum_required:
|
|
raise model_logic.ValidationError(
|
|
{'atomic_group_name':
|
|
'Insufficient hosts in Atomic Group "%s" with the'
|
|
' supplied dependencies and meta_hosts.' %
|
|
(atomic_group.name,)})
|
|
|
|
|
|
def check_modify_host(update_data):
|
|
"""
|
|
Sanity check modify_host* requests.
|
|
|
|
@param update_data: A dictionary with the changes to make to a host
|
|
or hosts.
|
|
"""
|
|
# Only the scheduler (monitor_db) is allowed to modify Host status.
|
|
# Otherwise race conditions happen as a hosts state is changed out from
|
|
# beneath tasks being run on a host.
|
|
if 'status' in update_data:
|
|
raise model_logic.ValidationError({
|
|
'status': 'Host status can not be modified by the frontend.'})
|
|
|
|
|
|
def check_modify_host_locking(host, update_data):
|
|
"""
|
|
Checks when locking/unlocking has been requested if the host is already
|
|
locked/unlocked.
|
|
|
|
@param host: models.Host object to be modified
|
|
@param update_data: A dictionary with the changes to make to the host.
|
|
"""
|
|
locked = update_data.get('locked', None)
|
|
lock_reason = update_data.get('lock_reason', None)
|
|
if locked is not None:
|
|
if locked and host.locked:
|
|
raise model_logic.ValidationError({
|
|
'locked': 'Host %s already locked by %s on %s.' %
|
|
(host.hostname, host.locked_by, host.lock_time)})
|
|
if not locked and not host.locked:
|
|
raise model_logic.ValidationError({
|
|
'locked': 'Host %s already unlocked.' % host.hostname})
|
|
if locked and not lock_reason and not host.locked:
|
|
raise model_logic.ValidationError({
|
|
'locked': 'Please provide a reason for locking Host %s' %
|
|
host.hostname})
|
|
|
|
|
|
def get_motd():
|
|
dirname = os.path.dirname(__file__)
|
|
filename = os.path.join(dirname, "..", "..", "motd.txt")
|
|
text = ''
|
|
try:
|
|
fp = open(filename, "r")
|
|
try:
|
|
text = fp.read()
|
|
finally:
|
|
fp.close()
|
|
except:
|
|
pass
|
|
|
|
return text
|
|
|
|
|
|
def _get_metahost_counts(metahost_objects):
|
|
metahost_counts = {}
|
|
for metahost in metahost_objects:
|
|
metahost_counts.setdefault(metahost, 0)
|
|
metahost_counts[metahost] += 1
|
|
return metahost_counts
|
|
|
|
|
|
def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
|
|
hosts = []
|
|
one_time_hosts = []
|
|
meta_hosts = []
|
|
atomic_group = None
|
|
hostless = False
|
|
|
|
queue_entries = job.hostqueueentry_set.all()
|
|
if queue_entry_filter_data:
|
|
queue_entries = models.HostQueueEntry.query_objects(
|
|
queue_entry_filter_data, initial_query=queue_entries)
|
|
|
|
for queue_entry in queue_entries:
|
|
if (queue_entry.host and (preserve_metahosts or
|
|
not queue_entry.meta_host)):
|
|
if queue_entry.deleted:
|
|
continue
|
|
if queue_entry.host.invalid:
|
|
one_time_hosts.append(queue_entry.host)
|
|
else:
|
|
hosts.append(queue_entry.host)
|
|
elif queue_entry.meta_host:
|
|
meta_hosts.append(queue_entry.meta_host)
|
|
else:
|
|
hostless = True
|
|
|
|
if atomic_group is None:
|
|
if queue_entry.atomic_group is not None:
|
|
atomic_group = queue_entry.atomic_group
|
|
else:
|
|
assert atomic_group.name == queue_entry.atomic_group.name, (
|
|
'DB inconsistency. HostQueueEntries with multiple atomic'
|
|
' groups on job %s: %s != %s' % (
|
|
id, atomic_group.name, queue_entry.atomic_group.name))
|
|
|
|
meta_host_counts = _get_metahost_counts(meta_hosts)
|
|
|
|
info = dict(dependencies=[label.name for label
|
|
in job.dependency_labels.all()],
|
|
hosts=hosts,
|
|
meta_hosts=meta_hosts,
|
|
meta_host_counts=meta_host_counts,
|
|
one_time_hosts=one_time_hosts,
|
|
atomic_group=atomic_group,
|
|
hostless=hostless)
|
|
return info
|
|
|
|
|
|
def check_for_duplicate_hosts(host_objects):
|
|
host_ids = set()
|
|
duplicate_hostnames = set()
|
|
for host in host_objects:
|
|
if host.id in host_ids:
|
|
duplicate_hostnames.add(host.hostname)
|
|
host_ids.add(host.id)
|
|
|
|
if duplicate_hostnames:
|
|
raise model_logic.ValidationError(
|
|
{'hosts' : 'Duplicate hosts: %s'
|
|
% ', '.join(duplicate_hostnames)})
|
|
|
|
|
|
def create_new_job(owner, options, host_objects, metahost_objects,
|
|
atomic_group=None):
|
|
all_host_objects = host_objects + metahost_objects
|
|
dependencies = options.get('dependencies', [])
|
|
synch_count = options.get('synch_count')
|
|
|
|
if atomic_group:
|
|
check_atomic_group_create_job(
|
|
synch_count, host_objects, metahost_objects,
|
|
dependencies, atomic_group)
|
|
else:
|
|
if synch_count is not None and synch_count > len(all_host_objects):
|
|
raise model_logic.ValidationError(
|
|
{'hosts':
|
|
'only %d hosts provided for job with synch_count = %d' %
|
|
(len(all_host_objects), synch_count)})
|
|
atomic_hosts = models.Host.objects.filter(
|
|
id__in=[host.id for host in host_objects],
|
|
labels__atomic_group=True)
|
|
unusable_host_names = [host.hostname for host in atomic_hosts]
|
|
if unusable_host_names:
|
|
raise model_logic.ValidationError(
|
|
{'hosts':
|
|
'Host(s) "%s" are atomic group hosts but no '
|
|
'atomic group was specified for this job.' %
|
|
(', '.join(unusable_host_names),)})
|
|
|
|
check_for_duplicate_hosts(host_objects)
|
|
|
|
for label_name in dependencies:
|
|
if provision.is_for_special_action(label_name):
|
|
# TODO: We could save a few queries
|
|
# if we had a bulk ensure-label-exists function, which used
|
|
# a bulk .get() call. The win is probably very small.
|
|
_ensure_label_exists(label_name)
|
|
|
|
# This only checks targeted hosts, not hosts eligible due to the metahost
|
|
check_job_dependencies(host_objects, dependencies)
|
|
check_job_metahost_dependencies(metahost_objects, dependencies)
|
|
|
|
options['dependencies'] = list(
|
|
models.Label.objects.filter(name__in=dependencies))
|
|
|
|
for label in metahost_objects + options['dependencies']:
|
|
if label.atomic_group and not atomic_group:
|
|
raise model_logic.ValidationError(
|
|
{'atomic_group_name':
|
|
'Dependency %r requires an atomic group but no '
|
|
'atomic_group_name or meta_host in an atomic group was '
|
|
'specified for this job.' % label.name})
|
|
elif (label.atomic_group and
|
|
label.atomic_group.name != atomic_group.name):
|
|
raise model_logic.ValidationError(
|
|
{'atomic_group_name':
|
|
'meta_hosts or dependency %r requires atomic group '
|
|
'%r instead of the supplied atomic_group_name=%r.' %
|
|
(label.name, label.atomic_group.name, atomic_group.name)})
|
|
|
|
job = models.Job.create(owner=owner, options=options,
|
|
hosts=all_host_objects)
|
|
job.queue(all_host_objects, atomic_group=atomic_group,
|
|
is_template=options.get('is_template', False))
|
|
return job.id
|
|
|
|
|
|
def _ensure_label_exists(name):
|
|
"""
|
|
Ensure that a label called |name| exists in the Django models.
|
|
|
|
This function is to be called from within afe rpcs only, as an
|
|
alternative to server.cros.provision.ensure_label_exists(...). It works
|
|
by Django model manipulation, rather than by making another create_label
|
|
rpc call.
|
|
|
|
@param name: the label to check for/create.
|
|
@raises ValidationError: There was an error in the response that was
|
|
not because the label already existed.
|
|
@returns True is a label was created, False otherwise.
|
|
"""
|
|
# Make sure this function is not called on shards but only on master.
|
|
assert not server_utils.is_shard()
|
|
try:
|
|
models.Label.objects.get(name=name)
|
|
except models.Label.DoesNotExist:
|
|
try:
|
|
new_label = models.Label.objects.create(name=name)
|
|
new_label.save()
|
|
return True
|
|
except django.db.utils.IntegrityError as e:
|
|
# It is possible that another suite/test already
|
|
# created the label between the check and save.
|
|
if DUPLICATE_KEY_MSG in str(e):
|
|
return False
|
|
else:
|
|
raise
|
|
return False
|
|
|
|
|
|
def find_platform_and_atomic_group(host):
|
|
"""
|
|
Figure out the platform name and atomic group name for the given host
|
|
object. If none, the return value for either will be None.
|
|
|
|
@returns (platform name, atomic group name) for the given host.
|
|
"""
|
|
platforms = [label.name for label in host.label_list if label.platform]
|
|
if not platforms:
|
|
platform = None
|
|
else:
|
|
platform = platforms[0]
|
|
if len(platforms) > 1:
|
|
raise ValueError('Host %s has more than one platform: %s' %
|
|
(host.hostname, ', '.join(platforms)))
|
|
for label in host.label_list:
|
|
if label.atomic_group:
|
|
atomic_group_name = label.atomic_group.name
|
|
break
|
|
else:
|
|
atomic_group_name = None
|
|
# Don't check for multiple atomic groups on a host here. That is an
|
|
# error but should not trip up the RPC interface. monitor_db_cleanup
|
|
# deals with it. This just returns the first one found.
|
|
return platform, atomic_group_name
|
|
|
|
|
|
# support for get_host_queue_entries_and_special_tasks()
|
|
|
|
def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
|
|
return dict(type=type,
|
|
host=entry['host'],
|
|
job=job_dict,
|
|
execution_path=exec_path,
|
|
status=status,
|
|
started_on=started_on,
|
|
id=str(entry['id']) + type,
|
|
oid=entry['id'])
|
|
|
|
|
|
def _special_task_to_dict(task, queue_entries):
|
|
"""Transforms a special task dictionary to another form of dictionary.
|
|
|
|
@param task Special task as a dictionary type
|
|
@param queue_entries Host queue entries as a list of dictionaries.
|
|
|
|
@return Transformed dictionary for a special task.
|
|
"""
|
|
job_dict = None
|
|
if task['queue_entry']:
|
|
# Scan queue_entries to get the job detail info.
|
|
for qentry in queue_entries:
|
|
if task['queue_entry']['id'] == qentry['id']:
|
|
job_dict = qentry['job']
|
|
break
|
|
# If not found, get it from DB.
|
|
if job_dict is None:
|
|
job = models.Job.objects.get(id=task['queue_entry']['job'])
|
|
job_dict = job.get_object_dict()
|
|
|
|
exec_path = server_utils.get_special_task_exec_path(
|
|
task['host']['hostname'], task['id'], task['task'],
|
|
time_utils.time_string_to_datetime(task['time_requested']))
|
|
status = server_utils.get_special_task_status(
|
|
task['is_complete'], task['success'], task['is_active'])
|
|
return _common_entry_to_dict(task, task['task'], job_dict,
|
|
exec_path, status, task['time_started'])
|
|
|
|
|
|
def _queue_entry_to_dict(queue_entry):
|
|
job_dict = queue_entry['job']
|
|
tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
|
|
exec_path = server_utils.get_hqe_exec_path(tag,
|
|
queue_entry['execution_subdir'])
|
|
return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
|
|
queue_entry['status'], queue_entry['started_on'])
|
|
|
|
|
|
def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
|
|
queue_entries):
|
|
"""
|
|
Prepare for serialization the interleaved entries of host queue entries
|
|
and special tasks.
|
|
Each element in the entries is a dictionary type.
|
|
The special task dictionary has only a job id for a job and lacks
|
|
the detail of the job while the host queue entry dictionary has.
|
|
queue_entries is used to look up the job detail info.
|
|
|
|
@param interleaved_entries Host queue entries and special tasks as a list
|
|
of dictionaries.
|
|
@param queue_entries Host queue entries as a list of dictionaries.
|
|
|
|
@return A post-processed list of dictionaries that is to be serialized.
|
|
"""
|
|
dict_list = []
|
|
for e in interleaved_entries:
|
|
# Distinguish the two mixed entries based on the existence of
|
|
# the key "task". If an entry has the key, the entry is for
|
|
# special task. Otherwise, host queue entry.
|
|
if 'task' in e:
|
|
dict_list.append(_special_task_to_dict(e, queue_entries))
|
|
else:
|
|
dict_list.append(_queue_entry_to_dict(e))
|
|
return prepare_for_serialization(dict_list)
|
|
|
|
|
|
def _compute_next_job_for_tasks(queue_entries, special_tasks):
|
|
"""
|
|
For each task, try to figure out the next job that ran after that task.
|
|
This is done using two pieces of information:
|
|
* if the task has a queue entry, we can use that entry's job ID.
|
|
* if the task has a time_started, we can try to compare that against the
|
|
started_on field of queue_entries. this isn't guaranteed to work perfectly
|
|
since queue_entries may also have null started_on values.
|
|
* if the task has neither, or if use of time_started fails, just use the
|
|
last computed job ID.
|
|
|
|
@param queue_entries Host queue entries as a list of dictionaries.
|
|
@param special_tasks Special tasks as a list of dictionaries.
|
|
"""
|
|
next_job_id = None # most recently computed next job
|
|
hqe_index = 0 # index for scanning by started_on times
|
|
for task in special_tasks:
|
|
if task['queue_entry']:
|
|
next_job_id = task['queue_entry']['job']
|
|
elif task['time_started'] is not None:
|
|
for queue_entry in queue_entries[hqe_index:]:
|
|
if queue_entry['started_on'] is None:
|
|
continue
|
|
t1 = time_utils.time_string_to_datetime(
|
|
queue_entry['started_on'])
|
|
t2 = time_utils.time_string_to_datetime(task['time_started'])
|
|
if t1 < t2:
|
|
break
|
|
next_job_id = queue_entry['job']['id']
|
|
|
|
task['next_job_id'] = next_job_id
|
|
|
|
# advance hqe_index to just after next_job_id
|
|
if next_job_id is not None:
|
|
for queue_entry in queue_entries[hqe_index:]:
|
|
if queue_entry['job']['id'] < next_job_id:
|
|
break
|
|
hqe_index += 1
|
|
|
|
|
|
def interleave_entries(queue_entries, special_tasks):
|
|
"""
|
|
Both lists should be ordered by descending ID.
|
|
"""
|
|
_compute_next_job_for_tasks(queue_entries, special_tasks)
|
|
|
|
# start with all special tasks that've run since the last job
|
|
interleaved_entries = []
|
|
for task in special_tasks:
|
|
if task['next_job_id'] is not None:
|
|
break
|
|
interleaved_entries.append(task)
|
|
|
|
# now interleave queue entries with the remaining special tasks
|
|
special_task_index = len(interleaved_entries)
|
|
for queue_entry in queue_entries:
|
|
interleaved_entries.append(queue_entry)
|
|
# add all tasks that ran between this job and the previous one
|
|
for task in special_tasks[special_task_index:]:
|
|
if task['next_job_id'] < queue_entry['job']['id']:
|
|
break
|
|
interleaved_entries.append(task)
|
|
special_task_index += 1
|
|
|
|
return interleaved_entries
|
|
|
|
|
|
def bucket_hosts_by_shard(host_objs, rpc_hostnames=False):
|
|
"""Figure out which hosts are on which shards.
|
|
|
|
@param host_objs: A list of host objects.
|
|
@param rpc_hostnames: If True, the rpc_hostnames of a shard are returned
|
|
instead of the 'real' shard hostnames. This only matters for testing
|
|
environments.
|
|
|
|
@return: A map of shard hostname: list of hosts on the shard.
|
|
"""
|
|
shard_host_map = {}
|
|
for host in host_objs:
|
|
if host.shard:
|
|
shard_name = (host.shard.rpc_hostname() if rpc_hostnames
|
|
else host.shard.hostname)
|
|
shard_host_map.setdefault(shard_name, []).append(host.hostname)
|
|
return shard_host_map
|
|
|
|
|
|
def get_create_job_common_args(local_args):
|
|
"""
|
|
Returns a dict containing only the args that apply for create_job_common
|
|
|
|
Returns a subset of local_args, which contains only the arguments that can
|
|
be passed in to create_job_common().
|
|
"""
|
|
# This code is only here to not kill suites scheduling tests when priority
|
|
# becomes an int instead of a string.
|
|
if isinstance(local_args['priority'], str):
|
|
local_args['priority'] = priorities.Priority.DEFAULT
|
|
# </migration hack>
|
|
arg_names, _, _, _ = inspect.getargspec(create_job_common)
|
|
return dict(item for item in local_args.iteritems() if item[0] in arg_names)
|
|
|
|
|
|
def create_job_common(name, priority, control_type, control_file=None,
|
|
hosts=(), meta_hosts=(), one_time_hosts=(),
|
|
atomic_group_name=None, synch_count=None,
|
|
is_template=False, timeout=None, timeout_mins=None,
|
|
max_runtime_mins=None, run_verify=True, email_list='',
|
|
dependencies=(), reboot_before=None, reboot_after=None,
|
|
parse_failed_repair=None, hostless=False, keyvals=None,
|
|
drone_set=None, parameterized_job=None,
|
|
parent_job_id=None, test_retry=0, run_reset=True,
|
|
require_ssp=None):
|
|
#pylint: disable-msg=C0111
|
|
"""
|
|
Common code between creating "standard" jobs and creating parameterized jobs
|
|
"""
|
|
user = models.User.current_user()
|
|
owner = user.login
|
|
|
|
# input validation
|
|
if not (hosts or meta_hosts or one_time_hosts or atomic_group_name
|
|
or hostless):
|
|
raise model_logic.ValidationError({
|
|
'arguments' : "You must pass at least one of 'hosts', "
|
|
"'meta_hosts', 'one_time_hosts', "
|
|
"'atomic_group_name', or 'hostless'"
|
|
})
|
|
|
|
if hostless:
|
|
if hosts or meta_hosts or one_time_hosts or atomic_group_name:
|
|
raise model_logic.ValidationError({
|
|
'hostless': 'Hostless jobs cannot include any hosts!'})
|
|
server_type = control_data.CONTROL_TYPE_NAMES.SERVER
|
|
if control_type != server_type:
|
|
raise model_logic.ValidationError({
|
|
'control_type': 'Hostless jobs cannot use client-side '
|
|
'control files'})
|
|
|
|
atomic_groups_by_name = dict((ag.name, ag)
|
|
for ag in models.AtomicGroup.objects.all())
|
|
label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
|
|
|
|
# Schedule on an atomic group automagically if one of the labels given
|
|
# is an atomic group label and no explicit atomic_group_name was supplied.
|
|
if not atomic_group_name:
|
|
for label in label_objects:
|
|
if label and label.atomic_group:
|
|
atomic_group_name = label.atomic_group.name
|
|
break
|
|
# convert hostnames & meta hosts to host/label objects
|
|
host_objects = models.Host.smart_get_bulk(hosts)
|
|
if not server_utils.is_shard():
|
|
shard_host_map = bucket_hosts_by_shard(host_objects)
|
|
num_shards = len(shard_host_map)
|
|
if (num_shards > 1 or (num_shards == 1 and
|
|
len(shard_host_map.values()[0]) != len(host_objects))):
|
|
# We disallow the following jobs on master:
|
|
# num_shards > 1: this is a job spanning across multiple shards.
|
|
# num_shards == 1 but number of hosts on shard is less
|
|
# than total number of hosts: this is a job that spans across
|
|
# one shard and the master.
|
|
raise ValueError(
|
|
'The following hosts are on shard(s), please create '
|
|
'seperate jobs for hosts on each shard: %s ' %
|
|
shard_host_map)
|
|
metahost_objects = []
|
|
meta_host_labels_by_name = {label.name: label for label in label_objects}
|
|
for label_name in meta_hosts or []:
|
|
if label_name in meta_host_labels_by_name:
|
|
metahost_objects.append(meta_host_labels_by_name[label_name])
|
|
elif label_name in atomic_groups_by_name:
|
|
# If given a metahost name that isn't a Label, check to
|
|
# see if the user was specifying an Atomic Group instead.
|
|
atomic_group = atomic_groups_by_name[label_name]
|
|
if atomic_group_name and atomic_group_name != atomic_group.name:
|
|
raise model_logic.ValidationError({
|
|
'meta_hosts': (
|
|
'Label "%s" not found. If assumed to be an '
|
|
'atomic group it would conflict with the '
|
|
'supplied atomic group "%s".' % (
|
|
label_name, atomic_group_name))})
|
|
atomic_group_name = atomic_group.name
|
|
else:
|
|
raise model_logic.ValidationError(
|
|
{'meta_hosts' : 'Label "%s" not found' % label_name})
|
|
|
|
# Create and sanity check an AtomicGroup object if requested.
|
|
if atomic_group_name:
|
|
if one_time_hosts:
|
|
raise model_logic.ValidationError(
|
|
{'one_time_hosts':
|
|
'One time hosts cannot be used with an Atomic Group.'})
|
|
atomic_group = models.AtomicGroup.smart_get(atomic_group_name)
|
|
if synch_count and synch_count > atomic_group.max_number_of_machines:
|
|
raise model_logic.ValidationError(
|
|
{'atomic_group_name' :
|
|
'You have requested a synch_count (%d) greater than the '
|
|
'maximum machines in the requested Atomic Group (%d).' %
|
|
(synch_count, atomic_group.max_number_of_machines)})
|
|
else:
|
|
atomic_group = None
|
|
|
|
for host in one_time_hosts or []:
|
|
this_host = models.Host.create_one_time_host(host)
|
|
host_objects.append(this_host)
|
|
|
|
options = dict(name=name,
|
|
priority=priority,
|
|
control_file=control_file,
|
|
control_type=control_type,
|
|
is_template=is_template,
|
|
timeout=timeout,
|
|
timeout_mins=timeout_mins,
|
|
max_runtime_mins=max_runtime_mins,
|
|
synch_count=synch_count,
|
|
run_verify=run_verify,
|
|
email_list=email_list,
|
|
dependencies=dependencies,
|
|
reboot_before=reboot_before,
|
|
reboot_after=reboot_after,
|
|
parse_failed_repair=parse_failed_repair,
|
|
keyvals=keyvals,
|
|
drone_set=drone_set,
|
|
parameterized_job=parameterized_job,
|
|
parent_job_id=parent_job_id,
|
|
test_retry=test_retry,
|
|
run_reset=run_reset,
|
|
require_ssp=require_ssp)
|
|
return create_new_job(owner=owner,
|
|
options=options,
|
|
host_objects=host_objects,
|
|
metahost_objects=metahost_objects,
|
|
atomic_group=atomic_group)
|
|
|
|
|
|
def encode_ascii(control_file):
|
|
"""Force a control file to only contain ascii characters.
|
|
|
|
@param control_file: Control file to encode.
|
|
|
|
@returns the control file in an ascii encoding.
|
|
|
|
@raises error.ControlFileMalformed: if encoding fails.
|
|
"""
|
|
try:
|
|
return control_file.encode('ascii')
|
|
except UnicodeDecodeError as e:
|
|
raise error.ControlFileMalformed(str(e))
|
|
|
|
|
|
def get_wmatrix_url():
|
|
"""Get wmatrix url from config file.
|
|
|
|
@returns the wmatrix url or an empty string.
|
|
"""
|
|
return global_config.global_config.get_config_value('AUTOTEST_WEB',
|
|
'wmatrix_url',
|
|
default='')
|
|
|
|
|
|
def inject_times_to_filter(start_time_key=None, end_time_key=None,
|
|
start_time_value=None, end_time_value=None,
|
|
**filter_data):
|
|
"""Inject the key value pairs of start and end time if provided.
|
|
|
|
@param start_time_key: A string represents the filter key of start_time.
|
|
@param end_time_key: A string represents the filter key of end_time.
|
|
@param start_time_value: Start_time value.
|
|
@param end_time_value: End_time value.
|
|
|
|
@returns the injected filter_data.
|
|
"""
|
|
if start_time_value:
|
|
filter_data[start_time_key] = start_time_value
|
|
if end_time_value:
|
|
filter_data[end_time_key] = end_time_value
|
|
return filter_data
|
|
|
|
|
|
def inject_times_to_hqe_special_tasks_filters(filter_data_common,
|
|
start_time, end_time):
|
|
"""Inject start and end time to hqe and special tasks filters.
|
|
|
|
@param filter_data_common: Common filter for hqe and special tasks.
|
|
@param start_time_key: A string represents the filter key of start_time.
|
|
@param end_time_key: A string represents the filter key of end_time.
|
|
|
|
@returns a pair of hqe and special tasks filters.
|
|
"""
|
|
filter_data_special_tasks = filter_data_common.copy()
|
|
return (inject_times_to_filter('started_on__gte', 'started_on__lte',
|
|
start_time, end_time, **filter_data_common),
|
|
inject_times_to_filter('time_started__gte', 'time_started__lte',
|
|
start_time, end_time,
|
|
**filter_data_special_tasks))
|
|
|
|
|
|
def retrieve_shard(shard_hostname):
|
|
"""
|
|
Retrieves the shard with the given hostname from the database.
|
|
|
|
@param shard_hostname: Hostname of the shard to retrieve
|
|
|
|
@raises models.Shard.DoesNotExist, if no shard with this hostname was found.
|
|
|
|
@returns: Shard object
|
|
"""
|
|
timer = autotest_stats.Timer('shard_heartbeat.retrieve_shard')
|
|
with timer:
|
|
return models.Shard.smart_get(shard_hostname)
|
|
|
|
|
|
def find_records_for_shard(shard, known_job_ids, known_host_ids):
|
|
"""Find records that should be sent to a shard.
|
|
|
|
@param shard: Shard to find records for.
|
|
@param known_job_ids: List of ids of jobs the shard already has.
|
|
@param known_host_ids: List of ids of hosts the shard already has.
|
|
|
|
@returns: Tuple of three lists for hosts, jobs, and suite job keyvals:
|
|
(hosts, jobs, suite_job_keyvals).
|
|
"""
|
|
timer = autotest_stats.Timer('shard_heartbeat')
|
|
with timer.get_client('find_hosts'):
|
|
hosts = models.Host.assign_to_shard(shard, known_host_ids)
|
|
with timer.get_client('find_jobs'):
|
|
jobs = models.Job.assign_to_shard(shard, known_job_ids)
|
|
with timer.get_client('find_suite_job_keyvals'):
|
|
parent_job_ids = [job.parent_job_id for job in jobs]
|
|
suite_job_keyvals = models.JobKeyval.objects.filter(
|
|
job_id__in=parent_job_ids)
|
|
return hosts, jobs, suite_job_keyvals
|
|
|
|
|
|
def _persist_records_with_type_sent_from_shard(
|
|
shard, records, record_type, *args, **kwargs):
|
|
"""
|
|
Handle records of a specified type that were sent to the shard master.
|
|
|
|
@param shard: The shard the records were sent from.
|
|
@param records: The records sent in their serialized format.
|
|
@param record_type: Type of the objects represented by records.
|
|
@param args: Additional arguments that will be passed on to the sanity
|
|
checks.
|
|
@param kwargs: Additional arguments that will be passed on to the sanity
|
|
checks.
|
|
|
|
@raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
|
|
|
|
@returns: List of primary keys of the processed records.
|
|
"""
|
|
pks = []
|
|
for serialized_record in records:
|
|
pk = serialized_record['id']
|
|
try:
|
|
current_record = record_type.objects.get(pk=pk)
|
|
except record_type.DoesNotExist:
|
|
raise error.UnallowedRecordsSentToMaster(
|
|
'Object with pk %s of type %s does not exist on master.' % (
|
|
pk, record_type))
|
|
|
|
current_record.sanity_check_update_from_shard(
|
|
shard, serialized_record, *args, **kwargs)
|
|
|
|
current_record.update_from_serialized(serialized_record)
|
|
pks.append(pk)
|
|
return pks
|
|
|
|
|
|
def persist_records_sent_from_shard(shard, jobs, hqes):
|
|
"""
|
|
Sanity checking then saving serialized records sent to master from shard.
|
|
|
|
During heartbeats shards upload jobs and hostqueuentries. This performs
|
|
some sanity checks on these and then updates the existing records for those
|
|
entries with the updated ones from the heartbeat.
|
|
|
|
The sanity checks include:
|
|
- Checking if the objects sent already exist on the master.
|
|
- Checking if the objects sent were assigned to this shard.
|
|
- hostqueueentries must be sent together with their jobs.
|
|
|
|
@param shard: The shard the records were sent from.
|
|
@param jobs: The jobs the shard sent.
|
|
@param hqes: The hostqueuentries the shart sent.
|
|
|
|
@raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
|
|
"""
|
|
timer = autotest_stats.Timer('shard_heartbeat')
|
|
with timer.get_client('persist_jobs'):
|
|
job_ids_sent = _persist_records_with_type_sent_from_shard(
|
|
shard, jobs, models.Job)
|
|
|
|
with timer.get_client('persist_hqes'):
|
|
_persist_records_with_type_sent_from_shard(
|
|
shard, hqes, models.HostQueueEntry, job_ids_sent=job_ids_sent)
|
|
|
|
|
|
def forward_single_host_rpc_to_shard(func):
|
|
"""This decorator forwards rpc calls that modify a host to a shard.
|
|
|
|
If a host is assigned to a shard, rpcs that change his attributes should be
|
|
forwarded to the shard.
|
|
|
|
This assumes the first argument of the function represents a host id.
|
|
|
|
@param func: The function to decorate
|
|
|
|
@returns: The function to replace func with.
|
|
"""
|
|
def replacement(**kwargs):
|
|
# Only keyword arguments can be accepted here, as we need the argument
|
|
# names to send the rpc. serviceHandler always provides arguments with
|
|
# their keywords, so this is not a problem.
|
|
|
|
# A host record (identified by kwargs['id']) can be deleted in
|
|
# func(). Therefore, we should save the data that can be needed later
|
|
# before func() is called.
|
|
shard_hostname = None
|
|
host = models.Host.smart_get(kwargs['id'])
|
|
if host and host.shard:
|
|
shard_hostname = host.shard.rpc_hostname()
|
|
ret = func(**kwargs)
|
|
if shard_hostname and not server_utils.is_shard():
|
|
run_rpc_on_multiple_hostnames(func.func_name,
|
|
[shard_hostname],
|
|
**kwargs)
|
|
return ret
|
|
|
|
return replacement
|
|
|
|
|
|
def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
|
|
"""Fanout the given rpc to shards of given hosts.
|
|
|
|
@param host_objs: Host objects for the rpc.
|
|
@param rpc_name: The name of the rpc.
|
|
@param include_hostnames: If True, include the hostnames in the kwargs.
|
|
Hostnames are not always necessary, this functions is designed to
|
|
send rpcs to the shard a host is on, the rpcs themselves could be
|
|
related to labels, acls etc.
|
|
@param kwargs: The kwargs for the rpc.
|
|
"""
|
|
# Figure out which hosts are on which shards.
|
|
shard_host_map = bucket_hosts_by_shard(
|
|
host_objs, rpc_hostnames=True)
|
|
|
|
# Execute the rpc against the appropriate shards.
|
|
for shard, hostnames in shard_host_map.iteritems():
|
|
if include_hostnames:
|
|
kwargs['hosts'] = hostnames
|
|
try:
|
|
run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
|
|
except:
|
|
ei = sys.exc_info()
|
|
new_exc = error.RPCException('RPC %s failed on shard %s due to '
|
|
'%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1]))
|
|
raise new_exc.__class__, new_exc, ei[2]
|
|
|
|
|
|
def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
|
|
"""Runs an rpc to multiple AFEs
|
|
|
|
This is i.e. used to propagate changes made to hosts after they are assigned
|
|
to a shard.
|
|
|
|
@param rpc_call: Name of the rpc endpoint to call.
|
|
@param shard_hostnames: List of hostnames to run the rpcs on.
|
|
@param **kwargs: Keyword arguments to pass in the rpcs.
|
|
"""
|
|
# Make sure this function is not called on shards but only on master.
|
|
assert not server_utils.is_shard()
|
|
for shard_hostname in shard_hostnames:
|
|
afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
|
|
user=thread_local.get_user())
|
|
afe.run(rpc_call, **kwargs)
|
|
|
|
|
|
def get_label(name):
|
|
"""Gets a label object using a given name.
|
|
|
|
@param name: Label name.
|
|
@raises model.Label.DoesNotExist: when there is no label matching
|
|
the given name.
|
|
@return: a label object matching the given name.
|
|
"""
|
|
try:
|
|
label = models.Label.smart_get(name)
|
|
except models.Label.DoesNotExist:
|
|
return None
|
|
return label
|
|
|
|
|
|
def route_rpc_to_master(func):
|
|
"""Route RPC to master AFE.
|
|
|
|
When a shard receives an RPC decorated by this, the RPC is just
|
|
forwarded to the master.
|
|
When the master gets the RPC, the RPC function is executed.
|
|
|
|
@param func: An RPC function to decorate
|
|
|
|
@returns: A function replacing the RPC func.
|
|
"""
|
|
@wraps(func)
|
|
def replacement(*args, **kwargs):
|
|
"""
|
|
We need a special care when decorating an RPC that can be called
|
|
directly using positional arguments. One example is
|
|
rpc_interface.create_job().
|
|
rpc_interface.create_job_page_handler() calls the function using
|
|
positional and keyword arguments.
|
|
Since frontend.RpcClient.run() takes only keyword arguments for
|
|
an RPC, positional arguments of the RPC function need to be
|
|
transformed to key-value pair (dictionary type).
|
|
|
|
inspect.getcallargs() is a useful utility to achieve the goal,
|
|
however, we need an additional effort when an RPC function has
|
|
**kwargs argument.
|
|
Let's say we have a following form of RPC function.
|
|
|
|
def rpcfunc(a, b, **kwargs)
|
|
|
|
When we call the function like "rpcfunc(1, 2, id=3, name='mk')",
|
|
inspect.getcallargs() returns a dictionary like below.
|
|
|
|
{'a':1, 'b':2, 'kwargs': {'id':3, 'name':'mk'}}
|
|
|
|
This is an incorrect form of arguments to pass to the rpc function.
|
|
Instead, the dictionary should be like this.
|
|
|
|
{'a':1, 'b':2, 'id':3, 'name':'mk'}
|
|
"""
|
|
argspec = inspect.getargspec(func)
|
|
if argspec.varargs is not None:
|
|
raise Exception('RPC function must not have *args.')
|
|
funcargs = inspect.getcallargs(func, *args, **kwargs)
|
|
kwargs = dict()
|
|
for k, v in funcargs.iteritems():
|
|
if argspec.keywords and k == argspec.keywords:
|
|
kwargs.update(v)
|
|
else:
|
|
kwargs[k] = v
|
|
|
|
if server_utils.is_shard():
|
|
afe = frontend_wrappers.RetryingAFE(
|
|
server=server_utils.get_global_afe_hostname(),
|
|
user=thread_local.get_user())
|
|
return afe.run(func.func_name, **kwargs)
|
|
return func(**kwargs)
|
|
return replacement
|