Skip to content

Commit

Permalink
add dual expr/non-expr streaming agent (facebookresearch#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
annasun28 authored Dec 1, 2023
1 parent a452079 commit 2d27163
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 4 deletions.
116 changes: 116 additions & 0 deletions src/seamless_communication/streaming/agents/dual_vocoder_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import copy

import logging
from argparse import ArgumentParser, Namespace
from typing import Dict, Any

from simuleval.agents import TextToSpeechAgent
from seamless_communication.streaming.agents.common import AgentStates
from simuleval.data.segments import Segment
from simuleval.agents.actions import Action

from seamless_communication.streaming.agents.pretssel_vocoder import (
PretsselVocoderAgent,
)
from seamless_communication.streaming.agents.online_vocoder import VocoderAgent

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
)

logger = logging.getLogger(__name__)


class DualVocoderStates(AgentStates):
def __init__(
self, vocoder_states: AgentStates, expr_vocoder_states: AgentStates
) -> None:
self.vocoder_states = vocoder_states
self.expr_vocoder_states = expr_vocoder_states
self.config: Dict[str, Any] = {}

@property
def target_finished(self): # type: ignore
return (
self.vocoder_states.target_finished
or self.expr_vocoder_states.target_finished
)

def reset(self) -> None:
self.vocoder_states.reset()
self.expr_vocoder_states.reset()
self.config = {}

def update_source(self, segment: Segment) -> None:
self.vocoder_states.update_config(segment.config)
self.vocoder_states.update_source(segment)
self.expr_vocoder_states.update_config(segment.config)
self.expr_vocoder_states.update_source(segment)

def update_target(self, segment: Segment) -> None:
self.vocoder_states.update_target(segment)
self.expr_vocoder_states.update_target(segment)


class DualVocoderAgent(TextToSpeechAgent): # type: ignore
def __init__(
self,
args: Namespace,
vocoder: VocoderAgent,
expr_vocoder: PretsselVocoderAgent,
) -> None:
self.vocoder = vocoder
self.expr_vocoder = expr_vocoder
super().__init__(args)
self.expressive = args.expressive

def build_states(self) -> DualVocoderStates:
return DualVocoderStates(
self.vocoder.build_states(), self.expr_vocoder.build_states()
)

@classmethod
def add_args(cls, parser: ArgumentParser) -> None:
PretsselVocoderAgent.add_args(parser)
VocoderAgent.add_args(parser)
parser.add_argument(
"--expr-vocoder-name",
type=str,
required=True,
help="expressive vocoder name - vocoder_pretssel or vocoder_pretssel_16khz",
)
parser.add_argument(
"--expressive",
action="store_true",
help="Whether to use expressive vocoder (overridable in segment.config)",
)

@classmethod
def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> DualVocoderAgent:
vocoder = VocoderAgent.from_args(args)
expr_args = copy.deepcopy(args)
expr_args.vocoder_name = args.expr_vocoder_name
expr_vocoder = PretsselVocoderAgent.from_args(expr_args)
return cls(args, vocoder, expr_vocoder)

def policy(self, states: AgentStates) -> Action:
expressive = self.expressive
if states.config is not None and "expressive" in states.config:
expressive = states.config["expressive"]
if expressive:
states.expr_vocoder_states.upstream_states = states.upstream_states
action = self.expr_vocoder.policy(states.expr_vocoder_states)
if len(states.expr_vocoder_states.source) == 0:
states.vocoder_states.source = []
else:
action = self.vocoder.policy(states.vocoder_states)
if len(states.vocoder_states.source) == 0:
states.expr_vocoder_states.source = []
return action
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from seamless_communication.streaming.agents.online_text_decoder import (
UnitYTextDecoderOutput,
)
from seamless_communication.streaming.agents.common import AgentStates
from simuleval.agents import GenericAgent
from simuleval.agents.actions import Action, ReadAction, WriteAction
from simuleval.agents.states import AgentStates
from simuleval.data.segments import Segment, TextSegment


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import torch
from seamless_communication.models.vocoder.loader import load_vocoder_model
from simuleval.agents import AgentStates, TextToSpeechAgent
from seamless_communication.streaming.agents.common import AgentStates
from simuleval.agents import TextToSpeechAgent
from simuleval.agents.actions import ReadAction, WriteAction
from simuleval.data.segments import SpeechSegment

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
from seamless_communication.models.unity import load_gcmvn_stats
from seamless_communication.store import add_gated_assets
from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
from simuleval.agents import AgentStates, TextToSpeechAgent
from seamless_communication.streaming.agents.common import (
AgentStates,
NoUpdateTargetMixin,
)
from simuleval.agents import TextToSpeechAgent
from simuleval.agents.actions import ReadAction, WriteAction
from simuleval.data.segments import SpeechSegment

Expand Down
15 changes: 15 additions & 0 deletions src/seamless_communication/streaming/agents/seamless_s2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from seamless_communication.streaming.agents.pretssel_vocoder import (
PretsselVocoderAgent,
)
from seamless_communication.streaming.agents.dual_vocoder_agent import (
DualVocoderAgent,
)
from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
from seamless_communication.streaming.agents.unity_pipeline import (
UnitYAgentPipeline,
Expand Down Expand Up @@ -48,3 +51,15 @@ class SeamlessS2STJointVADAgent(UnitYAgentTreePipeline):
NARUnitYUnitDecoderAgent: [PretsselVocoderAgent],
PretsselVocoderAgent: [],
}


class SeamlessS2STDualVocoderVADAgent(UnitYAgentTreePipeline):
pipeline = {
SileroVADAgent: [OnlineFeatureExtractorAgent],
OnlineFeatureExtractorAgent: [OfflineWav2VecBertEncoderAgent],
OfflineWav2VecBertEncoderAgent: [UnitYMMATextDecoderAgent],
UnitYMMATextDecoderAgent: [UnitYDetokenizerAgent, NARUnitYUnitDecoderAgent],
UnitYDetokenizerAgent: [],
NARUnitYUnitDecoderAgent: [DualVocoderAgent],
DualVocoderAgent: [],
}

0 comments on commit 2d27163

Please sign in to comment.