# -*- coding: utf-8 -*-# vim: tabstop=4 shiftwidth=4 softtabstop=4## Copyright (C) 2010-2025 GEM Foundation## OpenQuake is free software: you can redistribute it and/or modify it# under the terms of the GNU Affero General Public License as published# by the Free Software Foundation, either version 3 of the License, or# (at your option) any later version.## OpenQuake is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the# GNU Affero General Public License for more details.## You should have received a copy of the GNU Affero General Public License# along with OpenQuake. If not, see <http://www.gnu.org/licenses/>."""\The Starmap API====================================There are several good libraries to manage parallel programming in Python, bothin the standard library and in third party packages. Since we are notinterested in reinventing the wheel, OpenQuake does not provide any newparallel library; however, it does offer some glue code so that youcan use over your library of choice. Currently threading, multiprocessing,and zmq are supported. Moreover,:mod:`openquake.baselib.parallel` offers some additional facilitiesthat make it easier to parallelize scientific computations,i.e. embarrassingly parallel problems.Typically one wants to apply a callable to a list of arguments inparallel, and then combine together the results. This is known as a`MapReduce` problem. As a simple example, we will consider the problemof counting the letters in a text, by using the following `count`function:.. code-block:: python def count(word): return collections.Counter(word)The `collections.Counter` class works sequentially, and cansolve the problem in parallel by using:class:`openquake.baselib.parallel.Starmap`:>>> arglist = [('hello',), ('world',)] # list of arguments>>> smap = Starmap(count, arglist) # Starmap instance, nothing started yet>>> sorted(smap.reduce().items()) # build the counts per letter[('d', 1), ('e', 1), ('h', 1), ('l', 3), ('o', 2), ('r', 1), ('w', 1)]A `Starmap` object is an iterable: when iterating over it producestask results. It also has a `reduce` method similar to `functools.reduce`with sensible defaults:1. the default aggregation function is `add`, so there is no need to specify it2. the default accumulator is an empty accumulation dictionary (see :class:`openquake.baselib.AccumDict`) working as a `Counter`, so there is no need to specify it.You can of course override the defaults, so if you really want toreturn a `Counter` you can do>>> res = Starmap(count, arglist).reduce(acc=collections.Counter())In the engine we use nearly always callables that return dictionariesand we aggregate nearly always with the addition operator, so suchdefaults are very convenient. You are encouraged to do the same, since wefound that approach to be very flexible. Typically in a scientificapplication you will return a dictionary of numpy arrays.The parallelization algorithm used by `Starmap` will depend on theenvironment variable `OQ_DISTRIBUTE`. Here are the possibilitiesavailable at the moment:`OQ_DISTRIBUTE` not set or set to "processpool": use multiprocessing`OQ_DISTRIBUTE` set to "no": disable the parallelization, useful for debugging`OQ_DISTRIBUTE` set tp "zmq" use the zmq concurrency mechanism (experimental)There is also an `OQ_DISTRIBUTE` = "threadpool"; however theperformance of using threads instead of processes is normally bad for thekind of applications we are interested in (CPU-dominated, which largetasks such that the time to spawn a new process is negligible withrespect to the time to perform the task), so it is not recommended.If you are using a pool, is always a good idea to cleanup resources at the endwith>>> Starmap.shutdown()`Starmap.shutdown` is always defined. It does nothing if there isno pool, but it is still better to call it: in the future, you may changeidea and use another parallelization strategy requiring cleanup. In thisway your code is future-proof.Monitoring=============================A major feature of the Starmap API is the ability to monitor the time spentin each task and the memory allocated. Such information is written into anHDF5 file that can be provided by the user or autogenerated. To autogeneratethe file you can use :func:`openquake.commonlib.datastore.create_job_dstore`which will create a file named ``calc_XXX.hdf5`` in your $OQ_DATA directory(if the environment variable is not set, the engine will use $HOME/oqdata).An associated record will be added to the job table in the databasedb.sqlite3 in OQ_DATA.Here is an example of usage:>>> from openquake.commonlib.datastore import create_job_dstore>>> log, h5 = create_job_dstore()>>> smap = Starmap(count, [['hello'], ['world']], h5=h5)>>> print(sorted(smap.reduce().items()))[('d', 1), ('e', 1), ('h', 1), ('l', 3), ('o', 2), ('r', 1), ('w', 1)]After the calculation, or even while the calculation is running, you canopen the calculation file for reading and extract the performance informationfor it. The engine provides a command to do that, `oq show performance`,but you can also get it manually, with a call to`openquake.baselib.performance.performance_view(h5)` which will returnthe performance information as a numpy array:>>> from openquake.baselib.performance import performance_view>>> performance_view(h5).dtype.names[1:]('time_sec', 'memory_mb', 'counts')>>> h5.close()The four columns are as follows:operation: the name of the function running in parallel (in this case 'count')time_sec: the cumulative time in second spent running the functionmemory_mb: the maximum allocated memory per corecounts: the number of times the function was called (in this case 2)The Starmap.apply API====================================The `Starmap` class has a very convenient classmethod `Starmap.apply`which is used in several places in the engine. `Starmap.apply` is usefulwhen you have a sequence of objects that you want to split in homogenous chunksand then apply a callable to each chunk (in parallel). For instance, in theletter counting example discussed before, `Starmap.apply` couldbe used as follows:>>> text = 'helloworld' # sequence of characters>>> res3 = Starmap.apply(count, (text,)).reduce()>>> assert res3 == resThe API of `Starmap.apply` is designed to extend the one of `apply`,a builtin of Python 2; the second argument is the tuple of argumentspassed to the first argument. The difference with `apply` is that`Starmap.apply` returns a :class:`Starmap` object so that nothing isactually done until you iterate on it (`reduce` is doing that).How many chunks will be produced? That depends on the parameter`concurrent_tasks`; it it is not passed, it has a default of 5 timesthe number of cores in your machine - as returned by `os.cpu_count()` -and `Starmap.apply` will try to produce a number of chunks close tothat number. The nice thing is that it is also possible to pass a`weight` function. Suppose for instance that instead of a list ofletters you have a list of seismic sources: some sources requires along computation time (such as `ComplexFaultSources`), some requires ashort computation time (such as `PointSources`). By giving an heuristicweight to the different sources it is possible to produce chunks withnearly homogeneous weight; in particular `PointSource` tasks willcontain a lot more sources than tasks with `ComplexFaultSources`.It is *essential* in large computations to have a homogeneous taskdistribution, otherwise you will end up having a big task dominatingthe computation time (i.e. you may have 1000 cores of which 999 are free,having finished all the short tasks, but you have to wait for days forthe single core processing the slow task). The OpenQuake engine doesa great deal of work trying to split slow sources in more manageablefast sources."""importosimportreimportastimportsysimporttimeimportsocketimportsignalimportpickleimportgetpassimportinspectimportloggingimportoperatorimporttempfileimporttracebackimportcollectionsfromunittestimportmockimportmultiprocessing.dummyfrommultiprocessing.connectionimportwaitimportmultiprocessing.shared_memoryasshmemimportpsutilimportnumpyfromopenquake.baselibimportconfig,hdf5fromopenquake.baselib.python3compatimportdecodefromopenquake.baselib.zeromqimportzmq,Socketfromopenquake.baselib.performanceimport(Monitor,memory_gb,init_performance)fromopenquake.baselib.generalimport(split_in_blocks,block_splitter,AccumDict,humansize,CallableDict,gettemp,engine_version,shortlist,compress,decompress,mpasmp_context)sys.setrecursionlimit(2000)# raised to make pickle happier# see https://github.com/gem/oq-engine/issues/5230submit=CallableDict()MB=1024**2GB=1024**3host_cores=config.zworkers.host_cores.split(',')
[docs]defscratch_dir(job_id):""" :returns: scratch directory associated to the given job_id """tmp=config.directory.custom_tmportempfile.gettempdir()dirname=os.path.join(tmp,getpass.getuser(),f'calc_{job_id}')try:os.makedirs(dirname)exceptFileExistsError:# already createdpassreturndirname
[docs]@submit.add('zmq','slurm')defzmq_submit(self,func,args,monitor):idx=self.task_no%len(host_cores)host=host_cores[idx].split()[0]port=int(config.zworkers.ctrl_port)dest='tcp://%s:%d'%(host,port)logging.debug('Sending to %s',dest)withSocket(dest,zmq.REQ,'connect',timeout=300)assock:sub=sock.send((func,args,self.task_no,monitor))assertsub=='submitted',sub
[docs]defoq_distribute(task=None):""" :returns: the value of OQ_DISTRIBUTE or config.distribution.oq_distribute """dist=os.environ.get('OQ_DISTRIBUTE',config.distribution.oq_distribute)ifdistnotin('no','processpool','threadpool','zmq','slurm'):raiseValueError('Invalid oq_distribute=%s'%dist)returndist
[docs]definit_workers():"""Used to initialize the process pool"""try:fromsetproctitleimportsetproctitleexceptImportError:passelse:setproctitle('oq-worker')
[docs]classPickled(object):""" An utility to manually pickling/unpickling objects. Pickled instances have a nice string representation and length giving the size of the pickled bytestring. :param obj: the object to pickle """compressed=Falsedef__init__(self,obj):self.clsname=obj.__class__.__name__self.calc_id=str(getattr(obj,'calc_id',''))# for monitorstry:self.pik=pickle.dumps(obj,pickle.HIGHEST_PROTOCOL)exceptTypeErrorasexc:# can't pickle, show the obj in the messageraiseTypeError('%s: %s'%(exc,obj))self.compressed=len(self.pik)>MBandconfig.distribution.compressifself.compressed:self.pik=compress(self.pik)def__repr__(self):"""String representation of the pickled object"""return'<Pickled %s #%s%s>'%(self.clsname,self.calc_id,humansize(len(self)))def__len__(self):"""Length of the pickled bytestring"""returnlen(self.pik)
[docs]defunpickle(self):"""Unpickle the underlying object"""pik=decompress(self.pik)ifself.compressedelseself.pikreturnpickle.loads(pik)
[docs]defget_pickled_sizes(obj):""" Return the pickled sizes of an object and its direct attributes, ordered by decreasing size. Here is an example: >> total_size, partial_sizes = get_pickled_sizes(Monitor('')) >> total_size 345 >> partial_sizes [('_procs', 214), ('exc', 4), ('mem', 4), ('start_time', 4), ('_start_time', 4), ('duration', 4)] Notice that the sizes depend on the operating system and the machine. """sizes=[]attrs=getattr(obj,'__dict__',{})forname,valueinattrs.items():sizes.append((name,len(Pickled(value))))returnlen(Pickled(obj)),sorted(sizes,key=lambdapair:pair[1],reverse=True)
[docs]defpickle_sequence(objects):""" Convert an iterable of objects into a list of pickled objects. If the iterable contains copies, the pickling will be done only once. If the iterable contains objects already pickled, they will not be pickled again. :param objects: a sequence of objects to pickle """cache={}out=[]forobjinobjects:obj_id=id(obj)ifobj_idnotincache:ifisinstance(obj,Pickled):# already pickledcache[obj_id]=objelse:# pickle the objectcache[obj_id]=Pickled(obj)out.append(cache[obj_id])returnout
[docs]classResult(object):""" :param val: value to return or exception instance :param mon: Monitor instance :param tb_str: traceback string (empty if there was no exception) :param msg: message string (default empty) """func=Nonedef__init__(self,val,mon,tb_str='',msg=''):ifisinstance(val,dict):self.pik=Pickled(val)self.nbytes={k:len(Pickled(v))fork,vinval.items()}elifisinstance(val,tuple)andcallable(val[0]):self.func=val[0]self.pik=pickle_sequence(val[1:])self.nbytes={'args':sum(len(p)forpinself.pik)}elifmsg=='TASK_ENDED':self.pik=Pickled(None)self.nbytes={}else:self.pik=Pickled(val)self.nbytes={'tot':len(self.pik)}self.mon=monself.tb_str=tb_strself.msg=msgself.workerid=(socket.gethostname(),os.getpid())
[docs]defget(self):""" Returns the underlying value or raise the underlying exception """t0=time.time()val=self.pik.unpickle()self.dt=time.time()-t0ifself.tb_str:etype=val.__class__msg='\n%s%s: %s'%(self.tb_str,etype.__name__,val)ifissubclass(etype,KeyError):raiseRuntimeError(msg)# nicer messageelse:raiseetype(msg)returnval
[docs]@classmethoddefnew(cls,func,args,mon,sentbytes=0):""" :returns: a new Result instance """try:ifmon.versionandmon.version!=engine_version():raiseRuntimeError('The master is at version %s while the worker %s is at ''version %s'%(mon.version,socket.gethostname(),engine_version()))ifmon.dbserver_host!=config.dbserver.host:raiseRuntimeError('The worker has dbserver.host=%s while the master has %s'%(mon.dbserver_host,config.dbserver.host))withmon:val=func(*args)exceptStopIteration:mon.counts-=1# StopIteration does not countres=Result(None,mon,msg='TASK_ENDED')res.pik=FakePickle(sentbytes)exceptException:_etype,exc,tb=sys.exc_info()res=Result(exc,mon,''.join(traceback.format_tb(tb)))else:res=Result(val,mon)returnres
[docs]defcheck_mem_usage(soft_percent=None,hard_percent=None):""" Display a warning if we are running out of memory """soft_percent=soft_percentorconfig.memory.soft_mem_limithard_percent=hard_percentorconfig.memory.hard_mem_limitused_mem_percent=psutil.virtual_memory().percentifused_mem_percent>hard_percent:raiseMemoryError('Using more memory than allowed by configuration ''(Used: %d%% / Allowed: %d%%)! Shutting down.'%(used_mem_percent,hard_percent))elifused_mem_percent>soft_percent:msg='Using over %d%% of the memory in %s!'returnmsg%(used_mem_percent,socket.gethostname())
[docs]defsendback(res,zsocket):""" Send back to the master node the result by using the zsocket. :returns: the accumulated number of bytes sent """calc_id=res.mon.calc_idtask_no=res.mon.task_nonbytes=len(res.pik)try:zsocket.send(res)ifDEBUG:fromopenquake.commonlib.logsimportdblogifcalc_id:# None when building the png mapsmsg='sent back %s'%humansize(nbytes)dblog('DEBUG',calc_id,task_no,msg)exceptException:# like OverflowError_etype,exc,tb=sys.exc_info()tb_str=''.join(traceback.format_tb(tb))ifDEBUGandcalc_id:dblog('ERROR',calc_id,task_no,tb_str)res=Result(exc,res.mon,tb_str)zsocket.send(res)returnnbytes
[docs]defsafely_call(func,args,task_no=0,mon=dummy_mon):""" Call the given function with the given arguments safely, i.e. by trapping the exceptions. Return a pair (result, exc_type) where exc_type is None if no exceptions occur, otherwise it is the exception class and the result is a string containing error message and traceback. :param func: the function to call :param args: the arguments :param task_no: the task number :param mon: a monitor """isgenfunc=inspect.isgeneratorfunction(func)ifhasattr(args[0],'unpickle'):# args is a list of Pickled objectsargs=[a.unpickle()forainargs]ifmonisdummy_mon:# in the DbServerassertnotisgenfunc,funcreturnResult.new(func,args,mon)# debug(f'{mon.backurl=}, {task_no=}')ifmon.operation.endswith('_'):name=mon.operation[:-1]eliffuncissplit_task:name=args[1].__name__else:name=func.__name__mon=mon.new(operation='total '+name,measuremem=True)mon.weight=getattr(args[0],'weight',1.)# used in task_infomon.task_no=task_noifmon.inject:args+=(mon,)sentbytes=0ifisgenfunc:withSocket(mon.backurl,zmq.PUSH,'connect')aszsocket:it=func(*args)whileTrue:res=Result.new(next,(it,),mon,sentbytes)# StopIteration -> TASK_ENDEDifres.msg=='TASK_ENDED':zsocket.send(res)breaksentbytes+=sendback(res,zsocket)else:res=Result.new(func,args,mon)# send back a single result and a TASK_ENDEDwithSocket(mon.backurl,zmq.PUSH,'connect')aszsocket:sentbytes+=sendback(res,zsocket)end=Result(None,mon,msg='TASK_ENDED')end.pik=FakePickle(sentbytes)zsocket.send(end)
[docs]classIterResult(object):""" :param iresults: an iterator over Result objects :param taskname: the name of the task :param done_total: a function returning the number of done tasks and the total :param sent: a nested dictionary name -> {argname: number of bytes sent} :param progress: a logging function for the progress report :param hdf5path: a path where to store persistently the performance info """def__init__(self,iresults,taskname,argnames,sent,h5):self.iresults=iresultsself.name=tasknameself.argnames=' '.join(argnames)self.sent=sentself.h5=h5def_iter(self):first_time=Trueself.counts=0self.dt=0forresultinself.iresults:msg=check_mem_usage()# log a warning if too much memory is usedifmsgandfirst_time:logging.warning(msg)first_time=False# warn only onceself.nbytes+=result.nbytesself.counts+=1out=result.get()self.dt+=result.dtyieldoutdef__iter__(self):ifself.iresults==():return()t0=time.time()self.nbytes=AccumDict()try:yield fromself._iter()finally:items=sorted(self.nbytes.items(),key=operator.itemgetter(1))nb={k:humansize(v)fork,vinlist(reversed(items))[:3]}recv=sum(self.nbytes.values())mean=recv/(self.countsor1)pu='[unpik=%.2fs]'%self.dtlogging.info('Received %d * %s in %d seconds %s from %s\n%s',self.counts,humansize(mean),time.time()-t0,pu,self.name,nb)
[docs]@classmethoddefsum(cls,iresults):""" Sum the data transfer information of a set of results """res=object.__new__(cls)res.sent=0foriresultiniresults:res.sent+=iresult.sentname=iresult.name.split('#',1)[0]ifhasattr(res,'name'):assertres.name.split('#',1)[0]==name,(res.name,name)else:res.name=iresult.name.split('#')[0]returnres
[docs]defgetargnames(task_func):# a task can be a function, a method, a class or a callable instanceifinspect.isfunction(task_func):returninspect.getfullargspec(task_func).argselifinspect.ismethod(task_func):returninspect.getfullargspec(task_func).args[1:]elifinspect.isclass(task_func):returninspect.getfullargspec(task_func.__init__).args[1:]else:# instance with a __call__ methodreturninspect.getfullargspec(task_func.__call__).args[1:]
def__init__(self,shape,dtype,value):nbytes=numpy.zeros(1,dtype).nbytes*numpy.prod(shape)# NOTE: on Windows size wants an int an not a numpy.intself.sm=shmem.SharedMemory(create=True,size=int(nbytes))self.shape=shapeself.dtype=dtype# fill the SharedMemory buffer with the valuearr=numpy.ndarray(shape,dtype,buffer=self.sm.buf)arr[:]=valuedef__enter__(self):# this is called in the workersself._sm=shmem.SharedMemory(name=self.sm.name)returnnumpy.ndarray(self.shape,self.dtype,buffer=self._sm.buf)def__exit__(self,etype,exc,tb):# this is called in the workersself._sm.close()
# determine the number of cores to usecpu_count=psutil.cpu_count()ifsys.platform=='win32':# assume hyperthreading is on; use half the threads to save memorytot_cores=cpu_count//2or1elifsys.platform=='linux':# use only the "visible" cores, not the total system cores# if the underlying OS supports it (macOS does not)tot_cores=len(psutil.Process().cpu_affinity())else:tot_cores=cpu_count
[docs]classStarmap(object):pids=()running_tasks=[]# currently running tasksmaxtasksperchild=None# with 1 it hangs on the EUR calculation!num_cores=int(config.distribution.get('num_cores','0'))ortot_coresCT=num_cores*2expected_outputs=0# unknown
[docs]@classmethoddefinit(cls,distribute=None):cls.distribute=distributeoroq_distribute()ifcls.distribute=='processpool'andnothasattr(cls,'pool'):# unregister custom handlers before starting the processpoolterm_handler=signal.signal(signal.SIGTERM,signal.SIG_DFL)int_handler=signal.signal(signal.SIGINT,signal.SIG_IGN)# we use spawn here to avoid deadlocks with logging, see# https://github.com/gem/oq-engine/pull/3923 and# https://codewithoutrules.com/2018/09/04/python-multiprocessing/cls.pool=mp_context.Pool(cls.num_coresifcls.num_cores<=tot_coreselsetot_cores,init_workers,maxtasksperchild=cls.maxtasksperchild)cls.pids=[proc.pidforprocincls.pool._pool]# after spawning the processes restore the original handlers# i.e. the ones defined in openquake.engine.enginesignal.signal(signal.SIGTERM,term_handler)signal.signal(signal.SIGINT,int_handler)elifcls.distribute=='threadpool'andnothasattr(cls,'pool'):cls.pool=multiprocessing.dummy.Pool(cls.num_cores)
[docs]@classmethoddefshutdown(cls):# shutting down the pool during the runtime causes mysterious# race conditions with errors inside atexit._run_exitfuncsifhasattr(cls,'pool'):cls.pool.close()cls.pool.terminate()cls.pool.join()delcls.poolcls.pids=[]elifhasattr(cls,'executor'):cls.executor.shutdown()
[docs]@classmethoddefapply(cls,task,allargs,concurrent_tasks=None,maxweight=None,weight=lambdaitem:1,key=lambdaitem:'Unspecified',distribute=None,progress=logging.info,h5=None):r""" Apply a task to a tuple of the form (sequence, \*other_args) by first splitting the sequence in chunks, according to the weight of the elements and possibly to a key (see :func: `openquake.baselib.general.split_in_blocks`). :param task: a task to run in parallel :param args: the arguments to be passed to the task function :param concurrent_tasks: hint about how many tasks to generate :param maxweight: if not None, used to split the tasks :param weight: function to extract the weight of an item in arg0 :param key: function to extract the kind of an item in arg0 :param distribute: if not given, inferred from OQ_DISTRIBUTE :param progress: logging function to use (default logging.info) :param h5: an open hdf5.File where to store the performance info :returns: an :class:`IterResult` object """arg0,*args=allargsifmaxweight:# block_splitter is lazytaskargs=([blk]+argsforblkinblock_splitter(arg0,maxweight,weight,key))else:# split_in_blocks is eagerifconcurrent_tasksisNone:concurrent_tasks=cls.CTtaskargs=[[blk]+argsforblkinsplit_in_blocks(arg0,concurrent_tasksor1,weight,key)]returncls(task,taskargs,distribute,progress,h5)
[docs]defapply_split(cls,task,allargs,concurrent_tasks=None,maxweight=None,weight=lambdaitem:1,key=lambdaitem:'Unspecified',distribute=None,progress=logging.info,h5=None,duration=300,outs_per_task=5):""" Same as Starmap.apply, but possibly produces subtasks """args=(allargs[0],task,allargs[1:],duration,outs_per_task)returncls.apply(split_task,args,concurrent_tasksor2*cls.num_cores,maxweight,weight,key,distribute,progress,h5)
def__init__(self,task_func,task_args=(),distribute=None,progress=logging.info,h5=None):self.__class__.init(distribute=distribute)self.task_func=task_funcifh5:match=re.search(r'(\d+)',os.path.basename(h5.filename))self.calc_id=int(match.group(1))else:# TODO: see if we can forbid this caseself.calc_id=Noneh5=hdf5.File(gettemp(suffix='.hdf5'),'w')init_performance(h5)iftask_funcissplit_task:self.name=task_args[0][1].__name__else:self.name=task_func.__name__self.monitor=Monitor(self.name,dbserver_host=config.dbserver.host)self.monitor.filename=h5.filenameself.monitor.calc_id=self.calc_idself.task_args=task_argsself.progress=progressself.h5=h5self.task_queue=[]try:self.num_tasks=len(self.task_args)exceptTypeError:# generators have no lenself.num_tasks=Noneself.argnames=getargnames(task_func)self.sent=AccumDict(accum=AccumDict())# fname -> argname -> nbytesself.monitor.inject=(self.argnames[-1].startswith('mon')orself.argnames[-1].endswith('mon'))self.receiver='tcp://0.0.0.0:%s'%config.dbserver.receiver_portsifself.distributein('no','processpool')orsys.platform!='linux':self.return_ip='127.0.0.1'# zmq returns data to localhostelse:# zmq returns data to the receiver_hostself.return_ip=get_return_ip(config.dbserver.receiver_host)logging.debug(f'{self.return_ip=}')self.monitor.backurl=None# overridden laterself.tasks=[]# populated by .submitself.task_no=0self._shared={}self.n_out=0
[docs]deflog_percent(self):""" Log the progress of the computation in percentage """done=self.task_no-len(self.tasks)ifnothasattr(self,'prev_percent'):# first timeself.prev_percent=0ifself.expected_outputs:percent=int(self.n_out/self.expected_outputs*100)else:percent=int(done/self.task_no*100)ifpercent>self.prev_percent:queued=len(self.task_queue)self.progress('%s%3d%% [%d submitted, %d queued]',self.name,percent,self.task_no,queued)self.prev_percent=percentassertpercent<=100,percent# sanity checkreturndone
[docs]definit_slurm(self):""" Initialize the list host_cores by reading the file with the hostcores generated by the worker nodes """scr=scratch_dir(self.monitor.calc_id)withopen(os.path.join(scr,'hostcores'))asf:host_cores[:]=[ln.strip()forlninf.readlines()]
[docs]defsubmit(self,args,func=None):""" Submit the given arguments to the underlying task """func=funcorself.task_funcifnothasattr(self,'socket'):# setup the PULL socket the first timeself.__class__.running_tasks=self.tasksself.socket=Socket(self.receiver,zmq.PULL,'bind').__enter__()self.monitor.shared=self._sharedself.monitor.backurl='tcp://%s:%s'%(self.return_ip,self.socket.port)ifself.distribute=='slurm':self.init_slurm()OQ_TASK_NO=os.environ.get('OQ_TASK_NO')ifOQ_TASK_NOisnotNoneandself.task_no!=int(OQ_TASK_NO):self.task_no+=1returndist='no'ifself.num_tasks==1orOQ_TASK_NOelseself.distributeifdist!='no':pickled=isinstance(args[0],Pickled)ifnotpickled:assertnotisinstance(args[-1],Monitor)# sanity checkargs=pickle_sequence(args)iffuncisNone:fname=self.task_func.__name__argnames=self.argnames[:-1]else:fname=func.__name__argnames=getargnames(func)[:-1]self.sent[fname]+={a:len(p)fora,pinzip(argnames,args)}submit[dist](self,func,args,self.monitor)self.tasks.append(self.task_no)self.task_no+=1
[docs]defsubmit_split(self,args,duration,outs_per_task):""" Submit the given arguments to the underlying task """self.monitor.operation=self.task_func.__name__+'_'self.submit((args[0],self.task_func,args[1:],duration,outs_per_task),split_task)
[docs]defsubmit_all(self):""" :returns: an IterResult object """ifself.num_tasksisNone:# loop on the iteratorforargsinself.task_args:self.submit(args)else:# build a task queue in advanceself.task_queue=[(self.task_func,args)forargsinself.task_args]dist='no'ifself.num_tasks==1elseself.distributeifdist=='slurm':# submit the tasks via zmqforfunc,argsinself.task_queue:self.submit(args,func=func)self.task_queue.clear()returnself.get_results()
[docs]defget_results(self):""" :returns: an :class:`IterResult` instance """returnIterResult(self._loop(),self.name,self.argnames,self.sent,self.h5)
[docs]defreduce(self,agg=operator.add,acc=None):""" Submit all tasks and reduce the results """returnself.submit_all().reduce(agg,acc)
def__iter__(self):returniter(self.submit_all())def_submit_many(self,howmany):for_inrange(howmany):ifself.task_queue:# remove in LIFO orderfunc,args=self.task_queue[0]delself.task_queue[0]self.submit(args,func=func)# NB: the shared dictionary will be attached to the monitor# and used in the workers; to see an example of usage, look at# the event_based calculator
[docs]defshare(self,**dictarray):""" Apply SharedArray.new to a dictionary of arrays """self._shared={k:SharedArray.new(a)fork,aindictarray.items()}
[docs]defunlink(self):""" Unlink the shared arrays, if any """forname,shrinself._shared.items():logging.debug('Unlinking %s',name)shr.unlink()
def_loop(self):self.busytime=AccumDict(accum=[])# pid -> timedist='no'ifself.num_tasks==1elseself.distributeifdist=='slurm':self.monitor.task_no=self.task_no# total number of tasks#sbatch(self.monitor)elifself.task_queue:first_args=self.task_queue[:self.CT]self.task_queue[:]=self.task_queue[self.CT:]forfunc,argsinfirst_args:self.submit(args,func=func)ifnothasattr(self,'socket'):# no submit was ever madereturn()nbytes=sum(self.sent[self.task_func.__name__].values())logging.warning('Sent %d%s tasks, %s',len(self.tasks),self.name,humansize(nbytes))isocket=iter(self.socket)# read from the PULL socketfinished=set()whileself.tasks:res=next(isocket)self.log_percent()ifself.calc_id!=res.mon.calc_id:logging.warning('Discarding a result from job %s, since this ''is job %s',res.mon.calc_id,self.calc_id)elifres.msg=='TASK_ENDED':finished.add(res.mon.task_no)self.busytime+={res.workerid:res.mon.duration}self.tasks.remove(res.mon.task_no)self._submit_many(1)todo=set(range(self.task_no))-finishedlogging.debug('%d tasks todo %s',len(todo),shortlist(sorted(todo)))task_sent=ast.literal_eval(decode(self.h5['task_sent'][()]))task_sent.update(self.sent)delself.h5['task_sent']self.h5['task_sent']=str(task_sent)name=res.mon.operation[6:]# strip 'total 'n=self.name+':'+nameifname=='split_task'elsenameifself.distributein('zmq','slurm'):mem_gb=0ifres.mon.task_no%10==0:# measure the memory only for 1 task out of 10, to be fast# with 8 nodes the time to get the memory is 0.01 secsforlineinhost_cores:host,_cores=line.split()addr='tcp://%s:%s'%(host,config.zworkers.ctrl_port)withSocket(addr,zmq.REQ,'connect')assock:mem_gb+=sock.send('memory_gb')elifself._shared:# do not measure the memory on the workers, only in the master# otherwise memory_rss would double count the shared memorymem_gb=memory_gb()else:mem_gb=memory_gb(Starmap.pids)res.mon.save_task_info(self.h5,res,n,mem_gb)res.mon.flush(self.h5)elifres.func:# add subtaskself.task_queue.append((res.func,res.pik))self._submit_many(1)else:self.n_out+=1yieldresself.log_percent()self.socket.__exit__(None,None,None)self.tasks.clear()self.unlink()iflen(self.busytime)>1:times=numpy.array(list(self.busytime.values()))logging.info('Mean time per core=%ds, std=%.1fs, min=%ds, max=%ds',times.mean(),times.std(),times.min(),times.max())
[docs]defsequential_apply(task,args,concurrent_tasks=Starmap.CT,maxweight=None,weight=lambdaitem:1,key=lambdaitem:'Unspecified',progress=logging.info):""" Apply sequentially task to args by splitting args[0] in blocks """withmock.patch.dict('os.environ',{'OQ_DISTRIBUTE':'no'}):returnStarmap.apply(task,args,concurrent_tasks,maxweight,weight,key,progress=progress)
[docs]defcount(word):""" Used as example in the documentation """returncollections.Counter(word)
[docs]defsplit_task(elements,func,args,duration,outs_per_task,monitor):""" :param func: a task function with a monitor as last argument :param args: arguments of the task function, with args[0] being a sequence :param duration: split the task if it exceeds the duration :param outs_per_task: number of splits to try (ex. 5) :yields: a partial result, 0 or more task objects """n=len(elements)ifouts_per_task>n:# too many splitsouts_per_task=nelements=numpy.array(elements)# from WeightedSequence to arrayidxs=numpy.arange(n)split_elems=[elements[idxs%outs_per_task==i]foriinrange(outs_per_task)]# see how long it takes to run the first slicet0=time.time()fori,elemsinenumerate(split_elems):monitor.out_no=monitor.task_no+i*65536res=func(elems,*args,monitor=monitor)dt=time.time()-t0yieldresifdt>duration:# spawn subtasks for the rest and exit, used in classical/case_14forelsinsplit_elems[i+1:]:ls=List(els)ls.weight=sum(getattr(el,'weight',1.)forelinels)yield(func,ls)+argsbreak
[docs]deflogfinish(n,tot):logging.info('Finished %d of %d jobs',n,tot)returnn+1
[docs]defmultispawn(func,allargs,nprocs=Starmap.num_cores,logfinish=True):""" Spawn processes with the given arguments """ifoq_distribute()=='no':forargsinallargs:func(*args)returntot=len(allargs)allargs=allargs[::-1]# so that the first argument is submitted firstprocs={}# sentinel -> processn=1whileallargs:args=allargs.pop()proc=mp_context.Process(target=func,args=args)proc.start()procs[proc.sentinel]=procwhilelen(procs)>=nprocs:# wait for something to finishforfinishedinwait(procs):procs[finished].join()delprocs[finished]iflogfinish:logging.info('Finished %d of %d jobs',n,tot)n+=1whileprocs:forfinishedinwait(procs):procs[finished].join()delprocs[finished]iflogfinish:logging.info('Finished %d of %d jobs',n,tot)n+=1