Source code for openquake.commonlib.parallel

# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
#
# Copyright (C) 2010-2016 GEM Foundation
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake. If not, see <http://www.gnu.org/licenses/>.

"""
TODO: write documentation.
"""
from __future__ import print_function
from functools import reduce
import os
import sys
import socket
import inspect
import logging
import operator
import traceback
from concurrent.futures import as_completed, ProcessPoolExecutor
from decorator import FunctionMaker

from openquake.baselib.python3compat import pickle
from openquake.baselib.performance import Monitor, virtual_memory
from openquake.baselib.general import split_in_blocks, AccumDict, humansize
from openquake.hazardlib.gsim.base import GroundShakingIntensityModel

executor = ProcessPoolExecutor()
# the num_tasks_hint is chosen to be 2 times bigger than the name of
# cores; it is a heuristic number to get a good distribution;
# it has no more significance than that
executor.num_tasks_hint = executor._max_workers * 2

OQ_DISTRIBUTE = os.environ.get('OQ_DISTRIBUTE', 'futures').lower()

if OQ_DISTRIBUTE == 'celery':
    from celery.result import ResultSet
    from celery import Celery
    from celery.task import task
    from openquake.engine.celeryconfig import BROKER_URL
    app = Celery('openquake', backend='amqp://', broker=BROKER_URL)


