use std::collections::HashMap;
use std::collections::HashSet;
use heck::ToPascalCase;
use heck::ToSnakeCase;
use proc_macro2::Ident;
use proc_macro2::TokenStream;
use quote::format_ident;
use quote::quote;
use std::ops::Bound;
use tracing::warn;
use crate::passes::dependency_analysis::DependencyGraph;
use crate::passes::type_resolution::Type;
use crate::types::cooked::Amount;
use crate::types::cooked::CodeSet;
use crate::types::cooked::Date;
use crate::types::cooked::DateTime;
use crate::types::cooked::Definition;
use crate::types::cooked::Indicator;
use crate::types::cooked::MessageComponent;
use crate::types::cooked::MessageDefinition;
use crate::types::cooked::MessageElement;
use crate::types::cooked::Quantity;
use crate::types::cooked::Text;
use crate::types::external_code_sets::ExternalCodeSet;
pub fn generate_erepository(
definitions: &HashSet<Definition>,
graph: &DependencyGraph,
type_overrides: &HashMap<&str, Type>,
types: &HashMap<&str, Type>,
) -> TokenStream {
let message_definitions = graph
.roots()
.filter_map(|id| match definitions.get(id).unwrap() {
Definition::MessageDefinition(definition) => Some(definition),
_definition => None,
})
.collect::<Vec<_>>();
let root = generate_root(message_definitions, types);
let items = graph
.top_down()
.filter(|id| !type_overrides.contains_key(id))
.map(|id| {
let definition = definitions.get(id).unwrap();
match definition {
Definition::Amount(amount) => generate_amount(amount, types),
Definition::ChoiceComponent(component) => {
generate_choice_component(component, types)
}
Definition::CodeSet(code_set) => generate_code_set(code_set, types),
Definition::Date(date) => generate_date(date, types),
Definition::DateTime(date_time) => generate_date_time(date_time, types),
Definition::Indicator(indicator) => generate_indicator(indicator, types),
Definition::MessageComponent(component) => {
generate_message_component(component, types)
}
Definition::MessageDefinition(definition) => {
generate_message_definition(definition, types)
}
Definition::Quantity(quantity) => generate_quantity(quantity, types),
Definition::Rate(rate) => generate_rate(rate, types),
Definition::Text(text) => generate_text(text, types),
Definition::YearMonth(year_month) => generate_year_month(year_month, types),
_ => {
warn!(
?definition,
"skipping codegen due to unsupported definition type"
);
TokenStream::new()
}
}
});
quote! {
use core::str::FromStr as _;
use beef::lean::Cow;
use quick_xml::events::BytesStart;
use rust_decimal::Decimal;
use time::Date;
use time::OffsetDateTime;
use crate::parser::XmlReader;
use crate::XmlError;
use crate::XmlResult;
mod external_code_sets;
mod manual;
pub use self::manual::*;
#root
#(#items)*
}
}
pub fn generate_external_code_sets(
external_code_sets: &HashMap<&str, ExternalCodeSet>,
) -> TokenStream {
let items = external_code_sets
.iter()
.map(|(name, set)| generate_external_code_set(name, set));
quote! {
use quick_xml::events::BytesStart;
use crate::parser::XmlReader;
use crate::XmlResult;
#(#items)*
}
}
fn generate_amount(amount: &Amount, types: &HashMap<&str, Type>) -> TokenStream {
let doc = &amount.definition;
let ty = types[&*amount.id];
let (currency_field, currency_parser, currency_assignment) = if let Some(id) =
&amount.currency_identifier_set
{
let ty = types[&**id];
(
quote!(pub currency: #ty,),
quote! {
let currency = {
let currency = start.try_get_attribute(b"Ccy")?.ok_or_else(|| -> XmlError { todo!() })?;
let currency = currency.decode_and_unescape_value(reader.inner())?;
#ty::from_str(¤cy)?
};
},
quote!(, currency),
)
} else {
Default::default()
};
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[doc = #doc]
pub struct #ty {
amount: Decimal,
#currency_field
}
impl<'a> #ty {
pub(crate) fn parse(reader: &mut XmlReader<'a>, start: &BytesStart<'_>) -> XmlResult<Self> {
let text = reader.expect_text()?;
let amount = Decimal::from_str(&text).unwrap();
#currency_parser
Ok(Self { amount #currency_assignment })
}
}
}
}
fn generate_choice_component(
component: &MessageComponent,
types: &HashMap<&str, Type>,
) -> TokenStream {
let doc = &component.definition;
let ty = types[&*component.id];
let variants = generate_variants(&component.elements, types);
let parsers = generate_choice_variant_parsers(&component.elements, types);
quote! {
#[derive(Clone, Debug)]
#[doc = #doc]
pub enum #ty {
#(#variants)*
}
impl<'a> #ty {
pub(crate) fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'a>) -> XmlResult<Self> {
let value = match reader.peek_start_name()?.into_inner() {
#(#parsers),*
_ => todo!(),
};
Ok(value)
}
}
}
}
fn generate_code_set(code_set: &CodeSet, types: &HashMap<&str, Type>) -> TokenStream {
let ty = types[&*code_set.id];
if code_set.is_external {
return if code_set.derivation.is_some() {
TokenStream::new()
} else {
quote! {
pub use self::external_code_sets::#ty;
}
};
}
if let Some(trace_id) = &code_set.trace {
let trace_ty = types[&**trace_id];
return quote! {
pub type #ty = #trace_ty;
};
}
if code_set.codes.is_empty() {
tracing::warn!(%code_set.name, "skipping codegen for empty code set");
return TokenStream::new();
}
let doc = &*code_set.definition;
let codes = code_set.codes.iter().map(|code| {
let doc = &code.definition;
let name = generate_variant_name(&code.name);
quote! {
#[doc = #doc]
#name,
}
});
let parsers = code_set
.codes
.iter()
.map(|code| (&code.code, generate_variant_name(&code.name)));
let parsers = parsers.map(|(tag, name)| quote!(#tag => Self::#name,));
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[doc = #doc]
pub enum #ty {
#(#codes)*
}
impl<'a> #ty {
pub(crate) fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let value = match &*reader.expect_text()? {
#(#parsers)*
tag => todo!("handle code {tag:?}"),
};
Ok(value)
}
}
}
}
fn generate_date(date: &Date, types: &HashMap<&str, Type>) -> TokenStream {
let doc = &date.definition;
let ty = types[&*date.id];
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[doc = #doc]
pub struct #ty(Date);
impl<'a> #ty {
pub(crate) fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let text = reader.expect_text()?;
let inner = Date::parse(
&text,
&time::format_description::well_known::Iso8601::PARSING,
)
.unwrap();
Ok(Self(inner))
}
}
}
}
fn generate_date_time(date_time: &DateTime, types: &HashMap<&str, Type>) -> TokenStream {
let doc = &date_time.definition;
let ty = types[&*date_time.id];
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[doc = #doc]
pub struct #ty(OffsetDateTime);
impl<'a> #ty {
pub(crate) fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let text = reader.expect_text()?;
let inner = OffsetDateTime::parse(
&text,
&time::format_description::well_known::Iso8601::PARSING,
)
.unwrap();
Ok(Self(inner))
}
}
}
}
fn generate_external_code_set(name: &str, code_set: &ExternalCodeSet) -> TokenStream {
if code_set.values.is_empty() {
tracing::warn!(name, "skipping codegen for empty external code set");
return TokenStream::new();
}
let doc = &code_set.description;
let ty = generate_ident(name);
let codes = code_set.values.iter().map(|code| {
let code = generate_variant_name(code);
quote!(#code,)
});
let parsers = code_set
.values
.iter()
.map(|code| (code, generate_variant_name(code)));
let parsers = parsers.map(|(tag, name)| quote!(#tag => Self::#name,));
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[doc = #doc]
pub enum #ty {
#(#codes)*
}
impl<'a> #ty {
pub(crate) fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let value = match &*reader.expect_text()? {
#(#parsers)*
tag => todo!("handle code {tag:?}"),
};
Ok(value)
}
}
}
}
fn generate_indicator(indicator: &Indicator, types: &HashMap<&str, Type>) -> TokenStream {
let doc = &indicator.definition;
let ty = types[&*indicator.id];
let [meaning_when_false, meaning_when_true] = [
format_ident!("{}", indicator.meaning_when_false.to_pascal_case()),
format_ident!("{}", indicator.meaning_when_true.to_pascal_case()),
];
quote! {
#[derive(Clone, Debug)]
#[doc = #doc]
pub enum #ty {
#meaning_when_false,
#meaning_when_true,
}
impl<'a> #ty {
fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let value = if reader.expect_boolean()? {
Self::#meaning_when_true
} else {
Self::#meaning_when_false
};
Ok(value)
}
}
}
}
fn generate_message_component(
component: &MessageComponent,
types: &HashMap<&str, Type>,
) -> TokenStream {
let doc = &component.definition;
let ty = types[&*component.id];
let field_names = component
.elements
.iter()
.map(|element| generate_field_name(&element.name));
let field_declarations = generate_field_declarations(&component.elements, types);
let field_parsers = generate_field_parsers(&component.elements, types);
quote! {
#[derive(Clone, Debug)]
#[doc = #doc]
pub struct #ty {
#(#field_declarations)*
}
impl<'a> #ty {
fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
#(#field_parsers)*
Ok(Self {
#(#field_names),*
})
}
}
}
}
fn generate_message_definition(
definition: &MessageDefinition,
types: &HashMap<&str, Type>,
) -> TokenStream {
let doc = &definition.definition;
let ty = types[&*definition.id];
let root_tag = &definition.xml_root_tag;
let tag = &definition.xml_tag;
let field_names = definition
.elements
.iter()
.map(|element| generate_field_name(&element.name));
let fields = generate_field_declarations(&definition.elements, types);
let field_parsers = generate_field_parsers(&definition.elements, types);
quote! {
#[derive(Clone, Debug)]
#[doc = #doc]
pub struct #ty {
#(#fields)*
}
impl<'a> #ty {
pub fn from_str(document: &'a str) -> XmlResult<Self> {
let mut reader = XmlReader::from_str(document);
reader.try_declaration()?;
reader.parse_element(#root_tag, |reader, _| reader.parse_element(#tag, Self::parse))
}
}
impl<'a> #ty {
fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
#(#field_parsers)*
Ok(Self {
#(#field_names),*
})
}
}
}
}
fn generate_quantity(quantity: &Quantity, types: &HashMap<&str, Type>) -> TokenStream {
let doc = &quantity.definition;
let ty = types[&*quantity.id];
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[doc = #doc]
pub struct #ty(Decimal);
impl<'a> #ty {
fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let text = reader.expect_text()?;
let inner = Decimal::from_str(&text).unwrap();
Ok(Self(inner))
}
}
}
}
fn generate_rate(rate: &crate::types::cooked::Rate, types: &HashMap<&str, Type>) -> TokenStream {
let doc = &rate.definition;
let ty = types[&*rate.id];
let base = {
let base = rate.base_value;
let mantissa = u128::try_from(base.mantissa()).unwrap();
let lo = mantissa as u32;
let mid = (mantissa >> 32) as u32;
let high = (mantissa >> 64) as u32;
let scale = base.scale();
quote!(const BASE: Decimal = Decimal::from_parts(#lo, #mid, #high, false, #scale);)
};
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[doc = #doc]
pub struct #ty(Decimal);
impl<'a> #ty {
fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
#base
let text = reader.expect_text()?;
let inner = Decimal::from_str(&text).unwrap() / BASE;
Ok(Self(inner))
}
}
}
}
fn generate_root(
message_definitions: Vec<&MessageDefinition>,
types: &HashMap<&str, Type>,
) -> TokenStream {
let variants = message_definitions.iter().map(|definition| {
let doc = &definition.definition;
let name = generate_variant_name(&definition.name);
let ty = types[&*definition.id];
quote! {
#[doc = #doc]
#name(#ty),
}
});
let parsers = {
message_definitions.iter().map(move |definition| {
let root_tag = &definition.xml_root_tag;
let tag = &definition.xml_tag;
let ty = types[&*definition.id];
let name = generate_variant_name(&definition.name);
let namespace = definition.identifier.xml_namespace();
let parser = quote!(reader.parse_element(#tag, <#ty>::parse));
quote! {
(#root_tag, #namespace) => {
let inner = #parser?;
Self::#name(inner)
}
}
})
};
quote! {
#[derive(Clone, Debug)]
pub enum Iso20022Message<'a> {
#(#variants)*
}
impl<'a> TryFrom<&'a str> for Iso20022Message<'a> {
type Error = XmlError;
fn try_from(document: &'a str) -> XmlResult<Self> {
let mut reader = XmlReader::from_str(document);
reader.try_declaration()?;
let root_start = reader.expect_start_tag()?;
let root_tag = core::str::from_utf8(root_start.name().into_inner())?;
let namespace = root_start.try_get_attribute("xmlns")?.ok_or_else(|| -> XmlError { todo!() })?.decode_and_unescape_value(reader.inner())?;
let value = match (root_tag, &*namespace) {
#(#parsers),*
_ => todo!(),
};
reader.expect_element_end(root_start)?;
Ok(value)
}
}
}
}
fn generate_text(text: &Text, types: &HashMap<&str, Type>) -> TokenStream {
let doc = &text.definition;
let ty = types[&*text.id];
let length_check = text.length.as_ref().map(|range| {
let (min, max) = (range.start(), range.end());
quote! {
if !(#min..=#max).contains(&text.len()) {
todo!()
}
}
});
quote! {
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[doc = #doc]
pub struct #ty(Cow<'a, str>);
impl<'a> #ty {
fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let text = reader.expect_text()?;
#length_check
Ok(Self(text))
}
}
}
}
fn generate_year_month(
year_month: &crate::types::cooked::YearMonth,
types: &HashMap<&str, Type>,
) -> TokenStream {
let doc = &year_month.definition;
let ty = types[&*year_month.id];
quote! {
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[doc = #doc]
pub struct #ty {
pub month: time::Month,
pub year: i32,
}
impl<'a> #ty {
fn parse(reader: &mut XmlReader<'a>, _start: &BytesStart<'_>) -> XmlResult<Self> {
let text = reader.expect_text()?;
let mut parser = time::parsing::Parsed::new();
let rest = parser.parse_items(text.as_bytes(), time::macros::format_description!("[year]-[month]")).map_err(|_| -> XmlError { todo!() })?;
if !rest.is_empty() {
todo!()
}
let month = parser.month().unwrap();
let year = parser.year().unwrap();
Ok(Self { month, year })
}
}
impl Ord for #ty {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.year.cmp(&other.year).then_with(|| u8::from(self.month).cmp(&u8::from(other.month)))
}
}
impl PartialOrd for #ty {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
}
}
fn generate_variants<'a>(
elements: &'a [MessageElement<'a>],
types: &'a HashMap<&'a str, Type<'a>>,
) -> impl Iterator<Item = TokenStream> + 'a {
elements.iter().map(|element| {
let doc = &element.definition;
let name = generate_variant_name(&element.name);
let ty = generate_field_ty(&element.occurs, types[&*element.typ]);
quote! {
#[doc = #doc]
#name(#ty),
}
})
}
fn generate_choice_variant_parsers<'a>(
elements: &'a [MessageElement<'a>],
types: &'a HashMap<&str, Type<'a>>,
) -> impl Iterator<Item = TokenStream> + 'a {
elements.iter().map(move |element| {
let tag = &element.xml_tag;
let ty = types[&*element.typ];
let name = generate_variant_name(&element.name);
let parser = generate_field_parser(tag, &element.occurs, ty);
quote! {
name if name == #tag.as_bytes() => {
let inner = #parser?;
Self::#name(inner)
}
}
})
}
fn generate_variant_name(name: &str) -> Ident {
generate_ident(&name.to_pascal_case())
}
fn generate_field_declarations<'a>(
elements: &'a [MessageElement<'a>],
types: &'a HashMap<&'a str, Type<'a>>,
) -> impl Iterator<Item = TokenStream> + 'a {
elements.iter().map(|element| {
let doc = &element.definition;
let name = generate_field_name(&element.name);
let ty = generate_field_ty(&element.occurs, types[&*element.typ]);
quote! {
#[doc = #doc]
pub #name: #ty,
}
})
}
fn generate_field_parsers<'a>(
elements: &'a [MessageElement<'a>],
types: &'a HashMap<&str, Type<'a>>,
) -> impl Iterator<Item = TokenStream> + 'a {
elements.iter().map(move |element| {
let name = generate_field_name(&element.name);
let parser = generate_field_parser(&element.xml_tag, &element.occurs, types[&*element.typ]);
quote! {
let #name = #parser?;
}
})
}
fn generate_field_name(name: &str) -> Ident {
generate_ident(&name.to_snake_case())
}
fn generate_ident(ident: &str) -> Ident {
syn::parse_str::<Ident>(ident)
.or_else(|_| syn::parse_str(&format!("r#{ident}")))
.or_else(|_| syn::parse_str(&format!("_{ident}")))
.unwrap()
}
fn generate_field_ty(occurs: &(Bound<usize>, Bound<usize>), base: Type) -> TokenStream {
match occurs {
(core::ops::Bound::Included(0), core::ops::Bound::Included(1)) => quote!(Option<#base>),
(core::ops::Bound::Included(1), core::ops::Bound::Included(1)) => quote!(#base),
(core::ops::Bound::Included(_), _) => quote!(Vec<#base>),
_ => unreachable!(),
}
}
fn generate_field_parser(
tag: &str,
occurs: &(Bound<usize>, Bound<usize>),
base: Type,
) -> TokenStream {
match occurs {
(core::ops::Bound::Included(0), core::ops::Bound::Included(1)) => {
quote!(reader.parse_optional_element(#tag, <#base>::parse))
}
(core::ops::Bound::Included(1), core::ops::Bound::Included(1)) => {
quote!(reader.parse_element(#tag, <#base>::parse))
}
(core::ops::Bound::Included(_), _) => {
quote!(reader.parse_list(#tag, <#base>::parse))
}
_ => unreachable!(),
}
}