From 21dfe7f2fcee7181e515bb48d7244f7b05906db6 Mon Sep 17 00:00:00 2001 From: David Knaack Date: Fri, 8 Apr 2022 06:56:14 +0200 Subject: [PATCH] refactor(aws): improve parsing of config files (#3842) --- src/modules/aws.rs | 199 +++++++++++++++++++++++---------------------- 1 file changed, 103 insertions(+), 96 deletions(-) diff --git a/src/modules/aws.rs b/src/modules/aws.rs index 0a941bfa..3db65c56 100644 --- a/src/modules/aws.rs +++ b/src/modules/aws.rs @@ -3,15 +3,19 @@ use std::path::PathBuf; use std::str::FromStr; use chrono::DateTime; +use ini::Ini; +use once_cell::unsync::OnceCell; use super::{Context, Module, ModuleConfig}; use crate::configs::aws::AwsConfig; use crate::formatter::StringFormatter; -use crate::utils::{read_file, render_time}; +use crate::utils::render_time; type Profile = String; type Region = String; +type AwsConfigFile = OnceCell>; +type AwsCredsFile = OnceCell>; fn get_credentials_file_path(context: &Context) -> Option { context @@ -35,36 +39,65 @@ fn get_config_file_path(context: &Context) -> Option { }) } -fn get_aws_region_from_config(context: &Context, aws_profile: Option<&str>) -> Option { - let config_location = get_config_file_path(context)?; - - let contents = read_file(&config_location).ok()?; - - let region_line = if let Some(aws_profile) = aws_profile { - contents - .lines() - .skip_while(|line| line != &format!("[profile {}]", &aws_profile)) - .skip(1) - .take_while(|line| !line.starts_with('[')) - .find(|line| line.starts_with("region")) - } else { - contents - .lines() - .skip_while(|&line| line != "[default]") - .skip(1) - .take_while(|line| !line.starts_with('[')) - .find(|line| line.starts_with("region")) - }?; - - let region = region_line.split('=').nth(1)?; - let region = region.trim(); - - Some(region.to_string()) +// Initialize the AWS config file once +fn get_config<'a>(context: &Context, config: &'a OnceCell>) -> Option<&'a Ini> { + config + .get_or_init(|| { + let path = get_config_file_path(context)?; + Ini::load_from_file(path).ok() + }) + .as_ref() } -fn get_aws_profile_and_region(context: &Context) -> (Option, Option) { - let profile_env_vars = vec!["AWSU_PROFILE", "AWS_VAULT", "AWSUME_PROFILE", "AWS_PROFILE"]; - let region_env_vars = vec!["AWS_REGION", "AWS_DEFAULT_REGION"]; +// Initialize the AWS credentials file once +fn get_creds<'a>(context: &Context, config: &'a OnceCell>) -> Option<&'a Ini> { + config + .get_or_init(|| { + let path = get_credentials_file_path(context)?; + Ini::load_from_file(path).ok() + }) + .as_ref() +} + +// Get the section for a given profile name in the config file. +fn get_profile_config<'a>( + config: &'a Ini, + profile: &Option, +) -> Option<&'a ini::Properties> { + match profile { + Some(profile) => config.section(Some(format!("profile {}", profile))), + None => config.section(Some("default")), + } +} + +// Get the section for a given profile name in the credentials file. +fn get_profile_creds<'a>( + config: &'a Ini, + profile: &Option, +) -> Option<&'a ini::Properties> { + match profile { + None => config.section(Some("default")), + _ => config.section(profile.as_ref()), + } +} + +fn get_aws_region_from_config( + context: &Context, + aws_profile: &Option, + aws_config: &AwsConfigFile, +) -> Option { + let config = get_config(context, aws_config)?; + let section = get_profile_config(config, aws_profile)?; + + section.get("region").map(|region| region.to_owned()) +} + +fn get_aws_profile_and_region( + context: &Context, + aws_config: &AwsConfigFile, +) -> (Option, Option) { + let profile_env_vars = ["AWSU_PROFILE", "AWS_VAULT", "AWSUME_PROFILE", "AWS_PROFILE"]; + let region_env_vars = ["AWS_REGION", "AWS_DEFAULT_REGION"]; let profile = profile_env_vars .iter() .find_map(|env_var| context.get_env(env_var)); @@ -74,39 +107,32 @@ fn get_aws_profile_and_region(context: &Context) -> (Option, Option (Some(p), Some(r)), (None, Some(r)) => (None, Some(r)), - (Some(ref p), None) => ( + (Some(p), None) => ( Some(p.clone()), - get_aws_region_from_config(context, Some(p)), + get_aws_region_from_config(context, &Some(p), aws_config), ), - (None, None) => (None, get_aws_region_from_config(context, None)), + (None, None) => (None, get_aws_region_from_config(context, &None, aws_config)), } } -fn get_credentials_duration(context: &Context, aws_profile: Option<&Profile>) -> Option { - let expiration_env_vars = vec!["AWS_SESSION_EXPIRATION", "AWSUME_EXPIRATION"]; +fn get_credentials_duration( + context: &Context, + aws_profile: &Option, + aws_creds: &AwsCredsFile, +) -> Option { + let expiration_env_vars = ["AWS_SESSION_EXPIRATION", "AWSUME_EXPIRATION"]; let expiration_date = if let Some(expiration_date) = expiration_env_vars .iter() .find_map(|env_var| context.get_env(env_var)) { chrono::DateTime::parse_from_rfc3339(&expiration_date).ok() } else { - let contents = read_file(get_credentials_file_path(context)?).ok()?; + let creds = get_creds(context, aws_creds)?; + let section = get_profile_creds(creds, aws_profile)?; - let profile_line = if let Some(aws_profile) = aws_profile { - format!("[{}]", aws_profile) - } else { - "[default]".to_string() - }; - - let expiration_date_line = contents - .lines() - .skip_while(|line| line != &profile_line) - .skip(1) - .take_while(|line| !line.starts_with('[')) - .find(|line| line.starts_with("expiration"))?; - - let expiration_date = expiration_date_line.split('=').nth(1)?.trim(); - DateTime::parse_from_rfc3339(expiration_date).ok() + section + .get("expiration") + .and_then(|expiration| DateTime::parse_from_rfc3339(expiration).ok()) }?; Some(expiration_date.timestamp() - chrono::Local::now().timestamp()) @@ -119,82 +145,63 @@ fn alias_name(name: Option, aliases: &HashMap) -> Option) -> bool { - let fp = match get_config_file_path(context) { - Some(fp) => fp, - None => return false, - }; - let contents = match read_file(fp) { - Ok(contents) => contents, - Err(_) => return false, - }; +fn has_credential_process_or_sso( + context: &Context, + aws_profile: &Option, + aws_config: &AwsConfigFile, +) -> Option { + let config = get_config(context, aws_config)?; - let profile_line = if let Some(aws_profile) = aws_profile { - format!("[profile {}]", aws_profile) - } else { - "[default]".to_string() - }; - - contents - .lines() - .skip_while(|line| line != &profile_line) - .skip(1) - .take_while(|line| !line.starts_with('[')) - .any(|line| line.starts_with("credential_process") || line.starts_with("sso_start_url")) + let section = get_profile_config(config, aws_profile)?; + Some(section.contains_key("credential_process") || section.contains_key("sso_start_url")) } -fn get_defined_credentials(context: &Context, aws_profile: Option<&Profile>) -> Option { - let valid_env_vars = vec![ +fn has_defined_credentials( + context: &Context, + aws_profile: &Option, + aws_creds: &AwsCredsFile, +) -> Option { + let valid_env_vars = [ "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", ]; // accept if set through environment variable - if let Some(aws_identity_cred) = valid_env_vars + if valid_env_vars .iter() - .find_map(|env_var| context.get_env(env_var)) + .any(|env_var| context.get_env(env_var).is_some()) { - return Some(aws_identity_cred); + return Some(true); } - let contents = read_file(get_credentials_file_path(context)?).ok()?; - - let profile_line = if let Some(aws_profile) = aws_profile { - format!("[{}]", aws_profile) - } else { - "[default]".to_string() - }; - - let aws_key_id_line = contents - .lines() - .skip_while(|line| line != &profile_line) - .skip(1) - .take_while(|line| !line.starts_with('[')) - .find(|line| line.starts_with("aws_access_key_id"))?; - let aws_key_id = aws_key_id_line.split('=').nth(1)?.trim(); - Some(aws_key_id.to_string()) + let creds = get_creds(context, aws_creds)?; + let section = get_profile_creds(creds, aws_profile)?; + Some(section.contains_key("aws_access_key_id")) } pub fn module<'a>(context: &'a Context) -> Option> { let mut module = context.new_module("aws"); let config: AwsConfig = AwsConfig::try_load(module.config); - let (aws_profile, aws_region) = get_aws_profile_and_region(context); + let aws_config = OnceCell::new(); + let aws_creds = OnceCell::new(); + + let (aws_profile, aws_region) = get_aws_profile_and_region(context, &aws_config); if aws_profile.is_none() && aws_region.is_none() { return None; } // only display if credential_process is defined or has valid credentials if !config.force_display - && !has_credential_process_or_sso(context, aws_profile.as_ref()) - && get_defined_credentials(context, aws_profile.as_ref()).is_none() + && !has_credential_process_or_sso(context, &aws_profile, &aws_config).unwrap_or(false) + && !has_defined_credentials(context, &aws_profile, &aws_creds).unwrap_or(false) { return None; } let duration = { - get_credentials_duration(context, aws_profile.as_ref()).map(|duration| { + get_credentials_duration(context, &aws_profile, &aws_creds).map(|duration| { if duration > 0 { render_time((duration * 1000) as u128, false) } else {