|
14 | 14 | ENABLE_ATTENTION_SLICING,
|
15 | 15 | ENABLE_FLASH_ATTENTION,
|
16 | 16 | ENABLE_BETTERTRANSFORMER,
|
17 |
| - ENABLE_CPU_OFFLOADING |
| 17 | + ENABLE_CPU_OFFLOADING, |
| 18 | + NGROK_TOKEN_ENV, |
| 19 | + HF_TOKEN_ENV, |
| 20 | + get_env_var, |
| 21 | + set_env_var |
18 | 22 | )
|
19 | 23 |
|
20 | 24 | def is_in_colab() -> bool:
|
@@ -147,167 +151,53 @@ def prompt_for_config(use_ngrok: bool = None, port: int = None, ngrok_auth_token
|
147 | 151 | config["use_ngrok"] = use_ngrok
|
148 | 152 |
|
149 | 153 | if use_ngrok:
|
| 154 | + # Show current token if exists |
| 155 | + current_token = config.get("ngrok_auth_token") or get_env_var(NGROK_TOKEN_ENV) |
| 156 | + if current_token: |
| 157 | + click.echo(f"\nCurrent ngrok token: {current_token}") |
| 158 | + |
150 | 159 | ngrok_auth_token = click.prompt(
|
151 |
| - "🔑 Please enter your ngrok auth token (get one at https://dashboard.ngrok.com/get-started/your-authtoken)", |
152 |
| - default=config.get("ngrok_auth_token", ""), |
153 |
| - hide_input=True |
154 |
| - ) |
155 |
| - if ngrok_auth_token: |
156 |
| - os.environ["NGROK_AUTH_TOKEN"] = ngrok_auth_token |
157 |
| - config["ngrok_auth_token"] = ngrok_auth_token |
158 |
| - |
159 |
| - # Ask about optimizations |
160 |
| - setup_optimizations = click.confirm( |
161 |
| - "⚡ Would you like to configure optimizations for better performance?", |
162 |
| - default=True |
163 |
| - ) |
164 |
| - |
165 |
| - if setup_optimizations: |
166 |
| - # Quantization |
167 |
| - enable_quantization = click.confirm( |
168 |
| - "📊 Enable quantization for reduced memory usage?", |
169 |
| - default=config.get("enable_quantization", ENABLE_QUANTIZATION) |
170 |
| - ) |
171 |
| - os.environ["LOCALLAB_ENABLE_QUANTIZATION"] = str(enable_quantization).lower() |
172 |
| - config["enable_quantization"] = enable_quantization |
173 |
| - |
174 |
| - if enable_quantization: |
175 |
| - quant_type = click.prompt( |
176 |
| - "📊 Quantization type", |
177 |
| - type=click.Choice(["int8", "int4"]), |
178 |
| - default=config.get("quantization_type", QUANTIZATION_TYPE or "int8") |
179 |
| - ) |
180 |
| - os.environ["LOCALLAB_QUANTIZATION_TYPE"] = quant_type |
181 |
| - config["quantization_type"] = quant_type |
182 |
| - |
183 |
| - # Attention slicing |
184 |
| - enable_attn_slicing = click.confirm( |
185 |
| - "🔪 Enable attention slicing for reduced memory usage?", |
186 |
| - default=config.get("enable_attention_slicing", ENABLE_ATTENTION_SLICING) |
187 |
| - ) |
188 |
| - os.environ["LOCALLAB_ENABLE_ATTENTION_SLICING"] = str(enable_attn_slicing).lower() |
189 |
| - config["enable_attention_slicing"] = enable_attn_slicing |
190 |
| - |
191 |
| - # Flash attention |
192 |
| - enable_flash_attn = click.confirm( |
193 |
| - "⚡ Enable flash attention for faster inference?", |
194 |
| - default=config.get("enable_flash_attention", ENABLE_FLASH_ATTENTION) |
195 |
| - ) |
196 |
| - os.environ["LOCALLAB_ENABLE_FLASH_ATTENTION"] = str(enable_flash_attn).lower() |
197 |
| - config["enable_flash_attention"] = enable_flash_attn |
198 |
| - |
199 |
| - # BetterTransformer |
200 |
| - enable_better_transformer = click.confirm( |
201 |
| - "🔄 Enable BetterTransformer for optimized inference?", |
202 |
| - default=config.get("enable_better_transformer", ENABLE_BETTERTRANSFORMER) |
203 |
| - ) |
204 |
| - os.environ["LOCALLAB_ENABLE_BETTERTRANSFORMER"] = str(enable_better_transformer).lower() |
205 |
| - config["enable_better_transformer"] = enable_better_transformer |
206 |
| - |
207 |
| - # Ask about advanced options |
208 |
| - setup_advanced = click.confirm( |
209 |
| - "🔧 Would you like to configure advanced options?", |
210 |
| - default=False |
211 |
| - ) |
212 |
| - |
213 |
| - if setup_advanced: |
214 |
| - # CPU offloading |
215 |
| - enable_cpu_offloading = click.confirm( |
216 |
| - "💻 Enable CPU offloading for large models?", |
217 |
| - default=config.get("enable_cpu_offloading", ENABLE_CPU_OFFLOADING) |
| 160 | + "🔑 Enter your ngrok auth token (get one at https://dashboard.ngrok.com/get-started/your-authtoken)", |
| 161 | + default=current_token, |
| 162 | + type=str, |
| 163 | + show_default=True |
218 | 164 | )
|
219 |
| - os.environ["LOCALLAB_ENABLE_CPU_OFFLOADING"] = str(enable_cpu_offloading).lower() |
220 |
| - config["enable_cpu_offloading"] = enable_cpu_offloading |
221 | 165 |
|
222 |
| - # Model timeout |
223 |
| - model_timeout = click.prompt( |
224 |
| - "⏱️ Model unloading timeout in seconds (0 to disable)", |
225 |
| - default=config.get("model_timeout", 3600), |
226 |
| - type=int |
227 |
| - ) |
228 |
| - os.environ["LOCALLAB_MODEL_TIMEOUT"] = str(model_timeout) |
229 |
| - config["model_timeout"] = model_timeout |
230 |
| - |
231 |
| - # Cache settings |
232 |
| - enable_cache = click.confirm( |
233 |
| - "🔄 Enable response caching?", |
234 |
| - default=config.get("enable_cache", True) |
235 |
| - ) |
236 |
| - os.environ["LOCALLAB_ENABLE_CACHE"] = str(enable_cache).lower() |
237 |
| - config["enable_cache"] = enable_cache |
238 |
| - |
239 |
| - if enable_cache: |
240 |
| - cache_ttl = click.prompt( |
241 |
| - "⏱️ Cache TTL in seconds", |
242 |
| - default=config.get("cache_ttl", 3600), |
243 |
| - type=int |
244 |
| - ) |
245 |
| - os.environ["LOCALLAB_CACHE_TTL"] = str(cache_ttl) |
246 |
| - config["cache_ttl"] = cache_ttl |
247 |
| - |
248 |
| - # Logging settings |
249 |
| - log_level = click.prompt( |
250 |
| - "📝 Log level", |
251 |
| - type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]), |
252 |
| - default=config.get("log_level", "INFO") |
253 |
| - ) |
254 |
| - os.environ["LOCALLAB_LOG_LEVEL"] = log_level |
255 |
| - config["log_level"] = log_level |
256 |
| - |
257 |
| - enable_file_logging = click.confirm( |
258 |
| - "📄 Enable file logging?", |
259 |
| - default=config.get("enable_file_logging", False) |
260 |
| - ) |
261 |
| - os.environ["LOCALLAB_ENABLE_FILE_LOGGING"] = str(enable_file_logging).lower() |
262 |
| - config["enable_file_logging"] = enable_file_logging |
263 |
| - |
264 |
| - if enable_file_logging: |
265 |
| - log_file = click.prompt( |
266 |
| - "📄 Log file path", |
267 |
| - default=config.get("log_file", "locallab.log") |
268 |
| - ) |
269 |
| - os.environ["LOCALLAB_LOG_FILE"] = log_file |
270 |
| - config["log_file"] = log_file |
271 |
| - |
272 |
| - # Ask about HuggingFace token with improved UX |
273 |
| - hf_token = config.get("huggingface_token") or os.environ.get("HUGGINGFACE_TOKEN") |
274 |
| - if not hf_token or force_reconfigure: |
275 |
| - click.echo("\n�� HuggingFace Token Configuration") |
| 166 | + if ngrok_auth_token: |
| 167 | + token_str = str(ngrok_auth_token).strip() |
| 168 | + config["ngrok_auth_token"] = token_str |
| 169 | + set_env_var(NGROK_TOKEN_ENV, token_str) |
| 170 | + click.echo(f"✅ Ngrok token saved: {token_str}") |
| 171 | + |
| 172 | + # Ask about HuggingFace token |
| 173 | + current_hf_token = config.get("huggingface_token") or get_env_var(HF_TOKEN_ENV) |
| 174 | + if current_hf_token: |
| 175 | + click.echo(f"\nCurrent HuggingFace token: {current_hf_token}") |
| 176 | + |
| 177 | + if not current_hf_token or force_reconfigure: |
| 178 | + click.echo("\n🔑 HuggingFace Token Configuration") |
276 | 179 | click.echo("───────────────────────────────")
|
277 | 180 | click.echo("A token is required to download models like microsoft/phi-2")
|
278 | 181 | click.echo("Get your token from: https://huggingface.co/settings/tokens")
|
279 | 182 |
|
280 |
| - if hf_token: |
281 |
| - click.echo(f"\nCurrent token: {hf_token[:4]}...{hf_token[-4:]}") |
282 |
| - if not click.confirm("Would you like to update your token?", default=False): |
283 |
| - click.echo("Keeping existing token...") |
284 |
| - return config |
285 |
| - |
286 |
| - click.echo("\nEnter your HuggingFace token (press Enter to skip): ", nl=False) |
287 |
| - |
288 |
| - # Read token character by character for secure input |
289 |
| - chars = [] |
290 |
| - while True: |
291 |
| - char = click.getchar() |
292 |
| - if char in ('\r', '\n'): |
293 |
| - break |
294 |
| - chars.append(char) |
295 |
| - click.echo('*', nl=False) |
296 |
| - |
297 |
| - hf_token = ''.join(chars) |
| 183 | + hf_token = click.prompt( |
| 184 | + "Enter your HuggingFace token", |
| 185 | + default=current_hf_token, |
| 186 | + type=str, |
| 187 | + show_default=True |
| 188 | + ) |
298 | 189 |
|
299 | 190 | if hf_token:
|
300 |
| - # Validate token format |
301 |
| - if len(hf_token) < 20: # Basic validation |
302 |
| - click.echo("\n❌ Invalid token format. Token should be longer than 20 characters.") |
303 |
| - click.echo("Please check your token and try again.") |
| 191 | + if len(hf_token) < 20: |
| 192 | + click.echo("❌ Invalid token format. Token should be longer than 20 characters.") |
304 | 193 | return config
|
305 |
| - |
306 |
| - click.echo("\n✅ Token saved successfully!") |
307 |
| - os.environ["HUGGINGFACE_TOKEN"] = hf_token |
308 |
| - config["huggingface_token"] = hf_token |
309 | 194 |
|
310 |
| - # Save immediately to ensure it's persisted |
| 195 | + token_str = str(hf_token).strip() |
| 196 | + config["huggingface_token"] = token_str |
| 197 | + set_env_var(HF_TOKEN_ENV, token_str) |
| 198 | + click.echo(f"✅ HuggingFace token saved: {token_str}") |
| 199 | + |
| 200 | + # Save immediately |
311 | 201 | from .config import save_config
|
312 | 202 | save_config(config)
|
313 | 203 | else:
|
|
0 commit comments