"""Download F5-TTS model files from Hugging Face to the local models/ directory."""
from pathlib import Path
from huggingface_hub import hf_hub_download
from loguru import logger
MODELS_DIR = Path("models")
REPO_ID = "SWivid/F5-TTS"
def download_model(
model_name: str = "F5TTS_v1_Base",
repo_id: str = REPO_ID,
local_dir: Path = MODELS_DIR,
) -> None:
"""Download the model checkpoint and vocab for the requested F5-TTS variant."""
local_dir.mkdir(parents=True, exist_ok=True)
if model_name == "F5TTS_v1_Base":
filename = "F5TTS_v1_Base/model_1250000.safetensors"
elif model_name == "F5TTS_Base":
filename = "F5TTS_Base/model_1200000.safetensors"
elif model_name == "E2TTS_Base":
filename = "E2TTS_Base/model_1200000.safetensors"
else:
raise ValueError(f"Unsupported model: {model_name}")
logger.info("Downloading {} from {} ...", filename, repo_id)
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
logger.info("Model saved to {}", path)
vocab_filename = f"{model_name}/vocab.txt"
logger.info("Downloading vocab {} from {} ...", vocab_filename, repo_id)
vocab_path = hf_hub_download(
repo_id=repo_id,
filename=vocab_filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
logger.info("Vocab saved to {}", vocab_path)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Download F5-TTS model files")
parser.add_argument(
"--model",
default="F5TTS_v1_Base",
choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"],
)
parser.add_argument("--local-dir", default=str(MODELS_DIR))
args = parser.parse_args()
download_model(args.model, local_dir=Path(args.local_dir))