Skip to content

Commit

Permalink
Refactor swift_enhance script
Browse files Browse the repository at this point in the history
  • Loading branch information
brenjohn committed Jan 23, 2025
1 parent aea5174 commit 19c9d25
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 118 deletions.
18 changes: 18 additions & 0 deletions dmsr/dmsr_gan/dmsr_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ def build_generator_components(self):

self.output_size = N - 2 * self.crop_size
self.noise_shapes = noise_shapes


def compute_input_padding(self):
"""Computes the sizes of the input inner region and padding.
The low resolution input to the generator is thought of as being made
up of two regions: an inner region to be upscaled and an outer region
of padding. The output of the generator is thought of as an upscaled
version of the inner region. This method computes the sizes of these
regions.
"""
if self.output_size % self.scale_factor != 0:
print('WARNING: inner region of generator input not an integer')
self.inner_region = self.output_size // self.scale_factor

if (self.grid_size - self.inner_region) % 2 != 0:
print('WARNING: padding of generator input not an integer')
self.padding = (self.grid_size - self.inner_region) // 2


def forward(self, x, z):
Expand Down
2 changes: 1 addition & 1 deletion dmsr/field_operations/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def cut_field(fields, cut_size, stride=0, pad=0):
field and n is the grid size of each subfield (ie cut_size + 2 * pad).
"""
grid_size = fields.shape[-1]
cuts = []
if not stride:
stride = cut_size

cuts = []
for i in range(0, grid_size, stride):
for j in range(0, grid_size, stride):
for k in range(0, grid_size, stride):
Expand Down
128 changes: 12 additions & 116 deletions scripts/swift_upscaling/swift_enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,137 +4,33 @@
Created on Fri Oct 18 14:31:25 2024
@author: john
This script uses a specified dmsr generator to enhance the dark-matter data in
a low-resolution swift snapshot.
"""

import sys
sys.path.append("..")
sys.path.append("../..")

import os
import time
import torch
import shutil
import h5py as h5
import numpy as np

from dmsr.swift_processing import get_displacement_field, get_positions
from dmsr.field_operations.resize import cut_field, stitch_fields
from swift_tools.enhance import enhance

# Check if CUDA is available and set the device
gpu_id = 0
# device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
device = "cpu"
print(f"Using device: {device}")
ti = time.time()


#%% Load the generator model
dmsr_model_dir = './dmsr_model/'
# Load the generator model
dmsr_model_dir = './trained_model_levels/level_2/current_model/'
generator = torch.load(dmsr_model_dir + 'generator.pth').to(device)

input_grid_size = generator.grid_size
scale_factor = generator.scale_factor


#%%
# Specify paths to low-resolution snapshot and where to save enhanced snapshot.
data_dir = './swift_snapshots/'
lr_snapshot = data_dir + '064/snap_0002.hdf5'
sr_snapshot = lr_snapshot.replace('.hdf5', '_sr.hdf5')

if os.path.exists(sr_snapshot):
os.remove(sr_snapshot)
shutil.copy(lr_snapshot, sr_snapshot)

# with h5.File(sr_snapshot, 'a') as sr_file:
sr_file = h5.File(sr_snapshot, 'a')

dm_data = sr_file['DMParticles']


# Update particle mass
old_mass = np.asarray(dm_data['Masses'])
new_mass = old_mass / scale_factor**3
new_mass = np.tile(new_mass, scale_factor**3)

del dm_data['Masses']
dm_data.create_dataset('Masses', data=new_mass)


# Update particle velocities
new_velocities = np.zeros_like(dm_data['Velocities'])
new_velocities = np.tile(new_velocities, (scale_factor**3, 1))

del dm_data['Velocities']
dm_data.create_dataset('Velocities', data=new_velocities)


# Update potentials
new_potentials = np.zeros_like(dm_data['Potentials'])
new_potentials = np.tile(new_potentials, scale_factor**3)

del dm_data['Potentials']
dm_data.create_dataset('Potentials', data=new_potentials)


# Update softenings
# TODO: Check correctness of this update rule
old_soft = np.asarray(dm_data['Softenings'])
new_soft = old_soft / scale_factor
new_soft = np.tile(new_soft, scale_factor**3)

del dm_data['Softenings']
dm_data.create_dataset('Softenings', data=new_soft)

# TODO: some code in swift_processing could probably be reused here with some
# refactoring.
# Update particle coordinates and IDs
grid_size = sr_file['ICs_parameters'].attrs['Grid Resolution']
box_size = sr_file['Header'].attrs['BoxSize'][0]
ids = np.asarray(dm_data['ParticleIDs'])
positions = np.asarray(dm_data['Coordinates'])
positions = positions.transpose()

displacements = get_displacement_field(positions, ids, box_size, grid_size)

# TODO: move the padding attribute from the dmr gan to the generator.
# TODO: refactor cut_fields to have the option to return patches that cover the
# given field. At the moment if the cut_size doesn't evenly divide into the
# size of the field then some patches near the boundary are missing.
# TODO: A function for recombining patches into a full field would be useful.
cut_size = 16
stride = 16
pad = 2
field_patches = cut_field(displacements[None, ...], cut_size, stride, pad)


#%%
z = generator.sample_latent_space(1, device)

sr_patches = []
crop = 2 # TODO: this parameter should probably be a generator attribute

