Source code for openquake.baselib.general

# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright (C) 2014-2016 GEM Foundation
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake.  If not, see <http://www.gnu.org/licenses/>.


"""
Utility functions of general interest.
"""
from __future__ import division, print_function
import os
import sys
import imp
import math
import operator
import warnings
import tempfile
import importlib
import itertools
import subprocess
import collections

import numpy
from decorator import decorator

F64 = numpy.float64


[docs]class WeightedSequence(collections.MutableSequence): """ A wrapper over a sequence of weighted items with a total weight attribute. Adding items automatically increases the weight. """ @classmethod
[docs] def merge(cls, ws_list): """ Merge a set of WeightedSequence objects. :param ws_list: a sequence of :class: `openquake.baselib.general.WeightedSequence` instances :returns: a `openquake.baselib.general.WeightedSequence` instance """ return sum(ws_list, cls())
def __init__(self, seq=()): """ param seq: a finite sequence of pairs (item, weight) """ self._seq = [] self.weight = 0 self.extend(seq) def __getitem__(self, sliceobj): """ Return an item or a slice """ return self._seq[sliceobj] def __setitem__(self, i, v): """ Modify the sequence """ self._seq[i] = v def __delitem__(self, sliceobj): """ Remove an item from the sequence """ del self._seq[sliceobj] def __len__(self): """ The length of the sequence """ return len(self._seq) def __add__(self, other): """ Add two weighted sequences and return a new WeightedSequence with weight equal to the sum of the weights. """ new = self.__class__() new._seq.extend(self._seq) new._seq.extend(other._seq) new.weight = self.weight + other.weight return new
[docs] def insert(self, i, item_weight): """ Insert an item with the given weight in the sequence """ item, weight = item_weight self._seq.insert(i, item) self.weight += weight
def __lt__(self, other): """ Ensure ordering by weight """ return self.weight < other.weight def __eq__(self, other): """ Compare for equality the items contained in self """ return all(x == y for x, y in zip(self, other)) def __repr__(self): """ String representation of the sequence, including the weight """ return '<%s %s, weight=%s>' % (self.__class__.__name__, self._seq, self.weight)
[docs]def distinct(keys): """ Return the distinct keys in order. """ known = set() outlist = [] for key in keys: if key not in known: outlist.append(key) known.add(key) return outlist
[docs]def ceil(a, b): """ Divide a / b and return the biggest integer close to the quotient. :param a: a number :param b: a positive number :returns: the biggest integer close to the quotient """ assert b > 0, b return int(math.ceil(float(a) / b))
[docs]def block_splitter(items, max_weight, weight=lambda item: 1, kind=lambda item: 'Unspecified'): """ :param items: an iterator over items :param max_weight: the max weight to split on :param weight: a function returning the weigth of a given item :param kind: a function returning the kind of a given item Group together items of the same kind until the total weight exceeds the `max_weight` and yield `WeightedSequence` instances. Items with weight zero are ignored. For instance >>> items = 'ABCDE' >>> list(block_splitter(items, 3)) [<WeightedSequence ['A', 'B', 'C'], weight=3>, <WeightedSequence ['D', 'E'], weight=2>] The default weight is 1 for all items. """ if max_weight <= 0: raise ValueError('max_weight=%s' % max_weight) ws = WeightedSequence([]) prev_kind = 'Unspecified' for item in items: w = weight(item) k = kind(item) if w < 0: # error raise ValueError('The item %r got a negative weight %s!' % (item, w)) elif w == 0: # ignore items with 0 weight pass elif ws.weight + w > max_weight or k != prev_kind: new_ws = WeightedSequence([(item, w)]) if ws: yield ws ws = new_ws else: ws.append((item, w)) prev_kind = k if ws: yield ws
[docs]def split_in_blocks(sequence, hint, weight=lambda item: 1, key=lambda item: 'Unspecified'): """ Split the `sequence` in a number of WeightedSequences close to `hint`. :param sequence: a finite sequence of items :param hint: an integer suggesting the number of subsequences to generate :param weight: a function returning the weigth of a given item :param key: a function returning the key of a given item The WeightedSequences are of homogeneous key and they try to be balanced in weight. For instance >>> items = 'ABCDE' >>> list(split_in_blocks(items, 3)) [<WeightedSequence ['A', 'B'], weight=2>, <WeightedSequence ['C', 'D'], weight=2>, <WeightedSequence ['E'], weight=1>] """ if hint == 0: # do not split return sequence items = list(sequence) assert hint > 0, hint assert len(items) > 0, len(items) total_weight = float(sum(weight(item) for item in items)) return block_splitter(items, math.ceil(total_weight / hint), weight, key)
[docs]def assert_close_seq(seq1, seq2, rtol, atol, context=None): """ Compare two sequences of the same length. :param seq1: a sequence :param seq2: another sequence :param rtol: relative tolerance :param atol: absolute tolerance """ assert len(seq1) == len(seq2), 'Lists of different lenghts: %d != %d' % ( len(seq1), len(seq2)) for x, y in zip(seq1, seq2): assert_close(x, y, rtol, atol, context)
[docs]def assert_close(a, b, rtol=1e-07, atol=0, context=None): """ Compare for equality up to a given precision two composite objects which may contain floats. NB: if the objects are or contain generators, they are exhausted. :param a: an object :param b: another object :param rtol: relative tolerance :param atol: absolute tolerance """ if isinstance(a, float) or isinstance(a, numpy.ndarray) and a.shape: # shortcut numpy.testing.assert_allclose(a, b, rtol, atol) return if a == b: # another shortcut return if hasattr(a, '_slots_'): # record-like objects assert_close_seq(a._slots_, b._slots_, rtol, atol, a) for x, y in zip(a._slots_, b._slots_): assert_close(getattr(a, x), getattr(b, y), rtol, atol, x) return if isinstance(a, collections.Mapping): # dict-like objects assert_close_seq(a.keys(), b.keys(), rtol, atol, a) assert_close_seq(a.values(), b.values(), rtol, atol, a) return if hasattr(a, '__iter__'): # iterable objects assert_close_seq(list(a), list(b), rtol, atol, a) return if hasattr(a, '__dict__'): # objects with an attribute dictionary assert_close(vars(a), vars(b), context=a) return ctx = '' if context is None else 'in context ' + repr(context) raise AssertionError('%r != %r %s' % (a, b, ctx))
[docs]def writetmp(content=None, dir=None, prefix="tmp", suffix="tmp"): """Create temporary file with the given content. Please note: the temporary file must be deleted by the caller. :param string content: the content to write to the temporary file. :param string dir: directory where the file should be created :param string prefix: file name prefix :param string suffix: file name suffix :returns: a string with the path to the temporary file """ if dir is not None: if not os.path.exists(dir): os.makedirs(dir) fh, path = tempfile.mkstemp(dir=dir, prefix=prefix, suffix=suffix) if content: fh = os.fdopen(fh, "wb") if hasattr(content, 'encode'): content = content.encode('utf8') fh.write(content) fh.close() return path
[docs]def git_suffix(fname): """ :returns: `<short git hash>` if Git repository found """ try: gh = subprocess.check_output( ['git', 'rev-parse', '--short', 'HEAD'], stderr=open(os.devnull, 'w'), cwd=os.path.dirname(fname)).strip() gh = "-git" + gh if gh else '' return gh except: # trapping everything on purpose; git may not be installed or it # may not work properly return ''
[docs]def run_in_process(code, *args): """ Run in an external process the given Python code and return the output as a Python object. If there are arguments, then code is taken as a template and traditional string interpolation is performed. :param code: string or template describing Python code :param args: arguments to be used for interpolation :returns: the output of the process, as a Python object """ if args: code %= args try: out = subprocess.check_output([sys.executable, '-c', code]) except subprocess.CalledProcessError as exc: print(exc.cmd[-1], file=sys.stderr) raise if out: return eval(out, {}, {})
[docs]class CodeDependencyError(Exception): pass
[docs]def import_all(module_or_package): """ If `module_or_package` is a module, just import it; if it is a package, recursively imports all the modules it contains. Returns the names of the modules that were imported as a set. The set can be empty if the modules were already in sys.modules. """ already_imported = set(sys.modules) mod_or_pkg = importlib.import_module(module_or_package) if not hasattr(mod_or_pkg, '__path__'): # is a simple module return set(sys.modules) - already_imported # else import all modules contained in the package [pkg_path] = mod_or_pkg.__path__ n = len(pkg_path) for cwd, dirs, files in os.walk(pkg_path): if all(os.path.basename(f) != '__init__.py' for f in files): # the current working directory is not a subpackage continue for f in files: if f.endswith('.py'): # convert PKGPATH/subpackage/module.py -> subpackage.module # works at any level of nesting modname = (module_or_package + cwd[n:].replace(os.sep, '.') + '.' + os.path.basename(f[:-3])) try: importlib.import_module(modname) except Exception as exc: print('Could not import %s: %s: %s' % ( modname, exc.__class__.__name__, exc), file=sys.stderr) return set(sys.modules) - already_imported
[docs]def assert_independent(package, *packages): """ :param package: Python name of a module/package :param packages: Python names of modules/packages Make sure the `package` does not depend from the `packages`. """ assert packages, 'At least one package must be specified' import_package = 'from openquake.baselib.general import import_all\n' \ 'print(import_all("%s"))' % package imported_modules = run_in_process(import_package) for mod in imported_modules: for pkg in packages: if mod.startswith(pkg): raise CodeDependencyError('%s depends on %s' % (package, pkg))
[docs]def search_module(module, syspath=sys.path): """ Given a module name (possibly with dots) returns the corresponding filepath, or None, if the module cannot be found. :param module: (dotted) name of the Python module to look for :param syspath: a list of directories to search (default sys.path) """ lst = module.split(".") pkg, submodule = lst[0], ".".join(lst[1:]) try: fileobj, filepath, descr = imp.find_module(pkg, syspath) except ImportError: return if submodule: # recursive search return search_module(submodule, [filepath]) return filepath
[docs]class CallableDict(collections.OrderedDict): r""" A callable object built on top of a dictionary of functions, used as a smart registry or as a poor man generic function dispatching on the first argument. It is typically used to implement converters. Here is an example: >>> format_attrs = CallableDict() # dict of functions (fmt, obj) -> str >>> @format_attrs.add('csv') # implementation for csv ... def format_attrs_csv(fmt, obj): ... items = sorted(vars(obj).items()) ... return '\n'.join('%s,%s' % item for item in items) >>> @format_attrs.add('json') # implementation for json ... def format_attrs_json(fmt, obj): ... return json.dumps(vars(obj)) `format_attrs(fmt, obj)` calls the correct underlying function depending on the `fmt` key. If the format is unknown a `KeyError` is raised. It is also possible to set a `keymissing` function to specify what to return if the key is missing. For a more practical example see the implementation of the exporters in openquake.commonlib.export """ def __init__(self, keyfunc=lambda key: key, keymissing=None): super(CallableDict, self).__init__() self.keyfunc = keyfunc self.keymissing = keymissing
[docs] def add(self, *keys): """ Return a decorator registering a new implementation for the CallableDict for the given keys. """ def decorator(func): for key in keys: self[key] = func return func return decorator
def __call__(self, obj, *args, **kw): key = self.keyfunc(obj) return self[key](obj, *args, **kw) def __missing__(self, key): if callable(self.keymissing): return self.keymissing raise KeyError(key)
[docs]class AccumDict(dict): """ An accumulating dictionary, useful to accumulate variables. >> acc = AccumDict() >> acc += {'a': 1} >> acc += {'a': 1, 'b': 1} >> acc {'a': 2, 'b': 1} >> {'a': 1} + acc {'a': 3, 'b': 1} >> acc + 1 {'a': 3, 'b': 2} >> 1 - acc {'a': -1, 'b': 0} >> acc - 1 {'a': 1, 'b': 0} Also the multiplication has been defined: >> prob1 = AccumDict(a=0.4, b=0.5) >> prob2 = AccumDict(b=0.5) >> prob1 * prob2 {'a': 0.4, 'b': 0.25} >> prob1 * 1.2 {'a': 0.48, 'b': 0.6} >> 1.2 * prob1 {'a': 0.48, 'b': 0.6} """ def __iadd__(self, other): if hasattr(other, 'items'): for k, v in other.items(): try: self[k] = self[k] + v except KeyError: self[k] = v else: # add other to all elements for k in self: self[k] = self[k] + other return self def __add__(self, other): new = self.__class__(self) new += other return new __radd__ = __add__ def __isub__(self, other): if hasattr(other, 'items'): for k, v in other.items(): try: self[k] = self[k] - v except KeyError: self[k] = v else: # subtract other to all elements for k in self: self[k] = self[k] - other return self def __sub__(self, other): new = self.__class__(self) new -= other return new def __rsub__(self, other): return - self.__sub__(other) def __neg__(self): return self.__class__({k: -v for k, v in self.items()}) def __imul__(self, other): if hasattr(other, 'items'): for k, v in other.items(): try: self[k] = self[k] * v except KeyError: self[k] = v else: # add other to all elements for k in self: self[k] = self[k] * other return self def __mul__(self, other): new = self.__class__(self) new *= other return new __rmul__ = __mul__ def __truediv__(self, other): return self * (1. / other)
[docs] def apply(self, func, *extras): """ >> a = AccumDict({'a': 1, 'b': 2}) >> a.apply(lambda x, y: 2 * x + y, 1) {'a': 3, 'b': 5} """ return self.__class__({key: func(value, *extras) for key, value in self.items()})
# return a dict imt -> slice and the total number of levels def _slicedict_n(imt_dt): n = 0 slicedic = {} for imt in imt_dt.names: shp = imt_dt[imt].shape n1 = n + shp[0] if shp else 1 slicedic[imt] = slice(n, n1) n = n1 return slicedic, n
[docs]class DictArray(collections.Mapping): """ A small wrapper over a dictionary of arrays: >>> DictArray({'PGA': [0.01, 0.02, 0.04], 'PGV': [0.1, 0.2]}) <DictArray PGA: [ 0.01 0.02 0.04] PGV: [ 0.1 0.2]> The DictArray maintains the lexicographic order of the keys. """ def __init__(self, imtls): self.imt_dt = dt = numpy.dtype( [(imt, F64, len(imls) if hasattr(imls, '__len__') else 1) for imt, imls in sorted(imtls.items())]) self.slicedic, num_levels = _slicedict_n(dt) self.array = numpy.zeros(num_levels, F64) for imt, imls in imtls.items(): self[imt] = imls def __getitem__(self, imt): return self.array[self.slicedic[imt]] def __setitem__(self, imt, array): self.array[self.slicedic[imt]] = array def __iter__(self): for imt in self.imt_dt.names: yield imt def __len__(self): return len(self.imt_dt.names) def __repr__(self): data = ['%s: %s' % (imt, self[imt]) for imt in self] return '<%s\n%s>' % (self.__class__.__name__, '\n'.join(data))
[docs]def groupby(objects, key, reducegroup=list): """ :param objects: a sequence of objects with a key value :param key: the key function to extract the key value :param reducegroup: the function to apply to each group :returns: an OrderedDict {key value: map(reducegroup, group)} >>> groupby(['A1', 'A2', 'B1', 'B2', 'B3'], lambda x: x[0], ... lambda group: ''.join(x[1] for x in group)) OrderedDict([('A', '12'), ('B', '123')]) """ kgroups = itertools.groupby(sorted(objects, key=key), key) return collections.OrderedDict((k, reducegroup(group)) for k, group in kgroups)
[docs]def groupby2(records, kfield, vfield): """ :param records: a sequence of records with positional or named fields :param kfield: the index/name/tuple specifying the field to use as a key :param vfield: the index/name/tuple specifying the field to use as a value :returns: an list of pairs of the form (key, [value, ...]). >>> groupby2(['A1', 'A2', 'B1', 'B2', 'B3'], 0, 1) [('A', ['1', '2']), ('B', ['1', '2', '3'])] Here is an example where the keyfield is a tuple of integers: >>> groupby2(['A11', 'A12', 'B11', 'B21'], (0, 1), 2) [(('A', '1'), ['1', '2']), (('B', '1'), ['1']), (('B', '2'), ['1'])] """ if isinstance(kfield, tuple): kgetter = operator.itemgetter(*kfield) else: kgetter = operator.itemgetter(kfield) if isinstance(vfield, tuple): vgetter = operator.itemgetter(*vfield) else: vgetter = operator.itemgetter(vfield) dic = groupby(records, kgetter, lambda rows: [vgetter(r) for r in rows]) return list(dic.items()) # Python3 compatible
def _reducerecords(group): records = list(group) return numpy.array(records, records[0].dtype)
[docs]def group_array(array, *kfields): """ Convert an array into an OrderedDict kfields -> array """ return groupby(array, operator.itemgetter(*kfields), _reducerecords)
[docs]def get_array(array, **kw): """ Extract a subarray by filtering on the given keyword arguments """ for name, value in kw.items(): array = array[array[name] == value] return array
[docs]def humansize(nbytes, suffixes=('B', 'KB', 'MB', 'GB', 'TB', 'PB')): """ Return file size in a human-friendly format """ if nbytes == 0: return '0 B' i = 0 while nbytes >= 1024 and i < len(suffixes) - 1: nbytes /= 1024. i += 1 f = ('%.2f' % nbytes).rstrip('0').rstrip('.') return '%s %s' % (f, suffixes[i])
# the builtin DeprecationWarning has been silenced in Python 2.7
[docs]class DeprecationWarning(UserWarning): """ Raised the first time a deprecated function is called """
[docs]def deprecated(message): """ Return a decorator to make deprecated functions. :param message: the message to print the first time the deprecated function is used. Here is an example of usage: >>> @deprecated('Use new_function instead') ... def old_function(): ... 'Do something' Notice that if the function is called several time, the deprecation warning will be displayed only the first time. """ def _deprecated(func, *args, **kw): msg = '%s.%s has been deprecated. %s' % ( func.__module__, func.__name__, message) if not hasattr(func, 'called'): warnings.warn(msg, DeprecationWarning, stacklevel=2) func.called = 0 func.called += 1 return func(*args, **kw) return decorator(_deprecated)