Skip to content

Commit

Permalink
Update resunet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
liuxubo717 authored Oct 17, 2023
1 parent e00c895 commit 77a0e8b
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions models/resunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,4 +652,65 @@ def forward(self, input_dict):

return output_dict

@torch.no_grad()
def chunk_inference(self, input_dict):
chunk_config = {
'NL': 1.0,
'NC': 3.0,
'NR': 1.0,
'RATE': self.sampling_rate
}

mixtures = input_dict['mixture']
conditions = input_dict['condition']

film_dict = self.film(
conditions=conditions,
)

NL = int(chunk_config['NL'] * chunk_config['RATE'])
NC = int(chunk_config['NC'] * chunk_config['RATE'])
NR = int(chunk_config['NR'] * chunk_config['RATE'])

L = mixtures.shape[2]

out_np = np.zeros([1, L])

WINDOW = NL + NC + NR
current_idx = 0

while current_idx + WINDOW < L:
chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]

chunk_out = self.base(
mixtures=chunk_in,
film_dict=film_dict,
)['waveform']

chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()

if current_idx == 0:
out_np[:, current_idx:current_idx+WINDOW-NR] = \
chunk_out_np[:, :-NR] if NR != 0 else chunk_out_np
else:
out_np[:, current_idx+NL:current_idx+WINDOW-NR] = \
chunk_out_np[:, NL:-NR] if NR != 0 else chunk_out_np[:, NL:]

current_idx += NC

if current_idx < L:
chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
chunk_out = self.base(
mixtures=chunk_in,
film_dict=film_dict,
)['waveform']

chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()

seg_len = chunk_out_np.shape[1]
out_np[:, current_idx + NL:current_idx + seg_len] = \
chunk_out_np[:, NL:]

return out_np


0 comments on commit 77a0e8b

Please sign in to comment.