From 43421f6e5441635728cefd6041707dc3b9c220f0 Mon Sep 17 00:00:00 2001 From: Lonami Exo Date: Thu, 27 Oct 2022 20:12:41 +0200 Subject: [PATCH] Use UDL interface to avoid passing around pointers --- .../lonami/talaria/ui/screens/LoginScreen.kt | 9 +- native/src/lib.rs | 110 ++++++++++-------- native/src/talaria.udl | 12 +- 3 files changed, 76 insertions(+), 55 deletions(-) diff --git a/app/src/main/java/dev/lonami/talaria/ui/screens/LoginScreen.kt b/app/src/main/java/dev/lonami/talaria/ui/screens/LoginScreen.kt index edcf7b8..4c87381 100644 --- a/app/src/main/java/dev/lonami/talaria/ui/screens/LoginScreen.kt +++ b/app/src/main/java/dev/lonami/talaria/ui/screens/LoginScreen.kt @@ -19,8 +19,7 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp import dev.lonami.talaria.R import dev.lonami.talaria.ui.theme.TalariaTheme -import uniffi.talaria.requestLoginCode -import uniffi.talaria.signIn +import uniffi.talaria.LoginProcedure enum class LoginStage { ASK_PHONE, @@ -106,7 +105,7 @@ fun LoginScreen(onConfirmOtp: () -> Unit, modifier: Modifier = Modifier) { var phone by remember { mutableStateOf("") } var otp by remember { mutableStateOf("") } - var tokenPtr by remember { mutableStateOf(0UL) } + val loginProcedure by remember { mutableStateOf(LoginProcedure()) } Column( modifier = modifier @@ -128,7 +127,7 @@ fun LoginScreen(onConfirmOtp: () -> Unit, modifier: Modifier = Modifier) { phone, onPhoneChanged = { phone = it }, onSendCode = { - tokenPtr = requestLoginCode(phone) + loginProcedure.requestLoginCode(phone) stage = LoginStage.ASK_CODE } ) @@ -136,7 +135,7 @@ fun LoginScreen(onConfirmOtp: () -> Unit, modifier: Modifier = Modifier) { otp, onOtpChanged = { otp = it }, onConfirmOtp = { - signIn(tokenPtr, otp) + loginProcedure.signIn(otp) onConfirmOtp() } ) diff --git a/native/src/lib.rs b/native/src/lib.rs index ec71b51..d228272 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -64,6 +64,10 @@ impl fmt::Display for NativeError { impl std::error::Error for NativeError {} +struct LoginProcedure { + token: Mutex>, +} + #[derive(Debug, Clone, Copy)] pub enum MessageAck { Received, @@ -199,61 +203,75 @@ pub fn need_login() -> Result { block_on(client.is_authorized()).map_err(|_| NativeError::Network) } -pub fn request_login_code(phone: String) -> Result { - let client = CLIENT.get().ok_or(NativeError::Initialization)?; - block_on(client.request_login_code(&phone, API_ID, API_HASH)) - .map(|token| Box::into_raw(Box::new(token)) as u64) - .map_err(|_| NativeError::Network) -} - -pub fn sign_in(token_ptr: u64, code: String) -> Result<()> { - let token = unsafe { *Box::from_raw(token_ptr as *mut LoginToken) }; - let client = CLIENT.get().ok_or(NativeError::Initialization)?; - - block_on(client.sign_in(&token, &code)).map_err(|_| NativeError::Network)?; - - let guard = DATABASE.lock().unwrap(); - let conn = match guard.as_ref() { - Some(c) => c, - None => { - error!("Database was not initialized"); - return Err(NativeError::Initialization); +impl LoginProcedure { + fn new() -> Self { + Self { + token: Mutex::new(None), } - }; - - let mut session = db::create_session(conn).map_err(|_| NativeError::Database)?; - let s = client.session(); - if let Some(user) = s.get_user() { - session.user_id = Some(user.id); - session.dc_id = Some(user.dc); - session.bot = Some(user.bot); } - if let Some(state) = s.get_state() { - session.pts = Some(state.pts); - session.qts = Some(state.qts); - session.seq = Some(state.seq); - session.date = Some(state.date); + fn request_login_code(&self, phone: String) -> Result<()> { + let client = CLIENT.get().ok_or(NativeError::Initialization)?; + let token = block_on(client.request_login_code(&phone, API_ID, API_HASH)) + .map_err(|_| NativeError::Network)?; + *self.token.lock().unwrap() = Some(token); + Ok(()) } - if let Some(dc_id) = session.dc_id { - for dc in s.get_dcs() { - if dc.id == dc_id { - if let Some(ipv4) = dc.ipv4 { - session.dc_addr = Some(Ipv4Addr::from(ipv4.to_le_bytes()).to_string()) - } else if let Some(ipv6) = dc.ipv6 { - session.dc_addr = Some(Ipv6Addr::from(ipv6).to_string()) + fn sign_in(&self, code: String) -> Result<()> { + let token = self + .token + .lock() + .unwrap() + .take() + .ok_or(NativeError::Initialization)?; + let client = CLIENT.get().ok_or(NativeError::Initialization)?; + + block_on(client.sign_in(&token, &code)).map_err(|_| NativeError::Network)?; + + let guard = DATABASE.lock().unwrap(); + let conn = match guard.as_ref() { + Some(c) => c, + None => { + error!("Database was not initialized"); + return Err(NativeError::Initialization); + } + }; + + let mut session = db::create_session(conn).map_err(|_| NativeError::Database)?; + let s = client.session(); + if let Some(user) = s.get_user() { + session.user_id = Some(user.id); + session.dc_id = Some(user.dc); + session.bot = Some(user.bot); + } + + if let Some(state) = s.get_state() { + session.pts = Some(state.pts); + session.qts = Some(state.qts); + session.seq = Some(state.seq); + session.date = Some(state.date); + } + + if let Some(dc_id) = session.dc_id { + for dc in s.get_dcs() { + if dc.id == dc_id { + if let Some(ipv4) = dc.ipv4 { + session.dc_addr = Some(Ipv4Addr::from(ipv4.to_le_bytes()).to_string()) + } else if let Some(ipv6) = dc.ipv6 { + session.dc_addr = Some(Ipv6Addr::from(ipv6).to_string()) + } + session.dc_port = Some(dc.port as u16); + session.dc_auth = dc.auth.map(|b| b.try_into().unwrap()); + break; } - session.dc_port = Some(dc.port as u16); - session.dc_auth = dc.auth.map(|b| b.try_into().unwrap()); - break; } } + + db::update_session(conn, &session).map_err(|_| NativeError::Database)?; + + Ok(()) } - - db::update_session(conn, &session).map_err(|_| NativeError::Database)?; - - Ok(()) } pub fn get_dialogs() -> Result> { diff --git a/native/src/talaria.udl b/native/src/talaria.udl index d267cef..9532124 100644 --- a/native/src/talaria.udl +++ b/native/src/talaria.udl @@ -25,6 +25,14 @@ dictionary Dialog { boolean pinned; }; +interface LoginProcedure { + constructor(); + [Throws=NativeError] + void request_login_code(string phone); + [Throws=NativeError] + void sign_in(string code); +}; + namespace talaria { [Throws=NativeError] void init_database(string path); @@ -33,10 +41,6 @@ namespace talaria { [Throws=NativeError] boolean need_login(); [Throws=NativeError] - u64 request_login_code(string phone); - [Throws=NativeError] - void sign_in(u64 tokenPtr, string code); - [Throws=NativeError] sequence get_dialogs(); [Throws=NativeError] void send_message(string packed, string text);