Skip to content

Commit

Permalink
Added comm to mpisubset
Browse files Browse the repository at this point in the history
  • Loading branch information
lsawade committed Jan 25, 2024
1 parent eacd594 commit 82a08ab
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions src/gf3d/mpi_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ class MPISubset(object):

adjacency: np.ndarray | None = None

def __init__(self, subset) -> None:
def __init__(self, subset, comm=None, verbose=True) -> None:
"""Initializes the gfm manager"""

import mpi4py.MPI as MPI
self.comm = MPI.COMM_WORLD
if comm is not None:
self.comm = comm
else:
import mpi4py.MPI as MPI
self.comm = MPI.COMM_WORLD

self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()

Expand All @@ -91,6 +95,9 @@ def __init__(self, subset) -> None:
self.subset = True
self.headerfile = self.db

# Verbosity
self.verbose = verbose

# Load header
self.load_bcast_header()

Expand All @@ -99,6 +106,9 @@ def load_bcast_header(self):

if self.rank == 0:

if self.verbose:
print("MPISubset: --> reading header", flush=True)

self.header = dict()

with h5py.File(self.headerfile, 'r') as db:
Expand Down Expand Up @@ -167,6 +177,10 @@ def load_bcast_header(self):


# Broadcast all header variables
if self.rank == 0 and self.verbose:
print("MPISubset: --> broadcasting header", flush=True)
t0 = time()

self.header = self.comm.bcast(self.header, root=0)
self.NGLL = self.comm.bcast(self.NGLL, root=0)
self.networks = self.comm.bcast(self.networks, root=0)
Expand Down Expand Up @@ -205,9 +219,13 @@ def load_bcast_header(self):
self.header['ellipticity_spline'] = self.comm.bcast(self.header['ellipticity_spline'], root = 0)
self.header['ellipticity_spline2'] = self.comm.bcast(self.header['ellipticity_spline2'], root = 0)

if self.rank == 0 and self.verbose:
print("MPISubset: --> header broadcasted in", time()-t0, "seconds", flush=True)


def get_seismograms(self, cmt: CMTSOLUTION) -> np.ndarray:

t0 = time()
# Get moment tensor
x_target, y_target, z_target, Mx = source2xyz(
cmt.latitude, cmt.longitude, cmt.depth, M=cmt.tensor,
Expand Down Expand Up @@ -256,10 +274,20 @@ def get_seismograms(self, cmt: CMTSOLUTION) -> np.ndarray:
indeces = np.arange(len(iglobf))

# Hello
if self.verbose:
t0_read = time()
print(f"MPI Subset[{self.rank}]: --> Start reading from HDF5 subset", flush=True)

# Read displacement
with h5py.File(self.headerfile, 'r') as db:
# Get displacement for interpolation
displacement = db['displacement'][:, :, :, iglobf[sglobf], :]

if self.verbose:
t1_read = time()
print(f"MPI Subset[{self.rank}]: --> reading took {t1_read-t0_read} seconds.", flush=True)


# Loaded displacement is in weird order so we have to redo it
displacement = displacement[:, :, :, indeces[rsglob], :]

Expand Down Expand Up @@ -355,6 +383,11 @@ def get_seismograms(self, cmt: CMTSOLUTION) -> np.ndarray:
* phshift[None, None, :]
)[:, :, :self.header['nsteps']]) * self.header['dt']

t1 = time()

if self.verbose:
print(f"MPI Subset[{self.rank}]: --> Seismograms computed in {t1-t0} seconds", flush=True)

return data

def get_stream(self, cmt, data):
Expand Down

0 comments on commit 82a08ab

Please sign in to comment.