Newer
Older
voice / scripts / download_f5_tts.py
"""Download F5-TTS model files from Hugging Face to the local models/ directory."""

from pathlib import Path

from huggingface_hub import hf_hub_download
from loguru import logger


MODELS_DIR = Path("models")
REPO_ID = "SWivid/F5-TTS"


def download_model(
    model_name: str = "F5TTS_v1_Base",
    repo_id: str = REPO_ID,
    local_dir: Path = MODELS_DIR,
) -> None:
    """Download the model checkpoint and vocab for the requested F5-TTS variant."""
    local_dir.mkdir(parents=True, exist_ok=True)

    if model_name == "F5TTS_v1_Base":
        filename = "F5TTS_v1_Base/model_1250000.safetensors"
    elif model_name == "F5TTS_Base":
        filename = "F5TTS_Base/model_1200000.safetensors"
    elif model_name == "E2TTS_Base":
        filename = "E2TTS_Base/model_1200000.safetensors"
    else:
        raise ValueError(f"Unsupported model: {model_name}")

    logger.info("Downloading {} from {} ...", filename, repo_id)
    path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
    )
    logger.info("Model saved to {}", path)

    vocab_filename = f"{model_name}/vocab.txt"
    logger.info("Downloading vocab {} from {} ...", vocab_filename, repo_id)
    vocab_path = hf_hub_download(
        repo_id=repo_id,
        filename=vocab_filename,
        local_dir=local_dir,
        local_dir_use_symlinks=False,
    )
    logger.info("Vocab saved to {}", vocab_path)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Download F5-TTS model files")
    parser.add_argument(
        "--model",
        default="F5TTS_v1_Base",
        choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
    )
    parser.add_argument("--local-dir", default=str(MODELS_DIR))
    args = parser.parse_args()

    download_model(args.model, local_dir=Path(args.local_dir))