Skip to content

Commit

Permalink
添加部分注释
Browse files Browse the repository at this point in the history
  • Loading branch information
username committed Nov 12, 2024
1 parent c10c8ea commit 62c6176
Show file tree
Hide file tree
Showing 7 changed files with 1,546 additions and 37 deletions.
216 changes: 216 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,python

### Linux ###
*~
results
checkpoints

# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*

# KDE directory preferences
.directory

# Linux trash folder which might appear on any partition or disk
.Trash-*

# .nfs files are created when an open file is removed but is still being accessed
.nfs*

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
# *.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml

# ruff
.ruff_cache/

# LSP config files
pyrightconfig.json

### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets

# Local History for Visual Studio Code
.history/

# Built Visual Studio Code Extensions
*.vsix

### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide

# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,python

# Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option)

9 changes: 6 additions & 3 deletions cross_models/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,20 @@ def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1):
nn.GELU(),
nn.Linear(d_ff, d_model))

def forward(self, x):
def forward(self, x): # (B, C, patch_dim, d_model)
#Cross Time Stage: Directly apply MSA to each dimension
batch = x.shape[0]
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
# 对不同的patch进行自注意力
time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model') # (B * C, patch_dim, d_model)
time_enc = self.time_attention(
time_in, time_in, time_in
)

# 进行一系列dropout和正则化
dim_in = time_in + self.dropout(time_enc)
dim_in = self.norm1(dim_in)
dim_in = dim_in + self.dropout(self.MLP1(dim_in))
dim_in = self.norm2(dim_in)
dim_in = self.norm2(dim_in) # (B * C, patch_dim, d_model)

#Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch)
Expand Down
6 changes: 3 additions & 3 deletions cross_models/cross_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def __init__(self, seg_len, d_model):

def forward(self, x):
batch, ts_len, ts_dim = x.shape

# 分段 (32, 168, 7)->(6272, 6)
x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len)
x_embed = self.linear(x_segment)
x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim)
x_embed = self.linear(x_segment) #(6272, d_model)
x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim) # (32, 7, 28, 256)

return x_embed
4 changes: 2 additions & 2 deletions cross_models/cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \
self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \
d_ff, dropout))

def forward(self, x):
def forward(self, x): # (B, C, patch_dim, d_model)
_, ts_dim, _, _ = x.shape

if self.merge_layer is not None:
Expand All @@ -86,7 +86,7 @@ def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, drop
self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\
ceil(in_seg_num/win_size**i), factor))

def forward(self, x):
def forward(self, x): # (B, C, patch_dim, d_model)
encode_x = []
encode_x.append(x)

Expand Down
13 changes: 8 additions & 5 deletions cross_models/cross_former.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,24 @@ def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 4,
self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \
out_seg_num = (self.pad_out_len // seg_len), factor = factor)

def forward(self, x_seq):
def forward(self, x_seq): # (B,L,C)
if (self.baseline):
base = x_seq.mean(dim = 1, keepdim = True)
else:
base = 0
batch_size = x_seq.shape[0]
if (self.in_len_add != 0):
# 如果无法平均分段(因为有多余的),则取开头第一个元素复制in_len_add次然后填充,具体看论文附录D1
if (self.in_len_add != 0):
x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1)

x_seq = self.enc_value_embedding(x_seq)
x_seq += self.enc_pos_embedding
# embedding和patchtst很像
x_seq = self.enc_value_embedding(x_seq) # (B, C, patch_dim, d_model)
x_seq += self.enc_pos_embedding # + (1, C, patch_dim, d_model)
x_seq = self.pre_norm(x_seq)

enc_out = self.encoder(x_seq)


# 在batch_size纬度重复
dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size)
predict_y = self.decoder(dec_in, enc_out)

Expand Down
Loading

0 comments on commit 62c6176

Please sign in to comment.