"""Module for wrapping MPI functions.Most functions will do nothing if mpi4py is not present and will simplyreplicate expected behaviour. This allows for MPI and non-MPI code to berun without any changes (In theory...)."""importenumimportloggingimporttypingastfromfunctoolsimportlru_cache,wrapsimportnumpyasnpfrom.typesimportAnyValTypeT=t.TypeVar("T",bound=AnyValType)
[docs]defhas_mpi()->bool:"""Checks if mpi4py is installed Returns ------- bool: True if mpi4py is installed, False otherwise """try:importmpi4py# noqa: F401exceptImportError:returnFalsereturnTrue
[docs]defconvert_op(operation:Ops)->t.Any:"""Converts string to MPI operation."""frommpi4pyimportMPIreturngetattr(MPI,str(operation.upper()))
[docs]@lru_cache(maxsize=10)defshared_comm()->t.Any:"""Return shared memory communicator. Returns the process id within a node. Used for shared memory. """frommpi4pyimportMPIreturnMPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
[docs]@lru_cache(maxsize=10)defnprocs()->int:"""Gets number of processes or returns 1 if mpi is not installed Returns ------- int: Rank of process or 1 if MPI is not installed """try:frommpi4pyimportMPIexceptImportError:return1comm=MPI.COMM_WORLDreturncomm.Get_size()
[docs]defallgather(value:T)->t.List[T]:"""Gathers all values from all processes or returns value if mpi is not installed Parameters ---------- value: Value to gather Returns ------- list: List of gathered values or list containing value if mpi is not installed """try:frommpi4pyimportMPIexceptImportError:return[value]comm=MPI.COMM_WORLDdata=valuedata=comm.allgather(data)returndata
[docs]defallreduce(value:T,op:Ops)->T:"""Reduces all values from all processes or returns value if mpi is not installed Parameters ---------- value: Value to reduce op: :class:`~taurex.mpi.Ops` Operation to perform Returns ------- result: Reduced value or value if mpi is not installed """try:frommpi4pyimportMPIexceptImportError:returnvaluecomm=MPI.COMM_WORLDdata=valuedata=comm.allreduce(value,op=convert_op(op))returndata
[docs]defbroadcast(array:T,rank:t.Optional[int]=0)->T:"""Broadcasts array from rank or returns array if mpi is not installed Parameters ---------- array: Array to broadcast rank: int, optional Rank to broadcast from, default is 0 Returns ------- array: Broadcasted array or array if mpi is not installed """importnumpyasnptry:frommpi4pyimportMPIexceptImportError:returnarraycomm=MPI.COMM_WORLDifisinstance(array,np.ndarray):data=Noneifget_rank()==rank:data=np.copy(array)else:data=np.zeros_like(array)comm.Bcast(data,root=rank)else:data=comm.bcast(array,root=rank)returndata
[docs]@lru_cache(maxsize=10)defget_rank(comm:t.Any=None)->int:"""Gets rank or returns 0 if mpi is not installed Parameters ---------- comm: int, optional MPI communicator, default is MPI_COMM_WORLD Returns ------- int: Rank of process in communitor or 0 if MPI is not installed """try:frommpi4pyimportMPIexceptImportError:return0comm=commorMPI.COMM_WORLDrank=comm.Get_rank()returnrank
[docs]defbarrier(comm:t.Any=None)->None:"""Waits for all processes to finish. Does nothing if mpi4py not present Parameters ---------- comm: int, optional MPI communicator, default is MPI_COMM_WORLD """try:frommpi4pyimportMPIexceptImportError:returncomm=commorMPI.COMM_WORLDcomm.Barrier()
[docs]defonly_master_rank(f)->t.Callable:"""A decorator to ensure only the master MPI rank can run it."""@wraps(f)defwrapper(*args,**kwargs):ifget_rank()==0:returnf(*args,**kwargs)returnwrapper
[docs]@lru_cache(maxsize=10)defshared_rank()->int:"""Gets rank within shared memory communicator. (MPI only)"""returnshared_comm().Get_rank()
[docs]defallocate_as_shared(arr:np.ndarray,logger:t.Optional[logging.Logger]=None,force_shared:t.Optional[bool]=False,):"""Converts a numpy array into an MPI shared memory. This allow for things like opacities to be loaded only once per node when using MPI. Only activates if mpi4py installed and when enabled via the ``mpi_use_shared`` input:: [Global] mpi_use_shared = True or ``force_shared=True`` otherwise does nothing and returns the same array back Parameters ---------- arr: numpy array Array to convert logger: :class:`~taurex.log.logger.Logger` Logger object to print outputs force_shared: bool Force conversion to shared memory Returns ------- array: If enabled and MPI present, shared memory version of array otherwise the original array """try:frommpi4pyimportMPIexceptImportError:returnarrfromtaurex.cacheimportGlobalCacheifGlobalCache()["mpi_use_shared"]orforce_shared:ifloggerisnotNone:logger.info("Moving to shared memory")comm=shared_comm()nbytes=arr.size*arr.itemsizewindow=MPI.Win.Allocate_shared(nbytes,arr.itemsize,comm=comm)buf,itemsize=window.Shared_query(0)ifitemsize!=arr.itemsize:raiseException(f"Shared memory size {itemsize} != array itemsize {arr.itemsize}")shared_array=np.ndarray(buffer=buf,dtype=arr.dtype,shape=arr.shape)ifshared_rank()==0:np.copyto(shared_array,arr)comm.Barrier()returnshared_arrayelse:returnarr