forked from minitorch/Module-0
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
2,110 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
*.\#* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# MiniTorch Module 0 | ||
|
||
<img src="https://minitorch.github.io/_images/match.png" width="100px"> | ||
|
||
* Docs: https://minitorch.github.io/ | ||
|
||
* Overview: https://minitorch.github.io/module0.html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .module import * # noqa: F401,F403 | ||
from .testing import * # noqa: F401,F403 | ||
from .datasets import * # noqa: F401,F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from dataclasses import dataclass | ||
import random | ||
|
||
|
||
def make_pts(N): | ||
X = [] | ||
for i in range(N): | ||
x_1 = random.random() | ||
x_2 = random.random() | ||
X.append((x_1, x_2)) | ||
return X | ||
|
||
|
||
@dataclass | ||
class Graph: | ||
N: int | ||
X: list | ||
y: list | ||
|
||
|
||
def simple(N): | ||
X = make_pts(N) | ||
y = [] | ||
for x_1, x_2 in X: | ||
y1 = 1 if x_1 < 0.5 else 0 | ||
y.append(y1) | ||
return Graph(N, X, y) | ||
|
||
|
||
def diag(N): | ||
X = make_pts(N) | ||
y = [] | ||
for x_1, x_2 in X: | ||
y1 = 1 if x_1 + x_2 < 0.5 else 0 | ||
y.append(y1) | ||
return Graph(N, X, y) | ||
|
||
|
||
def split(N): | ||
X = make_pts(N) | ||
y = [] | ||
for x_1, x_2 in X: | ||
y1 = 1 if x_1 < 0.2 or x_1 > 0.8 else 0 | ||
y.append(y1) | ||
return Graph(N, X, y) | ||
|
||
|
||
def xor(N): | ||
X = make_pts(N) | ||
y = [] | ||
for x_1, x_2 in X: | ||
y1 = 1 if ((x_1 < 0.5 and x_2 > 0.5) or (x_1 > 0.5 and x_2 < 0.5)) else 0 | ||
y.append(y1) | ||
return Graph(N, X, y) | ||
|
||
|
||
datasets = {"Simple": simple, "Diag": diag, "Split": split, "Xor": xor} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
class Module: | ||
""" | ||
Modules form a tree that store parameters and other | ||
submodules. They make up the basis of neural network stacks. | ||
Attributes: | ||
_modules (dict of name x :class:`Module`): Storage of the child modules | ||
_parameters (dict of name x :class:`Parameter`): Storage of the module's parameters | ||
training (bool): Whether the module is in training mode or evaluation mode | ||
""" | ||
|
||
def __init__(self): | ||
self._modules = {} | ||
self._parameters = {} | ||
self.training = True | ||
|
||
def modules(self): | ||
"Return the direct child modules of this module." | ||
return self.__dict__["_modules"].values() | ||
|
||
def train(self): | ||
"Set the mode of this module and all descendent modules to `train`." | ||
# TODO: Implement for Task 0.4. | ||
raise NotImplementedError('Need to implement for Task 0.4') | ||
|
||
def eval(self): | ||
"Set the mode of this module and all descendent modules to `eval`." | ||
# TODO: Implement for Task 0.4. | ||
raise NotImplementedError('Need to implement for Task 0.4') | ||
|
||
def named_parameters(self): | ||
""" | ||
Collect all the parameters of this module and its descendents. | ||
Returns: | ||
list of pairs: Contains the name and :class:`Parameter` of each ancestor parameter. | ||
""" | ||
# TODO: Implement for Task 0.4. | ||
raise NotImplementedError('Need to implement for Task 0.4') | ||
|
||
def parameters(self): | ||
"Enumerate over all the parameters of this module and its descendents." | ||
# TODO: Implement for Task 0.4. | ||
raise NotImplementedError('Need to implement for Task 0.4') | ||
|
||
def add_parameter(self, k, v): | ||
""" | ||
Manually add a parameter. Useful helper for scalar parameters. | ||
Args: | ||
k (str): Local name of the parameter. | ||
v (value): Value for the parameter. | ||
Returns: | ||
Parameter: Newly created parameter. | ||
""" | ||
val = Parameter(v, k) | ||
self.__dict__["_parameters"][k] = val | ||
return val | ||
|
||
def __setattr__(self, key, val): | ||
if isinstance(val, Parameter): | ||
self.__dict__["_parameters"][key] = val | ||
elif isinstance(val, Module): | ||
self.__dict__["_modules"][key] = val | ||
else: | ||
super().__setattr__(key, val) | ||
|
||
def __getattr__(self, key): | ||
if key in self.__dict__["_parameters"]: | ||
return self.__dict__["_parameters"][key] | ||
|
||
if key in self.__dict__["_modules"]: | ||
return self.__dict__["_modules"][key] | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.forward(*args, **kwargs) | ||
|
||
def forward(self): | ||
assert False, "Not Implemented" | ||
|
||
def __repr__(self): | ||
def _addindent(s_, numSpaces): | ||
s = s_.split("\n") | ||
if len(s) == 1: | ||
return s_ | ||
first = s.pop(0) | ||
s = [(numSpaces * " ") + line for line in s] | ||
s = "\n".join(s) | ||
s = first + "\n" + s | ||
return s | ||
|
||
child_lines = [] | ||
|
||
for key, module in self._modules.items(): | ||
mod_str = repr(module) | ||
mod_str = _addindent(mod_str, 2) | ||
child_lines.append("(" + key + "): " + mod_str) | ||
lines = child_lines | ||
|
||
main_str = self.__class__.__name__ + "(" | ||
if lines: | ||
# simple one-liner info, which most builtin Modules will use | ||
main_str += "\n " + "\n ".join(lines) + "\n" | ||
|
||
main_str += ")" | ||
return main_str | ||
|
||
|
||
class Parameter: | ||
""" | ||
A Parameter is a special container stored in a :class:`Module`. | ||
It is designed to hold a :class:`Variable`, but we allow it to hold | ||
any value for testing. | ||
""" | ||
|
||
def __init__(self, x=None, name=None): | ||
self.value = x | ||
self.name = name | ||
if hasattr(x, "requires_grad_"): | ||
self.value.requires_grad_(True) | ||
if self.name: | ||
self.value.name = self.name | ||
|
||
def update(self, x): | ||
"Update the parameter value." | ||
self.value = x | ||
if hasattr(x, "requires_grad_"): | ||
self.value.requires_grad_(True) | ||
if self.name: | ||
self.value.name = self.name | ||
|
||
def __repr__(self): | ||
return repr(self.value) | ||
|
||
def __str__(self): | ||
return str(self.value) |
Oops, something went wrong.