Skip to content

Commit 0391fd5

Browse files
MixtureTable example2 + bug fix
1 parent 8cbcb77 commit 0391fd5

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

MixtureTable.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function MixtureTable:updateOutput(input)
2525
if self.batchSize ~= expertInputs:size(1) then
2626
self.size:resize(expertInputs:dim()):fill(1)
2727
self.size[1] = gaterInput:size(1)
28-
self.size[2] = gaterInput:size(2)
28+
self.size[self.dim] = gaterInput:size(2)
2929
self.output:resizeAs(expertInputs:select(self.dim, 1))
3030
self.batchSize = expertInputs:size(1)
3131
self.backwardSetup = false

doc/table.md

+39
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,45 @@ Forwarding a batch of 2 examples gives us something like this:
434434
[torch.DoubleTensor of dimension 2x5]
435435
```
436436

437+
Example 2:
438+
In the following, the MixtureTable expects `experts` to be a Tensor of
439+
`size = {1,4,2,5,n}`:
440+
```lua
441+
experts = nn.Concat(5)
442+
for i=1,n do
443+
local expert = nn.Sequential()
444+
expert:add(nn.Linear(3,4))
445+
expert:add(nn.Tanh())
446+
expert:add(nn.Linear(4,2*5))
447+
expert:add(nn.Tanh())
448+
expert:add(nn.Reshape(4,2,5,1))
449+
experts:add(expert)
450+
end
451+
452+
gater = nn.Sequential()
453+
gater:add(nn.Linear(3,7))
454+
gater:add(nn.Tanh())
455+
gater:add(nn.Linear(7,n))
456+
gater:add(nn.SoftMax())
457+
458+
trunk = nn.ConcatTable()
459+
trunk:add(gater)
460+
trunk:add(experts)
461+
462+
moe = nn.Sequential()
463+
moe:add(trunk)
464+
moe:add(nn.MixtureTable(5))
465+
```
466+
Forwarding a batch of 2 examples gives us something like this:
467+
```lua
468+
> =moe:forward(torch.randn(2,3)):size()
469+
2
470+
4
471+
2
472+
5
473+
[torch.LongStorage of size 4]
474+
475+
```
437476

438477

439478
<a name="nn.SelectTable"/>

0 commit comments

Comments
 (0)