Skip to content

Commit

Permalink
Merge pull request zhouhaoyi#32 from cookieminions/dev
Browse files Browse the repository at this point in the history
Update time features embed
  • Loading branch information
cookieminions authored Feb 23, 2021
2 parents 9adafc4 + 4a4b8a4 commit 8568425
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion exp/exp_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _build_model(self):
self.args.dropout,
self.args.attn,
self.args.embed,
self.args.data[:-1],
self.args.freq,
self.args.activation,
self.args.output_attention,
self.args.distil,
Expand Down
13 changes: 7 additions & 6 deletions models/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def forward(self, x):
return self.emb(x).detach()

class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type='fixed', data='ETTh'):
def __init__(self, d_model, embed_type='fixed', freq='h'):
super(TemporalEmbedding, self).__init__()

minute_size = 4; hour_size = 24
weekday_size = 7; day_size = 32; month_size = 13

Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
if data=='ETTm':
if freq=='t':
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
Expand All @@ -83,22 +83,23 @@ def forward(self, x):
return hour_x + weekday_x + day_x + month_x + minute_x

class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type='timeF', data='ETTh'):
def __init__(self, d_model, embed_type='timeF', freq='h'):
super(TimeFeatureEmbedding, self).__init__()

d_inp = 4 if data=='ETTh' else 5
freq_map = {'h':4, 't':5}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model)

def forward(self, x):
return self.embed(x)

class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', data='ETTh', dropout=0.1):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
super(DataEmbedding, self).__init__()

self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, data=data) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, data=data)
self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)

self.dropout = nn.Dropout(p=dropout)

Expand Down
12 changes: 6 additions & 6 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class Informer(nn.Module):
def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
factor=5, d_model=512, n_heads=8, e_layers=3, d_layers=2, d_ff=512,
dropout=0.0, attn='prob', embed='fixed', data='ETTh', activation='gelu',
dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu',
output_attention = False, distil=True,
device=torch.device('cuda:0')):
super(Informer, self).__init__()
Expand All @@ -20,8 +20,8 @@ def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
self.output_attention = output_attention

# Encoding
self.enc_embedding = DataEmbedding(enc_in, d_model, embed, data, dropout)
self.dec_embedding = DataEmbedding(dec_in, d_model, embed, data, dropout)
self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
# Attention
Attn = ProbAttention if attn=='prob' else FullAttention
# Encoder
Expand Down Expand Up @@ -84,7 +84,7 @@ def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
class InformerStack(nn.Module):
def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
factor=5, d_model=512, n_heads=8, e_layers=3, d_layers=2, d_ff=512,
dropout=0.0, attn='prob', embed='fixed', data='ETTh', activation='gelu',
dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu',
output_attention = False, distil=True,
device=torch.device('cuda:0')):
super(InformerStack, self).__init__()
Expand All @@ -93,8 +93,8 @@ def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
self.output_attention = output_attention

# Encoding
self.enc_embedding = DataEmbedding(enc_in, d_model, embed, data, dropout)
self.dec_embedding = DataEmbedding(dec_in, d_model, embed, data, dropout)
self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
# Attention
Attn = ProbAttention if attn=='prob' else FullAttention
# Encoder
Expand Down
3 changes: 1 addition & 2 deletions utils/timefeatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,4 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
raise RuntimeError(supported_freq_msg)

def time_features(dates, freq='h'):
return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])

return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])

0 comments on commit 8568425

Please sign in to comment.