Skip to content

Commit

Permalink
implement Q2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jaklvinc committed Dec 12, 2023
1 parent afff9b3 commit 465442b
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions HW1/hw1-q2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def __init__(
includes modules for several activation functions and dropout as well.
"""
super().__init__()
# Implement me!
raise NotImplementedError
self.layers = nn.Sequential(nn.Linear(n_features, hidden_size),
nn.Dropout(dropout),
(nn.ReLU() if activation_type == 'relu' else nn.Tanh()),
nn.Linear(hidden_size, n_classes),
nn.Softmax(dim=-1))

def forward(self, x, **kwargs):
"""
Expand All @@ -76,7 +79,7 @@ def forward(self, x, **kwargs):
the output logits from x. This will include using various hidden
layers, pointwise nonlinear functions, and dropout.
"""
raise NotImplementedError
return self.layers(x)


def train_batch(X, y, model, optimizer, criterion, **kwargs):
Expand Down

0 comments on commit 465442b

Please sign in to comment.