Skip to content

Commit

Permalink
Merge pull request minitorch#37 from minitorch/2024
Browse files Browse the repository at this point in the history
2024
  • Loading branch information
srush authored Aug 30, 2024
2 parents 298cc96 + 5d58448 commit 0e47013
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 70 deletions.
Empty file added files_to_sync.txt
Empty file.
46 changes: 18 additions & 28 deletions minitorch/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Tuple


def make_pts(N: int) -> List[Tuple[float, float]]:
def make_pts(N):
X = []
for i in range(N):
x_1 = random.random()
Expand All @@ -20,7 +20,7 @@ class Graph:
y: List[int]


def simple(N: int) -> Graph:
def simple(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -29,7 +29,7 @@ def simple(N: int) -> Graph:
return Graph(N, X, y)


def diag(N: int) -> Graph:
def diag(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -38,7 +38,7 @@ def diag(N: int) -> Graph:
return Graph(N, X, y)


def split(N: int) -> Graph:
def split(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -47,49 +47,39 @@ def split(N: int) -> Graph:
return Graph(N, X, y)


def xor(N: int) -> Graph:
def xor(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
y1 = 1 if ((x_1 < 0.5 and x_2 > 0.5) or (x_1 > 0.5 and x_2 < 0.5)) else 0
y1 = 1 if x_1 < 0.5 and x_2 > 0.5 or x_1 > 0.5 and x_2 < 0.5 else 0
y.append(y1)
return Graph(N, X, y)


def circle(N: int) -> Graph:
def circle(N):
X = make_pts(N)
y = []
for x_1, x_2 in X:
x1, x2 = (x_1 - 0.5, x_2 - 0.5)
x1, x2 = x_1 - 0.5, x_2 - 0.5
y1 = 1 if x1 * x1 + x2 * x2 > 0.1 else 0
y.append(y1)
return Graph(N, X, y)


def spiral(N: int) -> Graph:
def x(t: float) -> float:
def spiral(N):

def x(t):
return t * math.cos(t) / 20.0

def y(t: float) -> float:
def y(t):
return t * math.sin(t) / 20.0

X = [
(x(10.0 * (float(i) / (N // 2))) + 0.5, y(10.0 * (float(i) / (N // 2))) + 0.5)
for i in range(5 + 0, 5 + N // 2)
]
X = X + [
(y(-10.0 * (float(i) / (N // 2))) + 0.5, x(-10.0 * (float(i) / (N // 2))) + 0.5)
for i in range(5 + 0, 5 + N // 2)
]
X = [(x(10.0 * (float(i) / (N // 2))) + 0.5, y(10.0 * (float(i) / (N //
2))) + 0.5) for i in range(5 + 0, 5 + N // 2)]
X = X + [(y(-10.0 * (float(i) / (N // 2))) + 0.5, x(-10.0 * (float(i) /
(N // 2))) + 0.5) for i in range(5 + 0, 5 + N // 2)]
y2 = [0] * (N // 2) + [1] * (N // 2)
return Graph(N, X, y2)


datasets = {
"Simple": simple,
"Diag": diag,
"Split": split,
"Xor": xor,
"Circle": circle,
"Spiral": spiral,
}
datasets = {'Simple': simple, 'Diag': diag, 'Split': split, 'Xor': xor,
'Circle': circle, 'Spiral': spiral}
33 changes: 17 additions & 16 deletions minitorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@


class Module:
"""
Modules form a tree that store parameters and other
"""Modules form a tree that store parameters and other
submodules. They make up the basis of neural network stacks.
Attributes:
Attributes
----------
_modules : Storage of the child modules
_parameters : Storage of the module's parameters
training : Whether the module is in training mode or evaluation mode
Expand All @@ -25,46 +25,48 @@ def __init__(self) -> None:
self.training = True

def modules(self) -> Sequence[Module]:
"Return the direct child modules of this module."
"""Return the direct child modules of this module."""
m: Dict[str, Module] = self.__dict__["_modules"]
return list(m.values())

def train(self) -> None:
"Set the mode of this module and all descendent modules to `train`."
"""Set the mode of this module and all descendent modules to `train`."""
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")

def eval(self) -> None:
"Set the mode of this module and all descendent modules to `eval`."
"""Set the mode of this module and all descendent modules to `eval`."""
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")

def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
"""
Collect all the parameters of this module and its descendents.
"""Collect all the parameters of this module and its descendents.
Returns:
Returns
-------
The name and `Parameter` of each ancestor parameter.
"""
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")

def parameters(self) -> Sequence[Parameter]:
"Enumerate over all the parameters of this module and its descendents."
"""Enumerate over all the parameters of this module and its descendents."""
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")

def add_parameter(self, k: str, v: Any) -> Parameter:
"""
Manually add a parameter. Useful helper for scalar parameters.
"""Manually add a parameter. Useful helper for scalar parameters.
Args:
----
k: Local name of the parameter.
v: Value for the parameter.
Returns:
-------
Newly created parameter.
"""
val = Parameter(v, k)
self.__dict__["_parameters"][k] = val
Expand Down Expand Up @@ -118,8 +120,7 @@ def _addindent(s_: str, numSpaces: int) -> str:


class Parameter:
"""
A Parameter is a special container stored in a `Module`.
"""A Parameter is a special container stored in a `Module`.
It is designed to hold a `Variable`, but we allow it to hold
any value for testing.
Expand All @@ -134,7 +135,7 @@ def __init__(self, x: Any, name: Optional[str] = None) -> None:
self.value.name = self.name

def update(self, x: Any) -> None:
"Update the parameter value."
"""Update the parameter value."""
self.value = x
if hasattr(x, "requires_grad_"):
self.value.requires_grad_(True)
Expand Down
4 changes: 1 addition & 3 deletions minitorch/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Collection of the core mathematical operators used throughout the code base.
"""
"""Collection of the core mathematical operators used throughout the code base."""

import math

Expand Down
2 changes: 1 addition & 1 deletion project/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
st.sidebar.markdown(
"""
<h1 style="font-size:30pt; float: left; margin-right: 20px; margin-top: 1px;">MiniTorch</h1>{}
""".format(get_img_tag("https://minitorch.github.io/_images/match.png", width="40")),
""".format(get_img_tag("https://minitorch.github.io/logo-sm.png", width="40")),
unsafe_allow_html=True,
)

Expand Down
105 changes: 100 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@ name = "minitorch"
version = "0.5"

[tool.pyright]
include = ["minitorch","tests"]
exclude = ["**/docs", "**/project", "**/mt_diagrams", "**/assignments"]
include = ["**/minitorch"]
exclude = [
"**/docs",
"**/docs/module1/**",
"**/assignments",
"**/project",
"**/mt_diagrams",
"**/.*",
"*chainrule.py*",
]
venvPath = "."
venv = ".venv"
reportUnknownMemberType = "none"
Expand All @@ -21,6 +29,8 @@ reportUnusedExpression = "none"
reportUnknownLambdaType = "none"
reportIncompatibleMethodOverride = "none"
reportPrivateUsage = "none"
reportMissingParameterType = "error"


[tool.pytest.ini_options]
markers = [
Expand Down Expand Up @@ -50,7 +60,92 @@ markers = [
"task4_3",
"task4_4",
]
[tool.ruff]

exclude = [
".git",
"__pycache__",
"**/docs/slides/*",
"old,build",
"dist",
"**/project/**/*",
"**/mt_diagrams/*",
"**/minitorch/testing.py",
"**/docs/**/*",
]

ignore = [
"ANN101",
"ANN401",
"N801",
"E203",
"E266",
"E501",
"E741",
"N803",
"N802",
"N806",
"D400",
"D401",
"D105",
"D415",
"D402",
"D205",
"D100",
"D101",
"D107",
"D213",
"ANN204",
"ANN102",
]
select = ["D", "E", "F", "N", "ANN"]
fixable = [
"A",
"B",
"C",
"D",
"E",
"F",
"G",
"I",
"N",
"Q",
"S",
"T",
"W",
"ANN",
"ARG",
"BLE",
"COM",
"DJ",
"DTZ",
"EM",
"ERA",
"EXE",
"FBT",
"ICN",
"INP",
"ISC",
"NPY",
"PD",
"PGH",
"PIE",
"PL",
"PT",
"PTH",
"PYI",
"RET",
"RSE",
"RUF",
"SIM",
"SLF",
"TCH",
"TID",
"TRY",
"UP",
"YTT",
]
unfixable = []

[tool.ruff.lint]
ignore = ["N801", "E203", "E266", "E501", "E741", "N803", "N802", "N806"]
exclude = [".git","__pycache__","docs/slides/*","old,build","dist"]
[tool.ruff.extend-per-file-ignores]
"tests/**/*.py" = ["D"]
3 changes: 2 additions & 1 deletion requirements.extra.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
datasets==2.4.0
embeddings==0.0.8
networkx==2.4
plotly==4.14.3
pydot==1.4.1
python-mnist
streamlit==1.12.0
streamlit-ace
torch
watchdog==1.0.2
altair==4.2.2
networkx==3.3
2 changes: 1 addition & 1 deletion tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@


def assert_close(a: float, b: float) -> None:
assert minitorch.is_close(a, b), "Failure x=%f y=%f" % (a, b)
assert minitorch.operators.is_close(a, b), "Failure x=%f y=%f" % (a, b)
6 changes: 3 additions & 3 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self) -> None:

@pytest.mark.task0_4
def test_stacked_demo() -> None:
"Check that each of the properties match"
"""Check that each of the properties match"""
mod = ModuleA1()
np = dict(mod.named_parameters())

Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self) -> None:
@pytest.mark.task0_4
@given(med_ints, med_ints)
def test_module(size_a: int, size_b: int) -> None:
"Check the properties of a single module"
"""Check the properties of a single module"""
module = Module2()
module.eval()
assert not module.training
Expand All @@ -116,7 +116,7 @@ def test_module(size_a: int, size_b: int) -> None:
@pytest.mark.task0_4
@given(med_ints, med_ints, small_floats)
def test_stacked_module(size_a: int, size_b: int, val: float) -> None:
"Check the properties of a stacked module"
"""Check the properties of a stacked module"""
module = Module1(size_a, size_b, val)
module.eval()
assert not module.training
Expand Down
Loading

0 comments on commit 0e47013

Please sign in to comment.