若仔細研究UnitY Model中的組成,有語音編碼及解碼器(包含前級)、文本編碼及解碼器(包含前級)、投影層以及文本轉unit模型,UnitY Model的結構如下
class UnitYModel(EncoderDecoderModel):
"""Represents a UnitY model as described in
:cite:t`https://doi.org/10.48550/arxiv.2212.08055`.
Note that this implementation is augmented with a text encoder to enable
translating from text.
"""
model_dim: int
input_modality: str
speech_encoder_frontend: TransformerFrontend
speech_encoder: TransformerEncoder
text_encoder_frontend: Optional[TransformerFrontend]
text_encoder: Optional[TransformerEncoder]
text_decoder_frontend: TransformerFrontend
text_decoder: TransformerDecoder
final_proj: Projection
t2u_model: Optional["UnitYT2UModel"]
pad_idx: Optional[int]
def __init__(
self,
speech_encoder_frontend: TransformerFrontend,
speech_encoder: TransformerEncoder,
text_encoder_frontend: Optional[TransformerFrontend],
text_encoder: Optional[TransformerEncoder],
text_decoder_frontend: TransformerFrontend,
text_decoder: TransformerDecoder,
final_proj: Projection,
t2u_model: Optional["UnitYT2UModel"],
pad_idx: Optional[int],
input_modality: str = "speech",
) -> None:
model_dim = speech_encoder.model_dim
super().__init__(model_dim)
self.input_modality = input_modality
self.speech_encoder_frontend = speech_encoder_frontend
self.speech_encoder = speech_encoder
if text_encoder is not None:
if text_encoder_frontend is None:
raise ValueError(
"Both `text_encoder` and `text_encoder_frontend` must be specified, but `text_encoder_frontend` is `None`."
)
self.text_encoder_frontend = text_encoder_frontend
self.text_encoder = text_encoder
else:
if text_encoder_frontend is not None:
raise ValueError(
"Both `text_encoder` and `text_encoder_frontend` must be specified, but `text_encoder` is `None`."
)
self.register_module("text_encoder_frontend", None)
self.register_module("text_encoder", None)
self.text_decoder_frontend = text_decoder_frontend
self.text_decoder = text_decoder
self.final_proj = final_proj
if t2u_model is not None:
self.t2u_model = t2u_model
else:
self.register_module("t2u_model", None)
self.pad_idx = pad_idx
check_model_dim(self)
@finaloverride
def encode(
self, seqs: Tensor, seq_lens: Optional[Tensor]
) -> Tuple[Tensor, Optional[Tensor]]:
if self.input_modality == "speech":
return self.encode_speech(seqs, seq_lens)
if self.input_modality == "text":
return self.encode_text(seqs, seq_lens)
raise RuntimeError(
f"`input_modality` must be 'speech' or 'text', but is '{self.input_modality}' instead."
)
def encode_speech(
self, seqs: Tensor, seq_lens: Optional[Tensor]
) -> Tuple[Tensor, Optional[Tensor]]:
seqs, padding_mask = self.speech_encoder_frontend(seqs, seq_lens)
return self.speech_encoder(seqs, padding_mask) # type: ignore[no-any-return]
def encode_text(
self, seqs: Tensor, seq_lens: Optional[Tensor]
) -> Tuple[Tensor, Optional[Tensor]]:
if self.text_encoder is None:
raise ValueError(
"`encode_text()` requires a text encoder, but the current UnitY model does not have one."
)
assert self.text_encoder_frontend is not None
seqs, padding_mask = self.text_encoder_frontend(seqs, seq_lens)
return self.text_encoder(seqs, padding_mask) # type: ignore[no-any-return]
@finaloverride
def decode(
self,
seqs: Tensor,
seq_lens: Optional[Tensor],
encoder_output: Tensor,
encoder_padding_mask: Optional[Tensor],
state_bag: Optional[IncrementalStateBag] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
return self.text_decoder( # type: ignore[no-any-return]
seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
)
@finaloverride
def project(
self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
) -> SequenceModelOutput:
logits = self.final_proj(decoder_output)
return SequenceModelOutput(logits, self.pad_idx)
觀察其結構可發現,speech_encoder_frontend及speech_encoder所引用的模型分別為TransformerFrontend與TransformerEncoder(參考開源程式fairseq2.nn.transformer.TransformerEncoder及fairseq.models.transformer),所以如果把自己訓練好的Tranformer模型可以在此作替換,預計將Kaggle挑戰的訓練音檔訓練好模型並替換,做翻譯功能評估。