Skip to content

Commit

Permalink
Fix some things that broke in the transformers (LLNL#2106)
Browse files Browse the repository at this point in the history
* Fix some things that broke from the recent API changes

* cleanup some scaling_factor=str() issues

* Turns out, extend doesn't return the list :/
  • Loading branch information
benson31 authored May 2, 2022
1 parent 01cc59f commit e9a31b5
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion applications/nlp/transformer/subgraph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def make_model(
)
embeddings = lbann.WeightedSum(
embeddings,
scaling_factors=str(math.sqrt(embed_dim)),
scaling_factors=math.sqrt(embed_dim),
)
embeddings_slice = lbann.Slice(
embeddings,
Expand Down
2 changes: 1 addition & 1 deletion applications/nlp/transformer/subgraph/train_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def make_model(
)
embeddings = lbann.WeightedSum(
embeddings,
scaling_factors=str(math.sqrt(embed_dim)),
scaling_factors=math.sqrt(embed_dim),
)
embeddings_slice = lbann.Slice(
embeddings,
Expand Down
2 changes: 1 addition & 1 deletion applications/nlp/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def make_model(
)
embeddings = lbann.WeightedSum(
embeddings,
scaling_factors=str(math.sqrt(embed_dim)),
scaling_factors=math.sqrt(embed_dim),
)
embeddings_slice = lbann.Slice(
embeddings,
Expand Down
4 changes: 2 additions & 2 deletions applications/physics/cosmology/ExaGAN/ExaGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def inv_transform(self,y):
lbann.SafeDivide(
lbann.Add(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y)),
lbann.Subtract(lbann.Constant(value=1.0, hint_layer=y),lbann.Identity(y))),
scaling_factors=str(self.datascale))
scaling_factors=self.datascale)
linear_scale = 1/self.linear_scaler
CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=str(linear_scale)))
CH2 = lbann.Tanh(lbann.WeightedSum(inv_transform,scaling_factors=linear_scale))
return CH2
2 changes: 1 addition & 1 deletion python/lbann/models/subgraph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,7 @@ def _subsequent_mask(self, size):
if size not in self._subsequent_mask_cache:
vals = np.triu(np.full((size,size), -1e9), k=1)
weights = lbann.Weights(
initializer=lbann.ValueInitializer(values=np.nditer(vals)),
initializer=lbann.ValueInitializer(values=vals.flat),
optimizer=None,
name=f'{self.name}_mask{size}_weights',
)
Expand Down
6 changes: 3 additions & 3 deletions python/lbann/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def forward(self, x):
# Affine transform
s = lbann.WeightsLayer(
weights=self.weight,
dims=f'1 {self.normalized_shape}',
dims=[1] + list(make_iterable(self.normalized_shape)),
)
s = lbann.Tessellate(s, hint_layer=x)
b = lbann.WeightsLayer(
weights=self.bias,
dims=f'1 {self.normalized_shape}',
dims=[1] + list(make_iterable(self.normalized_shape)),
)
b = lbann.Tessellate(b, hint_layer=x)
x = lbann.Add(lbann.Multiply(s,x), b)
Expand Down Expand Up @@ -441,7 +441,7 @@ def _subsequent_mask(self, size):
if size not in self._subsequent_mask_cache:
vals = np.triu(np.full((size,size), -1e9), k=1)
weights = lbann.Weights(
initializer=lbann.ValueInitializer(values=np.nditer(vals)),
initializer=lbann.ValueInitializer(values=vals.flat),
optimizer=None,
name=f'{self.name}_mask{size}_weights',
)
Expand Down
6 changes: 3 additions & 3 deletions python/lbann/modules/subgraph/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def forward(self, queries, keys, values, mask=None):
)
y = lbann.WeightedSum(
y,
scaling_factors=str(1 / math.sqrt(self.head_dim)),
scaling_factors=1 / math.sqrt(self.head_dim),
name=f'{head_name}_scale',
)

Expand Down Expand Up @@ -432,7 +432,7 @@ def forward(self, queries, keys, values, mask=None):
)
y = lbann.WeightedSum(
y,
scaling_factors=str(1 / math.sqrt(self.head_dim)),
scaling_factors=1 / math.sqrt(self.head_dim),
name=f'{head_name}_scale',
)

Expand Down Expand Up @@ -766,7 +766,7 @@ def forward(self, queries, keys, values, mask=None):
)
y = lbann.WeightedSum(
y,
scaling_factors=str(1 / math.sqrt(self.head_dim)),
scaling_factors=1 / math.sqrt(self.head_dim),
name=f'{head_name}_scale',
)

Expand Down
2 changes: 1 addition & 1 deletion python/lbann/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def forward(self, queries, keys, values, mask=None):
)
y = lbann.WeightedSum(
y,
scaling_factors=str(1 / math.sqrt(self.head_dim)),
scaling_factors=1 / math.sqrt(self.head_dim),
name=f'{head_name}_scale',
)
if mask:
Expand Down

0 comments on commit e9a31b5

Please sign in to comment.