From c8ac8777a593358868813254c662da5fcb9fe6c8 Mon Sep 17 00:00:00 2001 From: Chris Rose Date: Sun, 27 Nov 2022 06:06:05 -0800 Subject: [PATCH] fix(aws): enable when using .aws/credentials (#4604) --- src/modules/aws.rs | 80 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/src/modules/aws.rs b/src/modules/aws.rs index 5083217b..eb9566e1 100644 --- a/src/modules/aws.rs +++ b/src/modules/aws.rs @@ -63,7 +63,7 @@ fn get_creds<'a>(context: &Context, config: &'a OnceCell>) -> Option // Get the section for a given profile name in the config file. fn get_profile_config<'a>( config: &'a Ini, - profile: &Option, + profile: Option<&Profile>, ) -> Option<&'a ini::Properties> { match profile { Some(profile) => config.section(Some(format!("profile {profile}"))), @@ -74,11 +74,11 @@ fn get_profile_config<'a>( // Get the section for a given profile name in the credentials file. fn get_profile_creds<'a>( config: &'a Ini, - profile: &Option, + profile: Option<&Profile>, ) -> Option<&'a ini::Properties> { match profile { None => config.section(Some("default")), - _ => config.section(profile.as_ref()), + _ => config.section(profile), } } @@ -88,7 +88,7 @@ fn get_aws_region_from_config( aws_config: &AwsConfigFile, ) -> Option { let config = get_config(context, aws_config)?; - let section = get_profile_config(config, aws_profile)?; + let section = get_profile_config(config, aws_profile.as_ref())?; section.get("region").map(std::borrow::ToOwned::to_owned) } @@ -118,7 +118,7 @@ fn get_aws_profile_and_region( fn get_credentials_duration( context: &Context, - aws_profile: &Option, + aws_profile: Option<&Profile>, aws_creds: &AwsCredsFile, ) -> Option { let expiration_env_vars = ["AWS_SESSION_EXPIRATION", "AWSUME_EXPIRATION"]; @@ -150,18 +150,35 @@ fn alias_name(name: Option, aliases: &HashMap) -> Option, + aws_profile: Option<&Profile>, aws_config: &AwsConfigFile, + aws_creds: &AwsCredsFile, ) -> Option { let config = get_config(context, aws_config)?; + let credentials = get_creds(context, aws_creds); - let section = get_profile_config(config, aws_profile)?; - Some(section.contains_key("credential_process") || section.contains_key("sso_start_url")) + let empty_section = ini::Properties::new(); + // We use the aws_profile here because `get_profile_config()` treats None + // as "special" and falls back to the "[default]"; otherwise this tries + // to look up "[profile default]" which doesn't exist + let config_section = get_profile_config(config, aws_profile).or(Some(&empty_section))?; + + let credential_section = match credentials { + Some(credentials) => get_profile_creds(credentials, aws_profile), + None => None, + }; + + Some( + config_section.contains_key("credential_process") + || config_section.contains_key("sso_start_url") + || credential_section?.contains_key("credential_process") + || credential_section?.contains_key("sso_start_url"), + ) } fn has_defined_credentials( context: &Context, - aws_profile: &Option, + aws_profile: Option<&Profile>, aws_creds: &AwsCredsFile, ) -> Option { let valid_env_vars = [ @@ -197,14 +214,15 @@ pub fn module<'a>(context: &'a Context) -> Option> { // only display if credential_process is defined or has valid credentials if !config.force_display - && !has_credential_process_or_sso(context, &aws_profile, &aws_config).unwrap_or(false) - && !has_defined_credentials(context, &aws_profile, &aws_creds).unwrap_or(false) + && !has_credential_process_or_sso(context, aws_profile.as_ref(), &aws_config, &aws_creds) + .unwrap_or(false) + && !has_defined_credentials(context, aws_profile.as_ref(), &aws_creds).unwrap_or(false) { return None; } let duration = { - get_credentials_duration(context, &aws_profile, &aws_creds).map(|duration| { + get_credentials_duration(context, aws_profile.as_ref(), &aws_creds).map(|duration| { if duration > 0 { render_time((duration * 1000) as u128, false) } else { @@ -899,6 +917,44 @@ credential_process = /opt/bin/awscreds-retriever dir.close() } + #[test] + fn credential_process_set_in_credentials() -> io::Result<()> { + let dir = tempfile::tempdir()?; + let config_path = dir.path().join("config"); + let credential_path = dir.path().join("credentials"); + let mut file = File::create(&config_path)?; + + file.write_all( + "[default] +region = ap-northeast-2 +" + .as_bytes(), + )?; + + let mut file = File::create(&credential_path)?; + + file.write_all( + "[default] +credential_process = /opt/bin/awscreds-for-tests +" + .as_bytes(), + )?; + let actual = ModuleRenderer::new("aws") + .env("AWS_CONFIG_FILE", config_path.to_string_lossy().as_ref()) + .env( + "AWS_CREDENTIALS_FILE", + credential_path.to_string_lossy().as_ref(), + ) + .collect(); + let expected = Some(format!( + "on {}", + Color::Yellow.bold().paint("☁️ (ap-northeast-2) ") + )); + + assert_eq!(expected, actual); + dir.close() + } + #[test] fn sso_set() -> io::Result<()> { let dir = tempfile::tempdir()?;