diff --git a/rasql-core/Cargo.toml b/rasql-core/Cargo.toml index fe4e437..2b636fd 100644 --- a/rasql-core/Cargo.toml +++ b/rasql-core/Cargo.toml @@ -5,4 +5,14 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +tokio-postgres = ["rasql-traits/tokio-postgres", "dep:tokio-postgres"] + [dependencies] +sqlparser = "0.54.0" +rasql-traits = { version = "0.1.0", path = "../rasql-traits" } +quote = "1.0.35" +proc-macro2 = "1.0.93" +syn = { version = "2.0.96", features = ["full"] } +tokio-postgres = { version = "0.7.12", optional = true } +convert_case = "0.7.1" diff --git a/rasql-core/src/lib.rs b/rasql-core/src/lib.rs index e69de29..b09ce40 100644 --- a/rasql-core/src/lib.rs +++ b/rasql-core/src/lib.rs @@ -0,0 +1,2 @@ +pub mod sql; +pub mod rust; \ No newline at end of file diff --git a/rasql-core/src/rust/client_gen.rs b/rasql-core/src/rust/client_gen.rs new file mode 100644 index 0000000..b9dac1f --- /dev/null +++ b/rasql-core/src/rust/client_gen.rs @@ -0,0 +1,145 @@ + +pub trait AsyncClientCodeGenerator { + /// Create a token stream for usage of `client` to prepare `statement_str` for + /// later execution, evaluating to a value of type + /// `Result` + /// + /// - `client` is an expr of type `&Client` + /// - `statement_str` is an expr of type `&str` + fn generate_prepare_statement( + client: &syn::Expr, + statement_str: &syn::Expr, + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_query_many_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result, Client::QueryError>`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_query_one_or_none_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `row` to read a column, evaluating to a value of type + /// `Result` where `T` is any type compatible with the database client. + /// + /// - `row` is an expr of type `&Client::Row` + /// - `column_name` is an expr of type `&str` + fn generate_row_read_column( + row: &syn::Expr, + column_name: &syn::Expr, + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_insert_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_update_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_delete_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; +} + +#[cfg(feature = "tokio-postgres")] +impl AsyncClientCodeGenerator for super::type_gen::TokioPostgresGenerator { + fn generate_prepare_statement( + client: &syn::Expr, + statement_str: &syn::Expr, + ) -> proc_macro2::TokenStream { + quote::quote!(#client.prepare(#statement_str).await) + } + + fn generate_query_many_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + quote::quote!(#client.query(#prepared_statement, &[#(#parameters,)*]).await) + } + + fn generate_query_one_or_none_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + quote::quote!(#client.query_opt(#prepared_statement, &[#(#parameters,)*]).await) + } + + fn generate_row_read_column( + row: &syn::Expr, + column_name: &syn::Expr, + ) -> proc_macro2::TokenStream { + quote::quote!(#row.try_get(#column_name)) + } + + fn generate_insert_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + Self::generate_execute(client, prepared_statement, parameters) + } + + fn generate_update_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + Self::generate_execute(client, prepared_statement, parameters) + } + + fn generate_delete_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + Self::generate_execute(client, prepared_statement, parameters) + } +} diff --git a/rasql-core/src/rust/mod.rs b/rasql-core/src/rust/mod.rs new file mode 100644 index 0000000..204c446 --- /dev/null +++ b/rasql-core/src/rust/mod.rs @@ -0,0 +1,257 @@ +pub mod type_gen; +pub mod client_gen; + +use std::collections::HashMap; + +use client_gen::AsyncClientCodeGenerator; +use convert_case::Casing; +use type_gen::TypeGenerator; + +pub struct TableStruct { + pub name: syn::Ident, + pub fields: Vec, + pub db_alias: Option, +} + +pub struct TableStructField { + pub name: syn::Ident, + pub r#type: syn::Type, + pub db_alias: Option, +} + +pub struct GeneratedTableStruct(pub proc_macro2::TokenStream); + +pub struct TableStructImpls { + pub base_table_impl: proc_macro2::TokenStream, + pub table_with_pk_impl: Option, +} + +fn sql_ident_to_type_name(ident: &sqlparser::ast::Ident) -> syn::Ident { + let mut ident = ident.value.to_case(convert_case::Case::Pascal); + if ident.chars().next().unwrap().is_ascii_digit() { + ident.insert(0, '_'); + } + syn::Ident::new(&ident, proc_macro2::Span::call_site()) +} + +fn sql_ident_to_field_name(ident: &sqlparser::ast::Ident) -> syn::Ident { + let mut ident = ident.value.to_case(convert_case::Case::Snake); + if ident.chars().next().unwrap().is_ascii_digit() { + ident.insert(0, '_'); + } + syn::Ident::new(&ident, proc_macro2::Span::call_site()) +} + +#[inline] +fn sql_ident_to_module_name(ident: &sqlparser::ast::Ident) -> syn::Ident { + sql_ident_to_field_name(ident) +} + +fn sql_datatype_to_rust_type(datatype: &sqlparser::ast::DataType) -> syn::Type { + match datatype { + sqlparser::ast::DataType::Character(..) + | sqlparser::ast::DataType::Char(..) + | sqlparser::ast::DataType::CharacterVarying(..) + | sqlparser::ast::DataType::CharVarying(..) + | sqlparser::ast::DataType::Varchar(..) + | sqlparser::ast::DataType::Nvarchar(..) + | sqlparser::ast::DataType::Text + | sqlparser::ast::DataType::TinyText + | sqlparser::ast::DataType::MediumText + | sqlparser::ast::DataType::LongText + | sqlparser::ast::DataType::String(_) + | sqlparser::ast::DataType::FixedString(_) => syn::Type::Verbatim(quote::quote! {String}), + sqlparser::ast::DataType::Uuid => todo!(), + sqlparser::ast::DataType::CharacterLargeObject(_) => todo!(), + sqlparser::ast::DataType::CharLargeObject(_) => todo!(), + sqlparser::ast::DataType::Clob(_) => todo!(), + sqlparser::ast::DataType::Binary(_) => todo!(), + sqlparser::ast::DataType::Varbinary(_) => todo!(), + sqlparser::ast::DataType::Blob(_) => todo!(), + sqlparser::ast::DataType::TinyBlob => todo!(), + sqlparser::ast::DataType::MediumBlob => todo!(), + sqlparser::ast::DataType::LongBlob => todo!(), + sqlparser::ast::DataType::Bytes(_) => todo!(), + sqlparser::ast::DataType::Numeric(exact_number_info) => todo!(), + sqlparser::ast::DataType::Decimal(exact_number_info) => todo!(), + sqlparser::ast::DataType::BigNumeric(exact_number_info) => todo!(), + sqlparser::ast::DataType::BigDecimal(exact_number_info) => todo!(), + sqlparser::ast::DataType::Dec(exact_number_info) => todo!(), + sqlparser::ast::DataType::Float(_) => todo!(), + sqlparser::ast::DataType::TinyInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedTinyInt(_) => todo!(), + sqlparser::ast::DataType::Int2(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt2(_) => todo!(), + sqlparser::ast::DataType::SmallInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedSmallInt(_) => todo!(), + sqlparser::ast::DataType::MediumInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedMediumInt(_) => todo!(), + sqlparser::ast::DataType::Int(_) => todo!(), + sqlparser::ast::DataType::Int16 => todo!(), + sqlparser::ast::DataType::Int128 => todo!(), + sqlparser::ast::DataType::Int256 => todo!(), + sqlparser::ast::DataType::Int32 + | sqlparser::ast::DataType::Int4(_) + | sqlparser::ast::DataType::Integer(_) => syn::Type::Verbatim(quote::quote! {i32}), + sqlparser::ast::DataType::UnsignedInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt4(_) => todo!(), + sqlparser::ast::DataType::UnsignedInteger(_) => todo!(), + sqlparser::ast::DataType::UInt8 => todo!(), + sqlparser::ast::DataType::UInt16 => todo!(), + sqlparser::ast::DataType::UInt32 => todo!(), + sqlparser::ast::DataType::UInt64 => todo!(), + sqlparser::ast::DataType::UInt128 => todo!(), + sqlparser::ast::DataType::UInt256 => todo!(), + sqlparser::ast::DataType::Int8(_) + | sqlparser::ast::DataType::Int64 + | sqlparser::ast::DataType::BigInt(_) => syn::Type::Verbatim(quote::quote! {i64}), + sqlparser::ast::DataType::UnsignedBigInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt8(_) => todo!(), + sqlparser::ast::DataType::Float4 + | sqlparser::ast::DataType::Real + | sqlparser::ast::DataType::Float32 => syn::Type::Verbatim(quote::quote! {f32}), + sqlparser::ast::DataType::Float64 + | sqlparser::ast::DataType::Float8 + | sqlparser::ast::DataType::Double(..) + | sqlparser::ast::DataType::DoublePrecision => syn::Type::Verbatim(quote::quote! {f64}), + sqlparser::ast::DataType::Bool => todo!(), + sqlparser::ast::DataType::Boolean => todo!(), + sqlparser::ast::DataType::Date => todo!(), + sqlparser::ast::DataType::Date32 => todo!(), + sqlparser::ast::DataType::Time(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Datetime(_) => todo!(), + sqlparser::ast::DataType::Datetime64(_, _) => todo!(), + sqlparser::ast::DataType::Timestamp(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Interval => todo!(), + sqlparser::ast::DataType::JSON => todo!(), + sqlparser::ast::DataType::JSONB => todo!(), + sqlparser::ast::DataType::Regclass => todo!(), + sqlparser::ast::DataType::Bytea => todo!(), + sqlparser::ast::DataType::Bit(_) => todo!(), + sqlparser::ast::DataType::BitVarying(_) => todo!(), + sqlparser::ast::DataType::Custom(object_name, vec) => todo!(), + sqlparser::ast::DataType::Array(array_elem_type_def) => match array_elem_type_def { + sqlparser::ast::ArrayElemTypeDef::None => unimplemented!(), + sqlparser::ast::ArrayElemTypeDef::AngleBracket(data_type) + | sqlparser::ast::ArrayElemTypeDef::SquareBracket(data_type, _) + | sqlparser::ast::ArrayElemTypeDef::Parenthesis(data_type) => todo!(), + }, + sqlparser::ast::DataType::Map(data_type, data_type1) => todo!(), + sqlparser::ast::DataType::Tuple(vec) => todo!(), + sqlparser::ast::DataType::Nested(vec) => todo!(), + sqlparser::ast::DataType::Enum(vec, _) => todo!(), + sqlparser::ast::DataType::Set(vec) => todo!(), + sqlparser::ast::DataType::Struct(vec, struct_bracket_kind) => todo!(), + sqlparser::ast::DataType::Union(vec) => todo!(), + sqlparser::ast::DataType::Nullable(data_type) => todo!(), + sqlparser::ast::DataType::LowCardinality(data_type) => todo!(), + sqlparser::ast::DataType::Unspecified => todo!(), + sqlparser::ast::DataType::Trigger => todo!(), + sqlparser::ast::DataType::AnyType => todo!(), + } +} + +fn generate_table_struct_and_impls< + Traits: rasql_traits::DbTraits, + TypeGen: TypeGenerator, + Client: rasql_traits::r#async::Client, + ClientGen: AsyncClientCodeGenerator, +>( + table: &crate::sql::Table, + module_config: Option<&ModuleCodeGenConfig>, +) -> (GeneratedTableStruct, TableStructImpls) { + let name = sql_ident_to_type_name(table.name.0.last().unwrap()); + let default_struct_config = StructCodeGenConfig { + field_configs: HashMap::new(), + deny_extra_fields: false, + }; + let struct_config = module_config + .and_then(|config| config.struct_configs.get(&name)) + .unwrap_or(&default_struct_config); + + let fields = table + .columns + .iter() + .map(|column| { + let default_field_config = StructFieldCodeGenConfig { + rename: None, + override_type: None, + attrs: vec![], + id_promote_mode: IdPromoteMode::None, + }; + + let field_config = struct_config + .field_configs + .get(&column.name) + .unwrap_or(&default_field_config); + + let (name, db_alias) = match &field_config.rename { + Some(rename) => (rename.clone(), Some(column.name.value.clone())), + None => { + let name = sql_ident_to_field_name(&column.name); + if name.to_string() == column.name.value { + (name, None) + } else { + (name, Some(column.name.value.clone())) + } + } + }; + + let r#type = match (&field_config.override_type, field_config.id_promote_mode) { + (Some(r#type), _) => r#type.clone(), + (None, IdPromoteMode::None) => column.data_type, + (None, IdPromoteMode::TrustedId) => todo!(), + (None, IdPromoteMode::Id) => todo!(), + }; + + TableStructField { + name, + r#type, + db_alias, + } + }) + .collect(); + + let table_struct = TableStruct { + name, + fields, + db_alias: todo!(), + }; + ( + GeneratedTableStruct(TypeGen::generate_table_struct(&table_struct)), + TableStructImpls { + base_table_impl: todo!(), + table_with_pk_impl: todo!(), + }, + ) +} + +pub struct CodeGenConfig { + pub module_configs: HashMap, +} + +pub struct ModuleCodeGenConfig { + pub use_statements: Vec, + pub struct_configs: HashMap, +} + +pub struct StructCodeGenConfig { + pub field_configs: HashMap, + pub deny_extra_fields: bool, +} + +#[derive(Default)] +pub struct StructFieldCodeGenConfig { + rename: Option, + override_type: Option, + attrs: Vec, + id_promote_mode: IdPromoteMode, +} + +#[derive(Clone, Copy, Default)] +pub enum IdPromoteMode { + #[default] + None, + TrustedId, + Id, +} diff --git a/rasql-core/src/rust/type_gen.rs b/rasql-core/src/rust/type_gen.rs new file mode 100644 index 0000000..5c24509 --- /dev/null +++ b/rasql-core/src/rust/type_gen.rs @@ -0,0 +1,63 @@ +use crate::rust::TableStructField; + +use super::TableStruct; + +pub trait TypeGenerator { + fn sql_datatype_to_rust_type(datatype: &sqlparser::ast::DataType) -> syn::Type; + + fn generate_table_struct(table_struct: &TableStruct) -> proc_macro2::TokenStream; +} + + + +#[cfg(feature = "tokio-postgres")] +pub struct TokioPostgresGenerator; + +#[cfg(feature = "tokio-postgres")] +impl TokioPostgresGenerator { + pub(super) fn generate_execute( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + quote::quote!(#client.execute(#prepared_statement, &[#(#parameters,)*]).await) + } +} + +#[cfg(feature = "tokio-postgres")] +impl TypeGenerator for TokioPostgresGenerator { + fn sql_datatype_to_rust_type(datatype: &sqlparser::ast::DataType) -> syn::Type { + todo!() + } + + fn generate_table_struct(table_struct: &TableStruct) -> proc_macro2::TokenStream { + let TableStruct { + name, + fields, + db_alias, + } = table_struct; + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + + let fields = fields.iter().map( + |TableStructField { + name, + r#type, + db_alias, + }| { + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + quote::quote!(#db_alias #name : #r#type) + }, + ); + quote::quote!( + #[derive(ToSql, FromSql)] + #db_alias + struct #name { + #(#fields,)* + } + ) + } +} diff --git a/rasql-core/src/sql.rs b/rasql-core/src/sql.rs new file mode 100644 index 0000000..cb925eb --- /dev/null +++ b/rasql-core/src/sql.rs @@ -0,0 +1,169 @@ +use std::collections::HashMap; + +use sqlparser::ast::{ColumnDef, DataType, Ident, ObjectName, SchemaName, TableConstraint}; + +pub fn parse_sql_schema( + sql_statements: impl IntoIterator< + Item = impl TryInto, + >, +) { + let mut schemas = HashMap::new(); + for statement in sql_statements { + let statement: sqlparser::ast::Statement = statement.try_into().unwrap(); + match statement { + sqlparser::ast::Statement::CreateSchema { schema_name, .. } => { + schemas + .entry(schema_name.clone()) + .or_insert_with(|| Schema { + name: schema_name, + tables: Default::default(), + types: Default::default(), + }); + } + sqlparser::ast::Statement::CreateTable(sqlparser::ast::CreateTable { + name, + columns, + constraints, + .. + }) => { + let schema = schema_for_object(&mut schemas, &name); + schema.types.insert( + name.clone(), + Type::Composite { + name: name.clone(), + fields: columns + .iter() + .map(|column| Field { + name: column.name.clone(), + r#type: column.data_type.clone(), + }) + .collect(), + }, + ); + schema.tables.insert( + name.clone(), + Table { + name, + columns, + constraints, + }, + ); + } + sqlparser::ast::Statement::AlterTable { + name, operations, .. + } => { + let schema = schema_for_object(&mut schemas, &name); + let Some(table) = schema.tables.get_mut(&name) else { + continue; + }; + for op in operations { + match op { + sqlparser::ast::AlterTableOperation::AddConstraint(table_constraint) => { + table.constraints.push(table_constraint); + } + _ => (), + } + } + } + sqlparser::ast::Statement::CreateType { + name, + representation: + sqlparser::ast::UserDefinedTypeRepresentation::Composite { attributes }, + } => { + let schema = schema_for_object(&mut schemas, &name); + schema.types.insert( + name.clone(), + Type::Composite { + name, + fields: attributes + .into_iter() + .map(|attr| Field { + name: attr.name, + r#type: attr.data_type, + }) + .collect(), + }, + ); + } + sqlparser::ast::Statement::CreateType { + name, + representation: sqlparser::ast::UserDefinedTypeRepresentation::Enum { labels }, + } => { + let schema = schema_for_object(&mut schemas, &name); + schema.types.insert( + name.clone(), + Type::Enum { + name, + variants: labels, + }, + ); + } + _ => (), + } + } +} + +fn schema_for_object<'a>( + schemas: &'a mut HashMap, + object_name: &ObjectName, +) -> &'a mut Schema { + let schema = match object_name.0.as_slice() { + [_table_name] => schemas + .entry(SchemaName::Simple(ObjectName(vec![Ident::new("public")]))) + .or_insert_with(|| Schema { + name: SchemaName::Simple(ObjectName(vec![Ident::new("public")])), + tables: Default::default(), + types: Default::default(), + }), + [schema_name, _table_name] => schemas + .entry(SchemaName::Simple(ObjectName(vec![schema_name.clone()]))) + .or_insert_with(|| Schema { + name: SchemaName::Simple(ObjectName(vec![schema_name.clone()])), + tables: Default::default(), + types: Default::default(), + }), + [catalog_name, schema_name, _table_name] => schemas + .entry(SchemaName::Simple(ObjectName(vec![ + catalog_name.clone(), + schema_name.clone(), + ]))) + .or_insert_with(|| Schema { + name: SchemaName::Simple(ObjectName(vec![ + catalog_name.clone(), + schema_name.clone(), + ])), + tables: Default::default(), + types: Default::default(), + }), + _ => unreachable!(), + }; + schema +} + +pub struct Schema { + pub name: SchemaName, + pub tables: HashMap, + pub types: HashMap, +} + +pub struct Table { + pub name: ObjectName, + pub columns: Vec, + pub constraints: Vec, +} + +pub enum Type { + Composite { + name: ObjectName, + fields: Vec, + }, + Enum { + name: ObjectName, + variants: Vec, + }, +} + +pub struct Field { + pub name: Ident, + pub r#type: DataType, +}