[docs]def oq_distribute(): """ Return the current value of the variable OQ_DISTRIBUTE; if undefined, return 'futures'. """ return os.environ.get('OQ_DISTRIBUTE', 'futures').lower()
[docs]def no_distribute(): """ True if the variable OQ_DISTRIBUTE is "no" """ return oq_distribute() == 'no'
[docs]def check_mem_usage(monitor=Monitor(), soft_percent=90, hard_percent=100): """ Display a warning if we are running out of memory :param int mem_percent: the memory limit as a percentage """ used_mem_percent = virtual_memory().percent if used_mem_percent > hard_percent: raise MemoryError('Using more memory than allowed by configuration ' '(Used: %d%% / Allowed: %d%%)! Shutting down.' % (used_mem_percent, hard_percent)) elif used_mem_percent > soft_percent: hostname = socket.gethostname() monitor.send('warn', 'Using over %d%% of the memory in %s!', used_mem_percent, hostname)
[docs]def safely_call(func, args, pickle=False): """ Call the given function with the given arguments safely, i.e. by trapping the exceptions. Return a pair (result, exc_type) where exc_type is None if no exceptions occur, otherwise it is the exception class and the result is a string containing error message and traceback. :param func: the function to call :param args: the arguments :param pickle: if set, the input arguments are unpickled and the return value is pickled; otherwise they are left unchanged """ if pickle: args = [a.unpickle() for a in args] ismon = args and isinstance(args[-1], Monitor) mon = args[-1] if ismon else Monitor() try: got = func(*args) if inspect.isgenerator(got): got = list(got) res = got, None, mon except: etype, exc, tb = sys.exc_info() tb_str = ''.join(traceback.format_tb(tb)) res = ('\n%s%s: %s' % (tb_str, etype.__name__, exc), etype, mon) if pickle: return Pickled(res) return res
[docs]def log_percent_gen(taskname, todo, progress): """ Generator factory. Each time the generator object is called log a message if the percentage is bigger than the last one. Yield the number of calls done at the current iteration. :param str taskname: the name of the task :param int todo: the number of times the generator object will be called :param progress: a logging function for the progress report """ yield 0 done = 1 prev_percent = 0 while done < todo: percent = int(float(done) / todo * 100) if percent > prev_percent: progress('%s %3d%%', taskname, percent) prev_percent = percent yield done done += 1 progress('%s 100%%', taskname) yield done
[docs]class Pickled(object): """ An utility to manually pickling/unpickling objects. The reason is that celery does not use the HIGHEST_PROTOCOL, so relying on celery is slower. Moreover Pickled instances have a nice string representation and length giving the size of the pickled bytestring. :param obj: the object to pickle """ def __init__(self, obj): self.clsname = obj.__class__.__name__ self.calc_id = str(getattr(obj, 'calc_id', '')) # for monitors self.pik = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) def __repr__(self): """String representation of the pickled object""" return '<Pickled %s %s %s>' % ( self.clsname, self.calc_id, humansize(len(self))) def __len__(self): """Length of the pickled bytestring""" return len(self.pik)
[docs] def unpickle(self): """Unpickle the underlying object""" return pickle.loads(self.pik)
[docs]def get_pickled_sizes(obj): """ Return the pickled sizes of an object and its direct attributes, ordered by decreasing size. Here is an example: >> total_size, partial_sizes = get_pickled_sizes(Monitor('')) >> total_size 345 >> partial_sizes [('_procs', 214), ('exc', 4), ('mem', 4), ('start_time', 4), ('_start_time', 4), ('duration', 4)] Notice that the sizes depend on the operating system and the machine. """ sizes = [] attrs = getattr(obj, '__dict__', {}) for name, value in attrs.items(): sizes.append((name, len(Pickled(value)))) return len(Pickled(obj)), sorted( sizes, key=lambda pair: pair[1], reverse=True)
[docs]def pickle_sequence(objects): """ Convert an iterable of objects into a list of pickled objects. If the iterable contains copies, the pickling will be done only once. If the iterable contains objects already pickled, they will not be pickled again. :param objects: a sequence of objects to pickle """ cache = {} out = [] for obj in objects: obj_id = id(obj) if obj_id not in cache: if isinstance(obj, Pickled): # already pickled cache[obj_id] = obj else: # pickle the object cache[obj_id] = Pickled(obj) out.append(cache[obj_id]) return out
[docs]class TaskManager(object): """ A manager to submit several tasks of the same type. The usage is:: tm = TaskManager(do_something, logging.info) tm.send(arg1, arg2) tm.send(arg3, arg4) print tm.reduce() Progress report is built-in. """ executor = executor progress = staticmethod(logging.info) task_ids = [] @classmethod
[docs] def restart(cls): cls.executor.shutdown() cls.executor = ProcessPoolExecutor()
@classmethod
[docs] def starmap(cls, task, task_args, name=None): """ Spawn a bunch of tasks with the given list of arguments :returns: a TaskManager object with a .result method. """ self = cls(task, name) for i, a in enumerate(task_args, 1): cls.progress('Submitting task %s #%d', self.name, i) if isinstance(a[-1], Monitor): # add incremental task number a[-1].task_no = i self.submit(*a) return self
@classmethod
[docs] def apply_reduce(cls, task, task_args, agg=operator.add, acc=None, concurrent_tasks=executor._max_workers, weight=lambda item: 1, key=lambda item: 'Unspecified', name=None): """ Apply a task to a tuple of the form (sequence, \*other_args) by first splitting the sequence in chunks, according to the weight of the elements and possibly to a key (see :function: `openquake.baselib.general.split_in_blocks`). Then reduce the results with an aggregation function. The chunks which are generated internally can be seen directly ( useful for debugging purposes) by looking at the attribute `._chunks`, right after the `apply_reduce` function has been called. :param task: a task to run in parallel :param task_args: the arguments to be passed to the task function :param agg: the aggregation function :param acc: initial value of the accumulator (default empty AccumDict) :param concurrent_tasks: hint about how many tasks to generate :param weight: function to extract the weight of an item in arg0 :param key: function to extract the kind of an item in arg0 """ arg0 = task_args[0] # this is assumed to be a sequence args = task_args[1:] task_func = getattr(task, 'task_func', task) if acc is None: acc = AccumDict() if len(arg0) == 0: # nothing to do return acc chunks = list(split_in_blocks( arg0, concurrent_tasks or 1, weight, key)) cls.apply_reduce.__func__._chunks = chunks if not concurrent_tasks or no_distribute() or len(chunks) == 1: # apply the function in the master process for i, chunk in enumerate(chunks): if args and hasattr(args[-1], 'flush'): # is monitor args[-1].task_no = i acc = agg(acc, task_func(chunk, *args)) return acc logging.info('Starting %d tasks', len(chunks)) self = cls.starmap(task, [(chunk,) + args for chunk in chunks], name) return self.reduce(agg, acc)
def __init__(self, oqtask, name=None): self.oqtask = oqtask self.task_func = getattr(oqtask, 'task_func', oqtask) self.name = name or oqtask.__name__ self.results = [] self.sent = AccumDict() self.received = [] self.no_distribute = no_distribute() self.argnames = inspect.getargspec(self.task_func).args
[docs] def submit(self, *args): """ Submit a function with the given arguments to the process pool and add a Future to the list `.results`. If the variable OQ_DISTRIBUTE is set, the function is run in process and the result is returned. """ check_mem_usage() # log a warning if too much memory is used if self.no_distribute: sent = {} res = (self.task_func(*args), None, args[-1]) else: piks = pickle_sequence(args) sent = {arg: len(p) for arg, p in zip(self.argnames, piks)} res = self._submit(piks) self.sent += sent self.results.append(res) return sent
def _submit(self, piks): # submit tasks by using the ProcessPoolExecutor if self.oqtask is self.task_func: return self.executor.submit( safely_call, self.task_func, piks, True) elif OQ_DISTRIBUTE == 'futures': # call the decorated task return self.executor.submit(self.oqtask, *piks) elif OQ_DISTRIBUTE == 'celery': res = self.oqtask.delay(*piks) self.task_ids.append(res.task_id) return res
[docs] def aggregate_result_set(self, agg, acc): """ Loop on a set results and update the accumulator by using the aggregation function. :param agg: the aggregation function, (acc, val) -> new acc :param acc: the initial value of the accumulator :returns: the final value of the accumulator """ if not self.results: return acc distribute = oq_distribute() # not called for distribute == 'no' if distribute == 'celery': rset = ResultSet(self.results) for task_id, result_dict in rset.iter_native(): idx = self.task_ids.index(task_id) self.task_ids.pop(idx) check_mem_usage() # warn if too much memory is used result = result_dict['result'] if isinstance(result, BaseException): raise result self.received.append(len(result)) acc = agg(acc, result.unpickle()) # work around a celery bug del app.backend._cache[task_id] return acc elif distribute == 'futures': for future in as_completed(self.results): check_mem_usage() # log a warning if too much memory is used result = future.result() if isinstance(result, BaseException): raise result self.received.append(len(result)) acc = agg(acc, result.unpickle()) return acc
[docs] def reduce(self, agg=operator.add, acc=None): """ Loop on a set of results and update the accumulator by using the aggregation function. :param agg: the aggregation function, (acc, val) -> new acc :param acc: the initial value of the accumulator :returns: the final value of the accumulator """ if acc is None: acc = AccumDict() num_tasks = len(self.results) if num_tasks == 0: logging.warn('No tasks were submitted') return acc log_percent = log_percent_gen(self.name, num_tasks, self.progress) next(log_percent) def agg_and_percent(acc, triple): (val, exc, mon) = triple if exc: raise RuntimeError(val) res = agg(acc, val) next(log_percent) mon.flush() return res if self.no_distribute: agg_result = reduce(agg_and_percent, self.results, acc) else: self.progress('Sent %s of data in %d task(s)', humansize(sum(self.sent.values())), num_tasks) agg_result = self.aggregate_result_set(agg_and_percent, acc) self.progress('Received %s of data, maximum per task %s', humansize(sum(self.received)), humansize(max(self.received))) self.results = [] return agg_result
[docs] def wait(self): """ Wait until all the task terminate. Discard the results. :returns: the total number of tasks that were spawned """ return self.reduce(self, lambda acc, res: acc + 1, 0)
def __iter__(self): """ An iterator over the results """ return iter(self.results)
# convenient aliases starmap = TaskManager.starmap apply_reduce = TaskManager.apply_reduce
[docs]def do_not_aggregate(acc, value): """ Do nothing aggregation function, use it in :class:`openquake.commonlib.parallel.apply_reduce` calls when no aggregation is required. :param acc: the accumulator :param value: the value to accumulate :returns: the accumulator unchanged """ return acc
[docs]class NoFlush(object): # this is instantiated by the litetask decorator def __init__(self, monitor, taskname): self.monitor = monitor self.taskname = taskname def __call__(self): raise RuntimeError('Monitor(%r).flush() must not be called ' 'by %s!' % (self.monitor.operation, self.taskname))
[docs]def rec_delattr(mon, name): """ Delete attribute from a monitor recursively """ for child in mon.children: rec_delattr(child, name) if name in vars(mon): delattr(mon, name)
[docs]def litetask_futures(func): """ Add monitoring support to the decorated function. The last argument must be a monitor object. """ def wrapper(*args): monitor = args[-1] check_mem_usage(monitor) # check if too much memory is used monitor.flush = NoFlush(monitor, func.__name__) with monitor('total ' + func.__name__, measuremem=True), \ GroundShakingIntensityModel.forbid_instantiation(): result = func(*args) # NB: flush must not be called in the workers - they must not # have access to the datastore - so we remove it rec_delattr(monitor, 'flush') return result # NB: the returned function must have the same signature of func; # we need pickle=True because celery is using the worst possible # protocol; once we remove celery we can try to remove pickle=True return FunctionMaker.create( func, 'return _s_(_w_, (%(shortsignature)s,), pickle=True)', dict(_s_=safely_call, _w_=wrapper), task_func=func)
if OQ_DISTRIBUTE == 'celery': def litetask_celery(task_func): """ Wrapper around celery.task """ tsk = task(litetask_futures(task_func), queue='celery') tsk.__func__ = tsk tsk.task_func = task_func return tsk litetask = litetask_celery else: litetask = litetask_futures