mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2024-12-23 11:29:04 +00:00
Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking
This commit is contained in:
parent
89fe05b6cc
commit
0b7d6bf6df
874
Cargo.lock
generated
874
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
27
Cargo.toml
27
Cargo.toml
@ -3,7 +3,7 @@ name = "vaultwarden"
|
|||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
authors = ["Daniel García <dani-garcia@users.noreply.github.com>"]
|
authors = ["Daniel García <dani-garcia@users.noreply.github.com>"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
rust-version = "1.60"
|
rust-version = "1.56"
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
repository = "https://github.com/dani-garcia/vaultwarden"
|
repository = "https://github.com/dani-garcia/vaultwarden"
|
||||||
@ -13,6 +13,7 @@ publish = false
|
|||||||
build = "build.rs"
|
build = "build.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
# default = ["sqlite"]
|
||||||
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
|
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
|
||||||
enable_syslog = []
|
enable_syslog = []
|
||||||
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
|
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
|
||||||
@ -29,22 +30,22 @@ unstable = []
|
|||||||
syslog = "4.0.1"
|
syslog = "4.0.1"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Web framework for nightly with a focus on ease-of-use, expressibility, and speed.
|
# Web framework
|
||||||
rocket = { version = "=0.5.0-dev", features = ["tls"], default-features = false }
|
rocket = { version = "0.5.0-rc.1", features = ["tls", "json"], default-features = false }
|
||||||
rocket_contrib = "=0.5.0-dev"
|
|
||||||
|
|
||||||
# HTTP client
|
# Async futures
|
||||||
reqwest = { version = "0.11.9", features = ["blocking", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
futures = "0.3.19"
|
||||||
|
tokio = { version = "1.16.1", features = ["rt-multi-thread", "fs", "io-util", "parking_lot"] }
|
||||||
|
|
||||||
|
# HTTP client
|
||||||
|
reqwest = { version = "0.11.9", features = ["stream", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
||||||
|
bytes = "1.1.0"
|
||||||
|
|
||||||
# Used for custom short lived cookie jar
|
# Used for custom short lived cookie jar
|
||||||
cookie = "0.15.1"
|
cookie = "0.15.1"
|
||||||
cookie_store = "0.15.1"
|
cookie_store = "0.15.1"
|
||||||
bytes = "1.1.0"
|
|
||||||
url = "2.2.2"
|
url = "2.2.2"
|
||||||
|
|
||||||
# multipart/form-data support
|
|
||||||
multipart = { version = "0.18.0", features = ["server"], default-features = false }
|
|
||||||
|
|
||||||
# WebSockets library
|
# WebSockets library
|
||||||
ws = { version = "0.11.1", package = "parity-ws" }
|
ws = { version = "0.11.1", package = "parity-ws" }
|
||||||
|
|
||||||
@ -141,10 +142,10 @@ backtrace = "0.3.64"
|
|||||||
paste = "1.0.6"
|
paste = "1.0.6"
|
||||||
governor = "0.4.1"
|
governor = "0.4.1"
|
||||||
|
|
||||||
|
ctrlc = { version = "3.2.1", features = ["termination"] }
|
||||||
|
|
||||||
[patch.crates-io]
|
[patch.crates-io]
|
||||||
# Use newest ring
|
rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '8cae077ba1d54b92cdef3e171a730b819d5eeb8e' }
|
||||||
rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
|
|
||||||
rocket_contrib = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
|
|
||||||
|
|
||||||
# The maintainer of the `job_scheduler` crate doesn't seem to have responded
|
# The maintainer of the `job_scheduler` crate doesn't seem to have responded
|
||||||
# to any issues or PRs for almost a year (as of April 2021). This hopefully
|
# to any issues or PRs for almost a year (as of April 2021). This hopefully
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
[global.limits]
|
|
||||||
json = 10485760 # 10 MiB
|
|
@ -1 +1 @@
|
|||||||
nightly-2022-01-23
|
stable
|
||||||
|
@ -3,13 +3,14 @@ use serde::de::DeserializeOwned;
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::{
|
use rocket::{
|
||||||
http::{Cookie, Cookies, SameSite, Status},
|
form::Form,
|
||||||
request::{self, FlashMessage, Form, FromRequest, Outcome, Request},
|
http::{Cookie, CookieJar, SameSite, Status},
|
||||||
response::{content::Html, Flash, Redirect},
|
request::{self, FlashMessage, FromRequest, Outcome, Request},
|
||||||
|
response::{content::RawHtml as Html, Flash, Redirect},
|
||||||
Route,
|
Route,
|
||||||
};
|
};
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
|
api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
|
||||||
@ -85,10 +86,11 @@ fn admin_path() -> String {
|
|||||||
|
|
||||||
struct Referer(Option<String>);
|
struct Referer(Option<String>);
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Referer {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for Referer {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||||
Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
|
Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -96,10 +98,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for Referer {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct IpHeader(Option<String>);
|
struct IpHeader(Option<String>);
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for IpHeader {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for IpHeader {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
if req.headers().get_one(&CONFIG.ip_header()).is_some() {
|
if req.headers().get_one(&CONFIG.ip_header()).is_some() {
|
||||||
Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
|
Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
|
||||||
} else if req.headers().get_one("X-Client-IP").is_some() {
|
} else if req.headers().get_one("X-Client-IP").is_some() {
|
||||||
@ -138,7 +141,7 @@ fn admin_url(referer: Referer) -> String {
|
|||||||
#[get("/", rank = 2)]
|
#[get("/", rank = 2)]
|
||||||
fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> {
|
fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> {
|
||||||
// If there is an error, show it
|
// If there is an error, show it
|
||||||
let msg = flash.map(|msg| format!("{}: {}", msg.name(), msg.msg()));
|
let msg = flash.map(|msg| format!("{}: {}", msg.kind(), msg.message()));
|
||||||
let json = json!({
|
let json = json!({
|
||||||
"page_content": "admin/login",
|
"page_content": "admin/login",
|
||||||
"version": VERSION,
|
"version": VERSION,
|
||||||
@ -159,7 +162,7 @@ struct LoginForm {
|
|||||||
#[post("/", data = "<data>")]
|
#[post("/", data = "<data>")]
|
||||||
fn post_admin_login(
|
fn post_admin_login(
|
||||||
data: Form<LoginForm>,
|
data: Form<LoginForm>,
|
||||||
mut cookies: Cookies,
|
cookies: &CookieJar,
|
||||||
ip: ClientIp,
|
ip: ClientIp,
|
||||||
referer: Referer,
|
referer: Referer,
|
||||||
) -> Result<Redirect, Flash<Redirect>> {
|
) -> Result<Redirect, Flash<Redirect>> {
|
||||||
@ -180,7 +183,7 @@ fn post_admin_login(
|
|||||||
|
|
||||||
let cookie = Cookie::build(COOKIE_NAME, jwt)
|
let cookie = Cookie::build(COOKIE_NAME, jwt)
|
||||||
.path(admin_path())
|
.path(admin_path())
|
||||||
.max_age(time::Duration::minutes(20))
|
.max_age(rocket::time::Duration::minutes(20))
|
||||||
.same_site(SameSite::Strict)
|
.same_site(SameSite::Strict)
|
||||||
.http_only(true)
|
.http_only(true)
|
||||||
.finish();
|
.finish();
|
||||||
@ -297,7 +300,7 @@ fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/logout")]
|
#[get("/logout")]
|
||||||
fn logout(mut cookies: Cookies, referer: Referer) -> Redirect {
|
fn logout(cookies: &CookieJar, referer: Referer) -> Redirect {
|
||||||
cookies.remove(Cookie::named(COOKIE_NAME));
|
cookies.remove(Cookie::named(COOKIE_NAME));
|
||||||
Redirect::to(admin_url(referer))
|
Redirect::to(admin_url(referer))
|
||||||
}
|
}
|
||||||
@ -462,23 +465,23 @@ struct GitCommit {
|
|||||||
sha: String,
|
sha: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
|
async fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
|
||||||
let github_api = get_reqwest_client();
|
let github_api = get_reqwest_client();
|
||||||
|
|
||||||
Ok(github_api.get(url).send()?.error_for_status()?.json::<T>()?)
|
Ok(github_api.get(url).send().await?.error_for_status()?.json::<T>().await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn has_http_access() -> bool {
|
async fn has_http_access() -> bool {
|
||||||
let http_access = get_reqwest_client();
|
let http_access = get_reqwest_client();
|
||||||
|
|
||||||
match http_access.head("https://github.com/dani-garcia/vaultwarden").send() {
|
match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await {
|
||||||
Ok(r) => r.status().is_success(),
|
Ok(r) => r.status().is_success(),
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/diagnostics")]
|
#[get("/diagnostics")]
|
||||||
fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
|
async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
|
||||||
use crate::util::read_file_string;
|
use crate::util::read_file_string;
|
||||||
use chrono::prelude::*;
|
use chrono::prelude::*;
|
||||||
use std::net::ToSocketAddrs;
|
use std::net::ToSocketAddrs;
|
||||||
@ -497,7 +500,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
|||||||
|
|
||||||
// Execute some environment checks
|
// Execute some environment checks
|
||||||
let running_within_docker = is_running_in_docker();
|
let running_within_docker = is_running_in_docker();
|
||||||
let has_http_access = has_http_access();
|
let has_http_access = has_http_access().await;
|
||||||
let uses_proxy = env::var_os("HTTP_PROXY").is_some()
|
let uses_proxy = env::var_os("HTTP_PROXY").is_some()
|
||||||
|| env::var_os("http_proxy").is_some()
|
|| env::var_os("http_proxy").is_some()
|
||||||
|| env::var_os("HTTPS_PROXY").is_some()
|
|| env::var_os("HTTPS_PROXY").is_some()
|
||||||
@ -513,11 +516,14 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
|||||||
// TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already.
|
// TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already.
|
||||||
let (latest_release, latest_commit, latest_web_build) = if has_http_access {
|
let (latest_release, latest_commit, latest_web_build) = if has_http_access {
|
||||||
(
|
(
|
||||||
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest") {
|
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest")
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(r) => r.tag_name,
|
Ok(r) => r.tag_name,
|
||||||
_ => "-".to_string(),
|
_ => "-".to_string(),
|
||||||
},
|
},
|
||||||
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main") {
|
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await
|
||||||
|
{
|
||||||
Ok(mut c) => {
|
Ok(mut c) => {
|
||||||
c.sha.truncate(8);
|
c.sha.truncate(8);
|
||||||
c.sha
|
c.sha
|
||||||
@ -531,7 +537,9 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
|||||||
} else {
|
} else {
|
||||||
match get_github_api::<GitRelease>(
|
match get_github_api::<GitRelease>(
|
||||||
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
|
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
|
||||||
) {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
|
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
|
||||||
_ => "-".to_string(),
|
_ => "-".to_string(),
|
||||||
}
|
}
|
||||||
@ -562,7 +570,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
|||||||
"ip_header_config": &CONFIG.ip_header(),
|
"ip_header_config": &CONFIG.ip_header(),
|
||||||
"uses_proxy": uses_proxy,
|
"uses_proxy": uses_proxy,
|
||||||
"db_type": *DB_TYPE,
|
"db_type": *DB_TYPE,
|
||||||
"db_version": get_sql_server_version(&conn),
|
"db_version": get_sql_server_version(&conn).await,
|
||||||
"admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
|
"admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
|
||||||
"overrides": &CONFIG.get_overrides().join(", "),
|
"overrides": &CONFIG.get_overrides().join(", "),
|
||||||
"server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
|
"server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
|
||||||
@ -591,9 +599,9 @@ fn delete_config(_token: AdminToken) -> EmptyResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[post("/config/backup_db")]
|
#[post("/config/backup_db")]
|
||||||
fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
async fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
||||||
if *CAN_BACKUP {
|
if *CAN_BACKUP {
|
||||||
backup_database(&conn)
|
backup_database(&conn).await
|
||||||
} else {
|
} else {
|
||||||
err!("Can't back up current DB (Only SQLite supports this feature)");
|
err!("Can't back up current DB (Only SQLite supports this feature)");
|
||||||
}
|
}
|
||||||
@ -601,21 +609,22 @@ fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
|||||||
|
|
||||||
pub struct AdminToken {}
|
pub struct AdminToken {}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AdminToken {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for AdminToken {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||||
if CONFIG.disable_admin_token() {
|
if CONFIG.disable_admin_token() {
|
||||||
Outcome::Success(AdminToken {})
|
Outcome::Success(AdminToken {})
|
||||||
} else {
|
} else {
|
||||||
let mut cookies = request.cookies();
|
let cookies = request.cookies();
|
||||||
|
|
||||||
let access_token = match cookies.get(COOKIE_NAME) {
|
let access_token = match cookies.get(COOKIE_NAME) {
|
||||||
Some(cookie) => cookie.value(),
|
Some(cookie) => cookie.value(),
|
||||||
None => return Outcome::Forward(()), // If there is no cookie, redirect to login
|
None => return Outcome::Forward(()), // If there is no cookie, redirect to login
|
||||||
};
|
};
|
||||||
|
|
||||||
let ip = match request.guard::<ClientIp>() {
|
let ip = match ClientIp::from_request(request).await {
|
||||||
Outcome::Success(ip) => ip.ip,
|
Outcome::Success(ip) => ip.ip,
|
||||||
_ => err_handler!("Error getting Client IP"),
|
_ => err_handler!("Error getting Client IP"),
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
|
|
||||||
use chrono::{NaiveDateTime, Utc};
|
use chrono::{NaiveDateTime, Utc};
|
||||||
use rocket::{http::ContentType, request::Form, Data, Route};
|
use rocket::fs::TempFile;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
|
use rocket::{
|
||||||
|
form::{Form, FromForm},
|
||||||
|
Route,
|
||||||
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use multipart::server::{save::SavedData, Multipart, SaveResult};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{self, EmptyResult, JsonResult, JsonUpcase, Notify, PasswordData, UpdateType},
|
api::{self, EmptyResult, JsonResult, JsonUpcase, Notify, PasswordData, UpdateType},
|
||||||
auth::Headers,
|
auth::Headers,
|
||||||
@ -79,9 +80,9 @@ pub fn routes() -> Vec<Route> {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn purge_trashed_ciphers(pool: DbPool) {
|
pub async fn purge_trashed_ciphers(pool: DbPool) {
|
||||||
debug!("Purging trashed ciphers");
|
debug!("Purging trashed ciphers");
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
Cipher::purge_trash(&conn);
|
Cipher::purge_trash(&conn);
|
||||||
} else {
|
} else {
|
||||||
error!("Failed to get DB connection while purging trashed ciphers")
|
error!("Failed to get DB connection while purging trashed ciphers")
|
||||||
@ -90,12 +91,12 @@ pub fn purge_trashed_ciphers(pool: DbPool) {
|
|||||||
|
|
||||||
#[derive(FromForm, Default)]
|
#[derive(FromForm, Default)]
|
||||||
struct SyncData {
|
struct SyncData {
|
||||||
#[form(field = "excludeDomains")]
|
#[field(name = "excludeDomains")]
|
||||||
exclude_domains: bool, // Default: 'false'
|
exclude_domains: bool, // Default: 'false'
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/sync?<data..>")]
|
#[get("/sync?<data..>")]
|
||||||
fn sync(data: Form<SyncData>, headers: Headers, conn: DbConn) -> Json<Value> {
|
fn sync(data: SyncData, headers: Headers, conn: DbConn) -> Json<Value> {
|
||||||
let user_json = headers.user.to_json(&conn);
|
let user_json = headers.user.to_json(&conn);
|
||||||
|
|
||||||
let folders = Folder::find_by_user(&headers.user.uuid, &conn);
|
let folders = Folder::find_by_user(&headers.user.uuid, &conn);
|
||||||
@ -828,6 +829,12 @@ fn post_attachment_v2(
|
|||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromForm)]
|
||||||
|
struct UploadData<'f> {
|
||||||
|
key: Option<String>,
|
||||||
|
data: TempFile<'f>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Saves the data content of an attachment to a file. This is common code
|
/// Saves the data content of an attachment to a file. This is common code
|
||||||
/// shared between the v2 and legacy attachment APIs.
|
/// shared between the v2 and legacy attachment APIs.
|
||||||
///
|
///
|
||||||
@ -836,22 +843,21 @@ fn post_attachment_v2(
|
|||||||
///
|
///
|
||||||
/// When used with the v2 API, post_attachment_v2() has already created the
|
/// When used with the v2 API, post_attachment_v2() has already created the
|
||||||
/// database record, which is passed in as `attachment`.
|
/// database record, which is passed in as `attachment`.
|
||||||
fn save_attachment(
|
async fn save_attachment(
|
||||||
mut attachment: Option<Attachment>,
|
mut attachment: Option<Attachment>,
|
||||||
cipher_uuid: String,
|
cipher_uuid: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: &Headers,
|
headers: &Headers,
|
||||||
conn: &DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> Result<Cipher, crate::error::Error> {
|
) -> Result<(Cipher, DbConn), crate::error::Error> {
|
||||||
let cipher = match Cipher::find_by_uuid(&cipher_uuid, conn) {
|
let cipher = match Cipher::find_by_uuid(&cipher_uuid, &conn) {
|
||||||
Some(cipher) => cipher,
|
Some(cipher) => cipher,
|
||||||
None => err_discard!("Cipher doesn't exist", data),
|
None => err!("Cipher doesn't exist"),
|
||||||
};
|
};
|
||||||
|
|
||||||
if !cipher.is_write_accessible_to_user(&headers.user.uuid, conn) {
|
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn) {
|
||||||
err_discard!("Cipher is not write accessible", data)
|
err!("Cipher is not write accessible")
|
||||||
}
|
}
|
||||||
|
|
||||||
// In the v2 API, the attachment record has already been created,
|
// In the v2 API, the attachment record has already been created,
|
||||||
@ -863,11 +869,11 @@ fn save_attachment(
|
|||||||
|
|
||||||
let size_limit = if let Some(ref user_uuid) = cipher.user_uuid {
|
let size_limit = if let Some(ref user_uuid) = cipher.user_uuid {
|
||||||
match CONFIG.user_attachment_limit() {
|
match CONFIG.user_attachment_limit() {
|
||||||
Some(0) => err_discard!("Attachments are disabled", data),
|
Some(0) => err!("Attachments are disabled"),
|
||||||
Some(limit_kb) => {
|
Some(limit_kb) => {
|
||||||
let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, conn) + size_adjust;
|
let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, &conn) + size_adjust;
|
||||||
if left <= 0 {
|
if left <= 0 {
|
||||||
err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
|
err!("Attachment storage limit reached! Delete some attachments to free up space")
|
||||||
}
|
}
|
||||||
Some(left as u64)
|
Some(left as u64)
|
||||||
}
|
}
|
||||||
@ -875,130 +881,78 @@ fn save_attachment(
|
|||||||
}
|
}
|
||||||
} else if let Some(ref org_uuid) = cipher.organization_uuid {
|
} else if let Some(ref org_uuid) = cipher.organization_uuid {
|
||||||
match CONFIG.org_attachment_limit() {
|
match CONFIG.org_attachment_limit() {
|
||||||
Some(0) => err_discard!("Attachments are disabled", data),
|
Some(0) => err!("Attachments are disabled"),
|
||||||
Some(limit_kb) => {
|
Some(limit_kb) => {
|
||||||
let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, conn) + size_adjust;
|
let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, &conn) + size_adjust;
|
||||||
if left <= 0 {
|
if left <= 0 {
|
||||||
err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
|
err!("Attachment storage limit reached! Delete some attachments to free up space")
|
||||||
}
|
}
|
||||||
Some(left as u64)
|
Some(left as u64)
|
||||||
}
|
}
|
||||||
None => None,
|
None => None,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err_discard!("Cipher is neither owned by a user nor an organization", data);
|
err!("Cipher is neither owned by a user nor an organization");
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut params = content_type.params();
|
let mut data = data.into_inner();
|
||||||
let boundary_pair = params.next().expect("No boundary provided");
|
|
||||||
let boundary = boundary_pair.1;
|
|
||||||
|
|
||||||
let base_path = Path::new(&CONFIG.attachments_folder()).join(&cipher_uuid);
|
if let Some(size_limit) = size_limit {
|
||||||
let mut path = PathBuf::new();
|
if data.data.len() > size_limit {
|
||||||
|
err!("Attachment storage limit exceeded with this file");
|
||||||
let mut attachment_key = None;
|
}
|
||||||
let mut error = None;
|
|
||||||
|
|
||||||
Multipart::with_body(data.open(), boundary)
|
|
||||||
.foreach_entry(|mut field| {
|
|
||||||
match &*field.headers.name {
|
|
||||||
"key" => {
|
|
||||||
use std::io::Read;
|
|
||||||
let mut key_buffer = String::new();
|
|
||||||
if field.data.read_to_string(&mut key_buffer).is_ok() {
|
|
||||||
attachment_key = Some(key_buffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"data" => {
|
|
||||||
// In the legacy API, this is the encrypted filename
|
|
||||||
// provided by the client, stored to the database as-is.
|
|
||||||
// In the v2 API, this value doesn't matter, as it was
|
|
||||||
// already provided and stored via an earlier API call.
|
|
||||||
let encrypted_filename = field.headers.filename;
|
|
||||||
|
|
||||||
// This random ID is used as the name of the file on disk.
|
|
||||||
// In the legacy API, we need to generate this value here.
|
|
||||||
// In the v2 API, we use the value from post_attachment_v2().
|
|
||||||
let file_id = match &attachment {
|
|
||||||
Some(attachment) => attachment.id.clone(), // v2 API
|
|
||||||
None => crypto::generate_attachment_id(), // Legacy API
|
|
||||||
};
|
|
||||||
path = base_path.join(&file_id);
|
|
||||||
|
|
||||||
let size =
|
|
||||||
match field.data.save().memory_threshold(0).size_limit(size_limit).with_path(path.clone()) {
|
|
||||||
SaveResult::Full(SavedData::File(_, size)) => size as i32,
|
|
||||||
SaveResult::Full(other) => {
|
|
||||||
error = Some(format!("Attachment is not a file: {:?}", other));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
SaveResult::Partial(_, reason) => {
|
|
||||||
error = Some(format!("Attachment storage limit exceeded with this file: {:?}", reason));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
SaveResult::Error(e) => {
|
|
||||||
error = Some(format!("Error: {:?}", e));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(attachment) = &mut attachment {
|
|
||||||
// v2 API
|
|
||||||
|
|
||||||
// Check the actual size against the size initially provided by
|
|
||||||
// the client. Upstream allows +/- 1 MiB deviation from this
|
|
||||||
// size, but it's not clear when or why this is needed.
|
|
||||||
const LEEWAY: i32 = 1024 * 1024; // 1 MiB
|
|
||||||
let min_size = attachment.file_size - LEEWAY;
|
|
||||||
let max_size = attachment.file_size + LEEWAY;
|
|
||||||
|
|
||||||
if min_size <= size && size <= max_size {
|
|
||||||
if size != attachment.file_size {
|
|
||||||
// Update the attachment with the actual file size.
|
|
||||||
attachment.file_size = size;
|
|
||||||
attachment.save(conn).expect("Error updating attachment");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
attachment.delete(conn).ok();
|
|
||||||
|
|
||||||
let err_msg = "Attachment size mismatch".to_string();
|
|
||||||
error!("{} (expected within [{}, {}], got {})", err_msg, min_size, max_size, size);
|
|
||||||
error = Some(err_msg);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Legacy API
|
|
||||||
|
|
||||||
if encrypted_filename.is_none() {
|
|
||||||
error = Some("No filename provided".to_string());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if attachment_key.is_none() {
|
|
||||||
error = Some("No attachment key provided".to_string());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
let attachment = Attachment::new(
|
|
||||||
file_id,
|
|
||||||
cipher_uuid.clone(),
|
|
||||||
encrypted_filename.unwrap(),
|
|
||||||
size,
|
|
||||||
attachment_key.clone(),
|
|
||||||
);
|
|
||||||
attachment.save(conn).expect("Error saving attachment");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => error!("Invalid multipart name"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.expect("Error processing multipart data");
|
|
||||||
|
|
||||||
if let Some(ref e) = error {
|
|
||||||
std::fs::remove_file(path).ok();
|
|
||||||
err!(e);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn));
|
let file_id = match &attachment {
|
||||||
|
Some(attachment) => attachment.id.clone(), // v2 API
|
||||||
|
None => crypto::generate_attachment_id(), // Legacy API
|
||||||
|
};
|
||||||
|
|
||||||
Ok(cipher)
|
let folder_path = tokio::fs::canonicalize(&CONFIG.attachments_folder()).await?.join(&cipher_uuid);
|
||||||
|
let file_path = folder_path.join(&file_id);
|
||||||
|
tokio::fs::create_dir_all(&folder_path).await?;
|
||||||
|
|
||||||
|
let size = data.data.len() as i32;
|
||||||
|
if let Some(attachment) = &mut attachment {
|
||||||
|
// v2 API
|
||||||
|
|
||||||
|
// Check the actual size against the size initially provided by
|
||||||
|
// the client. Upstream allows +/- 1 MiB deviation from this
|
||||||
|
// size, but it's not clear when or why this is needed.
|
||||||
|
const LEEWAY: i32 = 1024 * 1024; // 1 MiB
|
||||||
|
let min_size = attachment.file_size - LEEWAY;
|
||||||
|
let max_size = attachment.file_size + LEEWAY;
|
||||||
|
|
||||||
|
if min_size <= size && size <= max_size {
|
||||||
|
if size != attachment.file_size {
|
||||||
|
// Update the attachment with the actual file size.
|
||||||
|
attachment.file_size = size;
|
||||||
|
attachment.save(&conn).expect("Error updating attachment");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
attachment.delete(&conn).ok();
|
||||||
|
|
||||||
|
err!(format!("Attachment size mismatch (expected within [{}, {}], got {})", min_size, max_size, size));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Legacy API
|
||||||
|
let encrypted_filename = data.data.raw_name().map(|s| s.dangerous_unsafe_unsanitized_raw().to_string());
|
||||||
|
|
||||||
|
if encrypted_filename.is_none() {
|
||||||
|
err!("No filename provided")
|
||||||
|
}
|
||||||
|
if data.key.is_none() {
|
||||||
|
err!("No attachment key provided")
|
||||||
|
}
|
||||||
|
let attachment = Attachment::new(file_id, cipher_uuid.clone(), encrypted_filename.unwrap(), size, data.key);
|
||||||
|
attachment.save(&conn).expect("Error saving attachment");
|
||||||
|
}
|
||||||
|
|
||||||
|
data.data.persist_to(file_path).await?;
|
||||||
|
|
||||||
|
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(&conn));
|
||||||
|
|
||||||
|
Ok((cipher, conn))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// v2 API for uploading the actual data content of an attachment.
|
/// v2 API for uploading the actual data content of an attachment.
|
||||||
@ -1006,14 +960,13 @@ fn save_attachment(
|
|||||||
/// /ciphers/<uuid>/attachment/v2 route, which would otherwise conflict
|
/// /ciphers/<uuid>/attachment/v2 route, which would otherwise conflict
|
||||||
/// with this one.
|
/// with this one.
|
||||||
#[post("/ciphers/<uuid>/attachment/<attachment_id>", format = "multipart/form-data", data = "<data>", rank = 1)]
|
#[post("/ciphers/<uuid>/attachment/<attachment_id>", format = "multipart/form-data", data = "<data>", rank = 1)]
|
||||||
fn post_attachment_v2_data(
|
async fn post_attachment_v2_data(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
attachment_id: String,
|
attachment_id: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> EmptyResult {
|
) -> EmptyResult {
|
||||||
let attachment = match Attachment::find_by_id(&attachment_id, &conn) {
|
let attachment = match Attachment::find_by_id(&attachment_id, &conn) {
|
||||||
Some(attachment) if uuid == attachment.cipher_uuid => Some(attachment),
|
Some(attachment) if uuid == attachment.cipher_uuid => Some(attachment),
|
||||||
@ -1021,54 +974,51 @@ fn post_attachment_v2_data(
|
|||||||
None => err!("Attachment doesn't exist"),
|
None => err!("Attachment doesn't exist"),
|
||||||
};
|
};
|
||||||
|
|
||||||
save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
|
save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Legacy API for creating an attachment associated with a cipher.
|
/// Legacy API for creating an attachment associated with a cipher.
|
||||||
#[post("/ciphers/<uuid>/attachment", format = "multipart/form-data", data = "<data>")]
|
#[post("/ciphers/<uuid>/attachment", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_attachment(
|
async fn post_attachment(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> JsonResult {
|
) -> JsonResult {
|
||||||
// Setting this as None signifies to save_attachment() that it should create
|
// Setting this as None signifies to save_attachment() that it should create
|
||||||
// the attachment database record as well as saving the data to disk.
|
// the attachment database record as well as saving the data to disk.
|
||||||
let attachment = None;
|
let attachment = None;
|
||||||
|
|
||||||
let cipher = save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
|
let (cipher, conn) = save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
|
||||||
|
|
||||||
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, &conn)))
|
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, &conn)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/<uuid>/attachment-admin", format = "multipart/form-data", data = "<data>")]
|
#[post("/ciphers/<uuid>/attachment-admin", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_attachment_admin(
|
async fn post_attachment_admin(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> JsonResult {
|
) -> JsonResult {
|
||||||
post_attachment(uuid, data, content_type, headers, conn, nt)
|
post_attachment(uuid, data, headers, conn, nt).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/<uuid>/attachment/<attachment_id>/share", format = "multipart/form-data", data = "<data>")]
|
#[post("/ciphers/<uuid>/attachment/<attachment_id>/share", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_attachment_share(
|
async fn post_attachment_share(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
attachment_id: String,
|
attachment_id: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> JsonResult {
|
) -> JsonResult {
|
||||||
_delete_cipher_attachment_by_id(&uuid, &attachment_id, &headers, &conn, &nt)?;
|
_delete_cipher_attachment_by_id(&uuid, &attachment_id, &headers, &conn, &nt)?;
|
||||||
post_attachment(uuid, data, content_type, headers, conn, nt)
|
post_attachment(uuid, data, headers, conn, nt).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/<uuid>/attachment/<attachment_id>/delete-admin")]
|
#[post("/ciphers/<uuid>/attachment/<attachment_id>/delete-admin")]
|
||||||
@ -1248,13 +1198,13 @@ fn move_cipher_selected_put(
|
|||||||
|
|
||||||
#[derive(FromForm)]
|
#[derive(FromForm)]
|
||||||
struct OrganizationId {
|
struct OrganizationId {
|
||||||
#[form(field = "organizationId")]
|
#[field(name = "organizationId")]
|
||||||
org_id: String,
|
org_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/purge?<organization..>", data = "<data>")]
|
#[post("/ciphers/purge?<organization..>", data = "<data>")]
|
||||||
fn delete_all(
|
fn delete_all(
|
||||||
organization: Option<Form<OrganizationId>>,
|
organization: Option<OrganizationId>,
|
||||||
data: JsonUpcase<PasswordData>,
|
data: JsonUpcase<PasswordData>,
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
|
|
||||||
@ -709,13 +709,13 @@ fn check_emergency_access_allowed() -> EmptyResult {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn emergency_request_timeout_job(pool: DbPool) {
|
pub async fn emergency_request_timeout_job(pool: DbPool) {
|
||||||
debug!("Start emergency_request_timeout_job");
|
debug!("Start emergency_request_timeout_job");
|
||||||
if !CONFIG.emergency_access_allowed() {
|
if !CONFIG.emergency_access_allowed() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
||||||
|
|
||||||
if emergency_access_list.is_empty() {
|
if emergency_access_list.is_empty() {
|
||||||
@ -756,13 +756,13 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn emergency_notification_reminder_job(pool: DbPool) {
|
pub async fn emergency_notification_reminder_job(pool: DbPool) {
|
||||||
debug!("Start emergency_notification_reminder_job");
|
debug!("Start emergency_notification_reminder_job");
|
||||||
if !CONFIG.emergency_access_allowed() {
|
if !CONFIG.emergency_access_allowed() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
||||||
|
|
||||||
if emergency_access_list.is_empty() {
|
if emergency_access_list.is_empty() {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -31,8 +31,8 @@ pub fn routes() -> Vec<Route> {
|
|||||||
//
|
//
|
||||||
// Move this somewhere else
|
// Move this somewhere else
|
||||||
//
|
//
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -144,7 +144,7 @@ fn put_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbC
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/hibp/breach?<username>")]
|
#[get("/hibp/breach?<username>")]
|
||||||
fn hibp_breach(username: String) -> JsonResult {
|
async fn hibp_breach(username: String) -> JsonResult {
|
||||||
let url = format!(
|
let url = format!(
|
||||||
"https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
|
"https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
|
||||||
username
|
username
|
||||||
@ -153,14 +153,14 @@ fn hibp_breach(username: String) -> JsonResult {
|
|||||||
if let Some(api_key) = crate::CONFIG.hibp_api_key() {
|
if let Some(api_key) = crate::CONFIG.hibp_api_key() {
|
||||||
let hibp_client = get_reqwest_client();
|
let hibp_client = get_reqwest_client();
|
||||||
|
|
||||||
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send()?;
|
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?;
|
||||||
|
|
||||||
// If we get a 404, return a 404, it means no breached accounts
|
// If we get a 404, return a 404, it means no breached accounts
|
||||||
if res.status() == 404 {
|
if res.status() == 404 {
|
||||||
return Err(Error::empty().with_code(404));
|
return Err(Error::empty().with_code(404));
|
||||||
}
|
}
|
||||||
|
|
||||||
let value: Value = res.error_for_status()?.json()?;
|
let value: Value = res.error_for_status()?.json().await?;
|
||||||
Ok(Json(value))
|
Ok(Json(value))
|
||||||
} else {
|
} else {
|
||||||
Ok(Json(json!([{
|
Ok(Json(json!([{
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use num_traits::FromPrimitive;
|
use num_traits::FromPrimitive;
|
||||||
use rocket::{request::Form, Route};
|
use rocket::serde::json::Json;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::Route;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -469,12 +469,12 @@ fn put_collection_users(
|
|||||||
|
|
||||||
#[derive(FromForm)]
|
#[derive(FromForm)]
|
||||||
struct OrgIdData {
|
struct OrgIdData {
|
||||||
#[form(field = "organizationId")]
|
#[field(name = "organizationId")]
|
||||||
organization_id: String,
|
organization_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/ciphers/organization-details?<data..>")]
|
#[get("/ciphers/organization-details?<data..>")]
|
||||||
fn get_org_details(data: Form<OrgIdData>, headers: Headers, conn: DbConn) -> Json<Value> {
|
fn get_org_details(data: OrgIdData, headers: Headers, conn: DbConn) -> Json<Value> {
|
||||||
let ciphers = Cipher::find_by_org(&data.organization_id, &conn);
|
let ciphers = Cipher::find_by_org(&data.organization_id, &conn);
|
||||||
let ciphers_json: Vec<Value> =
|
let ciphers_json: Vec<Value> =
|
||||||
ciphers.iter().map(|c| c.to_json(&headers.host, &headers.user.uuid, &conn)).collect();
|
ciphers.iter().map(|c| c.to_json(&headers.host, &headers.user.uuid, &conn)).collect();
|
||||||
@ -1097,14 +1097,14 @@ struct RelationsData {
|
|||||||
|
|
||||||
#[post("/ciphers/import-organization?<query..>", data = "<data>")]
|
#[post("/ciphers/import-organization?<query..>", data = "<data>")]
|
||||||
fn post_org_import(
|
fn post_org_import(
|
||||||
query: Form<OrgIdData>,
|
query: OrgIdData,
|
||||||
data: JsonUpcase<ImportData>,
|
data: JsonUpcase<ImportData>,
|
||||||
headers: AdminHeaders,
|
headers: AdminHeaders,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify,
|
||||||
) -> EmptyResult {
|
) -> EmptyResult {
|
||||||
let data: ImportData = data.into_inner().data;
|
let data: ImportData = data.into_inner().data;
|
||||||
let org_id = query.into_inner().organization_id;
|
let org_id = query.organization_id;
|
||||||
|
|
||||||
// Read and create the collections
|
// Read and create the collections
|
||||||
let collections: Vec<_> = data
|
let collections: Vec<_> = data
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
use std::{io::Read, path::Path};
|
use std::path::Path;
|
||||||
|
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use multipart::server::{save::SavedData, Multipart, SaveResult};
|
use rocket::form::Form;
|
||||||
use rocket::{http::ContentType, response::NamedFile, Data};
|
use rocket::fs::NamedFile;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::fs::TempFile;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -31,9 +32,9 @@ pub fn routes() -> Vec<rocket::Route> {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn purge_sends(pool: DbPool) {
|
pub async fn purge_sends(pool: DbPool) {
|
||||||
debug!("Purging sends");
|
debug!("Purging sends");
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
Send::purge(&conn);
|
Send::purge(&conn);
|
||||||
} else {
|
} else {
|
||||||
error!("Failed to get DB connection while purging sends")
|
error!("Failed to get DB connection while purging sends")
|
||||||
@ -177,25 +178,23 @@ fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Not
|
|||||||
Ok(Json(send.to_json()))
|
Ok(Json(send.to_json()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromForm)]
|
||||||
|
struct UploadData<'f> {
|
||||||
|
model: Json<crate::util::UpCase<SendData>>,
|
||||||
|
data: TempFile<'f>,
|
||||||
|
}
|
||||||
|
|
||||||
#[post("/sends/file", format = "multipart/form-data", data = "<data>")]
|
#[post("/sends/file", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult {
|
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
|
||||||
enforce_disable_send_policy(&headers, &conn)?;
|
enforce_disable_send_policy(&headers, &conn)?;
|
||||||
|
|
||||||
let boundary = content_type.params().next().expect("No boundary provided").1;
|
let UploadData {
|
||||||
|
model,
|
||||||
|
mut data,
|
||||||
|
} = data.into_inner();
|
||||||
|
let model = model.into_inner().data;
|
||||||
|
|
||||||
let mut mpart = Multipart::with_body(data.open(), boundary);
|
enforce_disable_hide_email_policy(&model, &headers, &conn)?;
|
||||||
|
|
||||||
// First entry is the SendData JSON
|
|
||||||
let mut model_entry = match mpart.read_entry()? {
|
|
||||||
Some(e) if &*e.headers.name == "model" => e,
|
|
||||||
Some(_) => err!("Invalid entry name"),
|
|
||||||
None => err!("No model entry present"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut buf = String::new();
|
|
||||||
model_entry.data.read_to_string(&mut buf)?;
|
|
||||||
let data = serde_json::from_str::<crate::util::UpCase<SendData>>(&buf)?;
|
|
||||||
enforce_disable_hide_email_policy(&data.data, &headers, &conn)?;
|
|
||||||
|
|
||||||
// Get the file length and add an extra 5% to avoid issues
|
// Get the file length and add an extra 5% to avoid issues
|
||||||
const SIZE_525_MB: u64 = 550_502_400;
|
const SIZE_525_MB: u64 = 550_502_400;
|
||||||
@ -212,45 +211,27 @@ fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn
|
|||||||
None => SIZE_525_MB,
|
None => SIZE_525_MB,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create the Send
|
let mut send = create_send(model, headers.user.uuid)?;
|
||||||
let mut send = create_send(data.data, headers.user.uuid)?;
|
|
||||||
let file_id = crate::crypto::generate_send_id();
|
|
||||||
|
|
||||||
if send.atype != SendType::File as i32 {
|
if send.atype != SendType::File as i32 {
|
||||||
err!("Send content is not a file");
|
err!("Send content is not a file");
|
||||||
}
|
}
|
||||||
|
|
||||||
let file_path = Path::new(&CONFIG.sends_folder()).join(&send.uuid).join(&file_id);
|
let size = data.len();
|
||||||
|
if size > size_limit {
|
||||||
|
err!("Attachment storage limit exceeded with this file");
|
||||||
|
}
|
||||||
|
|
||||||
// Read the data entry and save the file
|
let file_id = crate::crypto::generate_send_id();
|
||||||
let mut data_entry = match mpart.read_entry()? {
|
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
|
||||||
Some(e) if &*e.headers.name == "data" => e,
|
let file_path = folder_path.join(&file_id);
|
||||||
Some(_) => err!("Invalid entry name"),
|
tokio::fs::create_dir_all(&folder_path).await?;
|
||||||
None => err!("No model entry present"),
|
data.persist_to(&file_path).await?;
|
||||||
};
|
|
||||||
|
|
||||||
let size = match data_entry.data.save().memory_threshold(0).size_limit(size_limit).with_path(&file_path) {
|
|
||||||
SaveResult::Full(SavedData::File(_, size)) => size as i32,
|
|
||||||
SaveResult::Full(other) => {
|
|
||||||
std::fs::remove_file(&file_path).ok();
|
|
||||||
err!(format!("Attachment is not a file: {:?}", other));
|
|
||||||
}
|
|
||||||
SaveResult::Partial(_, reason) => {
|
|
||||||
std::fs::remove_file(&file_path).ok();
|
|
||||||
err!(format!("Attachment storage limit exceeded with this file: {:?}", reason));
|
|
||||||
}
|
|
||||||
SaveResult::Error(e) => {
|
|
||||||
std::fs::remove_file(&file_path).ok();
|
|
||||||
err!(format!("Error: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set ID and sizes
|
|
||||||
let mut data_value: Value = serde_json::from_str(&send.data)?;
|
let mut data_value: Value = serde_json::from_str(&send.data)?;
|
||||||
if let Some(o) = data_value.as_object_mut() {
|
if let Some(o) = data_value.as_object_mut() {
|
||||||
o.insert(String::from("Id"), Value::String(file_id));
|
o.insert(String::from("Id"), Value::String(file_id));
|
||||||
o.insert(String::from("Size"), Value::Number(size.into()));
|
o.insert(String::from("Size"), Value::Number(size.into()));
|
||||||
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size)));
|
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size as i32)));
|
||||||
}
|
}
|
||||||
send.data = serde_json::to_string(&data_value)?;
|
send.data = serde_json::to_string(&data_value)?;
|
||||||
|
|
||||||
@ -367,10 +348,10 @@ fn post_access_file(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/sends/<send_id>/<file_id>?<t>")]
|
#[get("/sends/<send_id>/<file_id>?<t>")]
|
||||||
fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
|
async fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
|
||||||
if let Ok(claims) = crate::auth::decode_send(&t) {
|
if let Ok(claims) = crate::auth::decode_send(&t) {
|
||||||
if claims.sub == format!("{}/{}", send_id, file_id) {
|
if claims.sub == format!("{}/{}", send_id, file_id) {
|
||||||
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).ok();
|
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use data_encoding::BASE32;
|
use data_encoding::BASE32;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{
|
api::{
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use data_encoding::BASE64;
|
use data_encoding::BASE64;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
||||||
@ -152,7 +152,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[post("/two-factor/duo", data = "<data>")]
|
#[post("/two-factor/duo", data = "<data>")]
|
||||||
fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
async fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||||
let data: EnableDuoData = data.into_inner().data;
|
let data: EnableDuoData = data.into_inner().data;
|
||||||
let mut user = headers.user;
|
let mut user = headers.user;
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
|
|||||||
let (data, data_str) = if check_duo_fields_custom(&data) {
|
let (data, data_str) = if check_duo_fields_custom(&data) {
|
||||||
let data_req: DuoData = data.into();
|
let data_req: DuoData = data.into();
|
||||||
let data_str = serde_json::to_string(&data_req)?;
|
let data_str = serde_json::to_string(&data_req)?;
|
||||||
duo_api_request("GET", "/auth/v2/check", "", &data_req).map_res("Failed to validate Duo credentials")?;
|
duo_api_request("GET", "/auth/v2/check", "", &data_req).await.map_res("Failed to validate Duo credentials")?;
|
||||||
(data_req.obscure(), data_str)
|
(data_req.obscure(), data_str)
|
||||||
} else {
|
} else {
|
||||||
(DuoData::secret(), String::new())
|
(DuoData::secret(), String::new())
|
||||||
@ -185,11 +185,11 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[put("/two-factor/duo", data = "<data>")]
|
#[put("/two-factor/duo", data = "<data>")]
|
||||||
fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
async fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||||
activate_duo(data, headers, conn)
|
activate_duo(data, headers, conn).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
|
async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
|
||||||
use reqwest::{header, Method};
|
use reqwest::{header, Method};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
@ -209,7 +209,8 @@ fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> Em
|
|||||||
.basic_auth(username, Some(password))
|
.basic_auth(username, Some(password))
|
||||||
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
|
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
|
||||||
.header(header::DATE, date)
|
.header(header::DATE, date)
|
||||||
.send()?
|
.send()
|
||||||
|
.await?
|
||||||
.error_for_status()?;
|
.error_for_status()?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use chrono::{Duration, NaiveDateTime, Utc};
|
use chrono::{Duration, NaiveDateTime, Utc};
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
use data_encoding::BASE32;
|
use data_encoding::BASE32;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -158,14 +158,14 @@ fn disable_twofactor_put(data: JsonUpcase<DisableTwoFactorData>, headers: Header
|
|||||||
disable_twofactor(data, headers, conn)
|
disable_twofactor(data, headers, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_incomplete_2fa_notifications(pool: DbPool) {
|
pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
|
||||||
debug!("Sending notifications for incomplete 2FA logins");
|
debug!("Sending notifications for incomplete 2FA logins");
|
||||||
|
|
||||||
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
|
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = match pool.get() {
|
let conn = match pool.get().await {
|
||||||
Ok(conn) => conn,
|
Ok(conn) => conn,
|
||||||
_ => {
|
_ => {
|
||||||
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
|
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use u2f::{
|
use u2f::{
|
||||||
messages::{RegisterResponse, SignResponse, U2fSignRequest},
|
messages::{RegisterResponse, SignResponse, U2fSignRequest},
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};
|
use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use yubico::{config::Config, verify};
|
use yubico::{config::Config, verify};
|
||||||
|
|
||||||
|
125
src/api/icons.rs
125
src/api/icons.rs
@ -1,19 +1,19 @@
|
|||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
|
||||||
io::prelude::*,
|
|
||||||
net::{IpAddr, ToSocketAddrs},
|
net::{IpAddr, ToSocketAddrs},
|
||||||
sync::{Arc, RwLock},
|
sync::{Arc, RwLock},
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use bytes::{Buf, Bytes, BytesMut};
|
||||||
|
use futures::{stream::StreamExt, TryFutureExt};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use reqwest::{blocking::Client, blocking::Response, header};
|
use reqwest::{header, Client, Response};
|
||||||
use rocket::{
|
use rocket::{http::ContentType, response::Redirect, Route};
|
||||||
http::ContentType,
|
use tokio::{
|
||||||
response::{Content, Redirect},
|
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
||||||
Route,
|
io::{AsyncReadExt, AsyncWriteExt},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -104,27 +104,23 @@ fn icon_google(domain: String) -> Option<Redirect> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/<domain>/icon.png")]
|
#[get("/<domain>/icon.png")]
|
||||||
fn icon_internal(domain: String) -> Cached<Content<Vec<u8>>> {
|
async fn icon_internal(domain: String) -> Cached<(ContentType, Vec<u8>)> {
|
||||||
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
|
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
|
||||||
|
|
||||||
if !is_valid_domain(&domain) {
|
if !is_valid_domain(&domain) {
|
||||||
warn!("Invalid domain: {}", domain);
|
warn!("Invalid domain: {}", domain);
|
||||||
return Cached::ttl(
|
return Cached::ttl(
|
||||||
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
||||||
CONFIG.icon_cache_negttl(),
|
CONFIG.icon_cache_negttl(),
|
||||||
true,
|
true,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
match get_icon(&domain) {
|
match get_icon(&domain).await {
|
||||||
Some((icon, icon_type)) => {
|
Some((icon, icon_type)) => {
|
||||||
Cached::ttl(Content(ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
|
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
|
||||||
}
|
}
|
||||||
_ => Cached::ttl(
|
_ => Cached::ttl((ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), CONFIG.icon_cache_negttl(), true),
|
||||||
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
|
||||||
CONFIG.icon_cache_negttl(),
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -317,15 +313,15 @@ fn is_domain_blacklisted(domain: &str) -> bool {
|
|||||||
is_blacklisted
|
is_blacklisted
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||||
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
|
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
|
||||||
|
|
||||||
// Check for expiration of negatively cached copy
|
// Check for expiration of negatively cached copy
|
||||||
if icon_is_negcached(&path) {
|
if icon_is_negcached(&path).await {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(icon) = get_cached_icon(&path) {
|
if let Some(icon) = get_cached_icon(&path).await {
|
||||||
let icon_type = match get_icon_type(&icon) {
|
let icon_type = match get_icon_type(&icon) {
|
||||||
Some(x) => x,
|
Some(x) => x,
|
||||||
_ => "x-icon",
|
_ => "x-icon",
|
||||||
@ -338,31 +334,31 @@ fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the icon, or None in case of error
|
// Get the icon, or None in case of error
|
||||||
match download_icon(domain) {
|
match download_icon(domain).await {
|
||||||
Ok((icon, icon_type)) => {
|
Ok((icon, icon_type)) => {
|
||||||
save_icon(&path, &icon);
|
save_icon(&path, &icon).await;
|
||||||
Some((icon, icon_type.unwrap_or("x-icon").to_string()))
|
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Unable to download icon: {:?}", e);
|
warn!("Unable to download icon: {:?}", e);
|
||||||
let miss_indicator = path + ".miss";
|
let miss_indicator = path + ".miss";
|
||||||
save_icon(&miss_indicator, &[]);
|
save_icon(&miss_indicator, &[]).await;
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
||||||
// Check for expiration of successfully cached copy
|
// Check for expiration of successfully cached copy
|
||||||
if icon_is_expired(path) {
|
if icon_is_expired(path).await {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read the cached icon, and return it if it exists
|
// Try to read the cached icon, and return it if it exists
|
||||||
if let Ok(mut f) = File::open(path) {
|
if let Ok(mut f) = File::open(path).await {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
|
|
||||||
if f.read_to_end(&mut buffer).is_ok() {
|
if f.read_to_end(&mut buffer).await.is_ok() {
|
||||||
return Some(buffer);
|
return Some(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -370,22 +366,22 @@ fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
|
async fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
|
||||||
let meta = symlink_metadata(path)?;
|
let meta = symlink_metadata(path).await?;
|
||||||
let modified = meta.modified()?;
|
let modified = meta.modified()?;
|
||||||
let age = SystemTime::now().duration_since(modified)?;
|
let age = SystemTime::now().duration_since(modified)?;
|
||||||
|
|
||||||
Ok(ttl > 0 && ttl <= age.as_secs())
|
Ok(ttl > 0 && ttl <= age.as_secs())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon_is_negcached(path: &str) -> bool {
|
async fn icon_is_negcached(path: &str) -> bool {
|
||||||
let miss_indicator = path.to_owned() + ".miss";
|
let miss_indicator = path.to_owned() + ".miss";
|
||||||
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl());
|
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()).await;
|
||||||
|
|
||||||
match expired {
|
match expired {
|
||||||
// No longer negatively cached, drop the marker
|
// No longer negatively cached, drop the marker
|
||||||
Ok(true) => {
|
Ok(true) => {
|
||||||
if let Err(e) = remove_file(&miss_indicator) {
|
if let Err(e) = remove_file(&miss_indicator).await {
|
||||||
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
|
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
@ -397,8 +393,8 @@ fn icon_is_negcached(path: &str) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon_is_expired(path: &str) -> bool {
|
async fn icon_is_expired(path: &str) -> bool {
|
||||||
let expired = file_is_expired(path, CONFIG.icon_cache_ttl());
|
let expired = file_is_expired(path, CONFIG.icon_cache_ttl()).await;
|
||||||
expired.unwrap_or(true)
|
expired.unwrap_or(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -521,13 +517,13 @@ struct IconUrlResult {
|
|||||||
/// let icon_result = get_icon_url("github.com")?;
|
/// let icon_result = get_icon_url("github.com")?;
|
||||||
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
|
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
|
||||||
/// ```
|
/// ```
|
||||||
fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
// Default URL with secure and insecure schemes
|
// Default URL with secure and insecure schemes
|
||||||
let ssldomain = format!("https://{}", domain);
|
let ssldomain = format!("https://{}", domain);
|
||||||
let httpdomain = format!("http://{}", domain);
|
let httpdomain = format!("http://{}", domain);
|
||||||
|
|
||||||
// First check the domain as given during the request for both HTTPS and HTTP.
|
// First check the domain as given during the request for both HTTPS and HTTP.
|
||||||
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)) {
|
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
|
||||||
Ok(c) => Ok(c),
|
Ok(c) => Ok(c),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let mut sub_resp = Err(e);
|
let mut sub_resp = Err(e);
|
||||||
@ -546,7 +542,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
|||||||
let httpbase = format!("http://{}", base_domain);
|
let httpbase = format!("http://{}", base_domain);
|
||||||
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
|
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
|
||||||
|
|
||||||
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase));
|
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
|
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
|
||||||
@ -557,7 +553,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
|||||||
let httpwww = format!("http://{}", www_domain);
|
let httpwww = format!("http://{}", www_domain);
|
||||||
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
|
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
|
||||||
|
|
||||||
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww));
|
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -581,7 +577,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
|||||||
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
|
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
|
||||||
|
|
||||||
// 384KB should be more than enough for the HTML, though as we only really need the HTML header.
|
// 384KB should be more than enough for the HTML, though as we only really need the HTML header.
|
||||||
let mut limited_reader = content.take(384 * 1024);
|
let mut limited_reader = stream_to_bytes_limit(content, 384 * 1024).await?.reader();
|
||||||
|
|
||||||
use html5ever::tendril::TendrilSink;
|
use html5ever::tendril::TendrilSink;
|
||||||
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
|
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
|
||||||
@ -607,11 +603,11 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_page(url: &str) -> Result<Response, Error> {
|
async fn get_page(url: &str) -> Result<Response, Error> {
|
||||||
get_page_with_referer(url, "")
|
get_page_with_referer(url, "").await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
||||||
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
|
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
|
||||||
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
|
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
|
||||||
}
|
}
|
||||||
@ -621,7 +617,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
|||||||
client = client.header("Referer", referer)
|
client = client.header("Referer", referer)
|
||||||
}
|
}
|
||||||
|
|
||||||
match client.send() {
|
match client.send().await {
|
||||||
Ok(c) => c.error_for_status().map_err(Into::into),
|
Ok(c) => c.error_for_status().map_err(Into::into),
|
||||||
Err(e) => err_silent!(format!("{}", e)),
|
Err(e) => err_silent!(format!("{}", e)),
|
||||||
}
|
}
|
||||||
@ -706,14 +702,14 @@ fn parse_sizes(sizes: Option<&str>) -> (u16, u16) {
|
|||||||
(width, height)
|
(width, height)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
|
||||||
if is_domain_blacklisted(domain) {
|
if is_domain_blacklisted(domain) {
|
||||||
err_silent!("Domain is blacklisted", domain)
|
err_silent!("Domain is blacklisted", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
let icon_result = get_icon_url(domain)?;
|
let icon_result = get_icon_url(domain).await?;
|
||||||
|
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Bytes::new();
|
||||||
let mut icon_type: Option<&str> = None;
|
let mut icon_type: Option<&str> = None;
|
||||||
|
|
||||||
use data_url::DataUrl;
|
use data_url::DataUrl;
|
||||||
@ -722,8 +718,12 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
|||||||
if icon.href.starts_with("data:image") {
|
if icon.href.starts_with("data:image") {
|
||||||
let datauri = DataUrl::process(&icon.href).unwrap();
|
let datauri = DataUrl::process(&icon.href).unwrap();
|
||||||
// Check if we are able to decode the data uri
|
// Check if we are able to decode the data uri
|
||||||
match datauri.decode_to_vec() {
|
let mut body = BytesMut::new();
|
||||||
Ok((body, _fragment)) => {
|
match datauri.decode::<_, ()>(|bytes| {
|
||||||
|
body.extend_from_slice(bytes);
|
||||||
|
Ok(())
|
||||||
|
}) {
|
||||||
|
Ok(_) => {
|
||||||
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
|
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
|
||||||
if body.len() >= 67 {
|
if body.len() >= 67 {
|
||||||
// Check if the icon type is allowed, else try an icon from the list.
|
// Check if the icon type is allowed, else try an icon from the list.
|
||||||
@ -733,17 +733,17 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
info!("Extracted icon from data:image uri for {}", domain);
|
info!("Extracted icon from data:image uri for {}", domain);
|
||||||
buffer = body;
|
buffer = body.freeze();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => debug!("Extracted icon from data:image uri is invalid"),
|
_ => debug!("Extracted icon from data:image uri is invalid"),
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
match get_page_with_referer(&icon.href, &icon_result.referer) {
|
match get_page_with_referer(&icon.href, &icon_result.referer).await {
|
||||||
Ok(mut res) => {
|
Ok(res) => {
|
||||||
res.copy_to(&mut buffer)?;
|
buffer = stream_to_bytes_limit(res, 512 * 1024).await?; // 512 KB for each icon max
|
||||||
// Check if the icon type is allowed, else try an icon from the list.
|
// Check if the icon type is allowed, else try an icon from the list.
|
||||||
icon_type = get_icon_type(&buffer);
|
icon_type = get_icon_type(&buffer);
|
||||||
if icon_type.is_none() {
|
if icon_type.is_none() {
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
@ -765,13 +765,13 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
|||||||
Ok((buffer, icon_type))
|
Ok((buffer, icon_type))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_icon(path: &str, icon: &[u8]) {
|
async fn save_icon(path: &str, icon: &[u8]) {
|
||||||
match File::create(path) {
|
match File::create(path).await {
|
||||||
Ok(mut f) => {
|
Ok(mut f) => {
|
||||||
f.write_all(icon).expect("Error writing icon file");
|
f.write_all(icon).await.expect("Error writing icon file");
|
||||||
}
|
}
|
||||||
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
|
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||||
create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache folder");
|
create_dir_all(&CONFIG.icon_cache_folder()).await.expect("Error creating icon cache folder");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Unable to save icon: {:?}", e);
|
warn!("Unable to save icon: {:?}", e);
|
||||||
@ -820,8 +820,6 @@ impl reqwest::cookie::CookieStore for Jar {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
|
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
|
||||||
use bytes::Bytes;
|
|
||||||
|
|
||||||
let cookie_store = self.0.read().unwrap();
|
let cookie_store = self.0.read().unwrap();
|
||||||
let s = cookie_store
|
let s = cookie_store
|
||||||
.get_request_values(url)
|
.get_request_values(url)
|
||||||
@ -836,3 +834,12 @@ impl reqwest::cookie::CookieStore for Jar {
|
|||||||
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
|
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn stream_to_bytes_limit(res: Response, max_size: usize) -> Result<Bytes, reqwest::Error> {
|
||||||
|
let mut stream = res.bytes_stream().take(max_size);
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
buf.extend(chunk?);
|
||||||
|
}
|
||||||
|
Ok(buf.freeze())
|
||||||
|
}
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use num_traits::FromPrimitive;
|
use num_traits::FromPrimitive;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::{
|
use rocket::{
|
||||||
request::{Form, FormItems, FromForm},
|
form::{Form, FromForm},
|
||||||
Route,
|
Route,
|
||||||
};
|
};
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -455,66 +455,57 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
|
|||||||
|
|
||||||
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
|
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
|
||||||
// https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
|
// https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default, FromForm)]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
struct ConnectData {
|
struct ConnectData {
|
||||||
// refresh_token, password, client_credentials (API key)
|
#[field(name = uncased("grant_type"))]
|
||||||
grant_type: String,
|
#[field(name = uncased("granttype"))]
|
||||||
|
grant_type: String, // refresh_token, password, client_credentials (API key)
|
||||||
|
|
||||||
// Needed for grant_type="refresh_token"
|
// Needed for grant_type="refresh_token"
|
||||||
|
#[field(name = uncased("refresh_token"))]
|
||||||
|
#[field(name = uncased("refreshtoken"))]
|
||||||
refresh_token: Option<String>,
|
refresh_token: Option<String>,
|
||||||
|
|
||||||
// Needed for grant_type = "password" | "client_credentials"
|
// Needed for grant_type = "password" | "client_credentials"
|
||||||
client_id: Option<String>, // web, cli, desktop, browser, mobile
|
#[field(name = uncased("client_id"))]
|
||||||
client_secret: Option<String>, // API key login (cli only)
|
#[field(name = uncased("clientid"))]
|
||||||
|
client_id: Option<String>, // web, cli, desktop, browser, mobile
|
||||||
|
#[field(name = uncased("client_secret"))]
|
||||||
|
#[field(name = uncased("clientsecret"))]
|
||||||
|
client_secret: Option<String>,
|
||||||
|
#[field(name = uncased("password"))]
|
||||||
password: Option<String>,
|
password: Option<String>,
|
||||||
|
#[field(name = uncased("scope"))]
|
||||||
scope: Option<String>,
|
scope: Option<String>,
|
||||||
|
#[field(name = uncased("username"))]
|
||||||
username: Option<String>,
|
username: Option<String>,
|
||||||
|
|
||||||
|
#[field(name = uncased("device_identifier"))]
|
||||||
|
#[field(name = uncased("deviceidentifier"))]
|
||||||
device_identifier: Option<String>,
|
device_identifier: Option<String>,
|
||||||
|
#[field(name = uncased("device_name"))]
|
||||||
|
#[field(name = uncased("devicename"))]
|
||||||
device_name: Option<String>,
|
device_name: Option<String>,
|
||||||
|
#[field(name = uncased("device_type"))]
|
||||||
|
#[field(name = uncased("devicetype"))]
|
||||||
device_type: Option<String>,
|
device_type: Option<String>,
|
||||||
|
#[field(name = uncased("device_push_token"))]
|
||||||
|
#[field(name = uncased("devicepushtoken"))]
|
||||||
device_push_token: Option<String>, // Unused; mobile device push not yet supported.
|
device_push_token: Option<String>, // Unused; mobile device push not yet supported.
|
||||||
|
|
||||||
// Needed for two-factor auth
|
// Needed for two-factor auth
|
||||||
|
#[field(name = uncased("two_factor_provider"))]
|
||||||
|
#[field(name = uncased("twofactorprovider"))]
|
||||||
two_factor_provider: Option<i32>,
|
two_factor_provider: Option<i32>,
|
||||||
|
#[field(name = uncased("two_factor_token"))]
|
||||||
|
#[field(name = uncased("twofactortoken"))]
|
||||||
two_factor_token: Option<String>,
|
two_factor_token: Option<String>,
|
||||||
|
#[field(name = uncased("two_factor_remember"))]
|
||||||
|
#[field(name = uncased("twofactorremember"))]
|
||||||
two_factor_remember: Option<i32>,
|
two_factor_remember: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'f> FromForm<'f> for ConnectData {
|
|
||||||
type Error = String;
|
|
||||||
|
|
||||||
fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> {
|
|
||||||
let mut form = Self::default();
|
|
||||||
for item in items {
|
|
||||||
let (key, value) = item.key_value_decoded();
|
|
||||||
let mut normalized_key = key.to_lowercase();
|
|
||||||
normalized_key.retain(|c| c != '_'); // Remove '_'
|
|
||||||
|
|
||||||
match normalized_key.as_ref() {
|
|
||||||
"granttype" => form.grant_type = value,
|
|
||||||
"refreshtoken" => form.refresh_token = Some(value),
|
|
||||||
"clientid" => form.client_id = Some(value),
|
|
||||||
"clientsecret" => form.client_secret = Some(value),
|
|
||||||
"password" => form.password = Some(value),
|
|
||||||
"scope" => form.scope = Some(value),
|
|
||||||
"username" => form.username = Some(value),
|
|
||||||
"deviceidentifier" => form.device_identifier = Some(value),
|
|
||||||
"devicename" => form.device_name = Some(value),
|
|
||||||
"devicetype" => form.device_type = Some(value),
|
|
||||||
"devicepushtoken" => form.device_push_token = Some(value),
|
|
||||||
"twofactorprovider" => form.two_factor_provider = value.parse().ok(),
|
|
||||||
"twofactortoken" => form.two_factor_token = Some(value),
|
|
||||||
"twofactorremember" => form.two_factor_remember = value.parse().ok(),
|
|
||||||
key => warn!("Detected unexpected parameter during login: {}", key),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(form)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
|
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
|
||||||
if value.is_none() {
|
if value.is_none() {
|
||||||
err!(msg)
|
err!(msg)
|
||||||
|
@ -5,7 +5,7 @@ mod identity;
|
|||||||
mod notifications;
|
mod notifications;
|
||||||
mod web;
|
mod web;
|
||||||
|
|
||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
pub use crate::api::{
|
pub use crate::api::{
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value as JsonValue;
|
use serde_json::Value as JsonValue;
|
||||||
|
|
||||||
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
|
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
|
||||||
@ -417,7 +417,7 @@ pub enum UpdateType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
use rocket::State;
|
use rocket::State;
|
||||||
pub type Notify<'a> = State<'a, WebSocketUsers>;
|
pub type Notify<'a> = &'a State<WebSocketUsers>;
|
||||||
|
|
||||||
pub fn start_notification_server() -> WebSocketUsers {
|
pub fn start_notification_server() -> WebSocketUsers {
|
||||||
let factory = WsFactory::init();
|
let factory = WsFactory::init();
|
||||||
@ -430,12 +430,11 @@ pub fn start_notification_server() -> WebSocketUsers {
|
|||||||
settings.queue_size = 2;
|
settings.queue_size = 2;
|
||||||
settings.panic_on_internal = false;
|
settings.panic_on_internal = false;
|
||||||
|
|
||||||
ws::Builder::new()
|
let ws = ws::Builder::new().with_settings(settings).build(factory).unwrap();
|
||||||
.with_settings(settings)
|
CONFIG.set_ws_shutdown_handle(ws.broadcaster());
|
||||||
.build(factory)
|
ws.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())).unwrap();
|
||||||
.unwrap()
|
|
||||||
.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port()))
|
warn!("WS Server stopped!");
|
||||||
.unwrap();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use rocket::{http::ContentType, response::content::Content, response::NamedFile, Route};
|
use rocket::serde::json::Json;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::{fs::NamedFile, http::ContentType, Route};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -21,16 +21,16 @@ pub fn routes() -> Vec<Route> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn web_index() -> Cached<Option<NamedFile>> {
|
async fn web_index() -> Cached<Option<NamedFile>> {
|
||||||
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).ok(), false)
|
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).await.ok(), false)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/app-id.json")]
|
#[get("/app-id.json")]
|
||||||
fn app_id() -> Cached<Content<Json<Value>>> {
|
fn app_id() -> Cached<(ContentType, Json<Value>)> {
|
||||||
let content_type = ContentType::new("application", "fido.trusted-apps+json");
|
let content_type = ContentType::new("application", "fido.trusted-apps+json");
|
||||||
|
|
||||||
Cached::long(
|
Cached::long(
|
||||||
Content(
|
(
|
||||||
content_type,
|
content_type,
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"trustedFacets": [
|
"trustedFacets": [
|
||||||
@ -58,13 +58,13 @@ fn app_id() -> Cached<Content<Json<Value>>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
|
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
|
||||||
fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
|
async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
|
||||||
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).ok(), true)
|
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/attachments/<uuid>/<file_id>")]
|
#[get("/attachments/<uuid>/<file_id>")]
|
||||||
fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
|
async fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
|
||||||
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).ok()
|
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).await.ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
// We use DbConn here to let the alive healthcheck also verify the database connection.
|
// We use DbConn here to let the alive healthcheck also verify the database connection.
|
||||||
@ -78,25 +78,20 @@ fn alive(_conn: DbConn) -> Json<String> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/vw_static/<filename>")]
|
#[get("/vw_static/<filename>")]
|
||||||
fn static_files(filename: String) -> Result<Content<&'static [u8]>, Error> {
|
fn static_files(filename: String) -> Result<(ContentType, &'static [u8]), Error> {
|
||||||
match filename.as_ref() {
|
match filename.as_ref() {
|
||||||
"mail-github.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
|
"mail-github.png" => Ok((ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
|
||||||
"logo-gray.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
|
"logo-gray.png" => Ok((ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
|
||||||
"error-x.svg" => Ok(Content(ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
|
"error-x.svg" => Ok((ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
|
||||||
"hibp.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
|
"hibp.png" => Ok((ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
|
||||||
"vaultwarden-icon.png" => {
|
"vaultwarden-icon.png" => Ok((ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png"))),
|
||||||
Ok(Content(ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png")))
|
"bootstrap.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
|
||||||
}
|
"bootstrap-native.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js"))),
|
||||||
|
"identicon.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
|
||||||
"bootstrap.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
|
"datatables.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
|
||||||
"bootstrap-native.js" => {
|
"datatables.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
|
||||||
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js")))
|
|
||||||
}
|
|
||||||
"identicon.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
|
|
||||||
"datatables.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
|
|
||||||
"datatables.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
|
|
||||||
"jquery-3.6.0.slim.js" => {
|
"jquery-3.6.0.slim.js" => {
|
||||||
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
|
Ok((ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
|
||||||
}
|
}
|
||||||
_ => err!(format!("Static file not found: {}", filename)),
|
_ => err!(format!("Static file not found: {}", filename)),
|
||||||
}
|
}
|
||||||
|
262
src/auth.rs
262
src/auth.rs
@ -257,7 +257,10 @@ pub fn generate_send_claims(send_id: &str, file_id: &str) -> BasicJwtClaims {
|
|||||||
//
|
//
|
||||||
// Bearer token authentication
|
// Bearer token authentication
|
||||||
//
|
//
|
||||||
use rocket::request::{FromRequest, Outcome, Request};
|
use rocket::{
|
||||||
|
outcome::try_outcome,
|
||||||
|
request::{FromRequest, Outcome, Request},
|
||||||
|
};
|
||||||
|
|
||||||
use crate::db::{
|
use crate::db::{
|
||||||
models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException},
|
models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException},
|
||||||
@ -268,10 +271,11 @@ pub struct Host {
|
|||||||
pub host: String,
|
pub host: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Host {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for Host {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
let headers = request.headers();
|
let headers = request.headers();
|
||||||
|
|
||||||
// Get host
|
// Get host
|
||||||
@ -314,17 +318,14 @@ pub struct Headers {
|
|||||||
pub user: User,
|
pub user: User,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for Headers {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
let headers = request.headers();
|
let headers = request.headers();
|
||||||
|
|
||||||
let host = match Host::from_request(request) {
|
let host = try_outcome!(Host::from_request(request).await).host;
|
||||||
Outcome::Forward(_) => return Outcome::Forward(()),
|
|
||||||
Outcome::Failure(f) => return Outcome::Failure(f),
|
|
||||||
Outcome::Success(host) => host.host,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get access_token
|
// Get access_token
|
||||||
let access_token: &str = match headers.get_one("Authorization") {
|
let access_token: &str = match headers.get_one("Authorization") {
|
||||||
@ -344,7 +345,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
|||||||
let device_uuid = claims.device;
|
let device_uuid = claims.device;
|
||||||
let user_uuid = claims.sub;
|
let user_uuid = claims.sub;
|
||||||
|
|
||||||
let conn = match request.guard::<DbConn>() {
|
let conn = match DbConn::from_request(request).await {
|
||||||
Outcome::Success(conn) => conn,
|
Outcome::Success(conn) => conn,
|
||||||
_ => err_handler!("Error getting DB"),
|
_ => err_handler!("Error getting DB"),
|
||||||
};
|
};
|
||||||
@ -363,7 +364,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
|||||||
if let Some(stamp_exception) =
|
if let Some(stamp_exception) =
|
||||||
user.stamp_exception.as_deref().and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
|
user.stamp_exception.as_deref().and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
|
||||||
{
|
{
|
||||||
let current_route = match request.route().and_then(|r| r.name) {
|
let current_route = match request.route().and_then(|r| r.name.as_deref()) {
|
||||||
Some(name) => name,
|
Some(name) => name,
|
||||||
_ => err_handler!("Error getting current route for stamp exception"),
|
_ => err_handler!("Error getting current route for stamp exception"),
|
||||||
};
|
};
|
||||||
@ -411,13 +412,13 @@ pub struct OrgHeaders {
|
|||||||
// but there are cases where it is a query value.
|
// but there are cases where it is a query value.
|
||||||
// First check the path, if this is not a valid uuid, try the query values.
|
// First check the path, if this is not a valid uuid, try the query values.
|
||||||
fn get_org_id(request: &Request) -> Option<String> {
|
fn get_org_id(request: &Request) -> Option<String> {
|
||||||
if let Some(Ok(org_id)) = request.get_param::<String>(1) {
|
if let Some(Ok(org_id)) = request.param::<String>(1) {
|
||||||
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
||||||
return Some(org_id);
|
return Some(org_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(Ok(org_id)) = request.get_query_value::<String>("organizationId") {
|
if let Some(Ok(org_id)) = request.query_value::<String>("organizationId") {
|
||||||
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
||||||
return Some(org_id);
|
return Some(org_id);
|
||||||
}
|
}
|
||||||
@ -426,52 +427,48 @@ fn get_org_id(request: &Request) -> Option<String> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for OrgHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<Headers>() {
|
let headers = try_outcome!(Headers::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
match get_org_id(request) {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Some(org_id) => {
|
||||||
Outcome::Success(headers) => {
|
let conn = match DbConn::from_request(request).await {
|
||||||
match get_org_id(request) {
|
Outcome::Success(conn) => conn,
|
||||||
Some(org_id) => {
|
_ => err_handler!("Error getting DB"),
|
||||||
let conn = match request.guard::<DbConn>() {
|
};
|
||||||
Outcome::Success(conn) => conn,
|
|
||||||
_ => err_handler!("Error getting DB"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let user = headers.user;
|
let user = headers.user;
|
||||||
let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) {
|
let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) {
|
||||||
Some(user) => {
|
Some(user) => {
|
||||||
if user.status == UserOrgStatus::Confirmed as i32 {
|
if user.status == UserOrgStatus::Confirmed as i32 {
|
||||||
user
|
user
|
||||||
} else {
|
} else {
|
||||||
err_handler!("The current user isn't confirmed member of the organization")
|
err_handler!("The current user isn't confirmed member of the organization")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
None => err_handler!("The current user isn't member of the organization"),
|
|
||||||
};
|
|
||||||
|
|
||||||
Outcome::Success(Self {
|
|
||||||
host: headers.host,
|
|
||||||
device: headers.device,
|
|
||||||
user,
|
|
||||||
org_user_type: {
|
|
||||||
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
|
|
||||||
org_usr_type
|
|
||||||
} else {
|
|
||||||
// This should only happen if the DB is corrupted
|
|
||||||
err_handler!("Unknown user type in the database")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
org_user,
|
|
||||||
org_id,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
_ => err_handler!("Error getting the organization id"),
|
None => err_handler!("The current user isn't member of the organization"),
|
||||||
}
|
};
|
||||||
|
|
||||||
|
Outcome::Success(Self {
|
||||||
|
host: headers.host,
|
||||||
|
device: headers.device,
|
||||||
|
user,
|
||||||
|
org_user_type: {
|
||||||
|
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
|
||||||
|
org_usr_type
|
||||||
|
} else {
|
||||||
|
// This should only happen if the DB is corrupted
|
||||||
|
err_handler!("Unknown user type in the database")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
org_user,
|
||||||
|
org_id,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
_ => err_handler!("Error getting the organization id"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -483,25 +480,21 @@ pub struct AdminHeaders {
|
|||||||
pub org_user_type: UserOrgType,
|
pub org_user_type: UserOrgType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AdminHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for AdminHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type >= UserOrgType::Admin {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Outcome::Success(Self {
|
||||||
Outcome::Success(headers) => {
|
host: headers.host,
|
||||||
if headers.org_user_type >= UserOrgType::Admin {
|
device: headers.device,
|
||||||
Outcome::Success(Self {
|
user: headers.user,
|
||||||
host: headers.host,
|
org_user_type: headers.org_user_type,
|
||||||
device: headers.device,
|
})
|
||||||
user: headers.user,
|
} else {
|
||||||
org_user_type: headers.org_user_type,
|
err_handler!("You need to be Admin or Owner to call this endpoint")
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be Admin or Owner to call this endpoint")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -520,13 +513,13 @@ impl From<AdminHeaders> for Headers {
|
|||||||
// but there could be cases where it is a query value.
|
// but there could be cases where it is a query value.
|
||||||
// First check the path, if this is not a valid uuid, try the query values.
|
// First check the path, if this is not a valid uuid, try the query values.
|
||||||
fn get_col_id(request: &Request) -> Option<String> {
|
fn get_col_id(request: &Request) -> Option<String> {
|
||||||
if let Some(Ok(col_id)) = request.get_param::<String>(3) {
|
if let Some(Ok(col_id)) = request.param::<String>(3) {
|
||||||
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
||||||
return Some(col_id);
|
return Some(col_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(Ok(col_id)) = request.get_query_value::<String>("collectionId") {
|
if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
|
||||||
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
||||||
return Some(col_id);
|
return Some(col_id);
|
||||||
}
|
}
|
||||||
@ -545,46 +538,38 @@ pub struct ManagerHeaders {
|
|||||||
pub org_user_type: UserOrgType,
|
pub org_user_type: UserOrgType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for ManagerHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type >= UserOrgType::Manager {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
match get_col_id(request) {
|
||||||
Outcome::Success(headers) => {
|
Some(col_id) => {
|
||||||
if headers.org_user_type >= UserOrgType::Manager {
|
let conn = match DbConn::from_request(request).await {
|
||||||
match get_col_id(request) {
|
Outcome::Success(conn) => conn,
|
||||||
Some(col_id) => {
|
_ => err_handler!("Error getting DB"),
|
||||||
let conn = match request.guard::<DbConn>() {
|
};
|
||||||
Outcome::Success(conn) => conn,
|
|
||||||
_ => err_handler!("Error getting DB"),
|
|
||||||
};
|
|
||||||
|
|
||||||
if !headers.org_user.has_full_access() {
|
if !headers.org_user.has_full_access() {
|
||||||
match CollectionUser::find_by_collection_and_user(
|
match CollectionUser::find_by_collection_and_user(&col_id, &headers.org_user.user_uuid, &conn) {
|
||||||
&col_id,
|
Some(_) => (),
|
||||||
&headers.org_user.user_uuid,
|
None => err_handler!("The current user isn't a manager for this collection"),
|
||||||
&conn,
|
|
||||||
) {
|
|
||||||
Some(_) => (),
|
|
||||||
None => err_handler!("The current user isn't a manager for this collection"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ => err_handler!("Error getting the collection id"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Outcome::Success(Self {
|
|
||||||
host: headers.host,
|
|
||||||
device: headers.device,
|
|
||||||
user: headers.user,
|
|
||||||
org_user_type: headers.org_user_type,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
|
||||||
}
|
}
|
||||||
|
_ => err_handler!("Error getting the collection id"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Outcome::Success(Self {
|
||||||
|
host: headers.host,
|
||||||
|
device: headers.device,
|
||||||
|
user: headers.user,
|
||||||
|
org_user_type: headers.org_user_type,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -608,25 +593,21 @@ pub struct ManagerHeadersLoose {
|
|||||||
pub org_user_type: UserOrgType,
|
pub org_user_type: UserOrgType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeadersLoose {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for ManagerHeadersLoose {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type >= UserOrgType::Manager {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Outcome::Success(Self {
|
||||||
Outcome::Success(headers) => {
|
host: headers.host,
|
||||||
if headers.org_user_type >= UserOrgType::Manager {
|
device: headers.device,
|
||||||
Outcome::Success(Self {
|
user: headers.user,
|
||||||
host: headers.host,
|
org_user_type: headers.org_user_type,
|
||||||
device: headers.device,
|
})
|
||||||
user: headers.user,
|
} else {
|
||||||
org_user_type: headers.org_user_type,
|
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -647,24 +628,20 @@ pub struct OwnerHeaders {
|
|||||||
pub user: User,
|
pub user: User,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for OwnerHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for OwnerHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type == UserOrgType::Owner {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Outcome::Success(Self {
|
||||||
Outcome::Success(headers) => {
|
host: headers.host,
|
||||||
if headers.org_user_type == UserOrgType::Owner {
|
device: headers.device,
|
||||||
Outcome::Success(Self {
|
user: headers.user,
|
||||||
host: headers.host,
|
})
|
||||||
device: headers.device,
|
} else {
|
||||||
user: headers.user,
|
err_handler!("You need to be Owner to call this endpoint")
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be Owner to call this endpoint")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -678,10 +655,11 @@ pub struct ClientIp {
|
|||||||
pub ip: IpAddr,
|
pub ip: IpAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for ClientIp {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for ClientIp {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
let ip = if CONFIG._ip_header_enabled() {
|
let ip = if CONFIG._ip_header_enabled() {
|
||||||
req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| {
|
req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| {
|
||||||
match ip.find(',') {
|
match ip.find(',') {
|
||||||
|
@ -36,6 +36,9 @@ macro_rules! make_config {
|
|||||||
pub struct Config { inner: RwLock<Inner> }
|
pub struct Config { inner: RwLock<Inner> }
|
||||||
|
|
||||||
struct Inner {
|
struct Inner {
|
||||||
|
rocket_shutdown_handle: Option<rocket::Shutdown>,
|
||||||
|
ws_shutdown_handle: Option<ws::Sender>,
|
||||||
|
|
||||||
templates: Handlebars<'static>,
|
templates: Handlebars<'static>,
|
||||||
config: ConfigItems,
|
config: ConfigItems,
|
||||||
|
|
||||||
@ -332,6 +335,8 @@ make_config! {
|
|||||||
attachments_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "attachments");
|
attachments_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "attachments");
|
||||||
/// Sends folder
|
/// Sends folder
|
||||||
sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends");
|
sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends");
|
||||||
|
/// Temp folder |> Used for storing temporary file uploads
|
||||||
|
tmp_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "tmp");
|
||||||
/// Templates folder
|
/// Templates folder
|
||||||
templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates");
|
templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates");
|
||||||
/// Session JWT key
|
/// Session JWT key
|
||||||
@ -509,6 +514,9 @@ make_config! {
|
|||||||
/// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely
|
/// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely
|
||||||
db_connection_retries: u32, false, def, 15;
|
db_connection_retries: u32, false, def, 15;
|
||||||
|
|
||||||
|
/// Timeout when aquiring database connection
|
||||||
|
database_timeout: u64, false, def, 30;
|
||||||
|
|
||||||
/// Database connection pool size
|
/// Database connection pool size
|
||||||
database_max_conns: u32, false, def, 10;
|
database_max_conns: u32, false, def, 10;
|
||||||
|
|
||||||
@ -743,6 +751,8 @@ impl Config {
|
|||||||
|
|
||||||
Ok(Config {
|
Ok(Config {
|
||||||
inner: RwLock::new(Inner {
|
inner: RwLock::new(Inner {
|
||||||
|
rocket_shutdown_handle: None,
|
||||||
|
ws_shutdown_handle: None,
|
||||||
templates: load_templates(&config.templates_folder),
|
templates: load_templates(&config.templates_folder),
|
||||||
config,
|
config,
|
||||||
_env,
|
_env,
|
||||||
@ -907,6 +917,27 @@ impl Config {
|
|||||||
hb.render(name, data).map_err(Into::into)
|
hb.render(name, data).map_err(Into::into)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_rocket_shutdown_handle(&self, handle: rocket::Shutdown) {
|
||||||
|
self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_ws_shutdown_handle(&self, handle: ws::Sender) {
|
||||||
|
self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shutdown(&self) {
|
||||||
|
if let Ok(c) = self.inner.read() {
|
||||||
|
if let Some(handle) = c.ws_shutdown_handle.clone() {
|
||||||
|
handle.shutdown().ok();
|
||||||
|
}
|
||||||
|
// Wait a bit before stopping the web server
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
|
if let Some(handle) = c.rocket_shutdown_handle.clone() {
|
||||||
|
handle.notify();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable};
|
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable};
|
||||||
|
231
src/db/mod.rs
231
src/db/mod.rs
@ -1,8 +1,16 @@
|
|||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
|
use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
|
||||||
use rocket::{
|
use rocket::{
|
||||||
http::Status,
|
http::Status,
|
||||||
|
outcome::IntoOutcome,
|
||||||
request::{FromRequest, Outcome},
|
request::{FromRequest, Outcome},
|
||||||
Request, State,
|
Request,
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio::{
|
||||||
|
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
|
||||||
|
time::timeout,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@ -22,6 +30,23 @@ pub mod __mysql_schema;
|
|||||||
#[path = "schemas/postgresql/schema.rs"]
|
#[path = "schemas/postgresql/schema.rs"]
|
||||||
pub mod __postgresql_schema;
|
pub mod __postgresql_schema;
|
||||||
|
|
||||||
|
// There changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
|
||||||
|
|
||||||
|
// A wrapper around spawn_blocking that propagates panics to the calling code.
|
||||||
|
pub async fn run_blocking<F, R>(job: F) -> R
|
||||||
|
where
|
||||||
|
F: FnOnce() -> R + Send + 'static,
|
||||||
|
R: Send + 'static,
|
||||||
|
{
|
||||||
|
match tokio::task::spawn_blocking(job).await {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(e) => match e.try_into_panic() {
|
||||||
|
Ok(panic) => std::panic::resume_unwind(panic),
|
||||||
|
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
|
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
|
||||||
macro_rules! generate_connections {
|
macro_rules! generate_connections {
|
||||||
( $( $name:ident: $ty:ty ),+ ) => {
|
( $( $name:ident: $ty:ty ),+ ) => {
|
||||||
@ -29,12 +54,53 @@ macro_rules! generate_connections {
|
|||||||
#[derive(Eq, PartialEq)]
|
#[derive(Eq, PartialEq)]
|
||||||
pub enum DbConnType { $( $name, )+ }
|
pub enum DbConnType { $( $name, )+ }
|
||||||
|
|
||||||
|
pub struct DbConn {
|
||||||
|
conn: Arc<Mutex<Option<DbConnInner>>>,
|
||||||
|
permit: Option<OwnedSemaphorePermit>,
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
pub enum DbConn { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
|
pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct DbPool {
|
||||||
|
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
|
||||||
|
pool: Option<DbPoolInner>,
|
||||||
|
semaphore: Arc<Semaphore>
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum DbPool { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
|
pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
|
||||||
|
|
||||||
|
impl Drop for DbConn {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let conn = self.conn.clone();
|
||||||
|
let permit = self.permit.take();
|
||||||
|
|
||||||
|
// Since connection can't be on the stack in an async fn during an
|
||||||
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
|
||||||
|
|
||||||
|
if let Some(conn) = conn.take() {
|
||||||
|
drop(conn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop permit after the connection is dropped
|
||||||
|
drop(permit);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DbPool {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let pool = self.pool.take();
|
||||||
|
tokio::task::spawn_blocking(move || drop(pool));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl DbPool {
|
impl DbPool {
|
||||||
// For the given database URL, guess it's type, run migrations create pool and return it
|
// For the given database URL, guess it's type, run migrations create pool and return it
|
||||||
@ -50,9 +116,13 @@ macro_rules! generate_connections {
|
|||||||
let manager = ConnectionManager::new(&url);
|
let manager = ConnectionManager::new(&url);
|
||||||
let pool = Pool::builder()
|
let pool = Pool::builder()
|
||||||
.max_size(CONFIG.database_max_conns())
|
.max_size(CONFIG.database_max_conns())
|
||||||
|
.connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
|
||||||
.build(manager)
|
.build(manager)
|
||||||
.map_res("Failed to create pool")?;
|
.map_res("Failed to create pool")?;
|
||||||
return Ok(Self::$name(pool));
|
return Ok(DbPool {
|
||||||
|
pool: Some(DbPoolInner::$name(pool)),
|
||||||
|
semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#[cfg(not($name))]
|
#[cfg(not($name))]
|
||||||
#[allow(unreachable_code)]
|
#[allow(unreachable_code)]
|
||||||
@ -61,10 +131,26 @@ macro_rules! generate_connections {
|
|||||||
)+ }
|
)+ }
|
||||||
}
|
}
|
||||||
// Get a connection from the pool
|
// Get a connection from the pool
|
||||||
pub fn get(&self) -> Result<DbConn, Error> {
|
pub async fn get(&self) -> Result<DbConn, Error> {
|
||||||
match self { $(
|
let duration = Duration::from_secs(CONFIG.database_timeout());
|
||||||
|
let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
|
||||||
|
Ok(p) => p.expect("Semaphore should be open"),
|
||||||
|
Err(_) => {
|
||||||
|
err!("Timeout waiting for database connection");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.pool.as_ref().expect("DbPool.pool should always be Some()") { $(
|
||||||
#[cfg($name)]
|
#[cfg($name)]
|
||||||
Self::$name(p) => Ok(DbConn::$name(p.get().map_res("Error retrieving connection from pool")?)),
|
DbPoolInner::$name(p) => {
|
||||||
|
let pool = p.clone();
|
||||||
|
let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
|
||||||
|
|
||||||
|
return Ok(DbConn {
|
||||||
|
conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))),
|
||||||
|
permit: Some(permit)
|
||||||
|
});
|
||||||
|
},
|
||||||
)+ }
|
)+ }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -113,42 +199,95 @@ macro_rules! db_run {
|
|||||||
db_run! { $conn: sqlite, mysql, postgresql $body }
|
db_run! { $conn: sqlite, mysql, postgresql $body }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Different code for each db
|
|
||||||
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
|
|
||||||
#[allow(unused)] use diesel::prelude::*;
|
|
||||||
match $conn {
|
|
||||||
$($(
|
|
||||||
#[cfg($db)]
|
|
||||||
crate::db::DbConn::$db(ref $conn) => {
|
|
||||||
paste::paste! {
|
|
||||||
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
|
|
||||||
#[allow(unused)] use [<__ $db _model>]::*;
|
|
||||||
#[allow(unused)] use crate::db::FromDb;
|
|
||||||
}
|
|
||||||
$body
|
|
||||||
},
|
|
||||||
)+)+
|
|
||||||
}}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Same for all dbs
|
|
||||||
( @raw $conn:ident: $body:block ) => {
|
( @raw $conn:ident: $body:block ) => {
|
||||||
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
|
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Different code for each db
|
// Different code for each db
|
||||||
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {
|
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
|
||||||
#[allow(unused)] use diesel::prelude::*;
|
#[allow(unused)] use diesel::prelude::*;
|
||||||
#[allow(unused_variables)]
|
|
||||||
match $conn {
|
// It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
|
||||||
$($(
|
// derived from it) never be a variable on the stack at an await point,
|
||||||
#[cfg($db)]
|
// where Drop might be called at any time. This causes (synchronous)
|
||||||
crate::db::DbConn::$db(ref $conn) => {
|
// Drop to be called from asynchronous code, which some database
|
||||||
$body
|
// wrappers do not or can not handle.
|
||||||
},
|
let conn = $conn.conn.clone();
|
||||||
)+)+
|
|
||||||
}
|
// Since connection can't be on the stack in an async fn during an
|
||||||
};
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
/*
|
||||||
|
run_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in
|
||||||
|
// a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
|
||||||
|
let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
|
||||||
|
*/
|
||||||
|
let mut __conn_mutex = conn.try_lock_owned().unwrap();
|
||||||
|
let conn = __conn_mutex.as_mut().unwrap();
|
||||||
|
match conn {
|
||||||
|
$($(
|
||||||
|
#[cfg($db)]
|
||||||
|
crate::db::DbConnInner::$db($conn) => {
|
||||||
|
paste::paste! {
|
||||||
|
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
|
||||||
|
#[allow(unused)] use [<__ $db _model>]::*;
|
||||||
|
#[allow(unused)] use crate::db::FromDb;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
// Since connection can't be on the stack in an async fn during an
|
||||||
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
run_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in
|
||||||
|
// a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(async {
|
||||||
|
conn.lock_owned().await
|
||||||
|
});
|
||||||
|
|
||||||
|
let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
|
||||||
|
f(conn)
|
||||||
|
}).await;*/
|
||||||
|
|
||||||
|
$body
|
||||||
|
},
|
||||||
|
)+)+
|
||||||
|
}
|
||||||
|
// }).await
|
||||||
|
}};
|
||||||
|
|
||||||
|
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
|
||||||
|
#[allow(unused)] use diesel::prelude::*;
|
||||||
|
|
||||||
|
// It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
|
||||||
|
// derived from it) never be a variable on the stack at an await point,
|
||||||
|
// where Drop might be called at any time. This causes (synchronous)
|
||||||
|
// Drop to be called from asynchronous code, which some database
|
||||||
|
// wrappers do not or can not handle.
|
||||||
|
let conn = $conn.conn.clone();
|
||||||
|
|
||||||
|
// Since connection can't be on the stack in an async fn during an
|
||||||
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
run_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in
|
||||||
|
// a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
|
||||||
|
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
|
||||||
|
$($(
|
||||||
|
#[cfg($db)]
|
||||||
|
crate::db::DbConnInner::$db($conn) => {
|
||||||
|
paste::paste! {
|
||||||
|
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
|
||||||
|
// @RAW: #[allow(unused)] use [<__ $db _model>]::*;
|
||||||
|
#[allow(unused)] use crate::db::FromDb;
|
||||||
|
}
|
||||||
|
|
||||||
|
$body
|
||||||
|
},
|
||||||
|
)+)+
|
||||||
|
}
|
||||||
|
}).await
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait FromDb {
|
pub trait FromDb {
|
||||||
@ -227,9 +366,10 @@ pub mod models;
|
|||||||
|
|
||||||
/// Creates a back-up of the sqlite database
|
/// Creates a back-up of the sqlite database
|
||||||
/// MySQL/MariaDB and PostgreSQL are not supported.
|
/// MySQL/MariaDB and PostgreSQL are not supported.
|
||||||
pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
|
pub async fn backup_database(conn: &DbConn) -> Result<(), Error> {
|
||||||
db_run! {@raw conn:
|
db_run! {@raw conn:
|
||||||
postgresql, mysql {
|
postgresql, mysql {
|
||||||
|
let _ = conn;
|
||||||
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
|
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
|
||||||
}
|
}
|
||||||
sqlite {
|
sqlite {
|
||||||
@ -244,7 +384,7 @@ pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the SQL Server version
|
/// Get the SQL Server version
|
||||||
pub fn get_sql_server_version(conn: &DbConn) -> String {
|
pub async fn get_sql_server_version(conn: &DbConn) -> String {
|
||||||
db_run! {@raw conn:
|
db_run! {@raw conn:
|
||||||
postgresql, mysql {
|
postgresql, mysql {
|
||||||
no_arg_sql_function!(version, diesel::sql_types::Text);
|
no_arg_sql_function!(version, diesel::sql_types::Text);
|
||||||
@ -260,15 +400,14 @@ pub fn get_sql_server_version(conn: &DbConn) -> String {
|
|||||||
/// Attempts to retrieve a single connection from the managed database pool. If
|
/// Attempts to retrieve a single connection from the managed database pool. If
|
||||||
/// no pool is currently managed, fails with an `InternalServerError` status. If
|
/// no pool is currently managed, fails with an `InternalServerError` status. If
|
||||||
/// no connections are available, fails with a `ServiceUnavailable` status.
|
/// no connections are available, fails with a `ServiceUnavailable` status.
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for DbConn {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for DbConn {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<DbConn, ()> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
// https://github.com/SergioBenitez/Rocket/commit/e3c1a4ad3ab9b840482ec6de4200d30df43e357c
|
match request.rocket().state::<DbPool>() {
|
||||||
let pool = try_outcome!(request.guard::<State<DbPool>>());
|
Some(p) => p.get().await.map_err(|_| ()).into_outcome(Status::ServiceUnavailable),
|
||||||
match pool.get() {
|
None => Outcome::Failure((Status::InternalServerError, ())),
|
||||||
Ok(conn) => Outcome::Success(conn),
|
|
||||||
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
10
src/error.rs
10
src/error.rs
@ -45,6 +45,7 @@ use lettre::transport::smtp::Error as SmtpErr;
|
|||||||
use openssl::error::ErrorStack as SSLErr;
|
use openssl::error::ErrorStack as SSLErr;
|
||||||
use regex::Error as RegexErr;
|
use regex::Error as RegexErr;
|
||||||
use reqwest::Error as ReqErr;
|
use reqwest::Error as ReqErr;
|
||||||
|
use rocket::error::Error as RocketErr;
|
||||||
use serde_json::{Error as SerdeErr, Value};
|
use serde_json::{Error as SerdeErr, Value};
|
||||||
use std::io::Error as IoErr;
|
use std::io::Error as IoErr;
|
||||||
use std::time::SystemTimeError as TimeErr;
|
use std::time::SystemTimeError as TimeErr;
|
||||||
@ -84,6 +85,7 @@ make_error! {
|
|||||||
Address(AddrErr): _has_source, _api_error,
|
Address(AddrErr): _has_source, _api_error,
|
||||||
Smtp(SmtpErr): _has_source, _api_error,
|
Smtp(SmtpErr): _has_source, _api_error,
|
||||||
OpenSSL(SSLErr): _has_source, _api_error,
|
OpenSSL(SSLErr): _has_source, _api_error,
|
||||||
|
Rocket(RocketErr): _has_source, _api_error,
|
||||||
|
|
||||||
DieselCon(DieselConErr): _has_source, _api_error,
|
DieselCon(DieselConErr): _has_source, _api_error,
|
||||||
DieselMig(DieselMigErr): _has_source, _api_error,
|
DieselMig(DieselMigErr): _has_source, _api_error,
|
||||||
@ -193,8 +195,8 @@ use rocket::http::{ContentType, Status};
|
|||||||
use rocket::request::Request;
|
use rocket::request::Request;
|
||||||
use rocket::response::{self, Responder, Response};
|
use rocket::response::{self, Responder, Response};
|
||||||
|
|
||||||
impl<'r> Responder<'r> for Error {
|
impl<'r> Responder<'r, 'static> for Error {
|
||||||
fn respond_to(self, _: &Request) -> response::Result<'r> {
|
fn respond_to(self, _: &Request) -> response::Result<'static> {
|
||||||
match self.error {
|
match self.error {
|
||||||
ErrorKind::Empty(_) => {} // Don't print the error in this situation
|
ErrorKind::Empty(_) => {} // Don't print the error in this situation
|
||||||
ErrorKind::Simple(_) => {} // Don't print the error in this situation
|
ErrorKind::Simple(_) => {} // Don't print the error in this situation
|
||||||
@ -202,8 +204,8 @@ impl<'r> Responder<'r> for Error {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
|
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
|
||||||
|
let body = self.to_string();
|
||||||
Response::build().status(code).header(ContentType::JSON).sized_body(Cursor::new(format!("{}", self))).ok()
|
Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
98
src/main.rs
98
src/main.rs
@ -20,8 +20,15 @@ extern crate diesel;
|
|||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate diesel_migrations;
|
extern crate diesel_migrations;
|
||||||
|
|
||||||
use job_scheduler::{Job, JobScheduler};
|
use std::{
|
||||||
use std::{fs::create_dir_all, panic, path::Path, process::exit, str::FromStr, thread, time::Duration};
|
fs::{canonicalize, create_dir_all},
|
||||||
|
panic,
|
||||||
|
path::Path,
|
||||||
|
process::exit,
|
||||||
|
str::FromStr,
|
||||||
|
thread,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
mod error;
|
mod error;
|
||||||
@ -37,9 +44,11 @@ mod util;
|
|||||||
|
|
||||||
pub use config::CONFIG;
|
pub use config::CONFIG;
|
||||||
pub use error::{Error, MapResult};
|
pub use error::{Error, MapResult};
|
||||||
|
use rocket::data::{Limits, ToByteUnit};
|
||||||
pub use util::is_running_in_docker;
|
pub use util::is_running_in_docker;
|
||||||
|
|
||||||
fn main() {
|
#[rocket::main]
|
||||||
|
async fn main() -> Result<(), Error> {
|
||||||
parse_args();
|
parse_args();
|
||||||
launch_info();
|
launch_info();
|
||||||
|
|
||||||
@ -56,13 +65,16 @@ fn main() {
|
|||||||
});
|
});
|
||||||
check_web_vault();
|
check_web_vault();
|
||||||
|
|
||||||
create_icon_cache_folder();
|
create_dir(&CONFIG.icon_cache_folder(), "icon cache");
|
||||||
|
create_dir(&CONFIG.tmp_folder(), "tmp folder");
|
||||||
|
create_dir(&CONFIG.sends_folder(), "sends folder");
|
||||||
|
create_dir(&CONFIG.attachments_folder(), "attachments folder");
|
||||||
|
|
||||||
let pool = create_db_pool();
|
let pool = create_db_pool();
|
||||||
schedule_jobs(pool.clone());
|
schedule_jobs(pool.clone()).await;
|
||||||
crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().unwrap()).unwrap();
|
crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).unwrap();
|
||||||
|
|
||||||
launch_rocket(pool, extra_debug); // Blocks until program termination.
|
launch_rocket(pool, extra_debug).await // Blocks until program termination.
|
||||||
}
|
}
|
||||||
|
|
||||||
const HELP: &str = "\
|
const HELP: &str = "\
|
||||||
@ -127,10 +139,12 @@ fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
|
|||||||
.level_for("hyper::server", log::LevelFilter::Warn)
|
.level_for("hyper::server", log::LevelFilter::Warn)
|
||||||
// Silence rocket logs
|
// Silence rocket logs
|
||||||
.level_for("_", log::LevelFilter::Off)
|
.level_for("_", log::LevelFilter::Off)
|
||||||
.level_for("launch", log::LevelFilter::Off)
|
.level_for("rocket::launch", log::LevelFilter::Error)
|
||||||
.level_for("launch_", log::LevelFilter::Off)
|
.level_for("rocket::launch_", log::LevelFilter::Error)
|
||||||
.level_for("rocket::rocket", log::LevelFilter::Off)
|
.level_for("rocket::rocket", log::LevelFilter::Warn)
|
||||||
.level_for("rocket::fairing", log::LevelFilter::Off)
|
.level_for("rocket::server", log::LevelFilter::Warn)
|
||||||
|
.level_for("rocket::fairing::fairings", log::LevelFilter::Warn)
|
||||||
|
.level_for("rocket::shield::shield", log::LevelFilter::Warn)
|
||||||
// Never show html5ever and hyper::proto logs, too noisy
|
// Never show html5ever and hyper::proto logs, too noisy
|
||||||
.level_for("html5ever", log::LevelFilter::Off)
|
.level_for("html5ever", log::LevelFilter::Off)
|
||||||
.level_for("hyper::proto", log::LevelFilter::Off)
|
.level_for("hyper::proto", log::LevelFilter::Off)
|
||||||
@ -243,10 +257,6 @@ fn create_dir(path: &str, description: &str) {
|
|||||||
create_dir_all(path).expect(&err_msg);
|
create_dir_all(path).expect(&err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_icon_cache_folder() {
|
|
||||||
create_dir(&CONFIG.icon_cache_folder(), "icon cache");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_data_folder() {
|
fn check_data_folder() {
|
||||||
let data_folder = &CONFIG.data_folder();
|
let data_folder = &CONFIG.data_folder();
|
||||||
let path = Path::new(data_folder);
|
let path = Path::new(data_folder);
|
||||||
@ -314,51 +324,73 @@ fn create_db_pool() -> db::DbPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn launch_rocket(pool: db::DbPool, extra_debug: bool) {
|
async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> {
|
||||||
let basepath = &CONFIG.domain_path();
|
let basepath = &CONFIG.domain_path();
|
||||||
|
|
||||||
|
let mut config = rocket::Config::from(rocket::Config::figment());
|
||||||
|
config.address = std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); // TODO: Allow this to be changed, keep ROCKET_ADDRESS for compat
|
||||||
|
config.temp_dir = canonicalize(CONFIG.tmp_folder()).unwrap().into();
|
||||||
|
config.limits = Limits::new() //
|
||||||
|
.limit("json", 10.megabytes())
|
||||||
|
.limit("data-form", 150.megabytes())
|
||||||
|
.limit("file", 150.megabytes());
|
||||||
|
|
||||||
// If adding more paths here, consider also adding them to
|
// If adding more paths here, consider also adding them to
|
||||||
// crate::utils::LOGGED_ROUTES to make sure they appear in the log
|
// crate::utils::LOGGED_ROUTES to make sure they appear in the log
|
||||||
let result = rocket::ignite()
|
let instance = rocket::custom(config)
|
||||||
.mount(&[basepath, "/"].concat(), api::web_routes())
|
.mount([basepath, "/"].concat(), api::web_routes())
|
||||||
.mount(&[basepath, "/api"].concat(), api::core_routes())
|
.mount([basepath, "/api"].concat(), api::core_routes())
|
||||||
.mount(&[basepath, "/admin"].concat(), api::admin_routes())
|
.mount([basepath, "/admin"].concat(), api::admin_routes())
|
||||||
.mount(&[basepath, "/identity"].concat(), api::identity_routes())
|
.mount([basepath, "/identity"].concat(), api::identity_routes())
|
||||||
.mount(&[basepath, "/icons"].concat(), api::icons_routes())
|
.mount([basepath, "/icons"].concat(), api::icons_routes())
|
||||||
.mount(&[basepath, "/notifications"].concat(), api::notifications_routes())
|
.mount([basepath, "/notifications"].concat(), api::notifications_routes())
|
||||||
.manage(pool)
|
.manage(pool)
|
||||||
.manage(api::start_notification_server())
|
.manage(api::start_notification_server())
|
||||||
.attach(util::AppHeaders())
|
.attach(util::AppHeaders())
|
||||||
.attach(util::Cors())
|
.attach(util::Cors())
|
||||||
.attach(util::BetterLogging(extra_debug))
|
.attach(util::BetterLogging(extra_debug))
|
||||||
.launch();
|
.ignite()
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Launch and print error if there is one
|
CONFIG.set_rocket_shutdown_handle(instance.shutdown());
|
||||||
// The launch will restore the original logging level
|
ctrlc::set_handler(move || {
|
||||||
error!("Launch error {:#?}", result);
|
info!("Exiting vaultwarden!");
|
||||||
|
CONFIG.shutdown();
|
||||||
|
})
|
||||||
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
instance.launch().await?;
|
||||||
|
|
||||||
|
info!("Vaultwarden process exited!");
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn schedule_jobs(pool: db::DbPool) {
|
async fn schedule_jobs(pool: db::DbPool) {
|
||||||
if CONFIG.job_poll_interval_ms() == 0 {
|
if CONFIG.job_poll_interval_ms() == 0 {
|
||||||
info!("Job scheduler disabled.");
|
info!("Job scheduler disabled.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let runtime = tokio::runtime::Handle::current();
|
||||||
|
|
||||||
thread::Builder::new()
|
thread::Builder::new()
|
||||||
.name("job-scheduler".to_string())
|
.name("job-scheduler".to_string())
|
||||||
.spawn(move || {
|
.spawn(move || {
|
||||||
|
use job_scheduler::{Job, JobScheduler};
|
||||||
|
|
||||||
let mut sched = JobScheduler::new();
|
let mut sched = JobScheduler::new();
|
||||||
|
|
||||||
// Purge sends that are past their deletion date.
|
// Purge sends that are past their deletion date.
|
||||||
if !CONFIG.send_purge_schedule().is_empty() {
|
if !CONFIG.send_purge_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
|
||||||
api::purge_sends(pool.clone());
|
runtime.spawn(api::purge_sends(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purge trashed items that are old enough to be auto-deleted.
|
// Purge trashed items that are old enough to be auto-deleted.
|
||||||
if !CONFIG.trash_purge_schedule().is_empty() {
|
if !CONFIG.trash_purge_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
|
||||||
api::purge_trashed_ciphers(pool.clone());
|
runtime.spawn(api::purge_trashed_ciphers(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -366,7 +398,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
|||||||
// indicates that a user's master password has been compromised.
|
// indicates that a user's master password has been compromised.
|
||||||
if !CONFIG.incomplete_2fa_schedule().is_empty() {
|
if !CONFIG.incomplete_2fa_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
|
||||||
api::send_incomplete_2fa_notifications(pool.clone());
|
runtime.spawn(api::send_incomplete_2fa_notifications(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -375,7 +407,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
|||||||
// sending reminders for requests that are about to be granted anyway.
|
// sending reminders for requests that are about to be granted anyway.
|
||||||
if !CONFIG.emergency_request_timeout_schedule().is_empty() {
|
if !CONFIG.emergency_request_timeout_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
|
||||||
api::emergency_request_timeout_job(pool.clone());
|
runtime.spawn(api::emergency_request_timeout_job(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,7 +415,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
|||||||
// emergency access requests.
|
// emergency access requests.
|
||||||
if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
|
if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
|
||||||
api::emergency_notification_reminder_job(pool.clone());
|
runtime.spawn(api::emergency_notification_reminder_job(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
70
src/util.rs
70
src/util.rs
@ -5,10 +5,10 @@ use std::io::Cursor;
|
|||||||
|
|
||||||
use rocket::{
|
use rocket::{
|
||||||
fairing::{Fairing, Info, Kind},
|
fairing::{Fairing, Info, Kind},
|
||||||
http::{ContentType, Header, HeaderMap, Method, RawStr, Status},
|
http::{ContentType, Header, HeaderMap, Method, Status},
|
||||||
request::FromParam,
|
request::FromParam,
|
||||||
response::{self, Responder},
|
response::{self, Responder},
|
||||||
Data, Request, Response, Rocket,
|
Data, Orbit, Request, Response, Rocket,
|
||||||
};
|
};
|
||||||
|
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
@ -18,6 +18,7 @@ use crate::CONFIG;
|
|||||||
|
|
||||||
pub struct AppHeaders();
|
pub struct AppHeaders();
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl Fairing for AppHeaders {
|
impl Fairing for AppHeaders {
|
||||||
fn info(&self) -> Info {
|
fn info(&self) -> Info {
|
||||||
Info {
|
Info {
|
||||||
@ -26,7 +27,7 @@ impl Fairing for AppHeaders {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_response(&self, _req: &Request, res: &mut Response) {
|
async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) {
|
||||||
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
|
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
|
||||||
res.set_raw_header("Referrer-Policy", "same-origin");
|
res.set_raw_header("Referrer-Policy", "same-origin");
|
||||||
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
|
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
|
||||||
@ -72,6 +73,7 @@ impl Cors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl Fairing for Cors {
|
impl Fairing for Cors {
|
||||||
fn info(&self) -> Info {
|
fn info(&self) -> Info {
|
||||||
Info {
|
Info {
|
||||||
@ -80,7 +82,7 @@ impl Fairing for Cors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||||
let req_headers = request.headers();
|
let req_headers = request.headers();
|
||||||
|
|
||||||
if let Some(origin) = Cors::get_allowed_origin(req_headers) {
|
if let Some(origin) = Cors::get_allowed_origin(req_headers) {
|
||||||
@ -97,7 +99,7 @@ impl Fairing for Cors {
|
|||||||
response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
|
response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
|
||||||
response.set_status(Status::Ok);
|
response.set_status(Status::Ok);
|
||||||
response.set_header(ContentType::Plain);
|
response.set_header(ContentType::Plain);
|
||||||
response.set_sized_body(Cursor::new(""));
|
response.set_sized_body(Some(0), Cursor::new(""));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -134,25 +136,21 @@ impl<R> Cached<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> {
|
impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
|
||||||
fn respond_to(self, req: &Request) -> response::Result<'r> {
|
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
|
||||||
|
let mut res = self.response.respond_to(request)?;
|
||||||
|
|
||||||
let cache_control_header = if self.is_immutable {
|
let cache_control_header = if self.is_immutable {
|
||||||
format!("public, immutable, max-age={}", self.ttl)
|
format!("public, immutable, max-age={}", self.ttl)
|
||||||
} else {
|
} else {
|
||||||
format!("public, max-age={}", self.ttl)
|
format!("public, max-age={}", self.ttl)
|
||||||
};
|
};
|
||||||
|
res.set_raw_header("Cache-Control", cache_control_header);
|
||||||
|
|
||||||
let time_now = chrono::Local::now();
|
let time_now = chrono::Local::now();
|
||||||
|
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
||||||
match self.response.respond_to(req) {
|
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
||||||
Ok(mut res) => {
|
Ok(res)
|
||||||
res.set_raw_header("Cache-Control", cache_control_header);
|
|
||||||
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
|
||||||
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
e @ Err(_) => e,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,11 +173,9 @@ impl<'r> FromParam<'r> for SafeString {
|
|||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
|
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
|
||||||
let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?;
|
if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
||||||
|
Ok(SafeString(param.to_string()))
|
||||||
if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
|
||||||
Ok(SafeString(s))
|
|
||||||
} else {
|
} else {
|
||||||
Err(())
|
Err(())
|
||||||
}
|
}
|
||||||
@ -193,15 +189,16 @@ const LOGGED_ROUTES: [&str; 6] =
|
|||||||
|
|
||||||
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
|
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
|
||||||
pub struct BetterLogging(pub bool);
|
pub struct BetterLogging(pub bool);
|
||||||
|
#[rocket::async_trait]
|
||||||
impl Fairing for BetterLogging {
|
impl Fairing for BetterLogging {
|
||||||
fn info(&self) -> Info {
|
fn info(&self) -> Info {
|
||||||
Info {
|
Info {
|
||||||
name: "Better Logging",
|
name: "Better Logging",
|
||||||
kind: Kind::Launch | Kind::Request | Kind::Response,
|
kind: Kind::Liftoff | Kind::Request | Kind::Response,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_launch(&self, rocket: &Rocket) {
|
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||||
if self.0 {
|
if self.0 {
|
||||||
info!(target: "routes", "Routes loaded:");
|
info!(target: "routes", "Routes loaded:");
|
||||||
let mut routes: Vec<_> = rocket.routes().collect();
|
let mut routes: Vec<_> = rocket.routes().collect();
|
||||||
@ -225,34 +222,36 @@ impl Fairing for BetterLogging {
|
|||||||
info!(target: "start", "Rocket has launched from {}", addr);
|
info!(target: "start", "Rocket has launched from {}", addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_request(&self, request: &mut Request<'_>, _data: &Data) {
|
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
|
||||||
let method = request.method();
|
let method = request.method();
|
||||||
if !self.0 && method == Method::Options {
|
if !self.0 && method == Method::Options {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let uri = request.uri();
|
let uri = request.uri();
|
||||||
let uri_path = uri.path();
|
let uri_path = uri.path();
|
||||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
let uri_path_str = uri_path.url_decode_lossy();
|
||||||
|
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||||
match uri.query() {
|
match uri.query() {
|
||||||
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]),
|
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
|
||||||
None => info!(target: "request", "{} {}", method, uri_path),
|
None => info!(target: "request", "{} {}", method, uri_path_str),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||||
if !self.0 && request.method() == Method::Options {
|
if !self.0 && request.method() == Method::Options {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let uri_path = request.uri().path();
|
let uri_path = request.uri().path();
|
||||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
let uri_path_str = uri_path.url_decode_lossy();
|
||||||
|
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
if let Some(route) = request.route() {
|
if let Some(ref route) = request.route() {
|
||||||
info!(target: "response", "{} => {} {}", route, status.code, status.reason)
|
info!(target: "response", "{} => {}", route, status)
|
||||||
} else {
|
} else {
|
||||||
info!(target: "response", "{} {}", status.code, status.reason)
|
info!(target: "response", "{}", status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -614,10 +613,7 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use reqwest::{
|
use reqwest::{header, Client, ClientBuilder};
|
||||||
blocking::{Client, ClientBuilder},
|
|
||||||
header,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn get_reqwest_client() -> Client {
|
pub fn get_reqwest_client() -> Client {
|
||||||
get_reqwest_client_builder().build().expect("Failed to build client")
|
get_reqwest_client_builder().build().expect("Failed to build client")
|
||||||
|
Loading…
Reference in New Issue
Block a user