Skip to content

[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax for Partial Differential Equations

License

Notifications You must be signed in to change notification settings

scaomath/galerkin-transformer

Repository files navigation

[NeurIPS 2021] Galerkin Transformer: linear attention without softmax

License: MIT Python 3.8 Pytorch 1.9 arXiv Open in Visual Studio Code

Summary

Introduction

The new simple attention operator (for the encoder) is simply Q(K^TV) (Galerkin), or the quadratic complexity one (QK^T)V (Fourier).

  • No softmax, or the approximation thereof, at all.
  • Whichever two latent representations doing matmul get the layer normalization, similar to Gram-Schmidt process where we have to divide the basis's norm squared. Q, K get layer normalized in the Fourier-type attention (every position attends with every other), as for K, V in the Galerkin-type attention (every basis attends with every other basis). No layer normalization is applied afterward.
  • Some other components are tweaked according to our Hilbertian interpretation of attention.

For the full operator learner, the feature extractor is a simple linear layer or an interpolation-based CNN, the decoder is the spectral convolution real parameter re-implementation from the best operator learner to-date Fourier Neural Operator (FNO) in Li et al 2020 if the target is smooth, or just a pointwise FFN if otherwise. The resulting network is extremely powerful in learning PDE-related operators (energy decay, inverse coefficient identification).

Hilbertian framework to analyze a linear attention variant

Even though everyone is Transformer'ing, the mathematics behind the attention mechanism is not well understood. We have also shown that the Galerkin-type attention (a linear attention without softmax) has an approximation capacity on par with a Petrov-Galerkin projection under a Hilbertian setup. We use a method commonly known as ''mixed method'' in the finite element analysis community that is used to solve fluid/electromagnetics problems. Unlike finite element methods, in an attention-based operator learner the approximation is not discretization-tied, in that:

  1. The latent representation is interpreted "column-wise" (each column represents a basis), opposed to the conventional "row-wise"/ "position-wise"/"word-wise" interpretation of attention in NLP.
  2. The dimensions of the approximation spaces are not tied to the geometry as in the traditional finite element analysis (or finite difference, spectral methods, radial basis, etc.);
  3. The approximation spaces are being dynamically updated by the nonlinear universal approximator due to the presence of the positional encodings, which determines the topology of the approximation space.

Interpretation of the attention mechanism

  1. Approximation capacity: an incoming "query" is a function in some Hilbert space that comes to ask us to find its best representation in the latent space. To deliver the best approximator in "value" (trial function space), the "key" space (test function space) has to be big enough so that for every value there is a key to unlock it.
  2. Translation capacity: the attention is capable to find latent representations to minimize a functional norm that measures the distance between the input (query) and the target (values). An ideal operator learner is learning some nonlinear perturbations of the subspaces on which the input (query) and the target (values) are "close", and this closeness is measured by how they respond to a dynamically changing set of test basis (keys).

For details please refer to: https://arxiv.org/abs/2105.14995

@inproceedings{Cao2021transformer,
  author        = {Shuhao Cao},
  title         = {Choose a Transformer: {F}ourier or {G}alerkin},
  booktitle     = {Advances in Neural Information Processing Systems (NeurIPS 2021)},
  volume        = {34},
  year          = {2021},
  eprint        = {arXiv: 2105.14995},
  primaryclass  = {cs.CL},
  url={https://openreview.net/forum?id=ssohLcmn4-r},
}

Install

Requirements

(Updated Jun 17 2021) PyTorch requirement updated to 1.9.0 as the introduction of the batch_first argument will conform with our pipeline.

This package can be cloned locally and used with the following requirements:

git clone https://github.com/scaomath/galerkin-transformer.git
cd galerkin-transformer
python3 -m pip install -r requirements.txt
seaborn==0.11.1
torchinfo==0.0.8
numpy==1.20.2
torch==1.9.0
plotly==4.14.3
scipy==1.6.2
psutil==5.8.0
matplotlib==3.3.4
tqdm==4.56.0
PyYAML==5.4.1

If interactive mode is to be used, please install

jupyterthemes==0.20.0
ipython==7.23.1

Installing using pip

This package can be installed using pip.

python3 -m pip install galerkin-transformer

Example usage of the Simple Fourier/Galerkin Transformer encoder layers:

from galerkin_transformer.model import *

encoder_layer = FourierTransformerEncoderLayer(
                 d_model=128,
                 pos_dim=1,
                 n_head=4,
                 dim_feedforward=512,
                 attention_type='galerkin',
                 layer_norm=False,
                 attn_norm=True,
                 norm_type='layer',
                 dropout=0.05)
encoder_layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(6)])
x = torch.randn(8, 8192, 128) # embedding
pos = torch.arange(0, 8192).unsqueeze(-1) # Euclidean coordinates
pos = pos.repeat(8, 1, 1)
for layer in encoder_layers:
    x = layer(x, pos)

