| | import os |
| | import torch |
| | import torchaudio |
| | import psutil |
| | import time |
| | import sys |
| | import numpy as np |
| | import gc |
| | import gradio as gr |
| | from pydub import AudioSegment |
| | import soundfile as sf |
| | import pyloudnorm as pyln |
| | from audiocraft.models import MusicGen |
| | from torch.amp import autocast |
| | import json |
| | import configparser |
| | import random |
| | import string |
| | import uvicorn |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.responses import FileResponse |
| | from pydantic import BaseModel |
| | import multiprocessing |
| | import re |
| | import datetime |
| | import warnings |
| |
|
| | |
| | |
| | |
| | warnings.filterwarnings("ignore", category=UserWarning) |
| | multiprocessing.set_start_method('spawn', force=True) |
| |
|
| | |
| | |
| | |
| | os.environ["TORCH_NN_UTILS_LOG_LEVEL"] = "0" |
| | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
| | os.environ["CUDA_MODULE_LOADING"] = "LAZY" |
| | os.environ["TORCH_USE_CUDA_DSA"] = "1" |
| | |
| | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,garbage_collection_threshold:0.8,expandable_segments:True" |
| | |
| | os.environ["TORCH_CUDA_ARCH_LIST"] = "7.5;8.0;8.6;8.9" |
| |
|
| | |
| | try: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.benchmark = True |
| | except Exception: |
| | pass |
| |
|
| | |
| | |
| | |
| | def _parse_version_triplet(s: str): |
| | m = re.findall(r"\d+", s) |
| | m = [int(x) for x in m[:3]] |
| | while len(m) < 3: |
| | m.append(0) |
| | return tuple(m) |
| |
|
| | if _parse_version_triplet(torch.__version__) < (2, 0, 0): |
| | print(f"ERROR: PyTorch {torch.__version__} incompatible. Need >=2.0.0.") |
| | sys.exit(1) |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if device != "cuda": |
| | print("ERROR: CUDA required. CPU disabled.") |
| | sys.exit(1) |
| |
|
| | cc_major, cc_minor = torch.cuda.get_device_capability(0) |
| | if cc_major < 7: |
| | print(f"ERROR: GPU Compute Capability {torch.cuda.get_device_capability(0)} unsupported. Need >=7.0.") |
| | sys.exit(1) |
| |
|
| | gpu_name = torch.cuda.get_device_name(0) |
| | print(f"Using GPU: {gpu_name} (CUDA {torch.version.cuda}, Compute Capability {(cc_major, cc_minor)})") |
| |
|
| | |
| | try: |
| | bf16_supported = torch.cuda.is_bf16_supported() |
| | except Exception: |
| | bf16_supported = False |
| | AUTOCAST_DTYPE = torch.bfloat16 if bf16_supported and cc_major >= 8 else torch.float16 |
| |
|
| | |
| | |
| | |
| | def print_resource_usage(stage: str): |
| | try: |
| | alloc = torch.cuda.memory_allocated() / (1024 ** 3) |
| | reserved = torch.cuda.memory_reserved() / (1024 ** 3) |
| | except Exception: |
| | alloc, reserved = 0.0, 0.0 |
| | print(f"--- {stage} ---") |
| | print(f"GPU Memory: {alloc:.2f} GB allocated, {reserved:.2f} GB reserved") |
| | print(f"CPU: {psutil.cpu_percent()}% | Memory: {psutil.virtual_memory().percent}%") |
| | print("---------------") |
| |
|
| | |
| | |
| | |
| | output_dir = "mp3" |
| | os.makedirs(output_dir, exist_ok=True) |
| | metadata_file = os.path.join(output_dir, "songs_metadata.json") |
| | api_status = "idle" |
| |
|
| | |
| | |
| | |
| | prompt_variables = { |
| | 'style': [ |
| | 'epic', 'gritty', 'smooth', 'lush', 'raw', 'intimate', 'driving', 'moody', |
| | 'psychedelic', 'uplifting', 'melancholic', 'aggressive', 'dreamy', 'retro', |
| | 'futuristic', 'energetic', 'brooding', 'euphoric', 'jazzy', 'cinematic', |
| | 'somber', 'triumphant', 'mystical', 'grunge', 'ethereal' |
| | ], |
| | 'key': ['C major', 'D major', 'E minor', 'F minor', 'G major', 'A minor', 'B-flat major', 'G minor', 'D minor', 'F major'], |
| | 'bpm': [80, 90, 100, 110, 120, 124, 128, 130, 140, 150, 160, 170, 180], |
| | 'time_signature': ['4/4', '3/4', '6/8'], |
| | 'guitar_style': [ |
| | 'raw distorted', 'melodic', 'fuzzy', 'crisp', 'jangly', 'clean', 'twangy', |
| | 'shimmering', 'grunge', 'bluesy', 'slide', 'wah-infused', 'chunky' |
| | ], |
| | 'bass_style': [ |
| | 'punchy', 'deep', 'groovy', 'melodic', 'throbbing', 'slappy', 'funky', |
| | 'walking', 'booming', 'resonant', 'subtle' |
| | ], |
| | 'drum_style': [ |
| | 'dynamic', 'minimal', 'hard-hitting', 'swinging', 'polyrhythmic', 'brushed', |
| | 'tight', 'loose', 'electronic', 'acoustic', 'retro', 'punchy' |
| | ], |
| | 'drum_feature': [ |
| | 'heavy snare', 'crisp cymbals', 'tight kicks', 'syncopated hits', 'rolling toms', |
| | 'ghost notes', 'blast beats' |
| | ], |
| | 'organ_style': [ |
| | 'subtle Hammond', 'swirling', 'warm Leslie', 'church', 'gritty', 'vintage', |
| | 'moody' |
| | ], |
| | 'synth_style': [ |
| | 'atmospheric', 'bright', 'eerie', 'soaring', 'chopped', 'arpeggiated', |
| | 'pulsing', 'glitchy', 'analog', 'digital', 'layered' |
| | ], |
| | 'vocal_style': [ |
| | 'chopped', 'soulful', 'haunting', 'melodic', 'harmonized', 'layered', |
| | 'ethereal', 'gruff', 'breathy' |
| | ], |
| | 'hihat_style': [ |
| | 'crisp', 'swinging', 'rapid', 'shuffling', 'open', 'tight', 'stuttered' |
| | ], |
| | 'pad_style': [ |
| | 'evolving', 'ambient', 'lush', 'dark', 'shimmering', 'warm', 'icy' |
| | ], |
| | 'kick_style': [ |
| | 'deep', 'four-on-the-floor', 'subtle', 'punchy', 'booming', 'clicky' |
| | ], |
| | 'lead_style': [ |
| | 'fluid', 'intricate', 'soaring', 'expressive', 'virtuosic', 'minimalist', |
| | 'bluesy', 'lyrical' |
| | ], |
| | 'lead_instrument': [ |
| | 'saxophone', 'trumpet', 'guitar', 'flute', 'violin', 'clarinet', 'trombone' |
| | ], |
| | 'piano_style': [ |
| | 'expressive Rhodes', 'rapid', 'smooth', 'dramatic', 'stride', 'ambient', |
| | 'classical', 'jazzy', 'sparse' |
| | ], |
| | 'keyboard_style': [ |
| | 'ornate', 'delicate', 'virtuosic', 'minimal', 'retro', 'spacey' |
| | ], |
| | 'string_style': [ |
| | 'sweeping', 'delicate', 'dramatic', 'lush', 'pizzicato', 'staccato', |
| | 'sustained' |
| | ], |
| | 'brass_style': [ |
| | 'bold', 'heroic', 'muted', 'fanfare', 'jazzy', 'smooth' |
| | ], |
| | 'woodwind_style': [ |
| | 'subtle', 'fluttering', 'melodic', 'airy', 'reedy', 'expressive' |
| | ], |
| | 'flute_style': [ |
| | 'fluttering', 'ornate', 'airy', 'breathy', 'trilling' |
| | ], |
| | 'horn_style': [ |
| | 'heroic', 'bold', 'soaring', 'mellow', 'stinging' |
| | ], |
| | 'choir_style': [ |
| | 'mystical', 'ethereal', 'dramatic', 'angelic', 'epic', 'somber' |
| | ], |
| | 'sample_style': [ |
| | 'jazzy', 'soulful', 'gritty', 'cinematic', 'vinyl', 'lo-fi', 'retro' |
| | ], |
| | 'scratch_style': [ |
| | 'crackling vinyl', 'sharp', 'rhythmic', 'chopped', 'transform' |
| | ], |
| | 'snare_style': [ |
| | 'crisp', 'booming', 'tight', 'snappy', 'rimshot', 'layered' |
| | ], |
| | 'breakdown_style': [ |
| | 'euphoric', 'stripped-down', 'intense', 'ambient', 'glitchy', 'dramatic' |
| | ], |
| | 'intro_bars': [4, 8, 16], |
| | 'verse_bars': [8, 16, 32], |
| | 'chorus_bars': [8, 16], |
| | 'bridge_bars': [4, 8, 16], |
| | 'outro_bars': [8, 16], |
| | 'build_bars': [8, 16, 32], |
| | 'drop_bars': [16, 32], |
| | 'main_bars': [16, 32], |
| | 'breakdown_bars': [8, 16], |
| | 'head_bars': [16, 32], |
| | 'solo_bars': [8, 16, 32], |
| | 'fugue_bars': [16, 32], |
| | 'coda_bars': [8, 16], |
| | 'theme_bars': [16, 32], |
| | 'development_bars': [16, 32], |
| | 'climax_bars': [8, 16], |
| | 'groove_bars': [16, 32], |
| | 'vibe': [ |
| | 'raw', 'energetic', 'melancholic', 'hypnotic', 'soulful', 'intimate', |
| | 'virtuosic', 'elegant', 'cinematic', 'gritty', 'nostalgic', 'dark', |
| | 'uplifting', 'bittersweet', 'heroic', 'dreamy', 'aggressive', 'relaxed', |
| | 'futuristic', 'retro', 'mystical', 'triumphant' |
| | ], |
| | 'production_style': [ |
| | 'lo-fi', 'warm analog', 'clean digital', 'lush', 'crisp acoustic', |
| | 'polished pop', 'grand orchestral', 'grunge', 'minimalist', 'industrial', |
| | 'vintage' |
| | ] |
| | } |
| |
|
| | |
| | |
| | |
| | def create_default_genre_prompts_ini(ini_path): |
| | default_config = configparser.ConfigParser() |
| | default_config['Prompts'] = { |
| | 'nirvana': '{style} grunge with {guitar_style} guitar, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM', |
| | 'classic_rock': '{style} classic rock with {guitar_style} guitar, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM', |
| | 'detroit_techno': '{style} techno with {synth_style} synths, {kick_style} kick, {hihat_style} hi-hats, {vibe} vibe at {bpm} BPM', |
| | 'smooth_jazz': '{style} jazz with {piano_style} piano, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM', |
| | 'alternative_rock': '{style} alternative rock with {guitar_style} guitar, {bass_style} bass, {drum_style} drums in {key} at {bpm} BPM', |
| | 'deep_house': '{style} deep house with {synth_style} synths, {kick_style} kick, {vibe} vibe at {bpm} BPM', |
| | 'bebop_jazz': '{style} bebop jazz with {piano_style} piano, {bass_style} bass, {drum_style} drums in {key} at {bpm} BPM', |
| | 'baroque_classical': '{style} baroque classical with {string_style} strings, {keyboard_style} harpsichord in {key} at {bpm} BPM', |
| | 'romantic_classical': '{style} romantic classical with {string_style} strings, {piano_style} piano in {key} at {bpm} BPM', |
| | 'boom_bap_hiphop': '{style} boom bap hip-hop with {sample_style} samples, {drum_style} drums, {scratch_style} scratches at {bpm} BPM', |
| | 'trap_hiphop': '{style} trap hip-hop with {synth_style} synths, {kick_style} kick, {snare_style} snare at {bpm} BPM', |
| | 'pop_rock': '{style} pop rock with {guitar_style} guitar, {bass_style} bass, {drum_style} drums in {key} at {bpm} BPM', |
| | 'fusion_jazz': '{style} fusion jazz with {piano_style} piano, {guitar_style} guitar, {drum_style} drums in {key} at {bpm} BPM', |
| | 'edm': '{style} EDM with {synth_style} synths, {kick_style} kick, {vibe} vibe at {bpm} BPM', |
| | 'indie_folk': '{style} indie folk with {guitar_style} guitar, {vocal_style} vocals, {drum_style} drums in {key} at {bpm} BPM', |
| | 'star_wars': '{style} epic orchestral with {brass_style} brass, {string_style} strings, {vibe} vibe in {key} at {bpm} BPM', |
| | 'star_wars_classical': '{style} classical orchestral with {string_style} strings, {horn_style} horns in {key} at {bpm} BPM', |
| | 'wutang': '{style} hip-hop with {sample_style} samples, {drum_style} drums, {scratch_style} scratches at {bpm} BPM', |
| | 'milesdavis': '{style} jazz with {lead_instrument} lead, {piano_style} piano, {bass_style} bass in {key} at {bpm} BPM' |
| | } |
| | default_config['BandNames'] = { |
| | 'nirvana': 'Nirvana, Soundgarden', |
| | 'classic_rock': 'Led Zeppelin, The Rolling Stones', |
| | 'detroit_techno': 'Underground Resistance, Jeff Mills', |
| | 'smooth_jazz': 'Pat Metheny, George Benson', |
| | 'alternative_rock': 'Radiohead, Smashing Pumpkins', |
| | 'deep_house': 'Moodymann, Theo Parrish', |
| | 'bebop_jazz': 'Charlie Parker, Dizzy Gillespie', |
| | 'baroque_classical': 'Bach, Vivaldi', |
| | 'romantic_classical': 'Chopin, Liszt', |
| | 'boom_bap_hiphop': 'A Tribe Called Quest, Pete Rock', |
| | 'trap_hiphop': 'Future, Metro Boomin', |
| | 'pop_rock': 'Coldplay, The Killers', |
| | 'fusion_jazz': 'Weather Report, Herbie Hancock', |
| | 'edm': 'Deadmau5, Skrillex', |
| | 'indie_folk': 'Fleet Foxes, Bon Iver', |
| | 'star_wars': 'John Williams', |
| | 'star_wars_classical': 'John Williams', |
| | 'wutang': 'Wu-Tang Clan', |
| | 'milesdavis': 'Miles Davis' |
| | } |
| | with open(ini_path, 'w') as f: |
| | default_config.write(f) |
| | print(f"Created default {ini_path}") |
| |
|
| | |
| | |
| | |
| | css_path = "style.css" |
| | try: |
| | if not os.path.exists(css_path): |
| | print(f"ERROR: {css_path} not found. Please create style.css with the required CSS content.") |
| | sys.exit(1) |
| | with open(css_path, 'r') as f: |
| | css = f.read() |
| | except Exception as e: |
| | print(f"ERROR: Failed to read {css_path}: {e}. Please ensure style.css exists and is readable.") |
| | sys.exit(1) |
| |
|
| | |
| | |
| | |
| | config = configparser.ConfigParser() |
| | ini_path = "genre_prompts.ini" |
| | try: |
| | if not os.path.exists(ini_path): |
| | print(f"WARNING: {ini_path} not found. Creating default INI file.") |
| | create_default_genre_prompts_ini(ini_path) |
| | config.read(ini_path) |
| | if 'Prompts' not in config.sections() or 'BandNames' not in config.sections(): |
| | print(f"WARNING: Invalid {ini_path}. Creating default INI file.") |
| | create_default_genre_prompts_ini(ini_path) |
| | config.read(ini_path) |
| | except Exception as e: |
| | print(f"ERROR: Failed to read {ini_path}: {e}. Creating default INI file.") |
| | create_default_genre_prompts_ini(ini_path) |
| | config.read(ini_path) |
| |
|
| | |
| | |
| | |
| | def load_musicgen_with_fallback(): |
| | model_paths = [ |
| | os.getenv("MUSICGEN_MODEL_PATH_LARGE", "/home/ubuntu/musicpack/models/musicgen-large"), |
| | os.getenv("MUSICGEN_MODEL_PATH_MEDIUM", "/home/ubuntu/musicpack/models/musicgen-medium"), |
| | os.getenv("MUSICGEN_MODEL_PATH_SMALL", "/home/ubuntu/musicpack/models/musicgen-small"), |
| | ] |
| | model_names = ["large", "medium", "small"] |
| |
|
| | last_error = None |
| | for path, name in zip(model_paths, model_names): |
| | if not path: |
| | continue |
| | if not os.path.exists(path): |
| | print(f"NOTE: Model path not found: {path} (skipping {name})") |
| | continue |
| | try: |
| | print(f"Loading MusicGen {name} model from {path} ...") |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | with autocast('cuda', dtype=AUTOCAST_DTYPE): |
| | mdl = MusicGen.get_pretrained(path, device=device) |
| | print(f"Loaded MusicGen {name}. Sample rate: {mdl.sample_rate}Hz") |
| | return mdl, name |
| | except RuntimeError as e: |
| | last_error = e |
| | print(f"WARNING: Failed to load {name} model due to: {e}") |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | continue |
| | except Exception as e: |
| | last_error = e |
| | print(f"WARNING: Failed to load {name} model due to: {e}") |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | continue |
| | if last_error: |
| | print(f"ERROR: All model loads failed. Last error: {last_error}") |
| | raise SystemExit(1) |
| |
|
| | try: |
| | musicgen_model, loaded_model_name = load_musicgen_with_fallback() |
| | |
| | musicgen_model.set_generation_params( |
| | duration=10, |
| | use_sampling=True, |
| | top_k=50, |
| | top_p=0.0, |
| | temperature=0.8, |
| | cfg_coef=3.0, |
| | two_step_cfg=False |
| | ) |
| | sample_rate = musicgen_model.sample_rate |
| | print(f"Model active: {loaded_model_name}. Sample rate: {sample_rate}Hz") |
| | except SystemExit: |
| | sys.exit(1) |
| |
|
| | |
| | |
| | |
| | def apply_eq(segment): |
| | segment = segment.high_pass_filter(60) |
| | segment = segment.low_pass_filter(12000) |
| | segment = segment - 2.0 |
| | return segment |
| |
|
| | def apply_limiter(segment, max_db=-6.0, target_lufs=-16.0): |
| | samples = np.array(segment.get_array_of_samples(), dtype=np.float32) / (2**15) |
| | if segment.channels == 2: |
| | samples = samples.reshape(-1, 2) |
| | meter = pyln.Meter(segment.frame_rate) |
| | loudness = meter.integrated_loudness(samples) |
| | normalized_samples = pyln.normalize.loudness(samples, loudness, target_lufs) |
| | if np.max(np.abs(normalized_samples)) > (10 ** (max_db / 20)): |
| | normalized_samples *= (10 ** (max_db / 20)) / np.max(np.abs(normalized_samples)) |
| | normalized_samples = (normalized_samples * (2**15)).astype(np.int16) |
| | segment = AudioSegment( |
| | normalized_samples.tobytes(), |
| | frame_rate=segment.frame_rate, |
| | sample_width=2, |
| | channels=segment.channels |
| | ) |
| | del samples, normalized_samples |
| | gc.collect() |
| | return segment |
| |
|
| | def apply_fade(segment, fade_in_duration=1000, fade_out_duration=1000): |
| | segment = segment.fade_in(fade_in_duration) |
| | segment = segment.fade_out(fade_out_duration) |
| | return segment |
| |
|
| | |
| | |
| | |
| | made_up_names = [ |
| | 'blazepulse', 'shadowrift', 'neonquest', 'thunderclash', 'stargroove', |
| | 'mysticvibe', 'ironspark', 'ghostsurge', 'velvetstorm', 'crimsonrush', |
| | 'duskblitz', 'solarflame', 'nightdrift', 'frostsaga', 'emberwave', |
| | 'coolriff', 'wildpulse', 'echoslash', 'moontide', 'skydive' |
| | ] |
| |
|
| | def extract_song_keyword(prompt): |
| | if not prompt: |
| | return random.choice(made_up_names) |
| | words = re.findall(r'\b\w+\b', prompt.lower()) |
| | for word in words: |
| | if len(word) <= 15 and word.isalnum(): |
| | return word |
| | return random.choice(made_up_names) |
| |
|
| | def generate_unique_title(existing_titles, genre, song_keyword, style): |
| | letters = string.ascii_uppercase |
| | numbers = string.digits |
| | max_attempts = 100 |
| | attempt = 0 |
| | while attempt < max_attempts: |
| | title_base = f"{random.choice(letters)}{random.choice(numbers)}" |
| | band_names = config['BandNames'].get(genre, "nirvana").split(',') |
| | band_name = random.choice([name.strip() for name in band_names]) |
| | existing_count = sum(1 for t in existing_titles if t.startswith(title_base) and song_keyword in t and style in t and band_name in t) |
| | if existing_count == 0: |
| | return title_base, band_name |
| | suffix = f"{random.choice(letters)}{random.choice(numbers)}".lower() |
| | title_base = f"{title_base}_{suffix}" |
| | attempt += 1 |
| | raise ValueError("Failed to generate unique title after maximum attempts") |
| |
|
| | def update_metadata_storage(metadata): |
| | try: |
| | songs_metadata = [] |
| | if os.path.exists(metadata_file): |
| | with open(metadata_file, 'r') as f: |
| | songs_metadata = json.load(f) |
| | songs_metadata.append({ |
| | "title": metadata["title"], |
| | "filename": metadata["filename"], |
| | "prompt": metadata.get("prompt", ""), |
| | "duration": metadata.get("duration", 30), |
| | "volume_db": metadata.get("volume_db", -24.0), |
| | "target_lufs": metadata.get("target_lufs", -16.0), |
| | "timestamp": metadata.get("timestamp", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")), |
| | "file_path": metadata.get("file_path", ""), |
| | "sample_rate": metadata.get("sample_rate", musicgen_model.sample_rate), |
| | "style": metadata.get("style", ""), |
| | "band_name": metadata.get("band_name", ""), |
| | "chunk_index": metadata.get("chunk_index", 0) |
| | }) |
| | with open(metadata_file, 'w') as f: |
| | json.dump(songs_metadata, f, indent=4) |
| | except Exception as e: |
| | print(f"ERROR: Failed to update metadata storage: {e}") |
| |
|
| | def load_renders(): |
| | if not os.path.exists(metadata_file): |
| | return [], "No renders found." |
| | try: |
| | with open(metadata_file, 'r') as f: |
| | songs_metadata = json.load(f) |
| | renders = [ |
| | { |
| | "Title": entry["title"], |
| | "Filename": entry["filename"], |
| | "Prompt": entry["prompt"], |
| | "Duration (s)": entry["duration"], |
| | "Timestamp": entry["timestamp"], |
| | "Audio": entry["file_path"], |
| | "Download": f'<a href="/get-song/{entry["filename"]}" download><button class="download-btn" aria-label="Download {entry["title"]}">⬇️</button></a>', |
| | "Chunk": entry["chunk_index"] |
| | } |
| | for entry in songs_metadata |
| | ] |
| | return renders, "Renders loaded successfully." |
| | except Exception as e: |
| | return [], f"Error loading renders: {e}" |
| |
|
| | |
| | |
| | |
| | def get_genre_prompt(genre): |
| | base_prompt = config['Prompts'].get(genre, "") |
| | if not base_prompt: |
| | base_prompt = "{style} grunge with {guitar_style} guitar, {bass_style} bass, {drum_style} drums, {vibe} vibe in {key} at {bpm} BPM" |
| | prompt_dict = { |
| | 'style': random.choice(prompt_variables['style']), |
| | 'key': random.choice(prompt_variables['key']), |
| | 'bpm': random.choice(prompt_variables['bpm']), |
| | 'time_signature': random.choice(prompt_variables['time_signature']), |
| | 'guitar_style': random.choice(prompt_variables['guitar_style']), |
| | 'bass_style': random.choice(prompt_variables['bass_style']), |
| | 'drum_style': random.choice(prompt_variables['drum_style']), |
| | 'drum_feature': random.choice(prompt_variables['drum_feature']), |
| | 'organ_style': random.choice(prompt_variables['organ_style']), |
| | 'synth_style': random.choice(prompt_variables['synth_style']), |
| | 'vocal_style': random.choice(prompt_variables['vocal_style']), |
| | 'hihat_style': random.choice(prompt_variables['hihat_style']), |
| | 'pad_style': random.choice(prompt_variables['pad_style']), |
| | 'kick_style': random.choice(prompt_variables['kick_style']), |
| | 'lead_style': random.choice(prompt_variables['lead_style']), |
| | 'lead_instrument': random.choice(prompt_variables['lead_instrument']), |
| | 'piano_style': random.choice(prompt_variables['piano_style']), |
| | 'keyboard_style': random.choice(prompt_variables['keyboard_style']), |
| | 'string_style': random.choice(prompt_variables['string_style']), |
| | 'brass_style': random.choice(prompt_variables['brass_style']), |
| | 'woodwind_style': random.choice(prompt_variables['woodwind_style']), |
| | 'flute_style': random.choice(prompt_variables['flute_style']), |
| | 'horn_style': random.choice(prompt_variables['horn_style']), |
| | 'choir_style': random.choice(prompt_variables['choir_style']), |
| | 'sample_style': random.choice(prompt_variables['sample_style']), |
| | 'scratch_style': random.choice(prompt_variables['scratch_style']), |
| | 'snare_style': random.choice(prompt_variables['snare_style']), |
| | 'breakdown_style': random.choice(prompt_variables['breakdown_style']), |
| | 'intro_bars': random.choice(prompt_variables['intro_bars']), |
| | 'verse_bars': random.choice(prompt_variables['verse_bars']), |
| | 'chorus_bars': random.choice(prompt_variables['chorus_bars']), |
| | 'bridge_bars': random.choice(prompt_variables['bridge_bars']), |
| | 'outro_bars': random.choice(prompt_variables['outro_bars']), |
| | 'build_bars': random.choice(prompt_variables['build_bars']), |
| | 'drop_bars': random.choice(prompt_variables['drop_bars']), |
| | 'main_bars': random.choice(prompt_variables['main_bars']), |
| | 'breakdown_bars': random.choice(prompt_variables['breakdown_bars']), |
| | 'head_bars': random.choice(prompt_variables['head_bars']), |
| | 'solo_bars': random.choice(prompt_variables['solo_bars']), |
| | 'fugue_bars': random.choice(prompt_variables['fugue_bars']), |
| | 'coda_bars': random.choice(prompt_variables['coda_bars']), |
| | 'theme_bars': random.choice(prompt_variables['theme_bars']), |
| | 'development_bars': random.choice(prompt_variables['development_bars']), |
| | 'climax_bars': random.choice(prompt_variables['climax_bars']), |
| | 'groove_bars': random.choice(prompt_variables['groove_bars']), |
| | 'vibe': random.choice(prompt_variables['vibe']), |
| | 'production_style': random.choice(prompt_variables['production_style']) |
| | } |
| | try: |
| | formatted_prompt = base_prompt.format(**prompt_dict) |
| | words = re.findall(r'\b\w+\b', formatted_prompt.lower()) |
| | val_list = [] |
| | for k, v in prompt_variables.items(): |
| | if isinstance(v, list): |
| | val_list.extend(v) |
| | if not any(word in val_list for word in words): |
| | formatted_prompt = f"{prompt_dict['style']} music with {prompt_dict['guitar_style']} guitar, {prompt_dict['bass_style']} bass, {prompt_dict['drum_style']} drums in {prompt_dict['key']} at {prompt_dict['bpm']} BPM" |
| | except KeyError: |
| | formatted_prompt = f"{prompt_dict['style']} music with {prompt_dict['guitar_style']} guitar, {prompt_dict['bass_style']} bass, {prompt_dict['drum_style']} drums in {prompt_dict['key']} at {prompt_dict['bpm']} BPM" |
| | return formatted_prompt, prompt_dict['style'] |
| |
|
| | |
| | |
| | |
| | def generate_chunk_oom_safe(model, text_prompt, continuation_prompt, cfg_scale, top_k, top_p, temperature, target_duration): |
| | durations_to_try = [target_duration, 20, 15, 12, 10, 8, 6, 4, 3, 2] |
| | for dur in durations_to_try: |
| | try: |
| | torch.cuda.synchronize() |
| | torch.cuda.empty_cache() |
| | model.set_generation_params( |
| | duration=dur, |
| | use_sampling=True, |
| | top_k=int(top_k), |
| | top_p=float(top_p), |
| | temperature=float(temperature), |
| | cfg_coef=float(cfg_scale), |
| | two_step_cfg=False |
| | ) |
| | with torch.no_grad(): |
| | with autocast('cuda', dtype=AUTOCAST_DTYPE): |
| | if continuation_prompt is None: |
| | |
| | audio_chunk = model.generate([text_prompt], progress=False)[0] |
| | else: |
| | audio_chunk = model.generate_continuation( |
| | continuation_prompt, model.sample_rate, [text_prompt], progress=False |
| | )[0] |
| | return audio_chunk, dur |
| | except RuntimeError as e: |
| | msg = str(e).lower() |
| | if "out of memory" in msg or "cuda error" in msg: |
| | print(f"OOM at duration {dur}s — retrying with smaller chunk...") |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | continue |
| | else: |
| | raise |
| | raise RuntimeError("Failed to generate audio chunk without CUDA OOM.") |
| |
|
| | |
| | |
| | |
| | def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, volume_db: float, genre: str = None): |
| | global musicgen_model |
| | global api_status |
| | api_status = "rendering" |
| |
|
| | if not instrumental_prompt.strip() and not genre: |
| | instrumental_prompt, style = get_genre_prompt("nirvana") |
| | elif not instrumental_prompt.strip(): |
| | instrumental_prompt, style = get_genre_prompt(genre) |
| | else: |
| | words = re.findall(r'\b\w+\b', instrumental_prompt.lower()) |
| | val_list = [] |
| | for k, v in prompt_variables.items(): |
| | if isinstance(v, list): |
| | val_list.extend(v) |
| | if not any(word in val_list for word in words): |
| | instrumental_prompt, style = get_genre_prompt("nirvana") |
| | else: |
| | ek = extract_song_keyword(instrumental_prompt) |
| | style = ek if ek in prompt_variables['style'] else random.choice(prompt_variables['style']) |
| | |
| | try: |
| | start_time = time.time() |
| | base_chunk_target = 30 |
| | total_duration = max(total_duration, 30) |
| | remaining = total_duration |
| | audio_chunks = [] |
| | chunk_paths = [] |
| | continuation_prompt = None |
| | chunk_index = 0 |
| |
|
| | |
| | existing_titles = [] |
| | if os.path.exists(metadata_file): |
| | with open(metadata_file, 'r') as f: |
| | songs_metadata = json.load(f) |
| | existing_titles = [entry["title"] for entry in songs_metadata] |
| | song_keyword = extract_song_keyword(instrumental_prompt) |
| | title_base, band_name = generate_unique_title(existing_titles, genre if genre else "nirvana", song_keyword, style) |
| |
|
| | |
| | while remaining > 0: |
| | target = min(base_chunk_target, remaining) |
| | print_resource_usage(f"Before Chunk {chunk_index + 1}") |
| | try: |
| | audio_chunk, actual_dur = generate_chunk_oom_safe( |
| | musicgen_model, instrumental_prompt, continuation_prompt, cfg_scale, top_k, top_p, temperature, target |
| | ) |
| | audio_chunk = audio_chunk.cpu().to(dtype=torch.float32) |
| | if audio_chunk.dim() == 1: |
| | audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0) |
| | elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1: |
| | audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
| | elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2: |
| | audio_chunk = audio_chunk[:1, :] |
| | audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
| | elif audio_chunk.dim() > 2: |
| | audio_chunk = audio_chunk.view(2, -1) |
| | if audio_chunk.shape[0] != 2: |
| | raise ValueError(f"Expected stereo audio with shape (2, samples), got {audio_chunk.shape}") |
| |
|
| | |
| | samples_per_second = musicgen_model.sample_rate |
| | tail_sec = 2 |
| | tail_samples = min(int(tail_sec * samples_per_second), audio_chunk.shape[1] - 1 if audio_chunk.shape[1] > 1 else 1) |
| | if tail_samples > 0: |
| | continuation_prompt = audio_chunk[:, -tail_samples:].cpu() |
| | else: |
| | continuation_prompt = None |
| |
|
| | |
| | temp_wav_path = os.path.join(output_dir, f"temp_{random.randint(100, 999)}_{chunk_index}.wav") |
| | try: |
| | torchaudio.save(temp_wav_path, audio_chunk, musicgen_model.sample_rate, bits_per_sample=16) |
| | final_segment = AudioSegment.from_wav(temp_wav_path) |
| | finally: |
| | if os.path.exists(temp_wav_path): |
| | os.remove(temp_wav_path) |
| | del audio_chunk |
| | gc.collect() |
| |
|
| | |
| | print(f"Post-processing chunk {chunk_index + 1} (duration ~{actual_dur}s)...") |
| | final_segment = apply_eq(final_segment) |
| | final_segment = apply_limiter(final_segment, max_db=volume_db, target_lufs=-16.0) |
| | if chunk_index == 0: |
| | final_segment = final_segment.fade_in(1000) |
| | |
| | if remaining - actual_dur <= 0: |
| | final_segment = final_segment.fade_out(1000) |
| |
|
| | |
| | mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_chunk{chunk_index + 1}.mp3" |
| | mp3_path = os.path.join(output_dir, mp3_filename) |
| | final_segment.export( |
| | mp3_path, |
| | format="mp3", |
| | bitrate="64k", |
| | tags={"title": f"{title_base}_Chunk{chunk_index + 1}", "artist": "GhostAI"} |
| | ) |
| | print(f"Saved chunk {chunk_index + 1} to {mp3_path}") |
| | audio_chunks.append(final_segment) |
| | chunk_paths.append(mp3_path) |
| |
|
| | |
| | metadata = { |
| | "title": f"{title_base}_Chunk{chunk_index + 1}", |
| | "filename": mp3_filename, |
| | "prompt": instrumental_prompt, |
| | "duration": actual_dur, |
| | "volume_db": volume_db, |
| | "target_lufs": -16.0, |
| | "timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
| | "file_path": mp3_path, |
| | "sample_rate": musicgen_model.sample_rate, |
| | "style": style, |
| | "band_name": band_name, |
| | "chunk_index": chunk_index + 1 |
| | } |
| | update_metadata_storage(metadata) |
| |
|
| | chunk_index += 1 |
| | remaining -= actual_dur |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | print_resource_usage(f"After Chunk {chunk_index}") |
| | except Exception as e: |
| | print(f"ERROR: Failed to process chunk {chunk_index + 1}: {e}") |
| | api_status = "idle" |
| | raise |
| |
|
| | |
| | if len(audio_chunks) > 1: |
| | combined_segment = audio_chunks[0] |
| | for segment in audio_chunks[1:]: |
| | combined_segment = combined_segment.append(segment, crossfade=500) |
| | combined_mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_combined.mp3" |
| | combined_mp3_path = os.path.join(output_dir, combined_mp3_filename) |
| | combined_segment.export( |
| | combined_mp3_path, |
| | format="mp3", |
| | bitrate="64k", |
| | tags={"title": title_base, "artist": "GhostAI"} |
| | ) |
| | print(f"Saved combined audio to {combined_mp3_path}") |
| | metadata = { |
| | "title": title_base, |
| | "filename": combined_mp3_filename, |
| | "prompt": instrumental_prompt, |
| | "duration": total_duration, |
| | "volume_db": volume_db, |
| | "target_lufs": -16.0, |
| | "timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
| | "file_path": combined_mp3_path, |
| | "sample_rate": musicgen_model.sample_rate, |
| | "style": style, |
| | "band_name": band_name, |
| | "chunk_index": 0 |
| | } |
| | update_metadata_storage(metadata) |
| | del combined_segment, audio_chunks |
| | gc.collect() |
| | api_status = "idle" |
| | return combined_mp3_path, "✅ Done!", False, gr.update(value=load_renders()[0]) |
| | else: |
| | |
| | print(f"Saved metadata to {metadata_file}") |
| | del audio_chunks |
| | gc.collect() |
| | api_status = "idle" |
| | return chunk_paths[0], "✅ Done!", False, gr.update(value=load_renders()[0]) |
| |
|
| | except Exception as e: |
| | print(f"❌ Failed: {e}") |
| | api_status = "idle" |
| | return None, f"❌ Failed: {e}", False, gr.update(value=load_renders()[0]) |
| | finally: |
| | torch.cuda.synchronize() |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | def clear_inputs(): |
| | return "", 3.0, 50, 0.0, 0.8, 30, -24.0, False |
| |
|
| | def show_render_wheel(): |
| | return True |
| |
|
| | def set_genre_prompt(genre: str): |
| | prompt, _ = get_genre_prompt(genre) |
| | return prompt |
| |
|
| | |
| | |
| | |
| | with gr.Blocks(css=css) as demo: |
| | gr.Markdown(""" |
| | <div class="header-container" role="banner" aria-label="GhostAI Music Generator"> |
| | <h1>GhostAI Music Generator</h1> |
| | <p>Create Professional Instrumental Tracks</p> |
| | </div> |
| | """) |
| | with gr.Tabs(): |
| | with gr.Tab("Generate", id="generate"): |
| | with gr.Column(elem_classes="input-container"): |
| | gr.Markdown("### Instrumental Prompt") |
| | instrumental_prompt = gr.Textbox( |
| | label="Instrumental Prompt", |
| | placeholder="Select a genre or enter a custom prompt (e.g., 'coolriff grunge')", |
| | lines=4, |
| | elem_classes="textbox" |
| | ) |
| | with gr.Row(elem_classes="genre-buttons"): |
| | classic_rock_btn = gr.Button("Classic Rock", elem_classes="genre-btn") |
| | alternative_rock_btn = gr.Button("Alternative Rock", elem_classes="genre-btn") |
| | detroit_techno_btn = gr.Button("Detroit Techno", elem_classes="genre-btn") |
| | deep_house_btn = gr.Button("Deep House", elem_classes="genre-btn") |
| | smooth_jazz_btn = gr.Button("Smooth Jazz", elem_classes="genre-btn") |
| | bebop_jazz_btn = gr.Button("Bebop Jazz", elem_classes="genre-btn") |
| | baroque_classical_btn = gr.Button("Baroque Classical", elem_classes="genre-btn") |
| | romantic_classical_btn = gr.Button("Romantic Classical", elem_classes="genre-btn") |
| | boom_bap_hiphop_btn = gr.Button("Boom Bap Hip-Hop", elem_classes="genre-btn") |
| | trap_hiphop_btn = gr.Button("Trap Hip-Hop", elem_classes="genre-btn") |
| | pop_rock_btn = gr.Button("Pop Rock", elem_classes="genre-btn") |
| | fusion_jazz_btn = gr.Button("Fusion Jazz", elem_classes="genre-btn") |
| | edm_btn = gr.Button("EDM", elem_classes="genre-btn") |
| | indie_folk_btn = gr.Button("Indie Folk", elem_classes="genre-btn") |
| | star_wars_btn = gr.Button("Star Wars Epic", elem_classes="genre-btn") |
| | star_wars_classical_btn = gr.Button("Star Wars Classical", elem_classes="genre-btn") |
| | nirvana_btn = gr.Button("Nirvana", elem_classes="genre-btn") |
| | wutang_btn = gr.Button("Wu-Tang", elem_classes="genre-btn") |
| | milesdavis_btn = gr.Button("Miles Davis", elem_classes="genre-btn") |
| | with gr.Column(elem_classes="settings-container"): |
| | gr.Markdown("### Generation Settings") |
| | cfg_scale = gr.Slider( |
| | label="Guidance Scale (CFG)", |
| | minimum=1.0, |
| | maximum=10.0, |
| | value=3.0, |
| | step=0.1 |
| | ) |
| | top_k = gr.Slider( |
| | label="Top-K Sampling", |
| | minimum=10, |
| | maximum=500, |
| | value=50, |
| | step=10 |
| | ) |
| | top_p = gr.Slider( |
| | label="Top-P Sampling", |
| | minimum=0.0, |
| | maximum=1.0, |
| | value=0.0, |
| | step=0.1 |
| | ) |
| | temperature = gr.Slider( |
| | label="Temperature", |
| | minimum=0.1, |
| | maximum=2.0, |
| | value=0.8, |
| | step=0.1 |
| | ) |
| | total_duration = gr.Slider( |
| | label="Duration (seconds)", |
| | minimum=30, |
| | maximum=300, |
| | value=30, |
| | step=10 |
| | ) |
| | volume_db = gr.Slider( |
| | label="Output Volume (dBFS)", |
| | minimum=-30.0, |
| | maximum=0.0, |
| | value=-24.0, |
| | step=0.1 |
| | ) |
| | with gr.Row(elem_classes="action-buttons"): |
| | gen_btn = gr.Button("Generate Music") |
| | clr_btn = gr.Button("Clear Inputs") |
| | with gr.Column(elem_classes="output-container"): |
| | gr.Markdown("### Output") |
| | render_wheel = gr.HTML('<div class="render-wheel" aria-live="polite">Generating...</div>', label="Rendering Status") |
| | render_state = gr.State(value=False) |
| | out_audio = gr.Audio(label="Generated Track", type="filepath", interactive=True, elem_classes="audio-container") |
| | status = gr.Textbox(label="Status", interactive=False) |
| | with gr.Tab("Renders", id="renders"): |
| | with gr.Column(elem_classes="renders-container"): |
| | gr.Markdown("### Browse Renders") |
| | renders_table = gr.DataFrame( |
| | headers=["Title", "Filename", "Prompt", "Duration (s)", "Timestamp", "Audio", "Download", "Chunk"], |
| | datatype=["str", "str", "str", "number", "str", "audio", "html", "number"], |
| | interactive=False, |
| | value=load_renders()[0], |
| | elem_classes="renders-table" |
| | ) |
| | renders_status = gr.Textbox(label="Renders Status", interactive=False, value=load_renders()[1]) |
| |
|
| | |
| | classic_rock_btn.click(set_genre_prompt, inputs=[gr.State(value="classic_rock")], outputs=[instrumental_prompt]) |
| | alternative_rock_btn.click(set_genre_prompt, inputs=[gr.State(value="alternative_rock")], outputs=[instrumental_prompt]) |
| | detroit_techno_btn.click(set_genre_prompt, inputs=[gr.State(value="detroit_techno")], outputs=[instrumental_prompt]) |
| | deep_house_btn.click(set_genre_prompt, inputs=[gr.State(value="deep_house")], outputs=[instrumental_prompt]) |
| | smooth_jazz_btn.click(set_genre_prompt, inputs=[gr.State(value="smooth_jazz")], outputs=[instrumental_prompt]) |
| | bebop_jazz_btn.click(set_genre_prompt, inputs=[gr.State(value="bebop_jazz")], outputs=[instrumental_prompt]) |
| | baroque_classical_btn.click(set_genre_prompt, inputs=[gr.State(value="baroque_classical")], outputs=[instrumental_prompt]) |
| | romantic_classical_btn.click(set_genre_prompt, inputs=[gr.State(value="romantic_classical")], outputs=[instrumental_prompt]) |
| | boom_bap_hiphop_btn.click(set_genre_prompt, inputs=[gr.State(value="boom_bap_hiphop")], outputs=[instrumental_prompt]) |
| | trap_hiphop_btn.click(set_genre_prompt, inputs=[gr.State(value="trap_hiphop")], outputs=[instrumental_prompt]) |
| | pop_rock_btn.click(set_genre_prompt, inputs=[gr.State(value="pop_rock")], outputs=[instrumental_prompt]) |
| | fusion_jazz_btn.click(set_genre_prompt, inputs=[gr.State(value="fusion_jazz")], outputs=[instrumental_prompt]) |
| | edm_btn.click(set_genre_prompt, inputs=[gr.State(value="edm")], outputs=[instrumental_prompt]) |
| | indie_folk_btn.click(set_genre_prompt, inputs=[gr.State(value="indie_folk")], outputs=[instrumental_prompt]) |
| | star_wars_btn.click(set_genre_prompt, inputs=[gr.State(value="star_wars")], outputs=[instrumental_prompt]) |
| | star_wars_classical_btn.click(set_genre_prompt, inputs=[gr.State(value="star_wars_classical")], outputs=[instrumental_prompt]) |
| | nirvana_btn.click(set_genre_prompt, inputs=[gr.State(value="nirvana")], outputs=[instrumental_prompt]) |
| | wutang_btn.click(set_genre_prompt, inputs=[gr.State(value="wutang")], outputs=[instrumental_prompt]) |
| | milesdavis_btn.click(set_genre_prompt, inputs=[gr.State(value="milesdavis")], outputs=[instrumental_prompt]) |
| | gen_btn.click( |
| | fn=show_render_wheel, |
| | inputs=None, |
| | outputs=[render_state], |
| | ).then( |
| | fn=generate_music, |
| | inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, volume_db, gr.State(None)], |
| | outputs=[out_audio, status, render_state, renders_table], |
| | show_progress="full" |
| | ) |
| | clr_btn.click( |
| | fn=clear_inputs, |
| | inputs=None, |
| | outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, volume_db, render_state] |
| | ) |
| |
|
| | |
| | |
| | |
| | app = FastAPI() |
| |
|
| | class MusicRequest(BaseModel): |
| | prompt: str = None |
| | duration: int = 30 |
| | volume_db: float = -24.0 |
| | genre: str = None |
| |
|
| | @app.get("/prompts/") |
| | async def get_prompts(): |
| | global api_status |
| | try: |
| | prompts = list(config['Prompts'].keys()) |
| | return {"status": api_status, "prompts": prompts} |
| | except Exception as e: |
| | print(f"Error fetching prompts: {e}") |
| | raise HTTPException(status_code=500, detail=f"Error fetching prompts: {e}") |
| |
|
| | @app.post("/generate-music/") |
| | async def api_generate_music(request: MusicRequest): |
| | global api_status |
| | api_status = "rendering" |
| | try: |
| | instrumental_prompt = ( |
| | get_genre_prompt(request.genre)[0] if request.genre else |
| | request.prompt if request.prompt else |
| | get_genre_prompt("nirvana")[0] |
| | ) |
| | style = ( |
| | get_genre_prompt(request.genre)[1] if request.genre else |
| | extract_song_keyword(request.prompt) if request.prompt and extract_song_keyword(request.prompt) in prompt_variables['style'] else |
| | get_genre_prompt("nirvana")[1] |
| | ) |
| | if not instrumental_prompt.strip(): |
| | api_status = "idle" |
| | raise HTTPException(status_code=400, detail="Invalid prompt or genre") |
| |
|
| | total_duration = max(request.duration, 30) |
| | remaining = total_duration |
| | audio_chunks = [] |
| | chunk_paths = [] |
| | continuation_prompt = None |
| | chunk_index = 0 |
| |
|
| | existing_titles = [] |
| | if os.path.exists(metadata_file): |
| | with open(metadata_file, 'r') as f: |
| | songs_metadata = json.load(f) |
| | existing_titles = [entry["title"] for entry in songs_metadata] |
| | song_keyword = extract_song_keyword(request.prompt if request.prompt else instrumental_prompt) |
| | title_base, band_name = generate_unique_title(existing_titles, request.genre if request.genre else "nirvana", song_keyword, style) |
| |
|
| | while remaining > 0: |
| | target = min(30, remaining) |
| | print_resource_usage(f"Before API Chunk {chunk_index + 1}") |
| | try: |
| | audio_chunk, actual_dur = generate_chunk_oom_safe( |
| | musicgen_model, instrumental_prompt, continuation_prompt, 3.0, 50, 0.0, 0.8, target |
| | ) |
| | audio_chunk = audio_chunk.cpu().to(dtype=torch.float32) |
| | if audio_chunk.dim() == 1: |
| | audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0) |
| | elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1: |
| | audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
| | elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2: |
| | audio_chunk = audio_chunk[:1, :] |
| | audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
| | elif audio_chunk.dim() > 2: |
| | audio_chunk = audio_chunk.view(2, -1) |
| | if audio_chunk.shape[0] != 2: |
| | raise ValueError(f"Expected stereo audio with shape (2, samples), got {audio_chunk.shape}") |
| |
|
| | samples_per_second = musicgen_model.sample_rate |
| | tail_sec = 2 |
| | tail_samples = min(int(tail_sec * samples_per_second), audio_chunk.shape[1] - 1 if audio_chunk.shape[1] > 1 else 1) |
| | continuation_prompt = audio_chunk[:, -tail_samples:].cpu() if tail_samples > 0 else None |
| |
|
| | temp_wav_path = os.path.join(output_dir, f"temp_{random.randint(100, 999)}_{chunk_index}.wav") |
| | try: |
| | torchaudio.save(temp_wav_path, audio_chunk, musicgen_model.sample_rate, bits_per_sample=16) |
| | final_segment = AudioSegment.from_wav(temp_wav_path) |
| | finally: |
| | if os.path.exists(temp_wav_path): |
| | os.remove(temp_wav_path) |
| | del audio_chunk |
| | gc.collect() |
| |
|
| | final_segment = apply_eq(final_segment) |
| | final_segment = apply_limiter(final_segment, max_db=request.volume_db, target_lufs=-16.0) |
| | if chunk_index == 0: |
| | final_segment = final_segment.fade_in(1000) |
| | if remaining - actual_dur <= 0: |
| | final_segment = final_segment.fade_out(1000) |
| |
|
| | mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_chunk{chunk_index + 1}.mp3" |
| | mp3_path = os.path.join(output_dir, mp3_filename) |
| | final_segment.export( |
| | mp3_path, |
| | format="mp3", |
| | bitrate="64k", |
| | tags={"title": f"{title_base}_Chunk{chunk_index + 1}", "artist": "GhostAI"} |
| | ) |
| | print(f"Saved API chunk {chunk_index + 1} to {mp3_path}") |
| | audio_chunks.append(final_segment) |
| | chunk_paths.append(mp3_path) |
| |
|
| | metadata = { |
| | "title": f"{title_base}_Chunk{chunk_index + 1}", |
| | "filename": mp3_filename, |
| | "prompt": instrumental_prompt, |
| | "duration": actual_dur, |
| | "volume_db": request.volume_db, |
| | "target_lufs": -16.0, |
| | "timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
| | "file_path": mp3_path, |
| | "sample_rate": musicgen_model.sample_rate, |
| | "style": style, |
| | "band_name": band_name, |
| | "chunk_index": chunk_index + 1 |
| | } |
| | update_metadata_storage(metadata) |
| |
|
| | chunk_index += 1 |
| | remaining -= actual_dur |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| | print_resource_usage(f"After API Chunk {chunk_index}") |
| | except Exception as e: |
| | print(f"ERROR: Failed to process API chunk {chunk_index + 1}: {e}") |
| | api_status = "idle" |
| | raise |
| |
|
| | if len(audio_chunks) > 1: |
| | combined_segment = audio_chunks[0] |
| | for segment in audio_chunks[1:]: |
| | combined_segment = combined_segment.append(segment, crossfade=500) |
| | combined_mp3_filename = f"{title_base.lower()}_{song_keyword}_{style}_{band_name}_combined.mp3" |
| | combined_mp3_path = os.path.join(output_dir, combined_mp3_filename) |
| | combined_segment.export( |
| | combined_mp3_path, |
| | format="mp3", |
| | bitrate="64k", |
| | tags={"title": title_base, "artist": "GhostAI"} |
| | ) |
| | print(f"Saved combined audio to {combined_mp3_path}") |
| | metadata = { |
| | "title": title_base, |
| | "filename": combined_mp3_filename, |
| | "prompt": instrumental_prompt, |
| | "duration": total_duration, |
| | "volume_db": request.volume_db, |
| | "target_lufs": -16.0, |
| | "timestamp": datetime.datetime.now().strftime("%Y%m%d_%H%M%S"), |
| | "file_path": combined_mp3_path, |
| | "sample_rate": musicgen_model.sample_rate, |
| | "style": style, |
| | "band_name": band_name, |
| | "chunk_index": 0 |
| | } |
| | update_metadata_storage(metadata) |
| | del combined_segment, audio_chunks |
| | gc.collect() |
| | api_status = "idle" |
| | return FileResponse(combined_mp3_path, media_type="audio/mpeg") |
| | else: |
| | print(f"Saved metadata to {metadata_file}") |
| | del audio_chunks |
| | gc.collect() |
| | api_status = "idle" |
| | return FileResponse(chunk_paths[0], media_type="audio/mpeg") |
| | except Exception as e: |
| | print(f"Error generating music: {e}") |
| | api_status = "idle" |
| | raise HTTPException(status_code=500, detail=f"Error generating music: {e}") |
| | finally: |
| | torch.cuda.synchronize() |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | @app.get("/get-song/{filename}") |
| | async def get_song(filename: str): |
| | global api_status |
| | file_path = os.path.join(output_dir, filename) |
| | if not os.path.exists(file_path): |
| | print(f"Error: Song file {filename} not found") |
| | raise HTTPException(status_code=404, detail="Song file not found") |
| | print(f"Serving file: {filename}") |
| | return FileResponse(file_path, media_type="audio/mpeg", filename=filename) |
| |
|
| | @app.get("/status/") |
| | async def get_status(): |
| | global api_status |
| | return {"status": api_status} |
| |
|
| | def run_fastapi(): |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | fastapi_process = multiprocessing.Process(target=run_fastapi) |
| | fastapi_process.start() |
| | try: |
| | demo.launch(server_name="0.0.0.0", server_port=9999, share=False, inbrowser=True, show_error=True) |
| | except Exception as e: |
| | print(f"ERROR: Failed to launch Gradio: {e}") |
| | fastapi_process.terminate() |
| | sys.exit(1) |
| | finally: |
| | fastapi_process.terminate() |
| |
|