Skip to content

Commit 357c51b

Browse files
authored
Merge pull request #45 from audiohacking/thinking-lm
thinking-lm: Quality presets, wire Thinking/LM params, LM planner fro…
2 parents b547f9d + 59f10f8 commit 357c51b

19 files changed

Lines changed: 1831 additions & 115 deletions

CDMF.spec

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ a = Analysis(
246246
'cdmf_pipeline_ace_step',
247247
# Trainer CLI parser (--train --help path; avoids loading full cdmf_trainer in frozen app)
248248
'cdmf_trainer_parser',
249+
# ACE-Step 1.5 model downloader (bundled so acestep-download is always available)
250+
'acestep15_downloader',
251+
'acestep15_downloader.model_downloader',
249252
# Lyrics prompt model (lazily imported in cdmf_generation.py)
250253
'lyrics_prompt_model',
251254
# ACE-Step package and all its submodules (critical for frozen app)

acestep15_downloader/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ACE-Step 1.5 model downloader (vendored from github.com/ace-step/ACE-Step-1.5).
2+
# Use: python -m acestep15_downloader.model_downloader --dir <path> [--model <name>]
3+
# Or call acestep15_downloader.model_downloader.main() programmatically.
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
"""
2+
ACE-Step 1.5 Model Downloader (vendored from ACE-Step-1.5 for AceForge bundle).
3+
Downloads models from HuggingFace Hub or ModelScope (ModelScope optional).
4+
Source: https://github.com/ace-step/ACE-Step-1.5/blob/main/acestep/model_downloader.py
5+
6+
AceForge extensions: progress callback and cancel support for UI (tqdm_class + callbacks).
7+
"""
8+
9+
import argparse
10+
import os
11+
import socket
12+
import sys
13+
from pathlib import Path
14+
from typing import Callable, Dict, List, Optional, Tuple
15+
16+
try:
17+
from loguru import logger
18+
except ImportError:
19+
import logging
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class DownloadCancelled(Exception):
24+
"""Raised when the user cancels an in-progress model download."""
25+
pass
26+
27+
28+
# =============================================================================
29+
# Network & Download
30+
# =============================================================================
31+
32+
def _can_access_google(timeout: float = 3.0) -> bool:
33+
try:
34+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
35+
sock.settimeout(timeout)
36+
sock.connect(("www.google.com", 443))
37+
sock.close()
38+
return True
39+
except (socket.timeout, socket.error, OSError):
40+
return False
41+
42+
def _make_progress_tqdm(
43+
progress_callback: Optional[Callable[..., None]],
44+
cancel_check: Optional[Callable[[], bool]],
45+
):
46+
"""Build a tqdm subclass that reports progress and respects cancel_check."""
47+
try:
48+
from tqdm.auto import tqdm as base_tqdm
49+
except ImportError:
50+
base_tqdm = None
51+
52+
if base_tqdm is None:
53+
return None
54+
55+
progress_cb = progress_callback
56+
cancel_fn = cancel_check
57+
58+
class ProgressTqdm(base_tqdm):
59+
def update(self, n: int = 1) -> Optional[bool]:
60+
if cancel_fn and cancel_fn():
61+
raise DownloadCancelled("Download cancelled by user")
62+
result = super().update(n)
63+
if progress_cb and self.total:
64+
try:
65+
progress_cb(
66+
file_index=int(self.n),
67+
total_files=int(self.total),
68+
current_file=str(self.desc) if self.desc else None,
69+
fraction=self.n / self.total,
70+
)
71+
except Exception:
72+
pass
73+
return result
74+
75+
return ProgressTqdm
76+
77+
78+
def _download_from_huggingface(
79+
repo_id: str,
80+
local_dir: Path,
81+
token: Optional[str] = None,
82+
progress_callback: Optional[Callable[..., None]] = None,
83+
cancel_check: Optional[Callable[[], bool]] = None,
84+
) -> None:
85+
from huggingface_hub import snapshot_download
86+
logger.info(f"[Model Download] Downloading from HuggingFace: {repo_id} -> {local_dir}")
87+
tqdm_class = _make_progress_tqdm(progress_callback, cancel_check)
88+
kwargs = dict(
89+
repo_id=repo_id,
90+
local_dir=str(local_dir),
91+
local_dir_use_symlinks=False,
92+
token=token,
93+
max_workers=4,
94+
)
95+
if tqdm_class is not None:
96+
kwargs["tqdm_class"] = tqdm_class
97+
snapshot_download(**kwargs)
98+
99+
def _download_from_modelscope(repo_id: str, local_dir: Path) -> None:
100+
try:
101+
from modelscope import snapshot_download
102+
except ImportError:
103+
raise RuntimeError("ModelScope not installed. Install with: pip install modelscope")
104+
logger.info(f"[Model Download] Downloading from ModelScope: {repo_id} -> {local_dir}")
105+
snapshot_download(model_id=repo_id, local_dir=str(local_dir))
106+
107+
def _smart_download(
108+
repo_id: str,
109+
local_dir: Path,
110+
token: Optional[str] = None,
111+
prefer_source: Optional[str] = None,
112+
progress_callback: Optional[Callable[..., None]] = None,
113+
cancel_check: Optional[Callable[[], bool]] = None,
114+
) -> Tuple[bool, str]:
115+
local_dir.mkdir(parents=True, exist_ok=True)
116+
use_hf_first = prefer_source != "modelscope" if prefer_source else _can_access_google()
117+
hf_kw = {"progress_callback": progress_callback, "cancel_check": cancel_check}
118+
if use_hf_first:
119+
try:
120+
_download_from_huggingface(repo_id, local_dir, token, **hf_kw)
121+
return True, f"Successfully downloaded from HuggingFace: {repo_id}"
122+
except DownloadCancelled:
123+
raise
124+
except Exception as e:
125+
logger.warning(f"[Model Download] HuggingFace failed: {e}")
126+
try:
127+
_download_from_modelscope(repo_id, local_dir)
128+
return True, f"Successfully downloaded from ModelScope: {repo_id}"
129+
except Exception as e2:
130+
return False, f"Both sources failed. HF: {e}, MS: {e2}"
131+
else:
132+
try:
133+
_download_from_modelscope(repo_id, local_dir)
134+
return True, f"Successfully downloaded from ModelScope: {repo_id}"
135+
except Exception as e:
136+
logger.warning(f"[Model Download] ModelScope failed: {e}")
137+
try:
138+
_download_from_huggingface(repo_id, local_dir, token, **hf_kw)
139+
return True, f"Successfully downloaded from HuggingFace: {repo_id}"
140+
except DownloadCancelled:
141+
raise
142+
except Exception as e2:
143+
return False, f"Both sources failed. MS: {e}, HF: {e2}"
144+
145+
# =============================================================================
146+
# Model Registry (ACE-Step 1.5)
147+
# =============================================================================
148+
MAIN_MODEL_REPO = "ACE-Step/Ace-Step1.5"
149+
SUBMODEL_REGISTRY: Dict[str, str] = {
150+
"acestep-5Hz-lm-0.6B": "ACE-Step/acestep-5Hz-lm-0.6B",
151+
"acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B",
152+
"acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
153+
"acestep-v15-sft": "ACE-Step/acestep-v15-sft",
154+
"acestep-v15-base": "ACE-Step/acestep-v15-base",
155+
"acestep-v15-turbo-shift1": "ACE-Step/acestep-v15-turbo-shift1",
156+
"acestep-v15-turbo-continuous": "ACE-Step/acestep-v15-turbo-continuous",
157+
}
158+
MAIN_MODEL_COMPONENTS = [
159+
"acestep-v15-turbo",
160+
"vae",
161+
"Qwen3-Embedding-0.6B",
162+
"acestep-5Hz-lm-1.7B",
163+
]
164+
DEFAULT_LM_MODEL = "acestep-5Hz-lm-1.7B"
165+
166+
def get_checkpoints_dir(custom_dir: Optional[str] = None) -> Path:
167+
if custom_dir:
168+
return Path(custom_dir).resolve()
169+
return Path.cwd() / "checkpoints"
170+
171+
def check_main_model_exists(checkpoints_dir: Optional[Path] = None) -> bool:
172+
if checkpoints_dir is None:
173+
checkpoints_dir = get_checkpoints_dir()
174+
for component in MAIN_MODEL_COMPONENTS:
175+
if not (checkpoints_dir / component).exists():
176+
return False
177+
return True
178+
179+
def check_model_exists(model_name: str, checkpoints_dir: Optional[Path] = None) -> bool:
180+
if checkpoints_dir is None:
181+
checkpoints_dir = get_checkpoints_dir()
182+
return (checkpoints_dir / model_name).exists()
183+
184+
def download_main_model(
185+
checkpoints_dir: Optional[Path] = None,
186+
force: bool = False,
187+
token: Optional[str] = None,
188+
prefer_source: Optional[str] = None,
189+
progress_callback: Optional[Callable[..., None]] = None,
190+
cancel_check: Optional[Callable[[], bool]] = None,
191+
) -> Tuple[bool, str]:
192+
if checkpoints_dir is None:
193+
checkpoints_dir = get_checkpoints_dir()
194+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
195+
if not force and check_main_model_exists(checkpoints_dir):
196+
return True, f"Main model already exists at {checkpoints_dir}"
197+
return _smart_download(
198+
MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source,
199+
progress_callback=progress_callback, cancel_check=cancel_check,
200+
)
201+
202+
def download_submodel(
203+
model_name: str,
204+
checkpoints_dir: Optional[Path] = None,
205+
force: bool = False,
206+
token: Optional[str] = None,
207+
prefer_source: Optional[str] = None,
208+
progress_callback: Optional[Callable[..., None]] = None,
209+
cancel_check: Optional[Callable[[], bool]] = None,
210+
) -> Tuple[bool, str]:
211+
if model_name not in SUBMODEL_REGISTRY:
212+
return False, f"Unknown model '{model_name}'. Available: {', '.join(SUBMODEL_REGISTRY.keys())}"
213+
if checkpoints_dir is None:
214+
checkpoints_dir = get_checkpoints_dir()
215+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
216+
model_path = checkpoints_dir / model_name
217+
if not force and model_path.exists():
218+
return True, f"Model '{model_name}' already exists at {model_path}"
219+
repo_id = SUBMODEL_REGISTRY[model_name]
220+
return _smart_download(
221+
repo_id, model_path, token, prefer_source,
222+
progress_callback=progress_callback, cancel_check=cancel_check,
223+
)
224+
225+
def main() -> int:
226+
parser = argparse.ArgumentParser(description="Download ACE-Step 1.5 models (HuggingFace / ModelScope)")
227+
parser.add_argument("--model", "-m", type=str, help="Model to download (use --list to see available)")
228+
parser.add_argument("--all", "-a", action="store_true", help="Download all models")
229+
parser.add_argument("--list", "-l", action="store_true", help="List available models")
230+
parser.add_argument("--dir", "-d", type=str, default=None, help="Checkpoints directory (required when run from AceForge)")
231+
parser.add_argument("--force", "-f", action="store_true", help="Force re-download")
232+
parser.add_argument("--token", "-t", type=str, default=None, help="HuggingFace token")
233+
parser.add_argument("--skip-main", action="store_true", help="Skip main model when downloading a sub-model")
234+
args = parser.parse_args()
235+
236+
if args.list:
237+
print("\nAvailable models:")
238+
print(" main ->", MAIN_MODEL_REPO, "(vae, turbo DiT, 1.7B LM)")
239+
for name, repo in SUBMODEL_REGISTRY.items():
240+
print(f" {name} -> {repo}")
241+
return 0
242+
243+
checkpoints_dir = get_checkpoints_dir(args.dir)
244+
if not args.dir:
245+
print(f"Checkpoints directory: {checkpoints_dir} (use --dir to override)")
246+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
247+
248+
if args.all:
249+
success, msg = download_main_model(checkpoints_dir, args.force, args.token)
250+
print(msg)
251+
if not success:
252+
return 1
253+
for name in SUBMODEL_REGISTRY:
254+
ok, m = download_submodel(name, checkpoints_dir, args.force, args.token)
255+
print(m)
256+
if not ok:
257+
success = False
258+
return 0 if success else 1
259+
260+
if args.model:
261+
if args.model == "main":
262+
success, msg = download_main_model(checkpoints_dir, args.force, args.token)
263+
elif args.model in SUBMODEL_REGISTRY:
264+
if not args.skip_main and not check_main_model_exists(checkpoints_dir):
265+
print("Main model not found. Downloading main model first...")
266+
ok, m = download_main_model(checkpoints_dir, args.force, args.token)
267+
print(m)
268+
if not ok:
269+
return 1
270+
success, msg = download_submodel(args.model, checkpoints_dir, args.force, args.token)
271+
else:
272+
print(f"Unknown model: {args.model}. Use --list to see available models.")
273+
return 1
274+
print(msg)
275+
return 0 if success else 1
276+
277+
# Default: main model
278+
print("Downloading main model (vae, turbo DiT, 1.7B LM)...")
279+
success, msg = download_main_model(checkpoints_dir, args.force, args.token)
280+
print(msg)
281+
return 0 if success else 1
282+
283+
if __name__ == "__main__":
284+
sys.exit(main())

api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from api.reference_tracks import bp as reference_tracks_bp
1212
from api.search import bp as search_bp
1313
from api.preferences import bp as preferences_bp
14+
from api.ace_step_models import bp as ace_step_models_bp
1415

1516
__all__ = [
1617
"auth_bp",
@@ -22,4 +23,5 @@
2223
"reference_tracks_bp",
2324
"search_bp",
2425
"preferences_bp",
26+
"ace_step_models_bp",
2527
]

0 commit comments

Comments
 (0)