Skip to content

Commit

Permalink
patch stream bug
Browse files Browse the repository at this point in the history
  • Loading branch information
suxinsen committed Aug 6, 2024
1 parent fa28f99 commit 17b83e2
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions TTS/tts/layers/xtts/stream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ def generate(
requires_attention_mask = "encoder_outputs" not in model_kwargs

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
pad_token_tensor = torch.tensor([generation_config.pad_token_id],
device=inputs_tensor.device) if generation_config.pad_token_id is not None else None
eos_token_tensor = torch.tensor([generation_config.eos_token_id],
device=inputs_tensor.device) if generation_config.eos_token_id is not None else None
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
pad_token_tensor,
eos_token_tensor,
)

# decoder-only models should use left-padding for generation
Expand Down Expand Up @@ -409,7 +413,8 @@ def generate(
)
elif is_sample_gen_stream_mode:
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config)
# logits_warper = self._get_logits_warper(generation_config)
logits_warper = self._get_logits_warper(generation_config, device=inputs_tensor.device)

# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand Down

0 comments on commit 17b83e2

Please sign in to comment.