Source code for openquake.engine.utils.tasks

# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright (c) 2010-2014, GEM Foundation.
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake.  If not, see <http://www.gnu.org/licenses/>.

"""Utility functions related to splitting work into tasks."""

import operator

from celery.result import ResultSet
from celery.app import current_app
from celery.task import task

from openquake.baselib.general import split_in_blocks, AccumDict
from openquake.commonlib.parallel import \
    TaskManager, safely_call, check_mem_usage, pickle_sequence, no_distribute
from openquake.engine import logs
from openquake.engine.db import models
from openquake.engine.utils import config
from openquake.engine.writer import CacheInserter
from openquake.engine.performance import EnginePerformanceMonitor

CONCURRENT_TASKS = int(config.get('celery', 'concurrent_tasks'))
SOFT_MEM_LIMIT = int(config.get('memory', 'soft_mem_limit'))
HARD_MEM_LIMIT = int(config.get('memory', 'hard_mem_limit'))


[docs]class JobNotRunning(Exception): pass
[docs]class OqTaskManager(TaskManager): """ A celery-based task manager. The usage is:: oqm = OqTaskManager(do_something, logs.LOG.progress) oqm.send(arg1, arg2) oqm.send(arg3, arg4) print oqm.aggregate_results(agg, acc) Progress report is built-in. """
[docs] def submit(self, *args): """ Submit an oqtask with the given arguments to celery and return an AsyncResult. If the variable OQ_NO_DISTRIBUTE is set, the task function is run in process and the result is returned. """ # log a warning if too much memory is used check_mem_usage(SOFT_MEM_LIMIT, HARD_MEM_LIMIT) if no_distribute(): res = safely_call(self.oqtask.task_func, args) else: piks = pickle_sequence(args) self.sent += sum(len(p) for p in piks) res = self.oqtask.delay(*piks) self.results.append(res)
[docs] def aggregate_result_set(self, agg, acc): """ Loop on a set of celery AsyncResults and update the accumulator by using the aggregation function. :param agg: the aggregation function, (acc, val) -> new acc :param acc: the initial value of the accumulator :returns: the final value of the accumulator """ if not self.results: return acc backend = current_app().backend rset = ResultSet(self.results) for task_id, result_dict in rset.iter_native(): # log a warning if too much memory is used check_mem_usage(SOFT_MEM_LIMIT, HARD_MEM_LIMIT) result = result_dict['result'] if isinstance(result, BaseException): raise result self.received += len(result) acc = agg(acc, result.unpickle()) del backend._cache[task_id] # work around a celery bug return acc
# a convenient alias starmap = OqTaskManager.starmap
[docs]def apply_reduce(task, task_args, agg=operator.add, acc=None, concurrent_tasks=CONCURRENT_TASKS, weight=lambda item: 1, key=lambda item: 'Unspecified', name=None): """ Apply a task to a tuple of the form (job_id, data, *args) by splitting the data in chunks and reduce the results with an aggregation function. :param task: an oqtask :param task_args: the arguments to be passed to the task function :param agg: the aggregation function :param acc: initial value of the accumulator :param concurrent_tasks: hint about how many tasks to generate :param weight: function to extract the weight of an item in data :param key: function to extract the kind of an item in data """ if acc is None: acc = AccumDict() job_id = task_args[0] data = task_args[1] args = task_args[2:] if not data: return acc elif len(data) == 1 or not concurrent_tasks: return agg(acc, task.task_func(job_id, data, *args)) blocks = split_in_blocks(data, concurrent_tasks, weight, key) task_args = [(job_id, block) + args for block in blocks] return starmap(task, task_args, logs.LOG.progress, name).reduce(agg, acc)
[docs]def oqtask(task_func): """ Task function decorator which sets up logging and catches (and logs) any errors which occur inside the task. Also checks to make sure the job is actually still running. If it is not running, the task doesn't get executed, so we don't do useless computation. :param task_func: the function to decorate """ def wrapped(*args): """ Initialize logs, make sure the job is still running, and run the task code surrounded by a try-except. If any error occurs, log it as a critical failure. """ # job_id is always assumed to be the first argument job_id = args[0] job = models.OqJob.objects.get(id=job_id) if job.is_running is False: # the job was killed, it is useless to run the task raise JobNotRunning(job_id) # it is important to save the task id soon, so that # the revoke functionality can work EnginePerformanceMonitor.store_task_id(job_id, tsk) with EnginePerformanceMonitor( 'total ' + task_func.__name__, job_id, tsk, flush=True): # tasks write on the celery log file logs.set_level(job.log_level) try: # log a warning if too much memory is used check_mem_usage(SOFT_MEM_LIMIT, HARD_MEM_LIMIT) # run the task return task_func(*args) finally: # save on the db CacheInserter.flushall() # the task finished, we can remove from the performance # table the associated row 'storing task id' models.Performance.objects.filter( oq_job=job, operation='storing task id', task_id=tsk.request.id).delete() celery_queue = config.get('amqp', 'celery_queue') f = lambda *args: safely_call(wrapped, args, pickle=True) f.__name__ = task_func.__name__ tsk = task(f, queue=celery_queue) tsk.task_func = task_func return tsk