Skip to content

Commit

Permalink
Add test files
Browse files Browse the repository at this point in the history
  • Loading branch information
sooftware committed Jan 30, 2021
1 parent 4e2734e commit 542aa79
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 308 deletions.
3 changes: 1 addition & 2 deletions kospeech/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class ModelConfig:

from kospeech.models.deepspeech2.model import DeepSpeech2
from kospeech.models.las.encoder import EncoderRNN
from kospeech.models.las.decoder import DecoderRNN
from kospeech.models.las.topk_decoder import TopKDecoder
from kospeech.models.las.decoder import DecoderRNN, BeamDecoderRNN
from kospeech.models.las.model import ListenAttendSpell
from kospeech.models.transformer.model import SpeechTransformer
from kospeech.models.jasper.model import Jasper
Expand Down
4 changes: 2 additions & 2 deletions kospeech/models/conformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def __init__(
)
self.fc = nn.Sequential(
LayerNorm(encoder_dim),
Linear(encoder_dim, num_classes, bias=False)
Linear(encoder_dim, num_classes, bias=False),
)

def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
outputs, output_lengths = self.encoder(inputs, input_lengths)
outputs = self.fc(outputs).log_softmax(dim=-1)
outputs = self.fc(outputs).log_softmax(dim=2)
return outputs, output_lengths
183 changes: 180 additions & 3 deletions kospeech/models/las/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class DecoderRNN(DecoderInterface):
Inputs: inputs, encoder_outputs, teacher_forcing_ratio
- **inputs** (batch, seq_len, input_size): list of sequences, whose length is the batch size and within which
each sequence is a list of token IDs. It is used for teacher forcing when provided. (default `None`)
- **encoder_outputs** (batch, seq_len, hidden_dim): tensor with containing the outputs of the encoder.
- **encoder_outputs** (batch, seq_len, hidden_state_dim): tensor with containing the outputs of the encoder.
Used for attention mechanism (default is `None`).
- **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is
drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value,
Expand Down Expand Up @@ -263,9 +263,28 @@ def _validate_args(


class BeamDecoderRNN(DecoderInterface):
def __init__(self, decoder: DecoderRNN):
""" Beam Search Decoder RNN """
def __init__(self, decoder: DecoderRNN, beam_size: int, batch_size: int):
super(BeamDecoderRNN, self).__init__()
self.decoder = decoder
self.beam_size = beam_size
self.batch_size = batch_size
self.hidden_state_dim = decoder.hidden_state_dim
self.pad_id = decoder.pad_id
self.eos_id = decoder.eos_id
self.device = decoder.device
self.num_layers = decoder.num_layers
self.ongoing_beams = None
self.cumulative_ps = None
self.finished = [[] for _ in range(batch_size)]
self.finished_ps = [[] for _ in range(batch_size)]
self.validate_args = decoder.validate_args

def _inflate(self, tensor: Tensor, n_repeat: int, dim: int) -> Tensor:
repeat_dims = [1] * len(tensor.size())
repeat_dims[dim] *= n_repeat

return tensor.repeat(*repeat_dims)

def forward_step(
self,
Expand All @@ -282,4 +301,162 @@ def forward(self, encoder_outputs: Tensor) -> list:
@torch.no_grad()
def decode(self, encoder_outputs: Tensor, encoder_output_lengths: Tensor) -> Tensor:
""" Applies beam search decoing (Top k decoding) """
pass
batch_size, hidden_states = encoder_outputs.size(0), None
inputs, batch_size, max_length = self.validate_args(None, encoder_outputs, teacher_forcing_ratio=0.0)

step_outputs, hidden_states, attn = self.forward_step(inputs, hidden_states, encoder_outputs)
self.cumulative_ps, self.ongoing_beams = step_outputs.topk(self.beam_size)

self.ongoing_beams = self.ongoing_beams.view(batch_size * self.beam_size, 1)
self.cumulative_ps = self.cumulative_ps.view(batch_size * self.beam_size, 1)

input_var = self.ongoing_beams

encoder_dim = encoder_outputs.size(2)
encoder_outputs = self._inflate(encoder_outputs, self.beam_size, dim=0)
encoder_outputs = encoder_outputs.view(self.beam_size, batch_size, -1, encoder_dim)
encoder_outputs = encoder_outputs.transpose(0, 1)
encoder_outputs = encoder_outputs.reshape(batch_size * self.beam_size, -1, encoder_dim)
hidden_states = self._inflate(hidden_states, self.beam_size, dim=1)

for di in range(max_length - 1):
if self._is_all_finished(self.beam_size):
break

hidden_states = hidden_states.view(self.num_layers, batch_size * self.beam_size, self.hidden_state_dim)
step_outputs, hidden_states, attn = self.forward_step(input_var, hidden_states, encoder_outputs, attn)

step_outputs = step_outputs.view(batch_size, self.beam_size, -1)
current_ps, current_vs = step_outputs.topk(self.beam_size)

self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)

current_ps = (current_ps.permute(0, 2, 1) + self.cumulative_ps.unsqueeze(1)).permute(0, 2, 1)
current_ps = current_ps.view(batch_size, self.beam_size ** 2)
current_vs = current_vs.view(batch_size, self.beam_size ** 2)

self.cumulative_ps = self.cumulative_ps.view(batch_size, self.beam_size)
self.ongoing_beams = self.ongoing_beams.view(batch_size, self.beam_size, -1)

topk_current_ps, topk_status_ids = current_ps.topk(self.beam_size)
prev_status_ids = (topk_status_ids // self.beam_size)

topk_current_vs = torch.zeros((batch_size, self.beam_size), dtype=torch.long)
prev_status = torch.zeros(self.ongoing_beams.size(), dtype=torch.long)

for batch_idx, batch in enumerate(topk_status_ids):
for idx, topk_status_idx in enumerate(batch):
topk_current_vs[batch_idx, idx] = current_vs[batch_idx, topk_status_idx]
prev_status[batch_idx, idx] = self.ongoing_beams[batch_idx, prev_status_ids[batch_idx, idx]]

self.ongoing_beams = torch.cat([prev_status, topk_current_vs.unsqueeze(2)], dim=2).to(self.device)
self.cumulative_ps = topk_current_ps.to(self.device)

if torch.any(topk_current_vs == self.eos_id):
finished_ids = torch.where(topk_current_vs == self.eos_id)
num_successors = [1] * batch_size

for (batch_idx, idx) in zip(*finished_ids):
self.finished[batch_idx].append(self.ongoing_beams[batch_idx, idx])
self.finished_ps[batch_idx].append(self.cumulative_ps[batch_idx, idx])

if self.beam_size != 1:
eos_count = self._get_successor(
current_ps=current_ps,
current_vs=current_vs,
finished_ids=(batch_idx, idx),
num_successor=num_successors[batch_idx],
eos_count=1,
k=self.beam_size,
)
num_successors[batch_idx] += eos_count

input_var = self.ongoing_beams[:, :, -1]
input_var = input_var.view(batch_size * self.beam_size, -1)

predictions = self._get_hypothesis()
predictions = torch.stack(predictions, dim=1)
return predictions

def _get_successor(
self,
current_ps: Tensor,
current_vs: Tensor,
finished_ids: tuple,
num_successor: int,
eos_count: int,
k: int
) -> int:
finished_batch_idx, finished_idx = finished_ids

successor_ids = current_ps.topk(k + num_successor)[1]
successor_idx = successor_ids[finished_batch_idx, -1]

successor_p = current_ps[finished_batch_idx, successor_idx].to(self.device)
successor_v = current_vs[finished_batch_idx, successor_idx].to(self.device)

prev_status_idx = (successor_idx // k)
prev_status = self.ongoing_beams[finished_batch_idx, prev_status_idx]
prev_status = prev_status.view(-1)[:-1].to(self.device)

successor = torch.cat([prev_status, successor_v.view(1)])

if int(successor_v) == self.eos_id:
self.finished[finished_batch_idx].append(successor)
self.finished_ps[finished_batch_idx].append(successor_p)
eos_count = self._get_successor(
current_ps=current_ps,
current_vs=current_vs,
finished_ids=finished_ids,
num_successor=num_successor + eos_count,
eos_count=eos_count + 1,
k=k,
)

else:
self.ongoing_beams[finished_batch_idx, finished_idx] = successor
self.cumulative_ps[finished_batch_idx, finished_idx] = successor_p

return eos_count

def _get_hypothesis(self):
y_hats = list()

for batch_idx, batch in enumerate(self.finished):
# if there is no terminated sentences, bring ongoing sentence which has the highest probability instead
if len(batch) == 0:
prob_batch = self.cumulative_ps[batch_idx].to(self.device)
top_beam_idx = int(prob_batch.topk(1)[1])
y_hats.append(self.ongoing_beams[batch_idx, top_beam_idx])

# bring highest probability sentence
else:
top_beam_idx = int(torch.FloatTensor(self.finished_ps[batch_idx]).topk(1)[1])
y_hats.append(self.finished[batch_idx][top_beam_idx])

y_hats = self._fill_sequence(y_hats).to(self.device)
return y_hats

def _is_all_finished(self, k: int) -> bool:
for done in self.finished:
if len(done) < k:
return False

return True

def _fill_sequence(self, y_hats: list) -> Tensor:
batch_size = len(y_hats)
max_length = -1

for y_hat in y_hats:
if len(y_hat) > max_length:
max_length = len(y_hat)

matched = torch.zeros((batch_size, max_length), dtype=torch.long).to(self.device)

for batch_idx, y_hat in enumerate(y_hats):
matched[batch_idx, :len(y_hat)] = y_hat
matched[batch_idx, len(y_hat):] = int(self.pad_id)

return matched
Loading

0 comments on commit 542aa79

Please sign in to comment.