1
0
mirror of https://github.com/Llewellynvdm/starship.git synced 2024-11-28 15:56:28 +00:00

refactor(aws): improve parsing of config files (#3842)

This commit is contained in:
David Knaack 2022-04-08 06:56:14 +02:00 committed by GitHub
parent 4b7275efc4
commit 21dfe7f2fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,15 +3,19 @@ use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use chrono::DateTime; use chrono::DateTime;
use ini::Ini;
use once_cell::unsync::OnceCell;
use super::{Context, Module, ModuleConfig}; use super::{Context, Module, ModuleConfig};
use crate::configs::aws::AwsConfig; use crate::configs::aws::AwsConfig;
use crate::formatter::StringFormatter; use crate::formatter::StringFormatter;
use crate::utils::{read_file, render_time}; use crate::utils::render_time;
type Profile = String; type Profile = String;
type Region = String; type Region = String;
type AwsConfigFile = OnceCell<Option<Ini>>;
type AwsCredsFile = OnceCell<Option<Ini>>;
fn get_credentials_file_path(context: &Context) -> Option<PathBuf> { fn get_credentials_file_path(context: &Context) -> Option<PathBuf> {
context context
@ -35,36 +39,65 @@ fn get_config_file_path(context: &Context) -> Option<PathBuf> {
}) })
} }
fn get_aws_region_from_config(context: &Context, aws_profile: Option<&str>) -> Option<Region> { // Initialize the AWS config file once
let config_location = get_config_file_path(context)?; fn get_config<'a>(context: &Context, config: &'a OnceCell<Option<Ini>>) -> Option<&'a Ini> {
config
let contents = read_file(&config_location).ok()?; .get_or_init(|| {
let path = get_config_file_path(context)?;
let region_line = if let Some(aws_profile) = aws_profile { Ini::load_from_file(path).ok()
contents })
.lines() .as_ref()
.skip_while(|line| line != &format!("[profile {}]", &aws_profile))
.skip(1)
.take_while(|line| !line.starts_with('['))
.find(|line| line.starts_with("region"))
} else {
contents
.lines()
.skip_while(|&line| line != "[default]")
.skip(1)
.take_while(|line| !line.starts_with('['))
.find(|line| line.starts_with("region"))
}?;
let region = region_line.split('=').nth(1)?;
let region = region.trim();
Some(region.to_string())
} }
fn get_aws_profile_and_region(context: &Context) -> (Option<Profile>, Option<Region>) { // Initialize the AWS credentials file once
let profile_env_vars = vec!["AWSU_PROFILE", "AWS_VAULT", "AWSUME_PROFILE", "AWS_PROFILE"]; fn get_creds<'a>(context: &Context, config: &'a OnceCell<Option<Ini>>) -> Option<&'a Ini> {
let region_env_vars = vec!["AWS_REGION", "AWS_DEFAULT_REGION"]; config
.get_or_init(|| {
let path = get_credentials_file_path(context)?;
Ini::load_from_file(path).ok()
})
.as_ref()
}
// Get the section for a given profile name in the config file.
fn get_profile_config<'a>(
config: &'a Ini,
profile: &Option<Profile>,
) -> Option<&'a ini::Properties> {
match profile {
Some(profile) => config.section(Some(format!("profile {}", profile))),
None => config.section(Some("default")),
}
}
// Get the section for a given profile name in the credentials file.
fn get_profile_creds<'a>(
config: &'a Ini,
profile: &Option<Profile>,
) -> Option<&'a ini::Properties> {
match profile {
None => config.section(Some("default")),
_ => config.section(profile.as_ref()),
}
}
fn get_aws_region_from_config(
context: &Context,
aws_profile: &Option<Profile>,
aws_config: &AwsConfigFile,
) -> Option<Region> {
let config = get_config(context, aws_config)?;
let section = get_profile_config(config, aws_profile)?;
section.get("region").map(|region| region.to_owned())
}
fn get_aws_profile_and_region(
context: &Context,
aws_config: &AwsConfigFile,
) -> (Option<Profile>, Option<Region>) {
let profile_env_vars = ["AWSU_PROFILE", "AWS_VAULT", "AWSUME_PROFILE", "AWS_PROFILE"];
let region_env_vars = ["AWS_REGION", "AWS_DEFAULT_REGION"];
let profile = profile_env_vars let profile = profile_env_vars
.iter() .iter()
.find_map(|env_var| context.get_env(env_var)); .find_map(|env_var| context.get_env(env_var));
@ -74,39 +107,32 @@ fn get_aws_profile_and_region(context: &Context) -> (Option<Profile>, Option<Reg
match (profile, region) { match (profile, region) {
(Some(p), Some(r)) => (Some(p), Some(r)), (Some(p), Some(r)) => (Some(p), Some(r)),
(None, Some(r)) => (None, Some(r)), (None, Some(r)) => (None, Some(r)),
(Some(ref p), None) => ( (Some(p), None) => (
Some(p.clone()), Some(p.clone()),
get_aws_region_from_config(context, Some(p)), get_aws_region_from_config(context, &Some(p), aws_config),
), ),
(None, None) => (None, get_aws_region_from_config(context, None)), (None, None) => (None, get_aws_region_from_config(context, &None, aws_config)),
} }
} }
fn get_credentials_duration(context: &Context, aws_profile: Option<&Profile>) -> Option<i64> { fn get_credentials_duration(
let expiration_env_vars = vec!["AWS_SESSION_EXPIRATION", "AWSUME_EXPIRATION"]; context: &Context,
aws_profile: &Option<String>,
aws_creds: &AwsCredsFile,
) -> Option<i64> {
let expiration_env_vars = ["AWS_SESSION_EXPIRATION", "AWSUME_EXPIRATION"];
let expiration_date = if let Some(expiration_date) = expiration_env_vars let expiration_date = if let Some(expiration_date) = expiration_env_vars
.iter() .iter()
.find_map(|env_var| context.get_env(env_var)) .find_map(|env_var| context.get_env(env_var))
{ {
chrono::DateTime::parse_from_rfc3339(&expiration_date).ok() chrono::DateTime::parse_from_rfc3339(&expiration_date).ok()
} else { } else {
let contents = read_file(get_credentials_file_path(context)?).ok()?; let creds = get_creds(context, aws_creds)?;
let section = get_profile_creds(creds, aws_profile)?;
let profile_line = if let Some(aws_profile) = aws_profile { section
format!("[{}]", aws_profile) .get("expiration")
} else { .and_then(|expiration| DateTime::parse_from_rfc3339(expiration).ok())
"[default]".to_string()
};
let expiration_date_line = contents
.lines()
.skip_while(|line| line != &profile_line)
.skip(1)
.take_while(|line| !line.starts_with('['))
.find(|line| line.starts_with("expiration"))?;
let expiration_date = expiration_date_line.split('=').nth(1)?.trim();
DateTime::parse_from_rfc3339(expiration_date).ok()
}?; }?;
Some(expiration_date.timestamp() - chrono::Local::now().timestamp()) Some(expiration_date.timestamp() - chrono::Local::now().timestamp())
@ -119,82 +145,63 @@ fn alias_name(name: Option<String>, aliases: &HashMap<String, &str>) -> Option<S
.or(name) .or(name)
} }
fn has_credential_process_or_sso(context: &Context, aws_profile: Option<&Profile>) -> bool { fn has_credential_process_or_sso(
let fp = match get_config_file_path(context) { context: &Context,
Some(fp) => fp, aws_profile: &Option<Profile>,
None => return false, aws_config: &AwsConfigFile,
}; ) -> Option<bool> {
let contents = match read_file(fp) { let config = get_config(context, aws_config)?;
Ok(contents) => contents,
Err(_) => return false,
};
let profile_line = if let Some(aws_profile) = aws_profile { let section = get_profile_config(config, aws_profile)?;
format!("[profile {}]", aws_profile) Some(section.contains_key("credential_process") || section.contains_key("sso_start_url"))
} else {
"[default]".to_string()
};
contents
.lines()
.skip_while(|line| line != &profile_line)
.skip(1)
.take_while(|line| !line.starts_with('['))
.any(|line| line.starts_with("credential_process") || line.starts_with("sso_start_url"))
} }
fn get_defined_credentials(context: &Context, aws_profile: Option<&Profile>) -> Option<String> { fn has_defined_credentials(
let valid_env_vars = vec![ context: &Context,
aws_profile: &Option<Profile>,
aws_creds: &AwsCredsFile,
) -> Option<bool> {
let valid_env_vars = [
"AWS_ACCESS_KEY_ID", "AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN", "AWS_SESSION_TOKEN",
]; ];
// accept if set through environment variable // accept if set through environment variable
if let Some(aws_identity_cred) = valid_env_vars if valid_env_vars
.iter() .iter()
.find_map(|env_var| context.get_env(env_var)) .any(|env_var| context.get_env(env_var).is_some())
{ {
return Some(aws_identity_cred); return Some(true);
} }
let contents = read_file(get_credentials_file_path(context)?).ok()?; let creds = get_creds(context, aws_creds)?;
let section = get_profile_creds(creds, aws_profile)?;
let profile_line = if let Some(aws_profile) = aws_profile { Some(section.contains_key("aws_access_key_id"))
format!("[{}]", aws_profile)
} else {
"[default]".to_string()
};
let aws_key_id_line = contents
.lines()
.skip_while(|line| line != &profile_line)
.skip(1)
.take_while(|line| !line.starts_with('['))
.find(|line| line.starts_with("aws_access_key_id"))?;
let aws_key_id = aws_key_id_line.split('=').nth(1)?.trim();
Some(aws_key_id.to_string())
} }
pub fn module<'a>(context: &'a Context) -> Option<Module<'a>> { pub fn module<'a>(context: &'a Context) -> Option<Module<'a>> {
let mut module = context.new_module("aws"); let mut module = context.new_module("aws");
let config: AwsConfig = AwsConfig::try_load(module.config); let config: AwsConfig = AwsConfig::try_load(module.config);
let (aws_profile, aws_region) = get_aws_profile_and_region(context); let aws_config = OnceCell::new();
let aws_creds = OnceCell::new();
let (aws_profile, aws_region) = get_aws_profile_and_region(context, &aws_config);
if aws_profile.is_none() && aws_region.is_none() { if aws_profile.is_none() && aws_region.is_none() {
return None; return None;
} }
// only display if credential_process is defined or has valid credentials // only display if credential_process is defined or has valid credentials
if !config.force_display if !config.force_display
&& !has_credential_process_or_sso(context, aws_profile.as_ref()) && !has_credential_process_or_sso(context, &aws_profile, &aws_config).unwrap_or(false)
&& get_defined_credentials(context, aws_profile.as_ref()).is_none() && !has_defined_credentials(context, &aws_profile, &aws_creds).unwrap_or(false)
{ {
return None; return None;
} }
let duration = { let duration = {
get_credentials_duration(context, aws_profile.as_ref()).map(|duration| { get_credentials_duration(context, &aws_profile, &aws_creds).map(|duration| {
if duration > 0 { if duration > 0 {
render_time((duration * 1000) as u128, false) render_time((duration * 1000) as u128, false)
} else { } else {