for patch in field_patches:
patch = torch.from_numpy(patch).to(torch.float)
sr_patch = generator(patch[None, ...], z)
# sr_patch = sr_patch[:, :, crop:-crop, crop:-crop, crop:-crop].detach()
sr_patch = sr_patch.detach()
sr_patches.append(sr_patch.numpy())


#%%
sr_grid_size = scale_factor * grid_size
displacement_field = stitch_fields(sr_patches, 4)
sr_positions = get_positions(displacement_field, box_size, sr_grid_size)
sr_positions = sr_positions.transpose()
sr_ids = np.arange(sr_grid_size**3)


#%%
del dm_data['Coordinates']
dm_data.create_dataset('Coordinates', data=sr_positions)

del dm_data['ParticleIDs']
dm_data.create_dataset('ParticleIDs', data=sr_ids)

sr_file['ICs_parameters'].attrs['Grid Resolution'] = sr_grid_size
sr_snapshot = lr_snapshot.replace('.hdf5', '_sr_level_2_tmp.hdf5')

sr_file.close()
# Enhance the low-resolution snapshot
enhance(lr_snapshot, sr_snapshot, generator, device)
print(f'Upscaling took {time.time() - ti}')
2 changes: 1 addition & 1 deletion swift_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import h5py as h5
import numpy as np

from .postions import get_displacement_field
from .positions import get_displacement_field

# TODO: in the interest of reusability, this should read a single snapshot.
def read_snapshot(snapshots):
Expand Down
119 changes: 119 additions & 0 deletions swift_tools/enhance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 23 11:50:49 2025
@author: brennan
"""

import os
import torch
import shutil
import h5py as h5
import numpy as np

from .positions import get_displacement_field, get_positions
from dmsr.field_operations.resize import cut_field, stitch_fields


def enhance(lr_snapshot, sr_snapshot, generator, device):
"""
Use the given generator to enhance the `lr_snapshot` and save the result in
`sr_snapshot`
"""

if os.path.exists(sr_snapshot):
os.remove(sr_snapshot)
shutil.copy(lr_snapshot, sr_snapshot)
scale_factor = generator.scale_factor

with h5.File(sr_snapshot, 'a') as sr_file:
dm_data = sr_file['DMParticles']

update_particle_mass(dm_data, scale_factor)
update_particle_velocities(dm_data, scale_factor)
update_potentials(dm_data, scale_factor)
update_softenings(dm_data, scale_factor)
update_particle_data(sr_file, generator, device)

grid_size = sr_file['ICs_parameters'].attrs['Grid Resolution']
sr_grid_size = scale_factor * grid_size
sr_file['ICs_parameters'].attrs['Grid Resolution'] = sr_grid_size


def update_particle_data(file, generator, device):
"""
Use the given generator to upscale the particle data in the given file.
"""
dm_data = file['DMParticles']
grid_size = file['ICs_parameters'].attrs['Grid Resolution']
box_size = file['Header'].attrs['BoxSize'][0]
ids = np.asarray(dm_data['ParticleIDs'])
positions = np.asarray(dm_data['Coordinates'])
positions = positions.transpose()

generator.compute_input_padding()
cut_size = generator.inner_region
stride = cut_size
pad = generator.padding
z = generator.sample_latent_space(1, device)

displacements = get_displacement_field(positions, ids, box_size, grid_size)
field_patches = cut_field(displacements[None, ...], cut_size, stride, pad)

sr_patches = []
for patch in field_patches:
patch = torch.from_numpy(patch).to(torch.float)
sr_patch = generator(patch[None, ...], z)
sr_patch = sr_patch.detach()
sr_patches.append(sr_patch.numpy())

scale_factor = generator.scale_factor
sr_grid_size = scale_factor * grid_size
displacement_field = stitch_fields(sr_patches, 4)
sr_positions = get_positions(displacement_field, box_size, sr_grid_size)
sr_positions = sr_positions.transpose()
sr_ids = np.arange(sr_grid_size**3)

del dm_data['Coordinates']
dm_data.create_dataset('Coordinates', data=sr_positions)
del dm_data['ParticleIDs']
dm_data.create_dataset('ParticleIDs', data=sr_ids)


def update_particle_mass(dm_data, scale_factor):
"""Reduce the particle mass appropriately based on the scale factor.
"""
old_mass = np.asarray(dm_data['Masses'])
new_mass = old_mass / scale_factor**3
new_mass = np.tile(new_mass, scale_factor**3)
del dm_data['Masses']
dm_data.create_dataset('Masses', data=new_mass)


def update_particle_velocities(dm_data, scale_factor):
"""Replaces velocity data with zeros.
"""
new_velocities = np.zeros_like(dm_data['Velocities'])
new_velocities = np.tile(new_velocities, (scale_factor**3, 1))
del dm_data['Velocities']
dm_data.create_dataset('Velocities', data=new_velocities)


def update_potentials(dm_data, scale_factor):
"""Replaces potential data with zeros.
"""
new_potentials = np.zeros_like(dm_data['Potentials'])
new_potentials = np.tile(new_potentials, scale_factor**3)
del dm_data['Potentials']
dm_data.create_dataset('Potentials', data=new_potentials)


def update_softenings(dm_data, scale_factor):
"""Reduce the softening length appropriately based on the scale factor.
"""
old_soft = np.asarray(dm_data['Softenings'])
new_soft = old_soft / scale_factor
new_soft = np.tile(new_soft, scale_factor**3)
del dm_data['Softenings']
dm_data.create_dataset('Softenings', data=new_soft)

0 comments on commit 19c9d25

Please sign in to comment.