Skip to content

Commit

Permalink
Merge pull request jax-ml#8633 from shawwn:2021-11-19/autodidax-fix-j…
Browse files Browse the repository at this point in the history
…axpr-subcomp-return-type

PiperOrigin-RevId: 519745476
  • Loading branch information
jax authors committed Mar 27, 2023
2 parents 10d51c7 + 2d61a5f commit af4d494
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/autodidax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2038,7 +2038,7 @@
"outputs": [],
"source": [
"def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]\n",
" ) -> xe.XlaOp:\n",
" ) -> List[xe.XlaOp]:\n",
" env: Dict[Var, xe.XlaOp] = {}\n",
"\n",
" def read(x: Atom) -> xe.XlaOp:\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/autodidax.md
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ compiled program:

```{code-cell}
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
) -> xe.XlaOp:
) -> List[xe.XlaOp]:
env: Dict[Var, xe.XlaOp] = {}
def read(x: Atom) -> xe.XlaOp:
Expand Down
2 changes: 1 addition & 1 deletion docs/autodidax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ def _xla_shape(aval: ShapedArray) -> xe.Shape:

# +
def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]
) -> xe.XlaOp:
) -> List[xe.XlaOp]:
env: Dict[Var, xe.XlaOp] = {}

def read(x: Atom) -> xe.XlaOp:
Expand Down

0 comments on commit af4d494

Please sign in to comment.