Framework for embedding localizations into Rust types
use crate::fluent;
use crate::macro_impl::derive::ReferenceKind;

use icu_locale::locale;
use quote::quote;
use syn::{parse_quote_spanned, spanned::Spanned};
use thiserror::Error;

mod attribute;
pub mod derive;
pub mod error;

#[derive(Debug, Error)]
pub enum UnsupportedError {
    #[error("Unions are not supported")]
    Union { span: syn::Ident },
    #[error("Only 1 unnamed field is supported")]
    UnnamedFields {
        span: syn::Ident,
        field_count: usize,
    },
}

#[derive(Debug, Error)]
#[error("Failed to parse macro input")]
pub enum ParseError {
    #[error("invalid attribute")]
    InvalidAttribute(syn::Error),
    #[error("invalid item")]
    InvalidDeriveInput(syn::Error),
}

#[derive(Debug, Error)]
#[error(transparent)]
pub enum MacroError {
    Attribute(#[from] attribute::Error),
    Group(#[from] fluent::GroupError),
    Unsupported(#[from] UnsupportedError),
    ParseError(#[from] ParseError),
}

pub fn localize(
    attribute_stream: proc_macro2::TokenStream,
    derive_input_stream: proc_macro2::TokenStream,
) -> Result<proc_macro2::TokenStream, MacroError> {
    // Set up a global Miette report handler to ensure consistent non-Rustc diagnostics
    // If this returns an error, it just means the hook has already been set.
    let _result = miette::set_hook(Box::new(|_| {
        Box::new(
            miette::MietteHandlerOpts::new()
                // Force color output, even when printing using the debug formatter
                .color(true)
                .build(),
        )
    }));
    miette::set_panic_hook();

    // Parse the token streams
    let attribute: syn::LitStr =
        syn::parse2(attribute_stream).map_err(ParseError::InvalidAttribute)?;
    let derive_input: syn::DeriveInput =
        syn::parse2(derive_input_stream).map_err(ParseError::InvalidDeriveInput)?;

    let locales = attribute::locales(&attribute)?;

    // Keep track of all the Fluent files
    let tracked_paths = locales.clone();
    let tracked_paths = tracked_paths.values().map(|path| path.to_string());
    let path_count = locales.len();

    // TODO: user-controlled canonical locale
    let group = fluent::Group::new(locale!("en-US"), locales)?;
    let canonical_locale = group.canonical_locale().id.clone().to_string();

    let available_locales = match &derive_input.data {
        syn::Data::Struct(struct_data) => derive::locales_for_ident(
            &group,
            &struct_data.fields,
            ReferenceKind::StructField,
            &derive_input.ident,
        ),
        syn::Data::Enum(enum_data) => derive::locales_for_enum(&group, &enum_data.variants),
        syn::Data::Union(_) => {
            return Err(MacroError::Unsupported(UnsupportedError::Union {
                span: derive_input.ident.clone(),
            }));
        }
    };

    let message_body = match &derive_input.data {
        syn::Data::Struct(struct_data) => {
            derive::message_for_struct(group, &derive_input.ident, &struct_data.fields)
        }
        syn::Data::Enum(enum_data) => derive::messages_for_enum(group, &enum_data.variants),
        syn::Data::Union(_) => {
            return Err(MacroError::Unsupported(UnsupportedError::Union {
                span: derive_input.ident.clone(),
            }));
        }
    }?;

    let ident = &derive_input.ident;

    // Get the original generics for the derived item
    let (initial_impl_generics, initial_type_generics, initial_where_clause) =
        derive_input.generics.split_for_impl();

    // Get the types of each named field
    let field_types: Vec<&syn::Type> = match &derive_input.data {
        syn::Data::Struct(struct_data) => {
            types_for_fields(&struct_data.fields, derive_input.ident.clone())?
        }
        syn::Data::Enum(enum_data) => enum_data
            .variants
            .iter()
            .map(|variant| types_for_fields(&variant.fields, variant.ident.clone()))
            .collect::<Result<Vec<Vec<&syn::Type>>, UnsupportedError>>()?
            .into_iter()
            .flatten()
            .collect(),
        syn::Data::Union(_union_data) => {
            return Err(MacroError::Unsupported(UnsupportedError::Union {
                span: derive_input.ident.clone(),
            }));
        }
    };

    // Add a bound on `Localize` for each field's type
    let mut generics = derive_input.generics.clone();
    let additional_bounds = field_types.into_iter().map(|field| -> syn::WherePredicate {
        // Attribute this bound to the original source code
        let span = field.span();
        parse_quote_spanned!(span=> #field: ::l10n_embed::Localize)
    });
    generics
        .make_where_clause()
        .predicates
        .extend(additional_bounds);

    let (impl_generics, _type_generics, where_clause) = generics.split_for_impl();

    Ok(quote! {
        impl #initial_impl_generics #ident #initial_type_generics #initial_where_clause {
            // Call the `include_str!` macro to make sure the Fluent files are tracked
            // so when Fluent code changes, the generated code should be rebuild
            // TODO: This is a hack that should be replaced with https://github.com/rust-lang/rust/issues/99515 once stable
            const _TRACKED_PATHS: [&'static str; #path_count] = [#(include_str!(#tracked_paths)),*];
        }

        impl #impl_generics ::l10n_embed::Localize for #ident #initial_type_generics #where_clause {
            fn canonical_locale(&self) -> ::l10n_embed::macro_prelude::icu_locale::Locale {
                ::l10n_embed::macro_prelude::icu_locale::locale!(#canonical_locale)
            }

            fn available_locales(&self) -> Vec<::l10n_embed::macro_prelude::icu_locale::Locale> {
                #available_locales
            }

            fn localize_for(
                &self,
                locale: &::l10n_embed::macro_prelude::icu_locale::Locale,
            ) -> String {
                #message_body
            }
        }
    })
}

fn types_for_fields(
    fields: &syn::Fields,
    span: syn::Ident,
) -> Result<Vec<&syn::Type>, UnsupportedError> {
    match fields {
        syn::Fields::Named(named_fields) => {
            Ok(named_fields.named.iter().map(|field| &field.ty).collect())
        }
        syn::Fields::Unit => Ok(Vec::new()),
        syn::Fields::Unnamed(unnamed_fields) => {
            let unnamed_field_types: Vec<&syn::Type> = unnamed_fields
                .unnamed
                .iter()
                .map(|field| &field.ty)
                .collect();

            match unnamed_field_types.len() {
                1 => Ok(unnamed_field_types),
                _ => Err(UnsupportedError::UnnamedFields {
                    span,
                    field_count: unnamed_field_types.len(),
                }),
            }
        }
    }
}