[docs]defduplicated(items):""" :returns: the list of duplicated keys, possibly empty """counter=collections.Counter(items)return[keyforkey,countsincounter.items()ifcounts>1]
[docs]defcached_property(method):""" :param method: a method without arguments except self :returns: a cached property """name=method.__name__defnewmethod(self):try:val=self.__dict__[name]exceptKeyError:t0=time.time()val=method(self)cached_property.dt[name]=time.time()-t0self.__dict__[name]=valreturnvalnewmethod.__name__=method.__name__newmethod.__doc__=method.__doc__returnproperty(newmethod)
cached_property.dt={}# dictionary of times
[docs]defnokey(item):""" Dummy function to apply to items without a key """return'Unspecified'
[docs]classWeightedSequence(MutableSequence):""" A wrapper over a sequence of weighted items with a total weight attribute. Adding items automatically increases the weight. """
[docs]@classmethoddefmerge(cls,ws_list):""" Merge a set of WeightedSequence objects. :param ws_list: a sequence of :class: `openquake.baselib.general.WeightedSequence` instances :returns: a :class:`openquake.baselib.general.WeightedSequence` instance """returnsum(ws_list,cls())
def__init__(self,seq=()):""" param seq: a finite sequence of pairs (item, weight) """self._seq=[]self.weight=0self.extend(seq)def__getitem__(self,sliceobj):""" Return an item or a slice """returnself._seq[sliceobj]def__setitem__(self,i,v):""" Modify the sequence """self._seq[i]=vdef__delitem__(self,sliceobj):""" Remove an item from the sequence """delself._seq[sliceobj]def__len__(self):""" The length of the sequence """returnlen(self._seq)def__add__(self,other):""" Add two weighted sequences and return a new WeightedSequence with weight equal to the sum of the weights. """new=self.__class__()new._seq.extend(self._seq)new._seq.extend(other._seq)new.weight=self.weight+other.weightreturnnew
[docs]definsert(self,i,item_weight):""" Insert an item with the given weight in the sequence """item,weight=item_weightself._seq.insert(i,item)self.weight+=weight
def__lt__(self,other):""" Ensure ordering by weight """returnself.weight<other.weightdef__eq__(self,other):""" Compare for equality the items contained in self """returnall(x==yforx,yinzip(self,other))def__repr__(self):""" String representation of the sequence, including the weight """return'<%s%s, weight=%s>'%(self.__class__.__name__,self._seq,self.weight)
[docs]defdistinct(keys):""" Return the distinct keys in order. """known=set()outlist=[]forkeyinkeys:ifkeynotinknown:outlist.append(key)known.add(key)returnoutlist
[docs]defceil(x):""" Converts the result of math.ceil into an integer """returnint(math.ceil(x))
[docs]defblock_splitter(items,max_weight,weight=lambdaitem:1,key=nokey,sort=False):""" :param items: an iterator over items :param max_weight: the max weight to split on :param weight: a function returning the weigth of a given item :param key: a function returning the kind of a given item :param sort: if True, sort the items by reverse weight before splitting Group together items of the same kind until the total weight exceeds the `max_weight` and yield `WeightedSequence` instances. Items with weight zero are ignored. For instance >>> items = 'ABCDE' >>> list(block_splitter(items, 3)) [<WeightedSequence ['A', 'B', 'C'], weight=3>, <WeightedSequence ['D', 'E'], weight=2>] The default weight is 1 for all items. Here is an example leveraning on the key to group together results: >>> items = ['A1', 'C2', 'D2', 'E2'] >>> list(block_splitter(items, 2, key=operator.itemgetter(1))) [<WeightedSequence ['A1'], weight=1>, <WeightedSequence ['C2', 'D2'], weight=2>, <WeightedSequence ['E2'], weight=1>] """ifmax_weight<=0:raiseValueError('max_weight=%s'%max_weight)ws=WeightedSequence([])prev_key='Unspecified'foriteminsorted(items,key=weight,reverse=True)ifsortelseitems:w=weight(item)k=key(item)ifw<0:# errorraiseValueError('The item %r got a negative weight %s!'%(item,w))elifws.weight+w>max_weightork!=prev_key:new_ws=WeightedSequence([(item,w)])ifws:yieldwsws=new_wselifw>0:# ignore items with 0 weightws.append((item,w))prev_key=kifws:yieldws
[docs]defsplit_in_slices(number,num_slices):""" :param number: a positive number to split in slices :param num_slices: the number of slices to return (at most) :returns: a list of slices >>> split_in_slices(4, 2) [slice(0, 2, None), slice(2, 4, None)] >>> split_in_slices(5, 1) [slice(0, 5, None)] >>> split_in_slices(5, 2) [slice(0, 3, None), slice(3, 5, None)] >>> split_in_slices(2, 4) [slice(0, 1, None), slice(1, 2, None)] """assertnumber>0,numberassertnum_slices>0,num_slicesblocksize=int(math.ceil(number/num_slices))slices=[]start=0whileTrue:stop=min(start+blocksize,number)slices.append(slice(start,stop))ifstop==number:breakstart+=blocksizereturnslices
[docs]defgen_slices(start,stop,blocksize):""" Yields slices of lenght at most block_size. >>> list(gen_slices(1, 6, 2)) [slice(1, 3, None), slice(3, 5, None), slice(5, 6, None)] """blocksize=int(blocksize)assertstart<=stop,(start,stop)assertblocksize>0,blocksizewhileTrue:yieldslice(start,min(start+blocksize,stop))start+=blocksizeifstart>=stop:break
[docs]defsplit_in_blocks(sequence,hint,weight=lambdaitem:1,key=nokey):""" Split the `sequence` in a number of WeightedSequences close to `hint`. :param sequence: a finite sequence of items :param hint: an integer suggesting the number of subsequences to generate :param weight: a function returning the weigth of a given item :param key: a function returning the key of a given item The WeightedSequences are of homogeneous key and they try to be balanced in weight. For instance >>> items = 'ABCDE' >>> list(split_in_blocks(items, 3)) [<WeightedSequence ['A'], weight=1>, <WeightedSequence ['B'], weight=1>, <WeightedSequence ['C'], weight=1>, <WeightedSequence ['D'], weight=1>, <WeightedSequence ['E'], weight=1>] """ifisinstance(sequence,pandas.DataFrame):num_elements=len(sequence)out=numpy.array_split(sequence,num_elementsifnum_elements<hintelsehint)returnoutelifisinstance(sequence,int):returnsplit_in_slices(sequence,hint)elifhintin(0,1)andkeyisnokey:# do not splitreturn[sequence]elifhintin(0,1):# split by keyblocks=[]fork,groupingroupby(sequence,key).items():blocks.append(group)returnblocksitems=sorted(sequence,key=lambdaitem:(key(item),weight(item)))asserthint>0,hintassertlen(items)>0,len(items)total_weight=float(sum(weight(item)foriteminitems))returnblock_splitter(items,total_weight/hint,weight,key)
[docs]defassert_close(a,b,rtol=1e-07,atol=0,context=None):""" Compare for equality up to a given precision two composite objects which may contain floats. NB: if the objects are or contain generators, they are exhausted. :param a: an object :param b: another object :param rtol: relative tolerance :param atol: absolute tolerance """ifisinstance(a,float)orisinstance(a,numpy.ndarray)anda.shape:# shortcutnumpy.testing.assert_allclose(a,b,rtol,atol)returnifisinstance(a,(str,bytes,int)):# another shortcutasserta==b,(a,b)returnifhasattr(a,'keys'):# dict-like objectsasserta.keys()==b.keys(),set(a).symmetric_difference(set(b))forxina:ifx!='__geom__':assert_close(a[x],b[x],rtol,atol,x)returnifhasattr(a,'__dict__'):# objects with an attribute dictionaryassert_close(vars(a),vars(b),rtol,atol,context=a)returnifhasattr(a,'__iter__'):# iterable objectsxs,ys=list(a),list(b)assertlen(xs)==len(ys),('Lists of different lenghts: %d != %d'%(len(xs),len(ys)))forx,yinzip(xs,ys):assert_close(x,y,rtol,atol,x)returnifa==b:# last attempt to avoid raising the exceptionreturnctx=''ifcontextisNoneelse'in context '+repr(context)raiseAssertionError('%r != %r%s'%(a,b,ctx))
_tmp_paths=[]
[docs]defgettemp(content=None,dir=None,prefix="tmp",suffix="tmp",remove=True):"""Create temporary file with the given content. Please note: the temporary file can be deleted by the caller or not. :param string content: the content to write to the temporary file. :param string dir: directory where the file should be created :param string prefix: file name prefix :param string suffix: file name suffix :param bool remove: True by default, meaning the file will be automatically removed at the exit of the program :returns: a string with the path to the temporary file """ifdirisnotNone:ifnotos.path.exists(dir):os.makedirs(dir)fh,path=tempfile.mkstemp(dir=dirorconfig.directory.custom_tmporNone,prefix=prefix,suffix=suffix)ifremove:_tmp_paths.append(path)withos.fdopen(fh,"wb")asfh:ifcontent:ifhasattr(content,'encode'):content=content.encode('utf8')fh.write(content)returnpath
[docs]@atexit.registerdefremovetmp():""" Remove the temporary files created by gettemp """forpathin_tmp_paths:ifos.path.exists(path):# not removed yettry:os.remove(path)exceptPermissionError:pass
[docs]defcheck_extension(fnames):""" Make sure all file names have the same extension """ifnotfnames:return_,extension=os.path.splitext(fnames[0])forfnameinfnames[1:]:_,ext=os.path.splitext(fname)ifext!=extension:raiseNameError(f'{fname} does not end with {ext}')
[docs]defengine_version():""" :returns: __version__ + `<short git hash>` if Git repository found """# we assume that the .git folder is two levels above any package# i.e. openquake/engine/../../.gitgit_path=os.path.join(os.path.dirname(__file__),'..','..','.git')# macOS complains if we try to execute git and it's not available.# Code will run, but a pop-up offering to install bloatware (Xcode)# is raised. This is annoying in end-users installations, so we check# if .git exists before trying to execute the git executablegh=''ifos.path.isdir(git_path):try:withopen(os.devnull,'w')asdevnull:gh=subprocess.check_output(['git','rev-parse','--short','HEAD'],stderr=devnull,cwd=os.path.dirname(git_path)).strip()gh="-git"+decode(gh)ifghelse''exceptException:pass# trapping everything on purpose; git may not be installed or it# may not work properlyreturn__version__+gh
[docs]defextract_dependencies(lines):forlineinlines:longname=line.split('/')[-1]# i.e. urllib3-2.1.0-py3-none-any.whltry:pkg,version,_other=longname.split('-',2)exceptValueError:# for instance a commentcontinueifpkgin('fonttools','protobuf','pyreadline3','python_dateutil','python_pam','django_cors_headers','django_cookie_consent'):# not importablecontinueifpkgin('alpha_shapes','django_pam','pbr','iniconfig','importlib_metadata','zipp'):# missing __version__continueelifpkg=='pyzmq':pkg='zmq'elifpkg=='Pillow':pkg='PIL'elifpkg=='GDAL':pkg='osgeo.gdal'elifpkg=='Django':pkg='django'elifpkg=='pyshp':pkg='shapefile'elifpkg=='django_appconf':pkg='appconf'yieldpkg,version
[docs]defcheck_dependencies():""" Print a warning if we forgot to update the dependencies. Works only for development installations. """importopenquakeif'site-packages'inopenquake.__path__[0]:return# do nothing for non-devel installationspyver='%d%d'%(sys.version_info[0],sys.version_info[1])system=sys.platformifsystem=='linux':system='linux64'elifsystem=='win32':system='win64'elifsystem=='darwin':system='macos_arm64'else:# unsupported OS, do not check dependenciesreturnreqfile='requirements-py%s-%s.txt'%(pyver,system)repodir=os.path.dirname(os.path.dirname(os.path.dirname(__file__)))withopen(os.path.join(repodir,reqfile))asf:lines=f.readlines()forpkg,expectedinextract_dependencies(lines):try:installed_version=version(pkg)exceptPackageNotFoundError:# handling cases such as "No package metadata was found for zmq"# (in other cases, e.g. timezonefinder, __version__ is not defined)installed_version=__import__(pkg).__version__ifinstalled_version!=expected:logging.warning('%s is at version %s but the requirements say %s'%(pkg,installed_version,expected))
[docs]defrun_in_process(code,*args):""" Run in an external process the given Python code and return the output as a Python object. If there are arguments, then code is taken as a template and traditional string interpolation is performed. :param code: string or template describing Python code :param args: arguments to be used for interpolation :returns: the output of the process, as a Python object """ifargs:code%=argstry:out=subprocess.check_output([sys.executable,'-c',code])exceptsubprocess.CalledProcessErrorasexc:print(exc.cmd[-1],file=sys.stderr)raiseifout:out=out.rstrip(b'\x1b[?1034h')# this is absurd, but it happens: just importing a module can# produce escape sequences in stdout, see for instance# https://bugs.python.org/issue19884returneval(out,{},{})
[docs]defimport_all(module_or_package):""" If `module_or_package` is a module, just import it; if it is a package, recursively imports all the modules it contains. Returns the names of the modules that were imported as a set. The set can be empty if the modules were already in sys.modules. """already_imported=set(sys.modules)mod_or_pkg=importlib.import_module(module_or_package)ifnothasattr(mod_or_pkg,'__path__'):# is a simple modulereturnset(sys.modules)-already_imported# else import all modules contained in the package[pkg_path]=mod_or_pkg.__path__n=len(pkg_path)forcwd,dirs,filesinos.walk(pkg_path):ifall(os.path.basename(f)!='__init__.py'forfinfiles):# the current working directory is not a subpackagecontinueforfinfiles:iff.endswith('.py')andnotf.startswith('__init__'):# convert PKGPATH/subpackage/module.py -> subpackage.module# works at any level of nestingmodname=(module_or_package+cwd[n:].replace(os.sep,'.')+'.'+os.path.basename(f[:-3]))importlib.import_module(modname)returnset(sys.modules)-already_imported
[docs]defassert_independent(package,*packages):""" :param package: Python name of a module/package :param packages: Python names of modules/packages Make sure the `package` does not depend from the `packages`. """assertpackages,'At least one package must be specified'import_package='from openquake.baselib.general import import_all\n' \
'print(import_all("%s"))'%packageimported_modules=run_in_process(import_package)formodinimported_modules:forpkginpackages:ifmod.startswith(pkg):raiseCodeDependencyError('%s depends on %s'%(package,pkg))
[docs]classCallableDict(dict):r""" A callable object built on top of a dictionary of functions, used as a smart registry or as a poor man generic function dispatching on the first argument. It is typically used to implement converters. Here is an example: >>> format_attrs = CallableDict() # dict of functions (fmt, obj) -> str >>> @format_attrs.add('csv') # implementation for csv ... def format_attrs_csv(fmt, obj): ... items = sorted(vars(obj).items()) ... return '\n'.join('%s,%s' % item for item in items) >>> @format_attrs.add('json') # implementation for json ... def format_attrs_json(fmt, obj): ... return json.dumps(vars(obj)) `format_attrs(fmt, obj)` calls the correct underlying function depending on the `fmt` key. If the format is unknown a `KeyError` is raised. It is also possible to set a `keymissing` function to specify what to return if the key is missing. For a more practical example see the implementation of the exporters in openquake.calculators.export """def__init__(self,keyfunc=lambdakey:key,keymissing=None):super().__init__()self.keyfunc=keyfuncself.keymissing=keymissing
[docs]defadd(self,*keys):""" Return a decorator registering a new implementation for the CallableDict for the given keys. """defdecorator(func):forkeyinkeys:self[key]=funcreturnfuncreturndecorator
[docs]classpack(dict):""" Compact a dictionary of lists into a dictionary of arrays. If attrs are given, consider those keys as attributes. For instance, >>> p = pack(dict(x=[1], a=[0]), ['a']) >>> p {'x': array([1])} >>> p.a array([0]) """def__init__(self,dic,attrs=()):fork,vindic.items():arr=numpy.array(v)ifkinattrs:setattr(self,k,arr)else:self[k]=arr
[docs]classAccumDict(dict):""" An accumulating dictionary, useful to accumulate variables:: >>> acc = AccumDict() >>> acc += {'a': 1} >>> acc += {'a': 1, 'b': 1} >>> acc {'a': 2, 'b': 1} >>> {'a': 1} + acc {'a': 3, 'b': 1} >>> acc + 1 {'a': 3, 'b': 2} >>> 1 - acc {'a': -1, 'b': 0} >>> acc - 1 {'a': 1, 'b': 0} The multiplication has been defined: >>> prob1 = AccumDict(dict(a=0.4, b=0.5)) >>> prob2 = AccumDict(dict(b=0.5)) >>> prob1 * prob2 {'a': 0.4, 'b': 0.25} >>> prob1 * 1.2 {'a': 0.48, 'b': 0.6} >>> 1.2 * prob1 {'a': 0.48, 'b': 0.6} And even the power: >>> prob2 ** 2 {'b': 0.25} It is very common to use an AccumDict of accumulators; here is an example using the empty list as accumulator: >>> acc = AccumDict(accum=[]) >>> acc['a'] += [1] >>> acc['b'] += [2] >>> sorted(acc.items()) [('a', [1]), ('b', [2])] The implementation is smart enough to make (deep) copies of the accumulator, therefore each key has a different accumulator, which initially is the empty list (in this case). """def__init__(self,dic=None,accum=None,keys=()):forkeyinkeys:self[key]=copy.deepcopy(accum)ifdic:self.update(dic)self.accum=accumdef__iadd__(self,other):ifhasattr(other,'items'):fork,vinother.items():ifknotinself:self[k]=velifisinstance(v,list):# specialized for speedself[k].extend(v)else:self[k]+=velse:# add other to all elementsforkinself:self[k]+=otherreturnselfdef__add__(self,other):new=self.__class__(self)new+=otherreturnnew__radd__=__add__def__isub__(self,other):ifhasattr(other,'items'):fork,vinother.items():try:self[k]-=self[k]exceptKeyError:self[k]=velse:# subtract other to all elementsforkinself:self[k]-=otherreturnselfdef__sub__(self,other):new=self.__class__(self)new-=otherreturnnewdef__rsub__(self,other):return-self.__sub__(other)def__neg__(self):returnself.__class__({k:-vfork,vinself.items()})def__invert__(self):returnself.__class__({k:~vfork,vinself.items()})def__imul__(self,other):ifhasattr(other,'items'):fork,vinother.items():try:self[k]=self[k]*vexceptKeyError:self[k]=velse:# add other to all elementsforkinself:self[k]=self[k]*otherreturnselfdef__mul__(self,other):new=self.__class__(self)new*=otherreturnnew__rmul__=__mul__def__pow__(self,n):new=self.__class__(self)forkeyinnew:new[key]**=nreturnnewdef__truediv__(self,other):returnself*(1./other)def__missing__(self,key):ifself.accumisNone:# no accumulator, accessing a missing key is an errorraiseKeyError(key)val=self[key]=copy.deepcopy(self.accum)returnval
[docs]defapply(self,func,*extras):""" >> a = AccumDict({'a': 1, 'b': 2}) >> a.apply(lambda x, y: 2 * x + y, 1) {'a': 3, 'b': 5} """returnself.__class__({key:func(value,*extras)forkey,valueinself.items()})
[docs]defcopyobj(obj,**kwargs):""" :returns: a shallow copy of obj with some changed attributes """new=copy.copy(obj)fork,vinkwargs.items():setattr(new,k,v)returnnew
[docs]classDictArray(Mapping):""" A small wrapper over a dictionary of arrays with the same lenghts. """def__init__(self,imtls):levels=imtls[next(iter(imtls))]self.M=len(imtls)self.L1=len(levels)self.size=self.M*self.L1items=imtls.items()self.dt=numpy.dtype([(str(imt),F64,(self.L1,))forimt,imlsinitems])self.array=numpy.zeros((self.M,self.L1),F64)self.slicedic={}n=0self.mdic={}form,(imt,imls)inenumerate(items):iflen(imls)!=self.L1:raiseValueError('imt=%s has %d levels, expected %d'%(imt,len(imls),self.L1))self.slicedic[imt]=slice(n,n+self.L1)self.mdic[imt]=mself.array[m]=imlsn+=self.L1def__call__(self,imt):returnself.slicedic[imt]def__getitem__(self,imt):returnself.array[self.mdic[imt]]def__setitem__(self,imt,array):self.array[self.mdic[imt]]=arraydef__iter__(self):forimtinself.dt.names:yieldimtdef__len__(self):returnlen(self.dt.names)def__eq__(self,other):arr=self.array==other.arrayifisinstance(arr,bool):returnarrreturnarr.all()def__ne__(self,other):returnnotself.__eq__(other)def__repr__(self):data=['%s: %s'%(imt,self[imt])forimtinself]return'<%s\n%s>'%(self.__class__.__name__,'\n'.join(data))
[docs]defgroupby(objects,key,reducegroup=list):""" :param objects: a sequence of objects with a key value :param key: the key function to extract the key value :param reducegroup: the function to apply to each group :returns: a dict {key value: map(reducegroup, group)} >>> groupby(['A1', 'A2', 'B1', 'B2', 'B3'], lambda x: x[0], ... lambda group: ''.join(x[1] for x in group)) {'A': '12', 'B': '123'} """kgroups=itertools.groupby(sorted(objects,key=key),key)return{k:reducegroup(group)fork,groupinkgroups}
[docs]defgroupby2(records,kfield,vfield):""" :param records: a sequence of records with positional or named fields :param kfield: the index/name/tuple specifying the field to use as a key :param vfield: the index/name/tuple specifying the field to use as a value :returns: an list of pairs of the form (key, [value, ...]). >>> groupby2(['A1', 'A2', 'B1', 'B2', 'B3'], 0, 1) [('A', ['1', '2']), ('B', ['1', '2', '3'])] Here is an example where the keyfield is a tuple of integers: >>> groupby2(['A11', 'A12', 'B11', 'B21'], (0, 1), 2) [(('A', '1'), ['1', '2']), (('B', '1'), ['1']), (('B', '2'), ['1'])] """ifisinstance(kfield,tuple):kgetter=operator.itemgetter(*kfield)else:kgetter=operator.itemgetter(kfield)ifisinstance(vfield,tuple):vgetter=operator.itemgetter(*vfield)else:vgetter=operator.itemgetter(vfield)dic=groupby(records,kgetter,lambdarows:[vgetter(r)forrinrows])returnlist(dic.items())# Python3 compatible
[docs]defget_bins(values,nbins,key=None,minval=None,maxval=None):""" :param values: an array of N floats (or arrays) :returns: an array of N bin indices plus an array of B bins """assertlen(values)ifkeyisnotNone:values=numpy.array([key(val)forvalinvalues])ifminvalisNone:minval=values.min()ifmaxvalisNone:maxval=values.max()ifminval==maxval:bins=[minval]*nbinselse:bins=numpy.arange(minval,maxval,(maxval-minval)/nbins)returnnumpy.searchsorted(bins,values,side='right'),bins
[docs]defgroupby_grid(xs,ys,deltax,deltay):""" :param xs: an array of P abscissas :param ys: an array of P ordinates :param deltax: grid spacing on the x-axis :param deltay: grid spacing on the y-axis :returns: dictionary centroid -> indices (of the points around each centroid) """lx,ly=len(xs),len(ys)assertlx==ly,(lx,ly)assertlx>1,lxassertdeltax>0,deltaxassertdeltay>0,deltayxmin=xs.min()xmax=xs.max()ymin=ys.min()ymax=ys.max()nx=numpy.ceil((xmax-xmin)/deltax)ny=numpy.ceil((ymax-ymin)/deltay)assertnx>0,nxassertny>0,nyxbins=get_bins(xs,nx,None,xmin,xmax)[0]ybins=get_bins(ys,ny,None,ymin,ymax)[0]acc=AccumDict(accum=[])fork,ijinenumerate(zip(xbins,ybins)):acc[ij].append(k)dic={}forksinacc.values():ks=numpy.array(ks)dic[xs[ks].mean(),ys[ks].mean()]=ksreturndic
[docs]defgroupby_bin(values,nbins,key=None,minval=None,maxval=None):""" >>> values = numpy.arange(10) >>> for group in groupby_bin(values, 3): ... print(group) [0, 1, 2] [3, 4, 5] [6, 7, 8, 9] """iflen(values)==0:# do nothingreturnvaluesidxs=get_bins(values,nbins,key,minval,maxval)[0]acc=AccumDict(accum=[])foridx,valinzip(idxs,values):ifisinstance(idx,numpy.ndarray):idx=tuple(idx)# make it hashableacc[idx].append(val)returnacc.values()
[docs]defgroup_array(array,*kfields):""" Convert an array into a dict kfields -> array """returngroupby(array,operator.itemgetter(*kfields),_reducerecords)
[docs]defmulti_index(shape,axis=None):""" :param shape: a shape of lenght L :param axis: None or an integer in the range 0 .. L -1 :yields: tuples of indices with a slice(None) at the axis position (if any) >>> for slc in multi_index((2, 3), 0): print(slc) (slice(None, None, None), 0, 0) (slice(None, None, None), 0, 1) (slice(None, None, None), 0, 2) (slice(None, None, None), 1, 0) (slice(None, None, None), 1, 1) (slice(None, None, None), 1, 2) """ranges=(range(s)forsinshape)ifaxisisNone:yield fromitertools.product(*ranges)fortupinitertools.product(*ranges):lst=list(tup)lst.insert(axis,slice(None))yieldtuple(lst)
# NB: the fast_agg functions are usually faster than pandas
[docs]deffast_agg(indices,values=None,axis=0,factor=None,M=None):""" :param indices: N indices in the range 0 ... M - 1 with M < N :param values: N values (can be arrays) :param factor: if given, a multiplicate factor (or weight) for the values :param M: maximum index; if None, use max(indices) + 1 :returns: M aggregated values (can be arrays) >>> values = numpy.array([[.1, .11], [.2, .22], [.3, .33], [.4, .44]]) >>> fast_agg([0, 1, 1, 0], values) array([[0.5 , 0.55], [0.5 , 0.55]]) """ifvaluesisNone:values=numpy.ones_like(indices)N=len(values)iflen(indices)!=N:raiseValueError('There are %d values but %d indices'%(N,len(indices)))shp=values.shape[1:]ifMisNone:M=max(indices)+1ifnotshp:returnnumpy.bincount(indices,valuesiffactorisNoneelsevalues*factor,M)lst=list(shp)lst.insert(axis,M)res=numpy.zeros(lst,values.dtype)formiinmulti_index(shp,axis):vals=values[mi]iffactorisNoneelsevalues[mi]*factorres[mi]=numpy.bincount(indices,vals,M)returnres
# NB: the fast_agg functions are usually faster than pandas
[docs]deffast_agg2(tags,values=None,axis=0):""" :param tags: N non-unique tags out of M :param values: N values (can be arrays) :returns: (M unique tags, M aggregated values) >>> values = numpy.array([[.1, .11], [.2, .22], [.3, .33], [.4, .44]]) >>> fast_agg2(['A', 'B', 'B', 'A'], values) (array(['A', 'B'], dtype='<U1'), array([[0.5 , 0.55], [0.5 , 0.55]])) It can also be used to count the number of tags: >>> fast_agg2(['A', 'B', 'B', 'A', 'A']) (array(['A', 'B'], dtype='<U1'), array([3., 2.])) """uniq,indices=numpy.unique(tags,return_inverse=True)returnuniq,fast_agg(indices,values,axis)
# NB: the fast_agg functions are usually faster than pandas
[docs]deffast_agg3(structured_array,kfield,vfields=None,factor=None):""" Aggregate a structured array with a key field (the kfield) and some value fields (the vfields). If vfields is not passed, use all fields except the kfield. >>> data = numpy.array([(1, 2.4), (1, 1.6), (2, 2.5)], ... [('aid', U16), ('val', F32)]) >>> fast_agg3(data, 'aid') array([(1, 4. ), (2, 2.5)], dtype=[('aid', '<u2'), ('val', '<f4')]) """allnames=structured_array.dtype.namesifvfieldsisNone:vfields=[namefornameinallnamesifname!=kfield]assertkfieldinallnames,kfieldforvfieldinvfields:assertvfieldinallnames,vfieldtags=structured_array[kfield]uniq,indices=numpy.unique(tags,return_inverse=True)dic={}dtlist=[(kfield,structured_array.dtype[kfield])]fornameinvfields:dic[name]=fast_agg(indices,structured_array[name],factor=factor)dtlist.append((name,structured_array.dtype[name]))res=numpy.zeros(len(uniq),dtlist)res[kfield]=uniqfornameindic:res[name]=dic[name]returnres
[docs]defcountby(array,*kfields):""" :returns: a dict kfields -> number of records with that key """returngroupby(array,operator.itemgetter(*kfields),count)
[docs]defget_array(array,**kw):""" Extract a subarray by filtering on the given keyword arguments """forname,valueinkw.items():array=array[array[name]==value]returnarray
[docs]defnot_equal(array_or_none1,array_or_none2):""" Compare two arrays that can also be None or have diffent shapes and returns a boolean. >>> a1 = numpy.array([1]) >>> a2 = numpy.array([2]) >>> a3 = numpy.array([2, 3]) >>> not_equal(a1, a2) True >>> not_equal(a1, a3) True >>> not_equal(a1, None) True """ifarray_or_none1isNoneandarray_or_none2isNone:returnFalseelifarray_or_none1isNoneandarray_or_none2isnotNone:returnTrueelifarray_or_none1isnotNoneandarray_or_none2isNone:returnTrueifarray_or_none1.shape!=array_or_none2.shape:returnTruereturn(array_or_none1!=array_or_none2).any()
[docs]defall_equals(inputs):""" :param inputs: a list of arrays or strings :returns: True if all values are equal, False otherwise """inp0=inputs[0]forinpininputs[1:]:try:diff=inp!=inp0exceptValueError:# Lengths must match to comparereturnFalseifisinstance(diff,numpy.ndarray):ifdiff.any():returnFalseelifdiff:returnFalsereturnTrue
[docs]defhumansize(nbytes,suffixes=('B','KB','MB','GB','TB','PB')):""" Return file size in a human-friendly format """ifnbytes==0:return'0 B'i=0whilenbytes>=1024andi<len(suffixes)-1:nbytes/=1024.i+=1f=('%.2f'%nbytes).rstrip('0').rstrip('.')return'%s%s'%(f,suffixes[i])
# the builtin DeprecationWarning has been silenced in Python 2.7
[docs]classDeprecationWarning(UserWarning):""" Raised the first time a deprecated function is called """
@decoratordefdeprecated(func,msg='',*args,**kw):""" A family of decorators to mark deprecated functions. :param msg: the message to print the first time the deprecated function is used. Here is an example of usage: >>> @deprecated(msg='Use new_function instead') ... def old_function(): ... 'Do something' Notice that if the function is called several time, the deprecation warning will be displayed only the first time. """msg='%s.%s has been deprecated. %s'%(func.__module__,func.__name__,msg)ifnothasattr(func,'called'):warnings.warn(msg,DeprecationWarning,stacklevel=2)func.called=0func.called+=1returnfunc(*args,**kw)
[docs]defrandom_filter(objects,reduction_factor,seed=42):""" Given a list of objects, returns a sublist by extracting randomly some elements. The reduction factor (< 1) tells how small is the extracted list compared to the original list. """assert0<reduction_factor<=1,reduction_factorifreduction_factor==1:# do not reducereturnobjectsrnd=random.Random(seed)ifisinstance(objects,pandas.DataFrame):df=pandas.DataFrame({col:random_filter(objects[col],reduction_factor,seed)forcolinobjects.columns})returndfout=[]forobjinobjects:ifrnd.random()<=reduction_factor:out.append(obj)returnout
[docs]defrandom_choice(array,num_samples,offset=0,seed=42):""" Extract num_samples from an array. It has the fundamental property of splittability, i.e. if the seed is the same and `||` means array concatenation: choice(a, N) = choice(a, n, 0) || choice(a, N-n, n) This property makes `random_choice` suitable to be parallelized, while `random.choice` is not. It as also absurdly fast. """rng=numpy.random.default_rng(seed)rng.bit_generator.advance(offset)N=len(array)cumsum=numpy.repeat(1./N,N).cumsum()choices=numpy.searchsorted(cumsum,rng.random(num_samples))returnarray[choices]
[docs]defrandom_histogram(counts,nbins_or_binweights,seed):""" Distribute a total number of counts over a set of bins. If the weights of the bins are equal you can just pass the number of the bins and a faster algorithm will be used. Otherwise pass the weights. Here are a few examples: >>> list(random_histogram(1, 2, seed=42)) [0, 1] >>> list(random_histogram(100, 5, seed=42)) [22, 17, 21, 26, 14] >>> list(random_histogram(10000, 5, seed=42)) [2034, 2000, 2014, 1998, 1954] >>> list(random_histogram(1000, [.3, .3, .4], seed=42)) [308, 295, 397] """rng=numpy.random.default_rng(seed)try:nbins=len(nbins_or_binweights)exceptTypeError:# 'int' has no len()nbins=nbins_or_binweightsweights=numpy.repeat(1./nbins,nbins)else:weights=numpy.array(nbins_or_binweights)weights/=weights.sum()# normalize to 1bins=numpy.searchsorted(weights.cumsum(),rng.random(counts))returnnumpy.bincount(bins,minlength=len(weights))
[docs]defsafeprint(*args,**kwargs):""" Convert and print characters using the proper encoding """new_args=[]# when stdout is redirected to a file, python 2 uses ascii for the writer;# python 3 uses what is configured in the system (i.e. 'utf-8')# if sys.stdout is replaced by a StringIO instance, Python 2 does not# have an attribute 'encoding', and we assume ascii in that casestr_encoding=getattr(sys.stdout,'encoding',None)or'ascii'forsinargs:new_args.append(s.encode('utf-8').decode(str_encoding,'ignore'))returnprint(*new_args,**kwargs)
[docs]defsocket_ready(hostport):""" :param hostport: a pair (host, port) or a string (tcp://)host:port :returns: True if the socket is ready and False otherwise """ifhasattr(hostport,'startswith'):# string representation of the hostport combinationifhostport.startswith('tcp://'):hostport=hostport[6:]# strip tcp://host,port=hostport.split(':')hostport=(host,int(port))sock=socket.socket(socket.AF_INET,socket.SOCK_STREAM)try:exc=sock.connect_ex(hostport)finally:sock.close()returnFalseifexcelseTrue
port_candidates=list(range(1920,2000))def_get_free_port():# extracts a free port in the range 1920:2000 and raises a RuntimeError if# there are no free ports. NB: the port is free when extracted, but another# process may take it immediately, so this function is not safe against# race conditions. Moreover, once a port is taken, it is taken forever and# never considered free again, even if it is. These restrictions as# acceptable for usage in the tests, but only in that case.whileport_candidates:port=random.choice(port_candidates)port_candidates.remove(port)ifnotsocket_ready(('127.0.0.1',port)):# no server listeningreturnport# the port is freeraiseRuntimeError('No free ports in the range 1920:2000')
[docs]defzipfiles(fnames,archive,mode='w',log=lambdamsg:None,cleanup=False):""" Build a zip archive from the given file names. :param fnames: list of path names :param archive: path of the archive or BytesIO object """prefix=len(os.path.commonprefix([os.path.dirname(f)forfinfnames]))withzipfile.ZipFile(archive,mode,zipfile.ZIP_DEFLATED,allowZip64=True)asz:forfinfnames:log('Archiving %s'%f)z.write(f,f[prefix:])ifcleanup:# remove the zipped fileos.remove(f)returnarchive
[docs]defdetach_process():""" Detach the current process from the controlling terminal by using a double fork. Can be used only on platforms with fork (no Windows). """# see https://pagure.io/python-daemon/blob/master/f/daemon/daemon.py and# https://stackoverflow.com/questions/45911705/why-use-os-setsid-in-pythondeffork_then_exit_parent():pid=os.fork()ifpid:# in parentos._exit(0)fork_then_exit_parent()os.setsid()fork_then_exit_parent()
[docs]defprintln(msg):""" Convenience function to print messages on a single line in the terminal """sys.stdout.write(msg)sys.stdout.flush()sys.stdout.write('\x08'*len(msg))sys.stdout.flush()
[docs]defdebug(line):""" Append a debug line to the file /tmp/debug.txt """tmp=tempfile.gettempdir()withopen(os.path.join(tmp,'debug.txt'),'a',encoding='utf8')asf:f.write(line+'\n')
builtins.debug=debug
[docs]defwarn(msg,*args):""" Print a warning on stderr """ifnotargs:sys.stderr.write('WARNING: '+msg)else:sys.stderr.write('WARNING: '+msg%args)
[docs]defgetsizeof(o,ids=None):''' Find the memory footprint of a Python object recursively, see https://code.tutsplus.com/tutorials/understand-how-much-memory-your-python-objects-use--cms-25609 :param o: the object :returns: the size in bytes '''ids=idsorset()ifid(o)inids:return0ifhasattr(o,'nbytes'):returno.nbyteselifhasattr(o,'array'):returno.array.nbytesnbytes=sys.getsizeof(o)ids.add(id(o))ifisinstance(o,Mapping):returnnbytes+sum(getsizeof(k,ids)+getsizeof(v,ids)fork,vino.items())elifisinstance(o,Container):returnnbytes+sum(getsizeof(x,ids)forxino)returnnbytes
[docs]defget_duplicates(array,*fields):""" :returns: a dictionary {key: num_dupl} for duplicate records """uniq=numpy.unique(array[list(fields)])iflen(uniq)==len(array):# no duplicatesreturn{}return{k:len(g)fork,gingroup_array(array,*fields).items()iflen(g)>1}
[docs]defcategorize(values,nchars=2):""" Takes an array with duplicate values and categorize it, i.e. replace the values with codes of length nchars in BASE183. With nchars=2 33856 unique values can be encoded, if there are more nchars must be increased otherwise a ValueError will be raised. :param values: an array of V non-unique values :param nchars: number of characters in BASE183 for each code :returns: an array of V non-unique codes >>> categorize([1,2,2,3,4,1,1,2]) # 8 values, 4 unique ones array([b'AA', b'AB', b'AB', b'AC', b'AD', b'AA', b'AA', b'AB'], dtype='|S2') """uvalues=numpy.unique(values)mvalues=184**nchars# maximum number of unique valuesiflen(uvalues)>mvalues:raiseValueError(f'There are too many unique values ({len(uvalues)} > {mvalues})')prod=itertools.product(*[BASE183]*nchars)dic={uvalue:''.join(chars)foruvalue,charsinzip(uvalues,prod)}returnnumpy.array([dic[v]forvinvalues],(numpy.bytes_,nchars))
[docs]defget_nbytes_msg(sizedict,size=8):""" :param sizedict: mapping name -> num_dimensions :returns: (size of the array in bytes, descriptive message) >>> get_nbytes_msg(dict(nsites=2, nbins=5)) (80, '(nsites=2) * (nbins=5) * 8 bytes = 80 B') """nbytes=numpy.prod(list(sizedict.values()))*sizeprod=' * '.join('({}={:_d})'.format(k,int(v))fork,vinsizedict.items())returnnbytes,'%s * %d bytes = %s'%(prod,size,humansize(nbytes))
[docs]defgen_subclasses(cls):""" :returns: the subclasses of `cls`, ordered by name """forsubclassinsorted(cls.__subclasses__(),key=lambdacls:cls.__name__):yieldsubclassyield fromgen_subclasses(subclass)
[docs]defpprod(p,axis=None):""" Probability product 1 - prod(1-p) """return1.-numpy.prod(1.-p,axis)
[docs]defagg_probs(*probs):""" Aggregate probabilities with the usual formula 1 - (1 - P1) ... (1 - Pn) """acc=1.-probs[0]forprobinprobs[1:]:acc*=1.-probreturn1.-acc
[docs]classParam:""" Container class for a set of parameters with defaults >>> p = Param(a=1, b=2) >>> p.a = 3 >>> p.a, p.b (3, 2) >>> p.c = 4 Traceback (most recent call last): ... AttributeError: Unknown parameter c """def__init__(self,**defaults):fork,vindefaults.items():self.__dict__[k]=vdef__setattr__(self,name,value):ifnameinself.__dict__:object.__setattr__(self,name,value)else:raiseAttributeError('Unknown parameter %s'%name)
[docs]classRecordBuilder(object):""" Builder for numpy records or arrays. >>> rb = RecordBuilder(a=numpy.int64(0), b=1., c="2") >>> rb.dtype dtype([('a', '<i8'), ('b', '<f8'), ('c', 'S1')]) >>> rb() (0, 1., b'2') """def__init__(self,**defaults):self.names=[]self.values=[]dtypes=[]forname,valueindefaults.items():self.names.append(name)self.values.append(value)ifisinstance(value,(str,bytes)):tp=(numpy.bytes_,len(value)or1)elifisinstance(value,numpy.ndarray):tp=(value.dtype,len(value))else:tp=type(value)dtypes.append(tp)self.dtype=numpy.dtype([(n,d)forn,dinzip(self.names,dtypes)])
[docs]defrmsdiff(a,b):""" :param a: an array of shape (N, ...) :param b: an array with the same shape of a :returns: an array of shape (N,) with the root mean squares of a-b """asserta.shape==b.shapeaxis=tuple(range(1,len(a.shape)))rms=numpy.sqrt(((a-b)**2).mean(axis=axis))returnrms
[docs]defsqrscale(x_min,x_max,n):""" :param x_min: minumum value :param x_max: maximum value :param n: number of steps :returns: an array of n values from x_min to x_max in a quadratic scale """ifnot(isinstance(n,int)andn>0):raiseValueError('n must be a positive integer, got %s'%n)ifx_min<0:raiseValueError('x_min must be positive, got %s'%x_min)ifx_max<=x_min:raiseValueError('x_max (%s) must be bigger than x_min (%s)'%(x_max,x_min))delta=numpy.sqrt(x_max-x_min)/(n-1)returnx_min+(delta*numpy.arange(n))**2
# NB: this is present in contextlib in Python 3.11, but# we still support Python 3.9, so it cannot be removed yet
[docs]@contextmanagerdefchdir(path):""" Context manager to temporarily change the CWD """oldpwd=os.getcwd()os.chdir(path)try:yieldfinally:os.chdir(oldpwd)
[docs]defsmart_concat(arrays):""" Concatenated structured arrays by considering only the common fields """iflen(arrays)==0:return()common=set(arrays[0].dtype.names)forarrayinarrays[1:]:common&=set(array.dtype.names)assertcommon,'There are no common field names'common=sorted(common)dt=arrays[0][common].dtypereturnnumpy.concatenate([arr[common]forarrinarrays],dtype=dt)
[docs]defaround(vec,value,delta):""" :param vec: a numpy vector or pandas column :param value: a float value :param delta: a positive float :returns: array of booleans for the range [value-delta, value+delta] """return(vec<=value+delta)&(vec>=value-delta)
[docs]defsum_records(array):""" :returns: the sums of the composite array """res=numpy.zeros(1,array.dtype)fornameinarray.dtype.names:res[name]=array[name].sum(axis=0)returnres
[docs]defcompose_arrays(**kwarrays):""" Compose multiple 1D and 2D arrays into a single composite array. For instance >>> mag = numpy.array([5.5, 5.6]) >>> mea = numpy.array([[-4.5, -4.6], [-4.45, -4.55]]) >>> compose_arrays(mag=mag, mea=mea) array([(5.5, -4.5 , -4.6 ), (5.6, -4.45, -4.55)], dtype=[('mag', '<f8'), ('mea0', '<f8'), ('mea1', '<f8')]) """dic={}dtlist=[]nrows=set()forkey,arrayinkwarrays.items():shape=array.shapenrows.add(shape[0])iflen(shape)>=2:forkinrange(shape[1]):dic[f'{key}{k}']=array[:,k]dtlist.append((f'{key}{k}',(array.dtype,shape[2:])))else:dic[key]=arraydtlist.append((key,array.dtype))[R]=nrows# all arrays must have the same number of rowsarray=numpy.empty(R,dtlist)forkey,_indtlist:array[key]=dic[key]returnarray
# #################### COMPRESSION/DECOMPRESSION ##################### ## Compressing the task outputs makes everything slower, so you should NOT# do that, except in one case. The case if when you have a lot of workers# (say 320) sending a lot of data (say 320 GB) to a master node which is# not able to keep up. Then the zmq queue fills all of the avalaible RAM# until the master node blows up. With compression you can reduce the queue# size a lot (say one order of magnitude).# Therefore by losing a bit of speed (say 3%) you can convert a failing# calculation into a successful one.
[docs]defcompress(obj):""" gzip a Python object """# level=1: compress the least, but fast, good choice for usreturnzlib.compress(pickle.dumps(obj,pickle.HIGHEST_PROTOCOL),level=1)
[docs]defdecompress(cbytes):""" gunzip compressed bytes into a Python object """returnpickle.loads(zlib.decompress(cbytes))
# ########################### dumpa/loada ############################## ## the functions below as useful to avoid data transfer, to be used as# smap.share(arr=dumpa(big_object))# and then in the workers# with monitor.shared['arr'] as arr:# big_object = loada(arr)
[docs]defdumpa(obj):""" Dump a Python object as an array of uint8: >>> dumpa(23) array([128, 5, 75, 23, 46], dtype=uint8) """buf=memoryview(pickle.dumps(obj,pickle.HIGHEST_PROTOCOL))returnnumpy.ndarray(len(buf),dtype=numpy.uint8,buffer=buf)
[docs]defloada(arr):""" Convert an array of uint8 into a Python object: >>> loada(numpy.array([128, 5, 75, 23, 46], numpy.uint8)) 23 """returnpickle.loads(bytes(arr))