Skip to content

Commit

Permalink
Added things
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee committed Feb 4, 2021
1 parent e7d3aa1 commit 745e2a8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions fx/vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
#
# How is this feat accomplished? One observation is that to "batch" a model, it
# suffices to batch each individual operation. In other words, given an
# operation that works on the current shape, how do we make it work on another
# batch dimension? This leads us to batching rules.
# operation that works on the current shape, how do we make it work with an
# additional batch dimension? This leads us to batching rules.
#
# Batching Rules
# ---------------
Expand Down Expand Up @@ -194,4 +194,4 @@ def forward(self, a, b):
# outer product computation. ((B, N), (M,)) -> (B, N, M)

model = vmap(model, in_axes=(0, None), example_args=(x[0], y))
print(model(x, y).shape) # ((3, 5), (2,)) -> (3, 5, 2)
print(model(x, y).shape) # ((3, 5), (2,)) -> (3, 5, 2)

0 comments on commit 745e2a8

Please sign in to comment.