Skip to content

Commit

Permalink
allowed reloaded model to have multi-dimensional outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Seanny123 committed Oct 19, 2018
1 parent 2787ec2 commit 105e2e1
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 21 deletions.
7 changes: 2 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def prep_train_data(batch_idx: np.ndarray, t_cfg: TrainConfig, train_data: Train
y_target = train_data.targs[batch_idx + t_cfg.T]

for b_i, b_idx in enumerate(batch_idx):
b_slc = slice(b_idx, (b_idx + t_cfg.T - 1))
b_slc = slice(b_idx, b_idx + t_cfg.T - 1)
feats[b_i, :, :] = train_data.feats[b_slc, :]
y_history[b_i, :] = train_data.targs[b_slc]

Expand Down Expand Up @@ -150,9 +150,6 @@ def train_iteration(t_net: DaRnnNet, loss_func: typing.Callable, X, y_history, y
t_net.enc_opt.step()
t_net.dec_opt.step()

# if loss.data[0] < 10:
# self.logger.info("MSE: %s, loss: %s.", loss.data, (y_pred[:, 0] - y_true).pow(2).mean())

return loss.item()


Expand Down Expand Up @@ -189,7 +186,7 @@ def predict(t_net: DaRnnNet, t_dat: TrainData, train_size: int, batch_size: int,
save_plots = True
debug = False

raw_data = pd.read_csv(os.path.join('data/nasdaq100_padding.csv'), nrows=100 if debug else None)
raw_data = pd.read_csv(os.path.join("data", "nasdaq100_padding.csv"), nrows=100 if debug else None)
logger.info(f"Shape of data: {raw_data.shape}.\nMissing in data: {raw_data.isnull().sum().sum()}.")
targ_cols = ("NDX",)
data, scaler = preprocess_data(raw_data, targ_cols)
Expand Down
19 changes: 11 additions & 8 deletions main_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,29 @@
from constants import device


def preprocess_data(dat, scale) -> TrainData:
def preprocess_data(dat, col_names, scale) -> TrainData:
proc_dat = scale.transform(dat)

col_idx = list(dat.columns).index("NDX")
mask = np.ones(proc_dat.shape[1], dtype=bool)
mask[col_idx] = False
dat_cols = list(dat.columns)
for col_name in col_names:
mask[dat_cols.index(col_name)] = False

feats = proc_dat[:, mask]
targs = proc_dat[:, ~mask]

return TrainData(feats, targs.squeeze())
return TrainData(feats, targs)


def predict(encoder, decoder, t_dat, batch_size: int, T: int) -> np.ndarray:
y_pred = np.zeros(t_dat.feats.shape[0] - T + 1)
y_pred = np.zeros((t_dat.feats.shape[0] - T + 1, t_dat.targs.shape[0]))

for y_i in range(0, len(y_pred), batch_size):
y_slc = slice(y_i, y_i + batch_size)
batch_idx = range(len(y_pred))[y_slc]
b_len = len(batch_idx)
X = np.zeros((b_len, T - 1, t_dat.feats.shape[1]))
y_history = np.zeros((b_len, T - 1))
y_history = np.zeros((b_len, T - 1, t_dat.targs.shape[0]))

for b_i, b_idx in enumerate(batch_idx):
idx = range(b_idx, b_idx + T - 1)
Expand All @@ -44,7 +46,7 @@ def predict(encoder, decoder, t_dat, batch_size: int, T: int) -> np.ndarray:

y_history = numpy_to_tvar(y_history)
_, input_encoded = encoder(numpy_to_tvar(X))
y_pred[y_slc] = decoder(input_encoded, y_history).cpu().data.numpy()[:, 0]
y_pred[y_slc] = decoder(input_encoded, y_history).cpu().data.numpy()

return y_pred

Expand All @@ -64,7 +66,8 @@ def predict(encoder, decoder, t_dat, batch_size: int, T: int) -> np.ndarray:

scaler = joblib.load(os.path.join("data", "scaler.pkl"))
raw_data = pd.read_csv(os.path.join("data", "nasdaq100_padding.csv"), nrows=100 if debug else None)
data = preprocess_data(raw_data, scaler)
targ_cols = ("NDX",)
data = preprocess_data(raw_data, targ_cols, scaler)

with open(os.path.join("data", "da_rnn_kwargs.json"), "r") as fi:
da_rnn_kwargs = json.load(fi)
Expand Down
11 changes: 3 additions & 8 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,6 @@ def forward(self, input_data):
class Decoder(nn.Module):

def __init__(self, encoder_hidden_size: int, decoder_hidden_size: int, T: int, out_feats=1):
"""
encoder_hidden_size: hidden_size of encoder layer
decoder_hidden_size:
T:
"""
super(Decoder, self).__init__()

self.T = T
Expand All @@ -93,7 +88,7 @@ def forward(self, input_encoded, y_history):
context = Variable(torch.zeros(input_encoded.size(0), self.encoder_hidden_size))

for t in range(self.T - 1):
# (batch_size, T, (2*decoder_hidden_size + encoder_hidden_size))
# (batch_size, T, (2 * decoder_hidden_size + encoder_hidden_size))
x = torch.cat((hidden.repeat(self.T - 1, 1, 1).permute(1, 0, 2),
cell.repeat(self.T - 1, 1, 1).permute(1, 0, 2),
input_encoded), dim=2)
Expand All @@ -102,13 +97,13 @@ def forward(self, input_encoded, y_history):
self.attn_layer(
x.view(-1, 2 * self.decoder_hidden_size + self.encoder_hidden_size)
).view(-1, self.T - 1),
dim=1) # (batch_size, T - 1), attention weights sum up to 1
dim=1) # (batch_size, T - 1)

# Eqn. 14: compute context vector
context = torch.bmm(x.unsqueeze(1), input_encoded)[:, 0, :] # (batch_size, encoder_hidden_size)

# Eqn. 15
y_tilde = self.fc(torch.cat((context, y_history[:, t]), dim=1)) # batch_size * 1
y_tilde = self.fc(torch.cat((context, y_history[:, t]), dim=1)) # (batch_size, out_size)
# Eqn. 16: LSTM
self.lstm_layer.flatten_parameters()
_, lstm_output = self.lstm_layer(y_tilde.unsqueeze(0), (hidden, cell))
Expand Down

0 comments on commit 105e2e1

Please sign in to comment.