Data

The data is courtesy of Zongyi Li (Caltech) under the MIT license. Download the following data from here:

burgers_data_R10.mat
piececonst_r421_N1024_smooth1.mat
piececonst_r421_N1024_smooth2.mat.

The repo has a semi env variable $DATA_PATH set in utils_ft.py, if you have a global system environ variable name DATA_PATH, then please put the data in that folder. Otherwise, please unzip the Burgers and Darcy flow problem files to the ./data folder.

Examples

All examples are learning PDE-related operators. The setting can be found in config.yml. To fully reproducing our result, please refer to the training scripts in the Examples for all the possible args.

The memory and speed profiling scripts using autograd.profiler can be found in Examples folder as well.

Evaluation notebooks

Please download the pretrained model's .pt files from Releases and put them in the ./models folder.

License

This software is distributed with the MIT license which translates roughly that you can use it however you want and for whatever reason you want. All the information regarding support, copyright and the license can be found in the LICENSE file.

Acknowledgement

The hardware to perform this work is provided by Andromeda Saving Fund. The first author was supported in part by the National Science Foundation under grants DMS-1913080 and DMS-2136075. No additional revenues are related to this work. We would like to thank the anonymous reviewers and the area chair in NeurIPS 2021 for the suggestions on improving this paper. We would like to thank Dr. Long Chen (Univ of California Irvine) for the inspiration of and encouragement on the initial conceiving of this paper, as well as numerous constructive advices on revising this paper, not mentioning his persistent dedication of making publicly available tutorials on writing beautiful vectorized code. We would like to thank Dr. Ari Stern (Washington Univ in St. Louis) for the help on the relocation during the COVID-19 pandemic. We would like to thank Dr. Likai Chen (Washington Univ in St. Louis) for the invitation to the Stats and Data Sci seminar at WashU that resulted the reboot of this study. We would like to thank Dr. Ruchi Guo (Univ of California Irvine) and Dr. Yuanzhe Xi (Emory) for the invaluable feedbacks on the choice of the numerical experiments. We would like to thank the Kaggle community, including but not limited to Jean-François Puget (Uncle CPMP@Kaggle) for sharing a simple Graph Transformer in Tensorflow, Murakami Akira (mrkmakr@Kaggle) for sharing a Graph Transformer with a CNN feature extractor in Tensorflow, and Cher Keng Heng (hengck23@Kaggle) for sharing a Graph Transformer in PyTorch. We would like to thank daslab@Stanford, OpenVaccine, and Eterna for hosting the COVID-19 mRNA Vaccine competition and Deng Lab (Univ of Georgia) for collaborating in this competition. We would like to thank CHAMPS (Chemistry and Mathematics in Phase Space) for hosting the J-coupling quantum chemistry competition and Corey Levinson (returnofsputnik@Kaggle, Eligo Energy, LLC) for collaborating in this competition. We would like to thank Zongyi Li (Caltech) for sharing some early dev code in the updated PyTorch torch.fft interface. We would like to thank Ziteng Pang (Univ of Michigan) and Tianyang Lin (Fudan Univ) to update us with various references on Transformers. We would like to thank Joel Schlosser to incorporate our change to the PyTorch transformer submodule to simplify our testing pipeline. We would be grateful to the PyTorch community for selflessly code sharing, including Phil Wang(lucidrains@github) and Harvard NLP group Klein et al. (2017). We would like to thank the chebfun Driscoll et al. (2014) for integrating powerful tools into a simple interface to solve PDEs. We would like to thank Dr. Yannic Kilcher and Dr. Hung-yi Lee (National Taiwan Univ) for frequently covering the newest research on Transformers in video formats. We would also like to thank the Python community (Van Rossum and Drake (2009); Oliphant (2007)) for sharing and developing the tools that enabled this work, including Pytorch Paszke et al.(2017), NumPy Harris et al. (2020), SciPy Virtanen et al. (2020), Seaborn Waskom (2021), Plotly Inc. (2015), Matplotlib Hunter (2007), and the Python team for Visual Studio Code. We would like to thank draw.io for providing an easy and powerful interface for producing vector format diagrams. For details please refer to the documents of every function that is not built from the ground up in our open-source software library.

About

[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax for Partial Differential Equations

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages