use rocket::Route; use rocket_contrib::json::Json; use serde_json::Value as JsonValue; use crate::api::JsonResult; use crate::auth::Headers; use crate::db::DbConn; use crate::CONFIG; pub fn routes() -> Vec { routes![negotiate, websockets_err] } #[get("/hub")] fn websockets_err() -> JsonResult { err!("'/notifications/hub' should be proxied to the websocket server or notifications won't work. Go to the README for more info.") } #[post("/hub/negotiate")] fn negotiate(_headers: Headers, _conn: DbConn) -> JsonResult { use crate::crypto; use data_encoding::BASE64URL; let conn_id = BASE64URL.encode(&crypto::get_random(vec![0u8; 16])); let mut available_transports: Vec = Vec::new(); if CONFIG.websocket_enabled { available_transports.push(json!({"transport":"WebSockets", "transferFormats":["Text","Binary"]})); } // TODO: Implement transports // Rocket WS support: https://github.com/SergioBenitez/Rocket/issues/90 // Rocket SSE support: https://github.com/SergioBenitez/Rocket/issues/33 // {"transport":"ServerSentEvents", "transferFormats":["Text"]}, // {"transport":"LongPolling", "transferFormats":["Text","Binary"]} Ok(Json(json!({ "connectionId": conn_id, "availableTransports": available_transports }))) } // // Websockets server // use std::sync::Arc; use std::thread; use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender, WebSocket}; use chashmap::CHashMap; use chrono::NaiveDateTime; use serde_json::from_str; use crate::db::models::{Cipher, Folder, User}; use rmpv::Value; fn serialize(val: Value) -> Vec { use rmpv::encode::write_value; let mut buf = Vec::new(); write_value(&mut buf, &val).expect("Error encoding MsgPack"); // Add size bytes at the start // Extracted from BinaryMessageFormat.js let mut size: usize = buf.len(); let mut len_buf: Vec = Vec::new(); loop { let mut size_part = size & 0x7f; size >>= 7; if size > 0 { size_part |= 0x80; } len_buf.push(size_part as u8); if size == 0 { break; } } len_buf.append(&mut buf); len_buf } fn serialize_date(date: NaiveDateTime) -> Value { let seconds: i64 = date.timestamp(); let nanos: i64 = date.timestamp_subsec_nanos() as i64; let timestamp = nanos << 34 | seconds; let bs = timestamp.to_be_bytes(); // -1 is Timestamp // https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type Value::Ext(-1, bs.to_vec()) } fn convert_option>(option: Option) -> Value { match option { Some(a) => a.into(), None => Value::Nil, } } // Server WebSocket handler pub struct WSHandler { out: Sender, user_uuid: Option, users: WebSocketUsers, } const RECORD_SEPARATOR: u8 = 0x1e; const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, #[derive(Deserialize)] struct InitialMessage { protocol: String, version: i32, } const PING_MS: u64 = 15_000; const PING: Token = Token(1); impl Handler for WSHandler { fn on_open(&mut self, hs: Handshake) -> ws::Result<()> { // TODO: Improve this split let path = hs.request.resource(); let mut query_split: Vec<_> = path.split('?').nth(1).unwrap().split('&').collect(); query_split.sort(); let access_token = &query_split[0][13..]; let _id = &query_split[1][3..]; // Validate the user use crate::auth; let claims = match auth::decode_login(access_token) { Ok(claims) => claims, Err(_) => return Err(ws::Error::new(ws::ErrorKind::Internal, "Invalid access token provided")), }; // Assign the user to the handler let user_uuid = claims.sub; self.user_uuid = Some(user_uuid.clone()); // Add the current Sender to the user list let handler_insert = self.out.clone(); let handler_update = self.out.clone(); self.users .map .upsert(user_uuid, || vec![handler_insert], |ref mut v| v.push(handler_update)); // Schedule a ping to keep the connection alive self.out.timeout(PING_MS, PING) } fn on_message(&mut self, msg: Message) -> ws::Result<()> { info!("Server got message '{}'. ", msg); if let Message::Text(text) = msg.clone() { let json = &text[..text.len() - 1]; // Remove last char if let Ok(InitialMessage { protocol, version }) = from_str::(json) { if &protocol == "messagepack" && version == 1 { return self.out.send(&INITIAL_RESPONSE[..]); // Respond to initial message } } } // If it's not the initial message, just echo the message self.out.send(msg) } fn on_timeout(&mut self, event: Token) -> ws::Result<()> { if event == PING { // send ping self.out.send(create_ping())?; // reschedule the timeout self.out.timeout(PING_MS, PING) } else { Err(ws::Error::new( ws::ErrorKind::Internal, "Invalid timeout token provided", )) } } } struct WSFactory { pub users: WebSocketUsers, } impl WSFactory { pub fn init() -> Self { WSFactory { users: WebSocketUsers { map: Arc::new(CHashMap::new()), }, } } } impl Factory for WSFactory { type Handler = WSHandler; fn connection_made(&mut self, out: Sender) -> Self::Handler { WSHandler { out, user_uuid: None, users: self.users.clone(), } } fn connection_lost(&mut self, handler: Self::Handler) { // Remove handler if let Some(user_uuid) = &handler.user_uuid { if let Some(mut user_conn) = self.users.map.get_mut(user_uuid) { user_conn.remove_item(&handler.out); } } } } #[derive(Clone)] pub struct WebSocketUsers { map: Arc>>, } impl WebSocketUsers { fn send_update(&self, user_uuid: &String, data: &[u8]) -> ws::Result<()> { if let Some(user) = self.map.get(user_uuid) { for sender in user.iter() { sender.send(data)?; } } Ok(()) } // NOTE: The last modified date needs to be updated before calling these methods #[allow(dead_code)] pub fn send_user_update(&self, ut: UpdateType, user: &User) { let data = create_update( vec![ ("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at)), ], ut, ); self.send_update(&user.uuid.clone(), &data).ok(); } pub fn send_folder_update(&self, ut: UpdateType, folder: &Folder) { let data = create_update( vec![ ("Id".into(), folder.uuid.clone().into()), ("UserId".into(), folder.user_uuid.clone().into()), ("RevisionDate".into(), serialize_date(folder.updated_at)), ], ut, ); self.send_update(&folder.user_uuid, &data).ok(); } pub fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) { let user_uuid = convert_option(cipher.user_uuid.clone()); let org_uuid = convert_option(cipher.organization_uuid.clone()); let data = create_update( vec![ ("Id".into(), cipher.uuid.clone().into()), ("UserId".into(), user_uuid), ("OrganizationId".into(), org_uuid), ("CollectionIds".into(), Value::Nil), ("RevisionDate".into(), serialize_date(cipher.updated_at)), ], ut, ); for uuid in user_uuids { self.send_update(&uuid, &data).ok(); } } } /* Message Structure [ 1, // MessageType.Invocation {}, // Headers null, // InvocationId "ReceiveMessage", // Target [ // Arguments { "ContextId": "app_id", "Type": ut as i32, "Payload": {} } ] ] */ fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType) -> Vec { use rmpv::Value as V; let value = V::Array(vec![ 1.into(), V::Array(vec![]), V::Nil, "ReceiveMessage".into(), V::Array(vec![V::Map(vec![ ("ContextId".into(), "app_id".into()), ("Type".into(), (ut as i32).into()), ("Payload".into(), payload.into()), ])]), ]); serialize(value) } fn create_ping() -> Vec { serialize(Value::Array(vec![6.into()])) } #[allow(dead_code)] pub enum UpdateType { CipherUpdate = 0, CipherCreate = 1, LoginDelete = 2, FolderDelete = 3, Ciphers = 4, Vault = 5, OrgKeys = 6, FolderCreate = 7, FolderUpdate = 8, CipherDelete = 9, SyncSettings = 10, LogOut = 11, } use rocket::State; pub type Notify<'a> = State<'a, WebSocketUsers>; pub fn start_notification_server() -> WebSocketUsers { let factory = WSFactory::init(); let users = factory.users.clone(); if CONFIG.websocket_enabled { thread::spawn(move || { WebSocket::new(factory).unwrap().listen(&CONFIG.websocket_url).unwrap(); }); } users }