|
8 | 8 |
|
9 | 9 | use anyhow::{Context, Result}; |
10 | 10 | use clap::Parser; |
| 11 | +use cortex_engine::config::find_cortex_home; |
11 | 12 | use std::collections::HashMap; |
12 | 13 | use std::path::PathBuf; |
13 | 14 |
|
@@ -192,9 +193,7 @@ impl StatsCli { |
192 | 193 |
|
193 | 194 | /// Get the cortex home directory. |
194 | 195 | fn get_cortex_home() -> PathBuf { |
195 | | - dirs::home_dir() |
196 | | - .map(|h| h.join(".cortex")) |
197 | | - .unwrap_or_else(|| PathBuf::from(".cortex")) |
| 196 | + find_cortex_home().unwrap_or_else(|_| PathBuf::from(".cortex")) |
198 | 197 | } |
199 | 198 |
|
200 | 199 | /// Get pricing for a model. |
@@ -700,6 +699,58 @@ fn format_cost(cost: f64) -> String { |
700 | 699 | #[cfg(test)] |
701 | 700 | mod tests { |
702 | 701 | use super::*; |
| 702 | + use serial_test::serial; |
| 703 | + use std::env; |
| 704 | + use std::ffi::{OsStr, OsString}; |
| 705 | + use tempfile::TempDir; |
| 706 | + |
| 707 | + struct EnvVarGuard { |
| 708 | + key: &'static str, |
| 709 | + original: Option<OsString>, |
| 710 | + } |
| 711 | + |
| 712 | + impl EnvVarGuard { |
| 713 | + fn set(key: &'static str, value: impl AsRef<OsStr>) -> Self { |
| 714 | + let original = env::var_os(key); |
| 715 | + // SAFETY: These tests are serialized and restore the environment on drop. |
| 716 | + unsafe { |
| 717 | + env::set_var(key, value); |
| 718 | + } |
| 719 | + Self { key, original } |
| 720 | + } |
| 721 | + |
| 722 | + fn remove(key: &'static str) -> Self { |
| 723 | + let original = env::var_os(key); |
| 724 | + // SAFETY: These tests are serialized and restore the environment on drop. |
| 725 | + unsafe { |
| 726 | + env::remove_var(key); |
| 727 | + } |
| 728 | + Self { key, original } |
| 729 | + } |
| 730 | + } |
| 731 | + |
| 732 | + impl Drop for EnvVarGuard { |
| 733 | + fn drop(&mut self) { |
| 734 | + // SAFETY: These tests are serialized and restore the environment before returning. |
| 735 | + unsafe { |
| 736 | + match &self.original { |
| 737 | + Some(value) => env::set_var(self.key, value), |
| 738 | + None => env::remove_var(self.key), |
| 739 | + } |
| 740 | + } |
| 741 | + } |
| 742 | + } |
| 743 | + |
| 744 | + #[test] |
| 745 | + #[serial] |
| 746 | + fn test_get_cortex_home_uses_cortex_home_env() { |
| 747 | + let temp_dir = TempDir::new().unwrap(); |
| 748 | + let cortex_home = temp_dir.path().join("custom-cortex-home"); |
| 749 | + let _config_dir = EnvVarGuard::remove("CORTEX_CONFIG_DIR"); |
| 750 | + let _cortex_home = EnvVarGuard::set("CORTEX_HOME", &cortex_home); |
| 751 | + |
| 752 | + assert_eq!(get_cortex_home(), cortex_home); |
| 753 | + } |
703 | 754 |
|
704 | 755 | #[test] |
705 | 756 | fn test_format_number() { |
|
0 commit comments