1use std::{fmt::Debug, future::Future};
2
3use adw::{prelude::*, subclass::prelude::*};
4use gettextrs::gettext;
5use gtk::{glib, glib::clone};
6use matrix_sdk::{Error, encryption::CrossSigningResetAuthType};
7use ruma::{
8 api::{
9 MatrixVersion, OutgoingRequest, SupportedVersions,
10 auth_scheme::SendAccessToken,
11 client::uiaa::{
12 AuthData, AuthType, Dummy, FallbackAcknowledgement, Password, UiaaInfo, UserIdentifier,
13 get_uiaa_fallback_page,
14 },
15 },
16 assign,
17};
18use thiserror::Error;
19use tracing::{error, warn};
20
21mod in_browser_page;
22mod password_page;
23
24use self::{in_browser_page::AuthDialogInBrowserPage, password_page::AuthDialogPasswordPage};
25use crate::{
26 components::ToastableDialog, prelude::*, session::Session, spawn_tokio, toast,
27 utils::OneshotNotifier,
28};
29
30mod imp {
31 use std::{
32 borrow::Cow,
33 cell::{Cell, OnceCell, RefCell},
34 rc::Rc,
35 sync::Arc,
36 };
37
38 use glib::subclass::InitializingObject;
39 use tokio::task::JoinHandle;
40
41 use super::*;
42
43 #[derive(Debug, Default, gtk::CompositeTemplate, glib::Properties)]
44 #[template(resource = "/org/gnome/Fractal/ui/components/dialogs/auth/mod.ui")]
45 #[properties(wrapper_type = super::AuthDialog)]
46 pub struct AuthDialog {
47 #[template_child]
48 stack: TemplateChild<gtk::Stack>,
49 #[property(get, set, construct_only)]
51 session: glib::WeakRef<Session>,
52 is_presented: Cell<bool>,
54 state: RefCell<Option<AuthState>>,
58 current_page: RefCell<Option<gtk::Widget>>,
60 notifier: OnceCell<OneshotNotifier<Option<()>>>,
62 abort_handle: RefCell<Option<tokio::task::AbortHandle>>,
64 }
65
66 #[glib::object_subclass]
67 impl ObjectSubclass for AuthDialog {
68 const NAME: &'static str = "AuthDialog";
69 type Type = super::AuthDialog;
70 type ParentType = ToastableDialog;
71
72 fn class_init(klass: &mut Self::Class) {
73 Self::bind_template(klass);
74
75 klass.install_action("auth-dialog.continue", None, |obj, _, _| {
76 obj.imp().notifier().notify_value(Some(()));
77 });
78
79 klass.install_action("auth-dialog.close", None, |obj, _, _| {
80 obj.imp().close();
81 });
82 }
83
84 fn instance_init(obj: &InitializingObject<Self>) {
85 obj.init_template();
86 }
87 }
88
89 #[glib::derived_properties]
90 impl ObjectImpl for AuthDialog {
91 fn dispose(&self) {
92 if let Some(abort_handle) = self.abort_handle.take() {
93 abort_handle.abort();
94 }
95 }
96 }
97
98 impl WidgetImpl for AuthDialog {}
99 impl AdwDialogImpl for AuthDialog {}
100 impl ToastableDialogImpl for AuthDialog {}
101
102 impl AuthDialog {
103 fn notifier(&self) -> &OneshotNotifier<Option<()>> {
105 self.notifier
106 .get_or_init(|| OneshotNotifier::new("AuthDialog"))
107 }
108
109 pub(super) async fn authenticate<Response, Fut, FN>(
115 &self,
116 parent: >k::Widget,
117 callback: FN,
118 ) -> Result<Response, AuthError>
119 where
120 Response: Send + 'static,
121 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
122 FN: Fn(matrix_sdk::Client, Option<AuthData>) -> Fut + Send + Sync + 'static + Clone,
123 {
124 let Some(client) = self.session.upgrade().map(|s| s.client()) else {
125 return Err(AuthError::Unknown);
126 };
127
128 let callback_clone = callback.clone();
130 let client_clone = client.clone();
131 let handle = spawn_tokio!(async move { callback_clone(client_clone, None).await });
132 let result = self.await_tokio_task(handle).await;
133
134 let Some(uiaa_info) = result.uiaa_info() else {
136 return result;
137 };
138
139 let result = self
140 .perform_uiaa(uiaa_info.clone(), parent, move |auth_data| {
141 let client = client.clone();
142 let callback = callback.clone();
143 async move { callback(client, Some(auth_data)).await }
144 })
145 .await;
146
147 self.close();
148
149 result
150 }
151
152 pub(super) async fn reset_cross_signing(
163 &self,
164 parent: >k::Widget,
165 ) -> Result<(), AuthError> {
166 let Some(encryption) = self.session.upgrade().map(|s| s.client().encryption()) else {
167 return Err(AuthError::Unknown);
168 };
169
170 let handle = spawn_tokio!(async move { encryption.reset_cross_signing().await });
171 let result = self.await_tokio_task(handle).await?;
172
173 let Some(cross_signing_reset_handle) = result else {
174 return Ok(());
176 };
177
178 let result = match cross_signing_reset_handle.auth_type().clone() {
179 CrossSigningResetAuthType::Uiaa(uiaa_info) => {
180 let cross_signing_reset_handle = Arc::new(cross_signing_reset_handle);
181
182 self.perform_uiaa(uiaa_info, parent, move |auth_data| {
183 let cross_signing_reset_handle = cross_signing_reset_handle.clone();
184 async move { cross_signing_reset_handle.auth(Some(auth_data)).await }
185 })
186 .await
187 }
188 CrossSigningResetAuthType::OAuth(info) => {
189 let page = AuthDialogInBrowserPage::new(info.approval_url.to_string());
191 let default_widget = page.default_widget().clone();
192
193 self.show_page(page.upcast(), &default_widget, parent);
194
195 let handle =
197 spawn_tokio!(async move { cross_signing_reset_handle.auth(None).await });
198 self.await_tokio_task(handle).await
199 }
200 };
201
202 self.close();
203
204 result
205 }
206
207 async fn await_tokio_task<Response>(
209 &self,
210 handle: JoinHandle<Result<Response, Error>>,
211 ) -> Result<Response, AuthError>
212 where
213 Response: Send + 'static,
214 {
215 self.abort_handle.replace(Some(handle.abort_handle()));
216
217 let Ok(result) = handle.await else {
218 return Err(AuthError::UserCancelled);
220 };
221
222 self.abort_handle.take();
223
224 Ok(result?)
225 }
226
227 async fn perform_uiaa<Response, Fut, FN>(
230 &self,
231 mut uiaa_info: UiaaInfo,
232 parent: >k::Widget,
233 callback: FN,
234 ) -> Result<Response, AuthError>
235 where
236 Response: Send + 'static,
237 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
238 FN: Fn(AuthData) -> Fut + Send + Sync + 'static + Clone,
239 {
240 loop {
241 let callback = callback.clone();
242
243 let auth_data = self.perform_next_stage(&uiaa_info, parent).await?;
244
245 let handle = spawn_tokio!(async move { callback(auth_data).await });
247 let result = self.await_tokio_task(handle).await;
248
249 let Some(next_uiaa_info) = result.uiaa_info() else {
251 return result;
252 };
253
254 uiaa_info = next_uiaa_info.clone();
255 }
256 }
257
258 async fn perform_next_stage(
265 &self,
266 uiaa_info: &UiaaInfo,
267 parent: >k::Widget,
268 ) -> Result<AuthData, AuthError> {
269 let Some(next_state) = AuthState::next(uiaa_info) else {
270 error!("Cannot perform next stage when flow is complete");
272 return Err(AuthError::Unknown);
273 };
274
275 if matches!(next_state.stage, AuthType::Dummy) {
276 self.state.replace(Some(next_state));
278 return self.current_stage_auth_data();
279 }
280
281 let receiver = self.notifier().listen();
282
283 let is_same_state = self
285 .state
286 .borrow()
287 .as_ref()
288 .is_some_and(|state| *state == next_state);
289
290 if is_same_state {
291 self.retry_current_stage(&next_state.stage, uiaa_info);
292 } else {
293 let (next_page, default_widget) = self.page(&next_state).await?;
294 self.show_page(next_page, &default_widget, parent);
295 self.state.replace(Some(next_state));
296 }
297
298 if receiver.await.is_none() {
299 return Err(AuthError::UserCancelled);
301 }
302
303 self.current_stage_auth_data()
304 }
305
306 fn retry_current_stage(&self, stage: &AuthType, uiaa_info: &UiaaInfo) {
308 if let Some(error) = &uiaa_info.auth_error {
310 warn!("Could not perform authentication stage: {}", error.message);
311
312 if matches!(stage, AuthType::Password) {
313 toast!(self.stack, gettext("The password is invalid."));
314 } else {
315 toast!(self.stack, gettext("An unexpected error occurred."));
316 }
317 }
318
319 if let Some(page) = self.current_page.borrow().as_ref() {
321 if let Some(password_page) = page.downcast_ref::<AuthDialogPasswordPage>() {
322 password_page.retry();
323 } else if let Some(in_browser_page) = page.downcast_ref::<AuthDialogInBrowserPage>()
324 {
325 in_browser_page.retry();
326 }
327 }
328 }
329
330 fn show_page(&self, page: gtk::Widget, default_widget: >k::Widget, parent: >k::Widget) {
332 self.stack.add_child(&page);
333 self.stack.set_visible_child(&page);
334 self.obj().set_default_widget(Some(default_widget));
335
336 let prev_page = self.current_page.replace(Some(page));
337
338 if let Some(page) = prev_page {
340 let cell = Rc::new(RefCell::new(None));
341
342 let handler = self.stack.connect_transition_running_notify(clone!(
343 #[strong]
344 cell,
345 #[strong]
346 page,
347 move |stack| {
348 if !stack.is_transition_running()
349 && stack.visible_child().is_some_and(|child| child != page)
350 {
351 stack.remove(&page);
352
353 if let Some(handler) = cell.take() {
354 stack.disconnect(handler);
355 }
356 }
357 }
358 ));
359
360 cell.replace(Some(handler));
361 }
362
363 if !self.is_presented.get() {
365 self.obj().present(Some(parent));
366 self.is_presented.set(true);
367 }
368 }
369
370 async fn page(&self, state: &AuthState) -> Result<(gtk::Widget, gtk::Widget), AuthError> {
374 if state.stage == AuthType::Password {
375 let page = AuthDialogPasswordPage::new();
376 let default_widget = page.default_widget().clone();
377 Ok((page.upcast(), default_widget))
378 } else {
379 let fallback_url = self.fallback_url(state).await?;
380 let page = AuthDialogInBrowserPage::new(fallback_url);
381 let default_widget = page.default_widget().clone();
382 Ok((page.upcast(), default_widget))
383 }
384 }
385
386 async fn fallback_url(&self, state: &AuthState) -> Result<String, AuthError> {
388 let Some(session) = self.session.upgrade() else {
389 return Err(AuthError::Unknown);
390 };
391
392 let uiaa_session = state.session.clone().ok_or(AuthError::MissingSessionId)?;
393
394 let request =
395 get_uiaa_fallback_page::v3::Request::new(state.stage.clone(), uiaa_session);
396
397 let client = session.client();
398 let homeserver = client.homeserver();
399
400 let handle =
401 spawn_tokio!(async move { client.supported_versions().await.map_err(Into::into) });
402 let result = self.await_tokio_task(handle).await;
403
404 let supported_versions = match result {
405 Ok(supported_versions) => supported_versions,
406 Err(AuthError::ServerResponse(server_error)) => {
407 warn!("Could not get Matrix versions supported by homeserver: {server_error}");
408 SupportedVersions {
410 versions: [MatrixVersion::V1_1].into(),
411 features: Default::default(),
412 }
413 }
414 Err(error) => {
415 return Err(error);
416 }
417 };
418
419 let http_request = match request.try_into_http_request::<Vec<u8>>(
420 homeserver.as_ref(),
421 SendAccessToken::None,
422 Cow::Owned(supported_versions),
423 ) {
424 Ok(http_request) => http_request,
425 Err(error) => {
426 error!("Could not construct fallback UIAA URL: {error}");
427 return Err(AuthError::Unknown);
428 }
429 };
430
431 Ok(http_request.uri().to_string())
432 }
433
434 fn current_stage_auth_data(&self) -> Result<AuthData, AuthError> {
436 let Some(state) = self.state.borrow().clone() else {
437 error!("Could not get current authentication state");
438 return Err(AuthError::Unknown);
439 };
440
441 let auth_data = match state.stage {
442 AuthType::Password => {
443 let password = self
444 .current_page
445 .borrow()
446 .as_ref()
447 .and_then(|page| page.downcast_ref::<AuthDialogPasswordPage>())
448 .ok_or_else(|| {
449 error!(
450 "Could not get password because current page is not password page"
451 );
452 AuthError::Unknown
453 })?
454 .password();
455
456 let user_id = self
457 .session
458 .upgrade()
459 .ok_or(AuthError::Unknown)?
460 .user_id()
461 .to_string();
462
463 AuthData::Password(assign!(
464 Password::new(UserIdentifier::UserIdOrLocalpart(user_id), password),
465 { session: state.session }
466 ))
467 }
468 AuthType::Dummy => AuthData::Dummy(assign!(Dummy::new(), {
469 session: state.session
470 })),
471 _ => {
472 let uiaa_session = state.session.ok_or(AuthError::MissingSessionId)?;
473
474 AuthData::FallbackAcknowledgement(FallbackAcknowledgement::new(uiaa_session))
475 }
476 };
477
478 Ok(auth_data)
479 }
480
481 fn close(&self) {
483 if self.is_presented.get() {
484 self.obj().close();
485 }
486
487 if let Some(abort_handle) = self.abort_handle.take() {
488 abort_handle.abort();
489 }
490
491 self.notifier().notify();
492 }
493 }
494}
495
496glib::wrapper! {
497 pub struct AuthDialog(ObjectSubclass<imp::AuthDialog>)
501 @extends gtk::Widget, adw::Dialog, ToastableDialog,
502 @implements gtk::Accessible, gtk::Buildable, gtk::ConstraintTarget, gtk::ShortcutManager;
503}
504
505impl AuthDialog {
506 pub fn new(session: &Session) -> Self {
507 glib::Object::builder().property("session", session).build()
508 }
509
510 pub(crate) async fn authenticate<Response, Fut, FN>(
516 &self,
517 parent: &impl IsA<gtk::Widget>,
518 callback: FN,
519 ) -> Result<Response, AuthError>
520 where
521 Response: Send + 'static,
522 Fut: Future<Output = Result<Response, Error>> + Send + 'static,
523 FN: Fn(matrix_sdk::Client, Option<AuthData>) -> Fut + Send + Sync + 'static + Clone,
524 {
525 self.imp().authenticate(parent.upcast_ref(), callback).await
526 }
527
528 pub(crate) async fn reset_cross_signing(
534 &self,
535 parent: &impl IsA<gtk::Widget>,
536 ) -> Result<(), AuthError> {
537 self.imp().reset_cross_signing(parent.upcast_ref()).await
538 }
539}
540
541#[derive(Debug, Clone, PartialEq, Eq)]
543struct AuthState {
544 completed: Vec<AuthType>,
546
547 stage: AuthType,
549
550 session: Option<String>,
552}
553
554impl AuthState {
555 fn next(uiaa_info: &UiaaInfo) -> Option<Self> {
559 let stages = uiaa_info
563 .flows
564 .iter()
565 .filter_map(|flow| flow.stages.strip_prefix(uiaa_info.completed.as_slice()))
566 .filter_map(|stages_left| stages_left.first());
567
568 let mut next_stage = None;
570 for stage in stages {
571 if matches!(stage, AuthType::Password | AuthType::Sso | AuthType::Dummy) {
572 next_stage = Some(stage);
574 break;
575 } else if next_stage.is_none() {
576 next_stage = Some(stage);
578 }
579 }
580
581 let stage = next_stage?.clone();
582
583 Some(Self {
584 completed: uiaa_info.completed.clone(),
585 stage,
586 session: uiaa_info.session.clone(),
587 })
588 }
589}
590
591#[derive(Debug, Error)]
593pub enum AuthError {
594 #[error(transparent)]
596 ServerResponse(#[from] Error),
597
598 #[error("The ID of the session is missing")]
600 MissingSessionId,
601
602 #[error("The user cancelled the authentication")]
604 UserCancelled,
605
606 #[error("An unexpected error occurred")]
608 Unknown,
609}
610
611trait ExtractUiaa {
613 fn uiaa_info(&self) -> Option<&UiaaInfo>;
615}
616
617impl ExtractUiaa for AuthError {
618 fn uiaa_info(&self) -> Option<&UiaaInfo> {
619 if let Self::ServerResponse(server_error) = self {
620 server_error.as_uiaa_response()
621 } else {
622 None
623 }
624 }
625}
626
627impl ExtractUiaa for Error {
628 fn uiaa_info(&self) -> Option<&UiaaInfo> {
629 self.as_uiaa_response()
630 }
631}
632
633impl<T, Err> ExtractUiaa for Result<T, Err>
634where
635 Err: ExtractUiaa,
636{
637 fn uiaa_info(&self) -> Option<&UiaaInfo> {
638 match self {
639 Ok(_) => None,
640 Err(error) => error.uiaa_info(),
641 }
642 }
643}