Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
thunil committed Sep 30, 2018
1 parent b4e8009 commit 1b8f776
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 46 deletions.
41 changes: 22 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Deep-Flow-Prediction

In this repo you can find a framework for fluid flow
_Deep Flow Prediction_ is a framework for fluid flow
(Reynolds-averaged Navier Stokes) predictions with deep learning.
It contains both data generation, network training, and evaluation
scripts. Linux is highly recommended, and assumed as OS the following.
It contains code for data generation, network training, and evaluation.
Linux is highly recommended, and assumed as OS the following.

The accompanying paper can be found here:
<https://arxiv.org>
Expand All @@ -17,8 +17,8 @@ You can also check out our TUM lab website with additional physics-based deep le

## Required software

This script requires PyTorch and numpy for the deep learning part,
and openfoam and gmsh for data generation and meshing.
This codebase requires _PyTorch_ and _numpy_ for the deep learning part,
and _openfoam_ and _gmsh_ for data generation and meshing.
To install these under linux run, e.g.:
```
sudo apt-get install openfoam5 gmsh
Expand All @@ -36,18 +36,21 @@ airfoild database should contain 1498 files afterwards.
## Generate data

Now run `python ./dataGen.py` to generate a first set of 100 airfoils.
This script executes openfoam and runs gmsh for meshing the airfoil profiles.
This script executes _openfoam_ and runs _gmsh_ for meshing the airfoil profiles.

Once `dataGen.py` has finished, you should find 100 .npz files in a new
directory called `train`. You can call this script repeatedly, or adjust
the `samples` variables to generate more date, but for a first test, 100 samples
are sufficient.
directory called `train`. You can call this script repeatedly to generate
more data, or adjust
the `samples` variables to generate more samples with a single call.
For a first test, 100 samples are sufficient, for higher quality models, more
than 10k are recommended..

Output files are saved as compressed numpy arrays. The tensor size in each
sample file is 6x128x128. The first three channels represent the input,
sample file is 6x128x128 with dimensions: channels, x, y.
The first three channels represent the input,
consisting (in this order) of two fields corresponding to the freestream velocities in x and y
direction and one field containing a mask of the airfoil geometry as
a mask. The other three channels represent the target, containing one pressure and two velocity
a mask. The last three channels represent the target, containing one pressure and two velocity
fields.

## Convolutional neural network training
Expand All @@ -70,18 +73,18 @@ are located in `../data/test`. Hence, you either have to generate data in a new
`dataGen.py` script from above, or download the test data set via the link below.

Once the test data is in place, execute `python ./runTest.py`. This script can compute accuracy
evaluations for a range of models, it will automatically evaluate all test samples for all existing model files
named "modelG",
"modelGa",
"modelGb",
"modelGc", etc.
evaluations for a range of models, it will automatically evaluate the test samples for all existing model files
named `modelG`,
`modelGa`,
`modelGb`,
`modelGc`, etc.

The text output will also be written to a file `testout.txt`. In addition, visualized reference data
and corresponding inferred outputs are written to `results_test` as PNGs.

## Further steps

For further experiments, you can increase the `expo` parameter in `runTrain.py` and `runTest.py` (note, non-integers allowed). For large models you'll need much more data, though, to avoid overfitting.
For further experiments, you can increase the `expo` parameter in `runTrain.py` and `runTest.py` (note, non-integers are allowed). For large models you'll need much more data, though, to avoid overfitting.

In addition, the `DfpNet.py` file is worth a look: it contains most of the non-standard code for the RANS flow prediction. E.g., here you can find the U-net setup and data normalization. Hence, this class is a good starting point for experimenting with different architectures.

Expand All @@ -91,7 +94,7 @@ This can come in handy for automated runs with varying parameters.
# Data sets

Below you can download a large-scale training data set, and the test data set
used in the accompanying paper.
used in the accompanying paper, as well as pre-trained models:

((will be made available here soon))

Expand All @@ -102,7 +105,7 @@ used in the accompanying paper.
# Summary

Based on this framework, you should be able to train deep learning models that yield relative errors of 2-3%
for the RANS data sets. In addition, the network architecture should easily generalize do other types of dense
for the RANS data sets. In addition, the network architecture should be applicable to other types of dense
PDE solutions.

Let us know if things don't work, or if you find ways to make it work even better :) !
Expand Down
7 changes: 6 additions & 1 deletion data/download_airfoils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ mkdir ./airfoil_database_test

cd ./coord_seligFmt/

# remove airfoils with text comments, TODO - they should be fixed rather than removed
# cleanup:
# remove airfoils with text comments, TODO - they should be fixed, rather than removed
rm ag24.dat ag25.dat ag26.dat ag27.dat nasasc2-0714.dat goe795sm.dat
# fix some non ascii ones
sed -i 's/[\d128-\d255]//g' airfoil_database/goe187.dat
sed -i 's/[\d128-\d255]//g' airfoil_database/goe188.dat
sed -i 's/[\d128-\d255]//g' airfoil_database/goe235.dat

# move only selected files to make sure we have the right sets
echo moving...
Expand Down
46 changes: 23 additions & 23 deletions train/DfpNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def weights_init(m):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)

def blockUNet(in_c, out_c, name, transposed=False, bn=True, relu=True, size=4, pad=1, dropoutVal=0.):
def blockUNet(in_c, out_c, name, transposed=False, bn=True, relu=True, size=4, pad=1, dropout=0.):
block = nn.Sequential()
if relu:
block.add_module('%s_relu' % name, nn.ReLU(inplace=True))
Expand All @@ -31,32 +31,32 @@ def blockUNet(in_c, out_c, name, transposed=False, bn=True, relu=True, size=4, p
block.add_module('%s_tconv' % name, nn.Conv2d(in_c, out_c, kernel_size=(size-1), stride=1, padding=pad, bias=True))
if bn:
block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c))
if dropoutVal>0.:
block.add_module('%s_dropout' % name, nn.Dropout2d( dropoutVal, inplace=True))
if dropout>0.:
block.add_module('%s_dropout' % name, nn.Dropout2d( dropout, inplace=True))
return block

# generator model
class TurbNetG(nn.Module):
def __init__(self, channelExponent=6, drv=0.):
def __init__(self, channelExponent=6, dropout=0.):
super(TurbNetG, self).__init__()
channels = int(2 ** channelExponent + 0.5)

self.layer1 = nn.Sequential()
self.layer1.add_module('layer1_conv', nn.Conv2d(3, channels, 4, 2, 1, bias=True))

self.layer2 = blockUNet(channels , channels*2, 'layer2', transposed=False, bn=True, relu=False, dropoutVal=drv )
self.layer2x= blockUNet(channels*2, channels*2, 'layer2x',transposed=False, bn=True, relu=False, dropoutVal=drv )
self.layer3 = blockUNet(channels*2, channels*4, 'layer3', transposed=False, bn=True, relu=False, dropoutVal=drv )
self.layer4 = blockUNet(channels*4, channels*8, 'layer4', transposed=False, bn=True, relu=False, dropoutVal=drv , size=2,pad=0)
self.layer5 = blockUNet(channels*8, channels*8, 'layer5', transposed=False, bn=True, relu=False, dropoutVal=drv , size=2,pad=0)
self.layer6 = blockUNet(channels*8, channels*8, 'layer6', transposed=False, bn=False, relu=False, dropoutVal=drv , size=2,pad=0)
self.layer2 = blockUNet(channels , channels*2, 'layer2', transposed=False, bn=True, relu=False, dropout=dropout )
self.layer2b= blockUNet(channels*2, channels*2, 'layer2b',transposed=False, bn=True, relu=False, dropout=dropout )
self.layer3 = blockUNet(channels*2, channels*4, 'layer3', transposed=False, bn=True, relu=False, dropout=dropout )
self.layer4 = blockUNet(channels*4, channels*8, 'layer4', transposed=False, bn=True, relu=False, dropout=dropout , size=2,pad=0)
self.layer5 = blockUNet(channels*8, channels*8, 'layer5', transposed=False, bn=True, relu=False, dropout=dropout , size=2,pad=0)
self.layer6 = blockUNet(channels*8, channels*8, 'layer6', transposed=False, bn=False, relu=False, dropout=dropout , size=2,pad=0)

self.dlayer6 = blockUNet(channels*8, channels*8, 'dlayer6', transposed=True, bn=True, relu=True, dropoutVal=drv , size=2,pad=0)
self.dlayer5 = blockUNet(channels*16,channels*8, 'dlayer5', transposed=True, bn=True, relu=True, dropoutVal=drv , size=2,pad=0)
self.dlayer4 = blockUNet(channels*16,channels*4, 'dlayer4', transposed=True, bn=True, relu=True, dropoutVal=drv )
self.dlayer3 = blockUNet(channels*8, channels*2, 'dlayer3', transposed=True, bn=True, relu=True, dropoutVal=drv )
self.dlayer2x= blockUNet(channels*4, channels*2, 'dlayer2x',transposed=True, bn=True, relu=True, dropoutVal=drv )
self.dlayer2 = blockUNet(channels*4, channels , 'dlayer2', transposed=True, bn=True, relu=True, dropoutVal=drv )
self.dlayer6 = blockUNet(channels*8, channels*8, 'dlayer6', transposed=True, bn=True, relu=True, dropout=dropout , size=2,pad=0)
self.dlayer5 = blockUNet(channels*16,channels*8, 'dlayer5', transposed=True, bn=True, relu=True, dropout=dropout , size=2,pad=0)
self.dlayer4 = blockUNet(channels*16,channels*4, 'dlayer4', transposed=True, bn=True, relu=True, dropout=dropout )
self.dlayer3 = blockUNet(channels*8, channels*2, 'dlayer3', transposed=True, bn=True, relu=True, dropout=dropout )
self.dlayer2b= blockUNet(channels*4, channels*2, 'dlayer2b',transposed=True, bn=True, relu=True, dropout=dropout )
self.dlayer2 = blockUNet(channels*4, channels , 'dlayer2', transposed=True, bn=True, relu=True, dropout=dropout )

self.dlayer1 = nn.Sequential()
self.dlayer1.add_module('dlayer1_relu', nn.ReLU(inplace=True))
Expand All @@ -65,8 +65,8 @@ def __init__(self, channelExponent=6, drv=0.):
def forward(self, x):
out1 = self.layer1(x)
out2 = self.layer2(out1)
out2x= self.layer2x(out2)
out3 = self.layer3(out2x)
out2b= self.layer2b(out2)
out3 = self.layer3(out2b)
out4 = self.layer4(out3)
out5 = self.layer5(out4)
out6 = self.layer6(out5)
Expand All @@ -77,15 +77,15 @@ def forward(self, x):
dout4 = self.dlayer4(dout5_out4)
dout4_out3 = torch.cat([dout4, out3], 1)
dout3 = self.dlayer3(dout4_out3)
dout3_out2x = torch.cat([dout3, out2x], 1)
dout2x = self.dlayer2x(dout3_out2x)
dout2x_out2 = torch.cat([dout2x, out2], 1)
dout2 = self.dlayer2(dout2x_out2)
dout3_out2b = torch.cat([dout3, out2b], 1)
dout2b = self.dlayer2b(dout3_out2b)
dout2b_out2 = torch.cat([dout2b, out2], 1)
dout2 = self.dlayer2(dout2b_out2)
dout2_out1 = torch.cat([dout2, out1], 1)
dout1 = self.dlayer1(dout2_out1)
return dout1

# discriminator (only for adversarial training)
# discriminator (only for adversarial training, currently unused)
class TurbNetD(nn.Module):
def __init__(self, in_channels1, in_channels2,ch=64):
super(TurbNetD, self).__init__()
Expand Down
6 changes: 3 additions & 3 deletions train/runTrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@
print("Output prefix: {}".format(prefix))

autoIter = False
dropoutVal = 0.
dropout = 0.
doLoad = ""

print("LR: {}".format(lrG))
print("LR decay: {}".format(decayLr))
print("Iterations: {}".format(iterations))
print("Dropout: {}".format(dropoutVal))
print("Dropout: {}".format(dropout))

##########################

Expand All @@ -77,7 +77,7 @@

# setup training
epochs = int(iterations/len(trainLoader) + 0.5)
netG = TurbNetG(channelExponent=expo, drv=dropoutVal)
netG = TurbNetG(channelExponent=expo, dropout=dropout)
print(netG) # print full net
model_parameters = filter(lambda p: p.requires_grad, netG.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
Expand Down

0 comments on commit 1b8f776

Please sign in to comment.