diff --git a/rasql-core/src/rust/mod.rs b/rasql-core/src/rust/mod.rs index cffe61d..19cbee7 100644 --- a/rasql-core/src/rust/mod.rs +++ b/rasql-core/src/rust/mod.rs @@ -1,4 +1,5 @@ pub mod type_gen; +pub mod simple_type_gen; pub mod client_gen; use std::collections::HashMap; diff --git a/rasql-core/src/rust/simple_type_gen.rs b/rasql-core/src/rust/simple_type_gen.rs new file mode 100644 index 0000000..36ece71 --- /dev/null +++ b/rasql-core/src/rust/simple_type_gen.rs @@ -0,0 +1,208 @@ +use std::collections::HashMap; + +use crate::rust::{ + sql_ident_to_field_name, sql_ident_to_type_name, type_gen::UnsupportedDataType, + GeneratedTableStruct, IdPromoteMode, ModuleCodeGenConfig, StructCodeGenConfig, + StructFieldCodeGenConfig, TableStruct, TableStructField, +}; + +fn generate_table_struct_and_impls( + table: &crate::sql::Table, + module_config: Option<&ModuleCodeGenConfig>, +) -> GeneratedTableStruct { + let name = sql_ident_to_type_name(table.name.0.last().unwrap()); + let default_struct_config = StructCodeGenConfig { + field_configs: HashMap::new(), + deny_extra_fields: false, + }; + let struct_config = module_config + .and_then(|config| config.struct_configs.get(&name)) + .unwrap_or(&default_struct_config); + + let fields = table + .columns + .iter() + .map(|column| { + let default_field_config = StructFieldCodeGenConfig { + rename: None, + override_type: None, + attrs: vec![], + id_promote_mode: IdPromoteMode::None, + }; + + let field_config = struct_config + .field_configs + .get(&column.name) + .unwrap_or(&default_field_config); + + let (name, db_alias) = match &field_config.rename { + Some(rename) => (rename.clone(), Some(column.name.value.clone())), + None => { + let name = sql_ident_to_field_name(&column.name); + if name.to_string() == column.name.value { + (name, None) + } else { + (name, Some(column.name.value.clone())) + } + } + }; + + let r#type = match (&field_config.override_type, field_config.id_promote_mode) { + (Some(r#type), _) => r#type.clone(), + (None, IdPromoteMode::None) => { + sql_datatype_to_rust_type(&column.data_type).unwrap() + } + (None, IdPromoteMode::TrustedId) => todo!(), + (None, IdPromoteMode::Id) => todo!(), + }; + + TableStructField { + name, + r#type, + db_alias, + } + }) + .collect(); + + let table_struct = TableStruct { + name, + fields, + db_alias: todo!(), + }; + + GeneratedTableStruct(generate_table_struct(&table_struct)) +} + +fn sql_datatype_to_rust_type( + datatype: &sqlparser::ast::DataType, +) -> Result { + Ok(match datatype { + sqlparser::ast::DataType::Character(..) + | sqlparser::ast::DataType::Char(..) + | sqlparser::ast::DataType::CharacterVarying(..) + | sqlparser::ast::DataType::CharVarying(..) + | sqlparser::ast::DataType::Varchar(..) + | sqlparser::ast::DataType::Nvarchar(..) + | sqlparser::ast::DataType::Text + | sqlparser::ast::DataType::TinyText + | sqlparser::ast::DataType::MediumText + | sqlparser::ast::DataType::LongText + | sqlparser::ast::DataType::String(_) + | sqlparser::ast::DataType::FixedString(_) => syn::Type::Verbatim(quote::quote! {String}), + sqlparser::ast::DataType::Uuid => syn::Type::Verbatim(quote::quote! {uuid::Uuid}), + sqlparser::ast::DataType::Varbinary(_) + | sqlparser::ast::DataType::Blob(_) + | sqlparser::ast::DataType::TinyBlob + | sqlparser::ast::DataType::MediumBlob + | sqlparser::ast::DataType::LongBlob + | sqlparser::ast::DataType::Bytes(_) + | sqlparser::ast::DataType::Bytea + | sqlparser::ast::DataType::Binary(_) => syn::Type::Verbatim(quote::quote! {Vec}), + sqlparser::ast::DataType::Numeric(..) + | sqlparser::ast::DataType::Decimal(..) + | sqlparser::ast::DataType::Dec(..) => { + syn::Type::Verbatim(quote::quote! {rust_decimal::Decimal}) + } + sqlparser::ast::DataType::Int2(_) | sqlparser::ast::DataType::SmallInt(_) => { + syn::Type::Verbatim(quote::quote! {i16}) + } + sqlparser::ast::DataType::UnsignedInt2(_) => syn::Type::Verbatim(quote::quote! {u16}), + sqlparser::ast::DataType::Int16 => todo!(), + sqlparser::ast::DataType::Int128 => todo!(), + sqlparser::ast::DataType::Int256 => todo!(), + sqlparser::ast::DataType::Int(_) + | sqlparser::ast::DataType::Int32 + | sqlparser::ast::DataType::Int4(_) + | sqlparser::ast::DataType::Integer(_) => syn::Type::Verbatim(quote::quote! {i32}), + sqlparser::ast::DataType::UnsignedInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt4(_) => todo!(), + sqlparser::ast::DataType::UnsignedInteger(_) => todo!(), + sqlparser::ast::DataType::UInt8 => todo!(), + sqlparser::ast::DataType::UInt16 => todo!(), + sqlparser::ast::DataType::UInt32 => todo!(), + sqlparser::ast::DataType::UInt64 => todo!(), + sqlparser::ast::DataType::UInt128 => todo!(), + sqlparser::ast::DataType::UInt256 => todo!(), + sqlparser::ast::DataType::Int8(_) + | sqlparser::ast::DataType::Int64 + | sqlparser::ast::DataType::BigInt(_) => syn::Type::Verbatim(quote::quote! {i64}), + sqlparser::ast::DataType::UnsignedBigInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt8(_) => todo!(), + sqlparser::ast::DataType::Float(_) + | sqlparser::ast::DataType::Float4 + | sqlparser::ast::DataType::Real + | sqlparser::ast::DataType::Float32 => syn::Type::Verbatim(quote::quote! {f32}), + sqlparser::ast::DataType::Float64 + | sqlparser::ast::DataType::Float8 + | sqlparser::ast::DataType::Double(..) + | sqlparser::ast::DataType::DoublePrecision => syn::Type::Verbatim(quote::quote! {f64}), + sqlparser::ast::DataType::Bool => todo!(), + sqlparser::ast::DataType::Boolean => todo!(), + sqlparser::ast::DataType::Date => todo!(), + sqlparser::ast::DataType::Date32 => todo!(), + sqlparser::ast::DataType::Time(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Datetime(_) => todo!(), + sqlparser::ast::DataType::Datetime64(_, _) => todo!(), + sqlparser::ast::DataType::Timestamp(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Interval => todo!(), + sqlparser::ast::DataType::JSON => todo!(), + sqlparser::ast::DataType::JSONB => todo!(), + sqlparser::ast::DataType::Regclass => todo!(), + sqlparser::ast::DataType::Bit(_) => todo!(), + sqlparser::ast::DataType::BitVarying(_) => todo!(), + sqlparser::ast::DataType::Custom(object_name, vec) => todo!(), + sqlparser::ast::DataType::Array(array_elem_type_def) => match array_elem_type_def { + sqlparser::ast::ArrayElemTypeDef::None => { + return Err(UnsupportedDataType(datatype.clone())) + } + sqlparser::ast::ArrayElemTypeDef::AngleBracket(data_type) + | sqlparser::ast::ArrayElemTypeDef::SquareBracket(data_type, _) + | sqlparser::ast::ArrayElemTypeDef::Parenthesis(data_type) => { + let inner_type = sql_datatype_to_rust_type(&datatype)?; + syn::Type::Verbatim(quote::quote! {Vec<#inner_type>}) + } + }, + sqlparser::ast::DataType::Map(data_type, data_type1) => todo!(), + sqlparser::ast::DataType::Tuple(vec) => todo!(), + sqlparser::ast::DataType::Nested(vec) => todo!(), + sqlparser::ast::DataType::Enum(vec, _) => todo!(), + sqlparser::ast::DataType::Set(vec) => todo!(), + sqlparser::ast::DataType::Struct(vec, struct_bracket_kind) => todo!(), + sqlparser::ast::DataType::Union(vec) => todo!(), + sqlparser::ast::DataType::Nullable(data_type) => todo!(), + sqlparser::ast::DataType::LowCardinality(data_type) => todo!(), + sqlparser::ast::DataType::Trigger => todo!(), + _ => return Err(UnsupportedDataType(datatype.clone())), + }) +} + +fn generate_table_struct(table_struct: &TableStruct) -> proc_macro2::TokenStream { + let TableStruct { + name, + fields, + db_alias, + } = table_struct; + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + + let fields = fields.iter().map( + |TableStructField { + name, + r#type, + db_alias, + }| { + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + quote::quote!(#db_alias #name : #r#type) + }, + ); + quote::quote!( + #[derive(ToSql, FromSql)] + #db_alias + struct #name { + #(#fields,)* + } + ) +}