From 0919c4a13cc113edb654535054f6d35d0d438a1d Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 20 Oct 2022 19:59:11 +0200 Subject: [PATCH] Persist and load login details to a local database --- .../java/dev/lonami/talaria/MainActivity.kt | 1 + .../dev/lonami/talaria/bindings/Native.kt | 1 + native/Cargo.toml | 5 +- native/src/db/mod.rs | 174 ++++++++++++++++++ native/src/db/model.rs | 14 ++ native/src/db/utils.rs | 27 +++ native/src/lib.rs | 106 ++++++++++- 7 files changed, 324 insertions(+), 4 deletions(-) create mode 100644 native/src/db/mod.rs create mode 100644 native/src/db/model.rs create mode 100644 native/src/db/utils.rs diff --git a/app/src/main/java/dev/lonami/talaria/MainActivity.kt b/app/src/main/java/dev/lonami/talaria/MainActivity.kt index 53f8b58..c17146c 100644 --- a/app/src/main/java/dev/lonami/talaria/MainActivity.kt +++ b/app/src/main/java/dev/lonami/talaria/MainActivity.kt @@ -25,6 +25,7 @@ class MainActivity : ComponentActivity() { } } + Native.initDatabase(getDatabasePath("talaria.db").path) Native.initClient() } } diff --git a/app/src/main/java/dev/lonami/talaria/bindings/Native.kt b/app/src/main/java/dev/lonami/talaria/bindings/Native.kt index 607fdaf..e666f72 100644 --- a/app/src/main/java/dev/lonami/talaria/bindings/Native.kt +++ b/app/src/main/java/dev/lonami/talaria/bindings/Native.kt @@ -5,6 +5,7 @@ object Native { System.loadLibrary("talaria") } + external fun initDatabase(path: String) external fun initClient() external fun needLogin(): Boolean external fun requestLoginCode(phone: String): Long diff --git a/native/Cargo.toml b/native/Cargo.toml index 9741819..684c92d 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -10,13 +10,14 @@ crate-type = ["cdylib"] [dependencies] jni = { version = "0.10.2", default-features = false } # v0.4 of grammers-* is currently unreleased; clone the project and use path dependencies -grammers-client = { version = "0.4.0" } +grammers-client = { version = "0.4.1" } grammers-tl-types = { version = "0.4.0" } -grammers-session = { version = "0.4.0" } +grammers-session = { version = "0.4.1" } tokio = { version = "1.5.0", features = ["full"] } log = "0.4.14" android_logger = "0.11.1" once_cell = "1.15.0" +sqlite = "0.27.0" [profile.release] lto = true diff --git a/native/src/db/mod.rs b/native/src/db/mod.rs new file mode 100644 index 0000000..f70fb74 --- /dev/null +++ b/native/src/db/mod.rs @@ -0,0 +1,174 @@ +use std::net::IpAddr; + +use sqlite::{Connection, Error, State}; + +use model::Session; +use utils::{fetch_many, fetch_one}; + +mod model; +mod utils; + +fn init_schema(conn: &Connection) -> Result<(), Error> { + let version = match fetch_one(&conn, "SELECT version FROM version LIMIT 1", |stmt| { + stmt.read::(0) + }) { + Ok(Some(version)) => version, + _ => 0, + }; + + if version == 0 { + conn.execute( + " + BEGIN TRANSACTION; + CREATE TABLE version (version INTEGER NOT NULL); + CREATE TABLE session ( + user_id INTEGER, + dc_id INTEGER, + bot INTEGER, + pts INTEGER, + qts INTEGER, + seq INTEGER, + date INTEGER + ); + CREATE TABLE channel ( + session_id INTEGER NOT NULL REFERENCES session (rowid), + id INTEGER NOT NULL, + hash INTEGER NOT NULL, + pts INTEGER NOT NULL + ); + CREATE TABLE datacenter ( + session_id INTEGER NOT NULL REFERENCES session (rowid), + id INTEGER NOT NULL, + ipv4 TEXT, + ipv6 TEXT, + port INTEGER NOT NULL, + auth BLOB, + CONSTRAINT SingleIp CHECK( + (ipv4 IS NOT NULL AND ipv6 IS NULL) OR + (ipv6 IS NOT NULL AND ipv4 IS NULL)) + ); + INSERT INTO version VALUES (1); + COMMIT; + ", + )?; + } + + Ok(()) +} + +pub fn init_connection(db_path: &str) -> Result { + let conn = sqlite::open(db_path)?; + init_schema(&conn)?; + Ok(conn) +} + +pub fn get_sessions(conn: &Connection) -> Result, Error> { + let query = " + SELECT s.rowid, s.*, COALESCE(d.ipv4, d.ipv6), d.port, d.auth + FROM session s + LEFT JOIN datacenter d ON d.session_id = s.rowid AND d.id = s.dc_id + "; + fetch_many(conn, query, |stmt| { + Ok(Session { + id: stmt.read(0)?, + user_id: stmt.read(1)?, + dc_id: stmt.read::>(2)?.map(|x| x as _), + bot: stmt.read::>(3)?.map(|x| x != 0), + pts: stmt.read::>(4)?.map(|x| x as _), + qts: stmt.read::>(5)?.map(|x| x as _), + seq: stmt.read::>(6)?.map(|x| x as _), + date: stmt.read::>(7)?.map(|x| x as _), + dc_addr: stmt.read::>(8)?, + dc_port: stmt.read::>(9)?.map(|x| x as _), + dc_auth: stmt + .read::>>(10)? + .map(|x| x.try_into().unwrap()), + }) + }) +} + +pub fn create_session(conn: &Connection) -> Result { + conn.execute("INSERT INTO session DEFAULT VALUES;")?; + let id = fetch_one(conn, "SELECT LAST_INSERT_ROWID()", |stmt| { + stmt.read::(0) + })? + .unwrap(); + Ok(Session { + id, + user_id: None, + dc_id: None, + bot: None, + pts: None, + qts: None, + seq: None, + date: None, + dc_addr: None, + dc_port: None, + dc_auth: None, + }) +} + +pub fn update_session(conn: &Connection, session: &Session) -> Result<(), Error> { + let mut stmt = conn + .prepare( + " + UPDATE session SET + user_id = ?, + dc_id = ?, + bot = ?, + pts = ?, + qts = ?, + seq = ?, + date = ? + WHERE rowid = ? + ", + )? + .bind(1, session.user_id)? + .bind(2, session.dc_id.map(|x| x as i64))? + .bind(3, session.bot.map(|x| x as i64))? + .bind(4, session.pts.map(|x| x as i64))? + .bind(5, session.qts.map(|x| x as i64))? + .bind(6, session.seq.map(|x| x as i64))? + .bind(7, session.date.map(|x| x as i64))? + .bind(8, session.id)?; + while let State::Row = stmt.next()? {} + + match ( + session.dc_id, + session.dc_addr.as_ref(), + session.dc_port, + session.dc_auth, + ) { + (Some(id), Some(addr), Some(port), Some(auth)) => { + let (ipv4, ipv6) = match addr.parse().unwrap() { + IpAddr::V4(ipv4) => (Some(ipv4.to_string()), None), + IpAddr::V6(ipv6) => (None, Some(ipv6.to_string())), + }; + + let mut stmt = conn + .prepare( + " + DELETE FROM datacenter WHERE session_id = ? AND id = ? + ", + )? + .bind(1, session.id)? + .bind(2, id as i64)?; + + while let State::Row = stmt.next()? {} + + let mut stmt = conn + .prepare("INSERT INTO datacenter VALUES (?, ?, ?, ?, ?, ?)")? + .bind(1, session.id)? + .bind(2, id as i64)? + .bind(3, ipv4.as_deref())? + .bind(4, ipv6.as_deref())? + .bind(5, port as i64)? + .bind(6, auth.as_ref())?; + + while let State::Row = stmt.next()? {} + } + _ => {} + } + + Ok(()) +} diff --git a/native/src/db/model.rs b/native/src/db/model.rs new file mode 100644 index 0000000..4ae182b --- /dev/null +++ b/native/src/db/model.rs @@ -0,0 +1,14 @@ +#[derive(Debug)] +pub struct Session { + pub id: i64, + pub user_id: Option, + pub dc_id: Option, + pub bot: Option, + pub pts: Option, + pub qts: Option, + pub seq: Option, + pub date: Option, + pub dc_addr: Option, + pub dc_port: Option, + pub dc_auth: Option<[u8; 256]>, +} diff --git a/native/src/db/utils.rs b/native/src/db/utils.rs new file mode 100644 index 0000000..13ec09b --- /dev/null +++ b/native/src/db/utils.rs @@ -0,0 +1,27 @@ +use sqlite::{Connection, Error, State, Statement}; + +pub fn fetch_one Result>( + conn: &Connection, + query: &str, + adaptor: F, +) -> Result, Error> { + let mut stmt = conn.prepare(query)?; + if let State::Row = stmt.next()? { + adaptor(&stmt).map(Some) + } else { + Ok(None) + } +} + +pub fn fetch_many Result>( + conn: &Connection, + query: &str, + mut adaptor: F, +) -> Result, Error> { + let mut result = Vec::new(); + let mut stmt = conn.prepare(query)?; + while let State::Row = stmt.next()? { + result.push(adaptor(&stmt)?); + } + Ok(result) +} diff --git a/native/src/lib.rs b/native/src/lib.rs index 9e18a39..afd7442 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -1,12 +1,16 @@ #![cfg(target_os = "android")] #![allow(non_snake_case)] +use std::collections::HashMap; use std::ffi::{CStr, CString}; use std::future::Future; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::SocketAddr; +use std::sync::Mutex; use grammers_client::{Client, Config}; use grammers_client::types::{Dialog, LoginToken}; -use grammers_session::{PackedChat, Session}; +use grammers_session::{PackedChat, Session, UpdateState}; use jni::JNIEnv; use jni::objects::{JObject, JString}; use jni::sys::{jboolean, jint, jlong, jstring}; @@ -16,6 +20,8 @@ use once_cell::sync::OnceCell; use tokio::runtime; use tokio::runtime::Runtime; +mod db; + const LOG_MIN_LEVEL: Level = Level::Trace; const LOG_TAG: &str = ".native.talari"; const API_ID: i32 = 0; @@ -25,6 +31,7 @@ type Result = std::result::Result>; static RUNTIME: OnceCell = OnceCell::new(); static CLIENT: OnceCell = OnceCell::new(); +static DATABASE: Mutex> = Mutex::new(None); fn block_on(future: F) -> F::Output { if RUNTIME.get().is_none() { @@ -53,10 +60,46 @@ async fn init_client() -> Result<()> { return Ok(()); } + let guard = DATABASE.lock().unwrap(); + let conn = match guard.as_ref() { + Some(c) => c, + None => { + return Err("Database was not initialized".into()); + } + }; + info!("Connecting to Telegram..."); + let session = Session::new(); + + let sessions = db::get_sessions(conn)?; + if let Some(s) = sessions.get(0) { + match (s.user_id, s.dc_id, s.bot) { + (Some(id), Some(dc), Some(bot)) => session.set_user(id, dc, bot), + _ => {} + } + + match (s.pts, s.qts, s.seq, s.date) { + (Some(pts), Some(qts), Some(seq), Some(date)) => session.set_state(UpdateState { + pts, + qts, + seq, + date, + channels: HashMap::new(), + }), + _ => {} + } + + match (s.dc_id, s.dc_addr.as_ref(), s.dc_port, s.dc_auth) { + (Some(id), Some(addr), Some(port), Some(auth)) => { + session.insert_dc(id, SocketAddr::new(addr.parse().unwrap(), port), auth) + } + _ => {} + } + } + let client = Client::connect(Config { - session: Session::new(), + session, api_id: API_ID, api_hash: API_HASH.to_string(), params: Default::default(), @@ -86,6 +129,47 @@ async fn request_login_code(phone: &str) -> Result { async fn sign_in(token: LoginToken, code: &str) -> Result<()> { let client = CLIENT.get().ok_or("Client not initialized")?; client.sign_in(&token, &code).await?; + + let guard = DATABASE.lock().unwrap(); + let conn = match guard.as_ref() { + Some(c) => c, + None => { + return Err("Database was not initialized".into()); + } + }; + + let mut session = db::create_session(conn)?; + let s = client.session(); + if let Some(user) = s.get_user() { + session.user_id = Some(user.id); + session.dc_id = Some(user.dc); + session.bot = Some(user.bot); + } + + if let Some(state) = s.get_state() { + session.pts = Some(state.pts); + session.qts = Some(state.qts); + session.seq = Some(state.seq); + session.date = Some(state.date); + } + + if let Some(dc_id) = session.dc_id { + for dc in s.get_dcs() { + if dc.id == dc_id { + if let Some(ipv4) = dc.ipv4 { + session.dc_addr = Some(Ipv4Addr::from(ipv4.to_le_bytes()).to_string()) + } else if let Some(ipv6) = dc.ipv6 { + session.dc_addr = Some(Ipv6Addr::from(ipv6).to_string()) + } + session.dc_port = Some(dc.port as u16); + session.dc_auth = dc.auth.map(|b| b.try_into().unwrap()); + break; + } + } + } + + db::update_session(conn, &session)?; + Ok(()) } @@ -106,6 +190,24 @@ async fn send_message(chat: PackedChat, text: &str) -> Result<()> { Ok(()) } +#[no_mangle] +pub unsafe extern "C" fn Java_dev_lonami_talaria_bindings_Native_initDatabase( + env: JNIEnv, + _: JObject, + path: JString, +) { + let mut guard = DATABASE.lock().unwrap(); + if guard.is_some() { + info!("Database is already initialized"); + } + + let path = CString::from(CStr::from_ptr(env.get_string(path).unwrap().as_ptr())); + match db::init_connection(path.to_str().unwrap()) { + Ok(conn) => *guard = Some(conn), + Err(e) => error!("Failed to initialize database: {}", e), + } +} + #[no_mangle] pub unsafe extern "C" fn Java_dev_lonami_talaria_bindings_Native_initClient(_: JNIEnv, _: JObject) { match block_on(init_client()) {