diff --git a/Cargo.lock b/Cargo.lock index 0ded37a9..bdae9630 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1122,18 +1122,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.116" +version = "1.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96fe57af81d28386a513cbc6858332abc6117cfdb5999647c6444b8f43a370a5" +checksum = "b88fa983de7720629c9387e9f517353ed404164b1e482c970a90c1a4aaf7dc1a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.116" +version = "1.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f630a6370fd8e457873b4bd2ffdae75408bc291ba72be773772a4c2a065d9ae8" +checksum = "cbd1ae72adb44aab48f325a02444a5fc079349a8d804c1fc922aed3f7454c74e" dependencies = [ "proc-macro2", "quote 1.0.7", @@ -1210,7 +1210,7 @@ dependencies = [ "rayon", "regex", "rust-ini", - "serde_derive", + "serde", "serde_json", "shell-words", "starship_module_config_derive", diff --git a/Cargo.toml b/Cargo.toml index 23facdb8..5e348a9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ ansi_term = "0.12.1" dirs-next = "2.0.0" git2 = { version = "0.13.12", default-features = false } toml = { version = "0.5.7", features = ["preserve_order"] } -rust-ini = "0.16" +rust-ini = "0.16" serde_json = "1.0.59" rayon = "1.5.0" log = { version = "0.4.11", features = ["std"] } @@ -59,7 +59,7 @@ unicode-width = "0.1.8" term_size = "0.3.2" quick-xml = "0.20.0" rand = "0.7.3" -serde_derive = "1.0.115" +serde = { version = "1.0.117", features = ["derive"] } indexmap = "1.6.0" notify-rust = { version = "4.0.0", optional = true } diff --git a/src/modules/rust.rs b/src/modules/rust.rs index c72be0ec..00b029a8 100644 --- a/src/modules/rust.rs +++ b/src/modules/rust.rs @@ -2,6 +2,8 @@ use std::fs; use std::path::Path; use std::process::{Command, Output}; +use serde::Deserialize; + use super::{Context, Module, RootModuleConfig}; use crate::configs::rust::RustConfig; @@ -125,26 +127,47 @@ fn extract_toolchain_from_rustup_override_list(stdout: &str, cwd: &Path) -> Opti fn find_rust_toolchain_file(context: &Context) -> Option { // Look for 'rust-toolchain' as rustup does. - // https://github.com/rust-lang/rustup.rs/blob/d84e6e50126bccd84649e42482fc35a11d019401/src/config.rs#L320-L358 + // https://github.com/rust-lang/rustup/blob/89912c4cf51645b9c152ab7380fd07574fec43a3/src/config.rs#L546-L616 - fn read_first_line(path: &Path) -> Option { - let content = fs::read_to_string(path).ok()?; - let line = content.lines().next()?; - Some(line.trim().to_owned()) + #[derive(Deserialize)] + struct OverrideFile { + toolchain: ToolchainSection, + } + + #[derive(Deserialize)] + struct ToolchainSection { + channel: Option, + } + + fn read_channel(path: &Path) -> Option { + let contents = fs::read_to_string(path).ok()?; + + match contents.lines().count() { + 0 => None, + 1 => Some(contents), + _ => { + toml::from_str::(&contents) + .ok()? + .toolchain + .channel + } + } + .filter(|c| !c.trim().is_empty()) + .map(|c| c.trim().to_owned()) } if let Ok(true) = context .dir_contents() .map(|dir| dir.has_file("rust-toolchain")) { - if let Some(toolchain) = read_first_line(Path::new("rust-toolchain")) { + if let Some(toolchain) = read_channel(Path::new("rust-toolchain")) { return Some(toolchain); } } let mut dir = &*context.current_dir; loop { - if let Some(toolchain) = read_first_line(&dir.join("rust-toolchain")) { + if let Some(toolchain) = read_channel(&dir.join("rust-toolchain")) { return Some(toolchain); } dir = dir.parent()?; @@ -200,6 +223,7 @@ enum RustupRunRustcVersionOutcome { #[cfg(test)] mod tests { use once_cell::sync::Lazy; + use std::io; use std::process::{ExitStatus, Output}; use super::*; @@ -309,4 +333,46 @@ mod tests { let version_without_hash = String::from("rustc 1.34.0"); assert_eq!(format_rustc_version(version_without_hash), "v1.34.0"); } + + #[test] + fn test_find_rust_toolchain_file() -> io::Result<()> { + let dir = tempfile::tempdir()?; + fs::write(dir.path().join("rust-toolchain"), "1.34.0")?; + + let context = Context::new_with_dir(Default::default(), dir.path()); + + assert_eq!( + find_rust_toolchain_file(&context), + Some("1.34.0".to_owned()) + ); + dir.close()?; + + let dir = tempfile::tempdir()?; + fs::write( + dir.path().join("rust-toolchain"), + "[toolchain]\nchannel = \"1.34.0\"", + )?; + + let context = Context::new_with_dir(Default::default(), dir.path()); + + assert_eq!( + find_rust_toolchain_file(&context), + Some("1.34.0".to_owned()) + ); + dir.close()?; + + let dir = tempfile::tempdir()?; + fs::write( + dir.path().join("rust-toolchain"), + "\n\n[toolchain]\n\n\nchannel = \"1.34.0\"", + )?; + + let context = Context::new_with_dir(Default::default(), dir.path()); + + assert_eq!( + find_rust_toolchain_file(&context), + Some("1.34.0".to_owned()) + ); + dir.close() + } }