diff --git a/src/seamless_communication/streaming/agents/dual_vocoder_agent.py b/src/seamless_communication/streaming/agents/dual_vocoder_agent.py new file mode 100644 index 00000000..d05a702c --- /dev/null +++ b/src/seamless_communication/streaming/agents/dual_vocoder_agent.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations +import copy + +import logging +from argparse import ArgumentParser, Namespace +from typing import Dict, Any + +from simuleval.agents import TextToSpeechAgent +from seamless_communication.streaming.agents.common import AgentStates +from simuleval.data.segments import Segment +from simuleval.agents.actions import Action + +from seamless_communication.streaming.agents.pretssel_vocoder import ( + PretsselVocoderAgent, +) +from seamless_communication.streaming.agents.online_vocoder import VocoderAgent + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", +) + +logger = logging.getLogger(__name__) + + +class DualVocoderStates(AgentStates): + def __init__( + self, vocoder_states: AgentStates, expr_vocoder_states: AgentStates + ) -> None: + self.vocoder_states = vocoder_states + self.expr_vocoder_states = expr_vocoder_states + self.config: Dict[str, Any] = {} + + @property + def target_finished(self): # type: ignore + return ( + self.vocoder_states.target_finished + or self.expr_vocoder_states.target_finished + ) + + def reset(self) -> None: + self.vocoder_states.reset() + self.expr_vocoder_states.reset() + self.config = {} + + def update_source(self, segment: Segment) -> None: + self.vocoder_states.update_config(segment.config) + self.vocoder_states.update_source(segment) + self.expr_vocoder_states.update_config(segment.config) + self.expr_vocoder_states.update_source(segment) + + def update_target(self, segment: Segment) -> None: + self.vocoder_states.update_target(segment) + self.expr_vocoder_states.update_target(segment) + + +class DualVocoderAgent(TextToSpeechAgent): # type: ignore + def __init__( + self, + args: Namespace, + vocoder: VocoderAgent, + expr_vocoder: PretsselVocoderAgent, + ) -> None: + self.vocoder = vocoder + self.expr_vocoder = expr_vocoder + super().__init__(args) + self.expressive = args.expressive + + def build_states(self) -> DualVocoderStates: + return DualVocoderStates( + self.vocoder.build_states(), self.expr_vocoder.build_states() + ) + + @classmethod + def add_args(cls, parser: ArgumentParser) -> None: + PretsselVocoderAgent.add_args(parser) + VocoderAgent.add_args(parser) + parser.add_argument( + "--expr-vocoder-name", + type=str, + required=True, + help="expressive vocoder name - vocoder_pretssel or vocoder_pretssel_16khz", + ) + parser.add_argument( + "--expressive", + action="store_true", + help="Whether to use expressive vocoder (overridable in segment.config)", + ) + + @classmethod + def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> DualVocoderAgent: + vocoder = VocoderAgent.from_args(args) + expr_args = copy.deepcopy(args) + expr_args.vocoder_name = args.expr_vocoder_name + expr_vocoder = PretsselVocoderAgent.from_args(expr_args) + return cls(args, vocoder, expr_vocoder) + + def policy(self, states: AgentStates) -> Action: + expressive = self.expressive + if states.config is not None and "expressive" in states.config: + expressive = states.config["expressive"] + if expressive: + states.expr_vocoder_states.upstream_states = states.upstream_states + action = self.expr_vocoder.policy(states.expr_vocoder_states) + if len(states.expr_vocoder_states.source) == 0: + states.vocoder_states.source = [] + else: + action = self.vocoder.policy(states.vocoder_states) + if len(states.vocoder_states.source) == 0: + states.expr_vocoder_states.source = [] + return action diff --git a/src/seamless_communication/streaming/agents/online_unit_decoder.py b/src/seamless_communication/streaming/agents/online_unit_decoder.py index a5195e01..ac96cf85 100644 --- a/src/seamless_communication/streaming/agents/online_unit_decoder.py +++ b/src/seamless_communication/streaming/agents/online_unit_decoder.py @@ -14,9 +14,9 @@ from seamless_communication.streaming.agents.online_text_decoder import ( UnitYTextDecoderOutput, ) +from seamless_communication.streaming.agents.common import AgentStates from simuleval.agents import GenericAgent from simuleval.agents.actions import Action, ReadAction, WriteAction -from simuleval.agents.states import AgentStates from simuleval.data.segments import Segment, TextSegment diff --git a/src/seamless_communication/streaming/agents/online_vocoder.py b/src/seamless_communication/streaming/agents/online_vocoder.py index e943ae3e..2ca1ea93 100644 --- a/src/seamless_communication/streaming/agents/online_vocoder.py +++ b/src/seamless_communication/streaming/agents/online_vocoder.py @@ -11,7 +11,8 @@ import torch from seamless_communication.models.vocoder.loader import load_vocoder_model -from simuleval.agents import AgentStates, TextToSpeechAgent +from seamless_communication.streaming.agents.common import AgentStates +from simuleval.agents import TextToSpeechAgent from simuleval.agents.actions import ReadAction, WriteAction from simuleval.data.segments import SpeechSegment diff --git a/src/seamless_communication/streaming/agents/pretssel_vocoder.py b/src/seamless_communication/streaming/agents/pretssel_vocoder.py index 50912a47..b7945d34 100644 --- a/src/seamless_communication/streaming/agents/pretssel_vocoder.py +++ b/src/seamless_communication/streaming/agents/pretssel_vocoder.py @@ -16,8 +16,11 @@ from seamless_communication.models.generator.loader import load_pretssel_vocoder_model from seamless_communication.models.unity import load_gcmvn_stats from seamless_communication.store import add_gated_assets -from seamless_communication.streaming.agents.common import NoUpdateTargetMixin -from simuleval.agents import AgentStates, TextToSpeechAgent +from seamless_communication.streaming.agents.common import ( + AgentStates, + NoUpdateTargetMixin, +) +from simuleval.agents import TextToSpeechAgent from simuleval.agents.actions import ReadAction, WriteAction from simuleval.data.segments import SpeechSegment diff --git a/src/seamless_communication/streaming/agents/seamless_s2st.py b/src/seamless_communication/streaming/agents/seamless_s2st.py index 7d756fd1..07932ea7 100644 --- a/src/seamless_communication/streaming/agents/seamless_s2st.py +++ b/src/seamless_communication/streaming/agents/seamless_s2st.py @@ -21,6 +21,9 @@ from seamless_communication.streaming.agents.pretssel_vocoder import ( PretsselVocoderAgent, ) +from seamless_communication.streaming.agents.dual_vocoder_agent import ( + DualVocoderAgent, +) from seamless_communication.streaming.agents.silero_vad import SileroVADAgent from seamless_communication.streaming.agents.unity_pipeline import ( UnitYAgentPipeline, @@ -48,3 +51,15 @@ class SeamlessS2STJointVADAgent(UnitYAgentTreePipeline): NARUnitYUnitDecoderAgent: [PretsselVocoderAgent], PretsselVocoderAgent: [], } + + +class SeamlessS2STDualVocoderVADAgent(UnitYAgentTreePipeline): + pipeline = { + SileroVADAgent: [OnlineFeatureExtractorAgent], + OnlineFeatureExtractorAgent: [OfflineWav2VecBertEncoderAgent], + OfflineWav2VecBertEncoderAgent: [UnitYMMATextDecoderAgent], + UnitYMMATextDecoderAgent: [UnitYDetokenizerAgent, NARUnitYUnitDecoderAgent], + UnitYDetokenizerAgent: [], + NARUnitYUnitDecoderAgent: [DualVocoderAgent], + DualVocoderAgent: [], + }