K1Z3M1112 commited on
Commit
c166ed3
·
verified ·
1 Parent(s): 332035a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py CHANGED
@@ -97,9 +97,50 @@ def t2i(prompt, model, seed, steps, scale, w, h):
97
  generator=gen
98
  ).images[0]
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # ===== Gradio UI (Minimal) =====
101
  with gr.Blocks() as demo:
102
  gr.Markdown("# 🎨 Minimal Style2Paints")
 
 
 
 
 
 
103
 
104
  with gr.Tab("🎨 Colorize"):
105
  with gr.Row():
 
97
  generator=gen
98
  ).images[0]
99
 
100
+ # ===== NEW: Function to unload all models =====
101
+ def unload_all_models():
102
+ global PIPE_STANDARD, PIPE_ANIME, LINEART_DETECTOR, LINEART_ANIME_DETECTOR
103
+ global CURRENT_T2I_PIPE, CURRENT_T2I_MODEL
104
+
105
+ print("Unloading all models from memory...")
106
+
107
+ # Unload lineart models
108
+ if PIPE_STANDARD is not None:
109
+ del PIPE_STANDARD
110
+ PIPE_STANDARD = None
111
+ if PIPE_ANIME is not None:
112
+ del PIPE_ANIME
113
+ PIPE_ANIME = None
114
+ if LINEART_DETECTOR is not None:
115
+ del LINEART_DETECTOR
116
+ LINEART_DETECTOR = None
117
+ if LINEART_ANIME_DETECTOR is not None:
118
+ del LINEART_ANIME_DETECTOR
119
+ LINEART_ANIME_DETECTOR = None
120
+
121
+ # Unload T2I model
122
+ if CURRENT_T2I_PIPE is not None:
123
+ del CURRENT_T2I_PIPE
124
+ CURRENT_T2I_PIPE = None
125
+ CURRENT_T2I_MODEL = None
126
+
127
+ # Force garbage collection
128
+ gc.collect()
129
+ if torch.cuda.is_available():
130
+ torch.cuda.empty_cache()
131
+ print(f"GPU memory cleared. Allocated: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
132
+
133
+ return "✅ All models unloaded from memory!"
134
+
135
  # ===== Gradio UI (Minimal) =====
136
  with gr.Blocks() as demo:
137
  gr.Markdown("# 🎨 Minimal Style2Paints")
138
+
139
+ # Add unload button at the top
140
+ with gr.Row():
141
+ unload_btn = gr.Button("🗑️ Unload All Models", variant="stop")
142
+ status_text = gr.Textbox(label="Status", interactive=False)
143
+ unload_btn.click(unload_all_models, outputs=status_text)
144
 
145
  with gr.Tab("🎨 Colorize"):
146
  with gr.Row():