diff --git a/Cargo.lock b/Cargo.lock index a0fc90e..31871d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,9 +114,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -303,7 +303,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -477,11 +477,11 @@ dependencies = [ [[package]] name = "rasql-build" -version = "0.1.0" +version = "0.0.0" [[package]] name = "rasql-core" -version = "0.1.0" +version = "0.0.0" dependencies = [ "convert_case", "proc-macro2", @@ -489,16 +489,24 @@ dependencies = [ "rasql-traits", "sqlparser", "syn", + "thiserror", "tokio-postgres", ] [[package]] -name = "rasql-query" +name = "rasql-model" version = "0.1.0" +dependencies = [ + "sqlparser", +] + +[[package]] +name = "rasql-query" +version = "0.0.0" [[package]] name = "rasql-traits" -version = "0.1.0" +version = "0.0.0" dependencies = [ "tokio-postgres", ] @@ -589,7 +597,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -612,7 +620,7 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -643,6 +651,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinyvec" version = "1.8.1" @@ -670,7 +698,7 @@ dependencies = [ "mio", "pin-project-lite", "socket2", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -856,6 +884,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 128eedc..be3c469 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ resolver = "2" members = [ "rasql-build", "rasql-core", + "rasql-model", "rasql-query", "rasql-traits", ] diff --git a/README.md b/README.md index eea5b66..f088bb1 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,11 @@ would like to hold type definitions, you can use this from your crate's build sc automatically typed row output to prevent runtime type errors. This depends on you already using `rasql-build`. +## `rasql-traits` + +`rasql-traits` provides trait definitions for various database operations and types. The types that +`rasql-core` generates implement these traits depending on the generation config. + ## Acknowledgements Rasql builds upon the work of the [sqlparser-rs](https://github.com/sqlparser-rs/sqlparser-rs) diff --git a/rasql-build/Cargo.toml b/rasql-build/Cargo.toml index 9227858..74aa530 100644 --- a/rasql-build/Cargo.toml +++ b/rasql-build/Cargo.toml @@ -1,7 +1,9 @@ [package] name = "rasql-build" -version = "0.1.0" +version = "0.0.0" edition = "2021" +license = "MIT" +description = "Build script utilities for Rasql" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/rasql-core/Cargo.toml b/rasql-core/Cargo.toml index fe4e437..d87c7f2 100644 --- a/rasql-core/Cargo.toml +++ b/rasql-core/Cargo.toml @@ -1,8 +1,21 @@ [package] name = "rasql-core" -version = "0.1.0" +version = "0.0.0" edition = "2021" +license = "MIT" +description = "SQL analysis and Rust type generation for Rasql" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +tokio-postgres = ["rasql-traits/tokio-postgres", "dep:tokio-postgres"] + [dependencies] +sqlparser = "0.54.0" +rasql-traits = { version = "0.0.0", path = "../rasql-traits" } +quote = "1.0.35" +proc-macro2 = "1.0.93" +syn = { version = "2.0.96", features = ["full"] } +tokio-postgres = { version = "0.7.12", optional = true } +convert_case = "0.7.1" +thiserror = "2.0.11" diff --git a/rasql-core/src/lib.rs b/rasql-core/src/lib.rs index e69de29..b09ce40 100644 --- a/rasql-core/src/lib.rs +++ b/rasql-core/src/lib.rs @@ -0,0 +1,2 @@ +pub mod sql; +pub mod rust; \ No newline at end of file diff --git a/rasql-core/src/rust/client_gen.rs b/rasql-core/src/rust/client_gen.rs new file mode 100644 index 0000000..b9dac1f --- /dev/null +++ b/rasql-core/src/rust/client_gen.rs @@ -0,0 +1,145 @@ + +pub trait AsyncClientCodeGenerator { + /// Create a token stream for usage of `client` to prepare `statement_str` for + /// later execution, evaluating to a value of type + /// `Result` + /// + /// - `client` is an expr of type `&Client` + /// - `statement_str` is an expr of type `&str` + fn generate_prepare_statement( + client: &syn::Expr, + statement_str: &syn::Expr, + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_query_many_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result, Client::QueryError>`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_query_one_or_none_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `row` to read a column, evaluating to a value of type + /// `Result` where `T` is any type compatible with the database client. + /// + /// - `row` is an expr of type `&Client::Row` + /// - `column_name` is an expr of type `&str` + fn generate_row_read_column( + row: &syn::Expr, + column_name: &syn::Expr, + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_insert_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_update_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; + + /// Create a token stream for usage of `client` to execute the `prepared_statement` + /// with the provided `parameters`, evaluating to a value of type + /// `Result`. + /// + /// - `client` is an expr of type `&Client` + /// - `prepared_statement` is an expr of type `&Client::PreparedStatement` + /// - exprs in `parameters` are references of types that can be assumed to be compatible with the client + fn generate_delete_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream; +} + +#[cfg(feature = "tokio-postgres")] +impl AsyncClientCodeGenerator for super::type_gen::TokioPostgresGenerator { + fn generate_prepare_statement( + client: &syn::Expr, + statement_str: &syn::Expr, + ) -> proc_macro2::TokenStream { + quote::quote!(#client.prepare(#statement_str).await) + } + + fn generate_query_many_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + quote::quote!(#client.query(#prepared_statement, &[#(#parameters,)*]).await) + } + + fn generate_query_one_or_none_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + quote::quote!(#client.query_opt(#prepared_statement, &[#(#parameters,)*]).await) + } + + fn generate_row_read_column( + row: &syn::Expr, + column_name: &syn::Expr, + ) -> proc_macro2::TokenStream { + quote::quote!(#row.try_get(#column_name)) + } + + fn generate_insert_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + Self::generate_execute(client, prepared_statement, parameters) + } + + fn generate_update_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + Self::generate_execute(client, prepared_statement, parameters) + } + + fn generate_delete_with_statement( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + Self::generate_execute(client, prepared_statement, parameters) + } +} diff --git a/rasql-core/src/rust/mod.rs b/rasql-core/src/rust/mod.rs new file mode 100644 index 0000000..19cbee7 --- /dev/null +++ b/rasql-core/src/rust/mod.rs @@ -0,0 +1,156 @@ +pub mod type_gen; +pub mod simple_type_gen; +pub mod client_gen; + +use std::collections::HashMap; + +use client_gen::AsyncClientCodeGenerator; +use convert_case::Casing; +use type_gen::TypeGenerator; + +pub struct TableStruct { + pub name: syn::Ident, + pub fields: Vec, + pub db_alias: Option, +} + +pub struct TableStructField { + pub name: syn::Ident, + pub r#type: syn::Type, + pub db_alias: Option, +} + +pub struct GeneratedTableStruct(pub proc_macro2::TokenStream); + +pub struct TableStructImpls { + pub base_table_impl: proc_macro2::TokenStream, + pub table_with_pk_impl: Option, +} + +fn sql_ident_to_type_name(ident: &sqlparser::ast::Ident) -> syn::Ident { + let mut ident = ident.value.to_case(convert_case::Case::Pascal); + if ident.chars().next().unwrap().is_ascii_digit() { + ident.insert(0, '_'); + } + syn::Ident::new(&ident, proc_macro2::Span::call_site()) +} + +fn sql_ident_to_field_name(ident: &sqlparser::ast::Ident) -> syn::Ident { + let mut ident = ident.value.to_case(convert_case::Case::Snake); + if ident.chars().next().unwrap().is_ascii_digit() { + ident.insert(0, '_'); + } + syn::Ident::new(&ident, proc_macro2::Span::call_site()) +} + +#[inline] +fn sql_ident_to_module_name(ident: &sqlparser::ast::Ident) -> syn::Ident { + sql_ident_to_field_name(ident) +} + +fn generate_table_struct_and_impls< + Traits: rasql_traits::DbTraits, + TypeGen: TypeGenerator, + Client: rasql_traits::r#async::Client, + ClientGen: AsyncClientCodeGenerator, +>( + table: &crate::sql::Table, + module_config: Option<&ModuleCodeGenConfig>, + type_gen: &TypeGen, + client_gen: &ClientGen, +) -> (GeneratedTableStruct, TableStructImpls) { + let name = sql_ident_to_type_name(table.name.0.last().unwrap()); + let default_struct_config = StructCodeGenConfig { + field_configs: HashMap::new(), + deny_extra_fields: false, + }; + let struct_config = module_config + .and_then(|config| config.struct_configs.get(&name)) + .unwrap_or(&default_struct_config); + + let fields = table + .columns + .iter() + .map(|column| { + let default_field_config = StructFieldCodeGenConfig { + rename: None, + override_type: None, + attrs: vec![], + id_promote_mode: IdPromoteMode::None, + }; + + let field_config = struct_config + .field_configs + .get(&column.name) + .unwrap_or(&default_field_config); + + let (name, db_alias) = match &field_config.rename { + Some(rename) => (rename.clone(), Some(column.name.value.clone())), + None => { + let name = sql_ident_to_field_name(&column.name); + if name.to_string() == column.name.value { + (name, None) + } else { + (name, Some(column.name.value.clone())) + } + } + }; + + let r#type = match (&field_config.override_type, field_config.id_promote_mode) { + (Some(r#type), _) => r#type.clone(), + (None, IdPromoteMode::None) => type_gen.sql_datatype_to_rust_type(&column.data_type).unwrap(), + (None, IdPromoteMode::TrustedId) => todo!(), + (None, IdPromoteMode::Id) => todo!(), + }; + + TableStructField { + name, + r#type, + db_alias, + } + }) + .collect(); + + let table_struct = TableStruct { + name, + fields, + db_alias: todo!(), + }; + ( + GeneratedTableStruct(type_gen.generate_table_struct(&table_struct)), + TableStructImpls { + base_table_impl: todo!(), + table_with_pk_impl: todo!(), + }, + ) +} + +pub struct CodeGenConfig { + pub module_configs: HashMap, +} + +pub struct ModuleCodeGenConfig { + pub use_statements: Vec, + pub struct_configs: HashMap, +} + +pub struct StructCodeGenConfig { + pub field_configs: HashMap, + pub deny_extra_fields: bool, +} + +#[derive(Default)] +pub struct StructFieldCodeGenConfig { + rename: Option, + override_type: Option, + attrs: Vec, + id_promote_mode: IdPromoteMode, +} + +#[derive(Clone, Copy, Default)] +pub enum IdPromoteMode { + #[default] + None, + TrustedId, + Id, +} diff --git a/rasql-core/src/rust/simple_type_gen.rs b/rasql-core/src/rust/simple_type_gen.rs new file mode 100644 index 0000000..36ece71 --- /dev/null +++ b/rasql-core/src/rust/simple_type_gen.rs @@ -0,0 +1,208 @@ +use std::collections::HashMap; + +use crate::rust::{ + sql_ident_to_field_name, sql_ident_to_type_name, type_gen::UnsupportedDataType, + GeneratedTableStruct, IdPromoteMode, ModuleCodeGenConfig, StructCodeGenConfig, + StructFieldCodeGenConfig, TableStruct, TableStructField, +}; + +fn generate_table_struct_and_impls( + table: &crate::sql::Table, + module_config: Option<&ModuleCodeGenConfig>, +) -> GeneratedTableStruct { + let name = sql_ident_to_type_name(table.name.0.last().unwrap()); + let default_struct_config = StructCodeGenConfig { + field_configs: HashMap::new(), + deny_extra_fields: false, + }; + let struct_config = module_config + .and_then(|config| config.struct_configs.get(&name)) + .unwrap_or(&default_struct_config); + + let fields = table + .columns + .iter() + .map(|column| { + let default_field_config = StructFieldCodeGenConfig { + rename: None, + override_type: None, + attrs: vec![], + id_promote_mode: IdPromoteMode::None, + }; + + let field_config = struct_config + .field_configs + .get(&column.name) + .unwrap_or(&default_field_config); + + let (name, db_alias) = match &field_config.rename { + Some(rename) => (rename.clone(), Some(column.name.value.clone())), + None => { + let name = sql_ident_to_field_name(&column.name); + if name.to_string() == column.name.value { + (name, None) + } else { + (name, Some(column.name.value.clone())) + } + } + }; + + let r#type = match (&field_config.override_type, field_config.id_promote_mode) { + (Some(r#type), _) => r#type.clone(), + (None, IdPromoteMode::None) => { + sql_datatype_to_rust_type(&column.data_type).unwrap() + } + (None, IdPromoteMode::TrustedId) => todo!(), + (None, IdPromoteMode::Id) => todo!(), + }; + + TableStructField { + name, + r#type, + db_alias, + } + }) + .collect(); + + let table_struct = TableStruct { + name, + fields, + db_alias: todo!(), + }; + + GeneratedTableStruct(generate_table_struct(&table_struct)) +} + +fn sql_datatype_to_rust_type( + datatype: &sqlparser::ast::DataType, +) -> Result { + Ok(match datatype { + sqlparser::ast::DataType::Character(..) + | sqlparser::ast::DataType::Char(..) + | sqlparser::ast::DataType::CharacterVarying(..) + | sqlparser::ast::DataType::CharVarying(..) + | sqlparser::ast::DataType::Varchar(..) + | sqlparser::ast::DataType::Nvarchar(..) + | sqlparser::ast::DataType::Text + | sqlparser::ast::DataType::TinyText + | sqlparser::ast::DataType::MediumText + | sqlparser::ast::DataType::LongText + | sqlparser::ast::DataType::String(_) + | sqlparser::ast::DataType::FixedString(_) => syn::Type::Verbatim(quote::quote! {String}), + sqlparser::ast::DataType::Uuid => syn::Type::Verbatim(quote::quote! {uuid::Uuid}), + sqlparser::ast::DataType::Varbinary(_) + | sqlparser::ast::DataType::Blob(_) + | sqlparser::ast::DataType::TinyBlob + | sqlparser::ast::DataType::MediumBlob + | sqlparser::ast::DataType::LongBlob + | sqlparser::ast::DataType::Bytes(_) + | sqlparser::ast::DataType::Bytea + | sqlparser::ast::DataType::Binary(_) => syn::Type::Verbatim(quote::quote! {Vec}), + sqlparser::ast::DataType::Numeric(..) + | sqlparser::ast::DataType::Decimal(..) + | sqlparser::ast::DataType::Dec(..) => { + syn::Type::Verbatim(quote::quote! {rust_decimal::Decimal}) + } + sqlparser::ast::DataType::Int2(_) | sqlparser::ast::DataType::SmallInt(_) => { + syn::Type::Verbatim(quote::quote! {i16}) + } + sqlparser::ast::DataType::UnsignedInt2(_) => syn::Type::Verbatim(quote::quote! {u16}), + sqlparser::ast::DataType::Int16 => todo!(), + sqlparser::ast::DataType::Int128 => todo!(), + sqlparser::ast::DataType::Int256 => todo!(), + sqlparser::ast::DataType::Int(_) + | sqlparser::ast::DataType::Int32 + | sqlparser::ast::DataType::Int4(_) + | sqlparser::ast::DataType::Integer(_) => syn::Type::Verbatim(quote::quote! {i32}), + sqlparser::ast::DataType::UnsignedInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt4(_) => todo!(), + sqlparser::ast::DataType::UnsignedInteger(_) => todo!(), + sqlparser::ast::DataType::UInt8 => todo!(), + sqlparser::ast::DataType::UInt16 => todo!(), + sqlparser::ast::DataType::UInt32 => todo!(), + sqlparser::ast::DataType::UInt64 => todo!(), + sqlparser::ast::DataType::UInt128 => todo!(), + sqlparser::ast::DataType::UInt256 => todo!(), + sqlparser::ast::DataType::Int8(_) + | sqlparser::ast::DataType::Int64 + | sqlparser::ast::DataType::BigInt(_) => syn::Type::Verbatim(quote::quote! {i64}), + sqlparser::ast::DataType::UnsignedBigInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt8(_) => todo!(), + sqlparser::ast::DataType::Float(_) + | sqlparser::ast::DataType::Float4 + | sqlparser::ast::DataType::Real + | sqlparser::ast::DataType::Float32 => syn::Type::Verbatim(quote::quote! {f32}), + sqlparser::ast::DataType::Float64 + | sqlparser::ast::DataType::Float8 + | sqlparser::ast::DataType::Double(..) + | sqlparser::ast::DataType::DoublePrecision => syn::Type::Verbatim(quote::quote! {f64}), + sqlparser::ast::DataType::Bool => todo!(), + sqlparser::ast::DataType::Boolean => todo!(), + sqlparser::ast::DataType::Date => todo!(), + sqlparser::ast::DataType::Date32 => todo!(), + sqlparser::ast::DataType::Time(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Datetime(_) => todo!(), + sqlparser::ast::DataType::Datetime64(_, _) => todo!(), + sqlparser::ast::DataType::Timestamp(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Interval => todo!(), + sqlparser::ast::DataType::JSON => todo!(), + sqlparser::ast::DataType::JSONB => todo!(), + sqlparser::ast::DataType::Regclass => todo!(), + sqlparser::ast::DataType::Bit(_) => todo!(), + sqlparser::ast::DataType::BitVarying(_) => todo!(), + sqlparser::ast::DataType::Custom(object_name, vec) => todo!(), + sqlparser::ast::DataType::Array(array_elem_type_def) => match array_elem_type_def { + sqlparser::ast::ArrayElemTypeDef::None => { + return Err(UnsupportedDataType(datatype.clone())) + } + sqlparser::ast::ArrayElemTypeDef::AngleBracket(data_type) + | sqlparser::ast::ArrayElemTypeDef::SquareBracket(data_type, _) + | sqlparser::ast::ArrayElemTypeDef::Parenthesis(data_type) => { + let inner_type = sql_datatype_to_rust_type(&datatype)?; + syn::Type::Verbatim(quote::quote! {Vec<#inner_type>}) + } + }, + sqlparser::ast::DataType::Map(data_type, data_type1) => todo!(), + sqlparser::ast::DataType::Tuple(vec) => todo!(), + sqlparser::ast::DataType::Nested(vec) => todo!(), + sqlparser::ast::DataType::Enum(vec, _) => todo!(), + sqlparser::ast::DataType::Set(vec) => todo!(), + sqlparser::ast::DataType::Struct(vec, struct_bracket_kind) => todo!(), + sqlparser::ast::DataType::Union(vec) => todo!(), + sqlparser::ast::DataType::Nullable(data_type) => todo!(), + sqlparser::ast::DataType::LowCardinality(data_type) => todo!(), + sqlparser::ast::DataType::Trigger => todo!(), + _ => return Err(UnsupportedDataType(datatype.clone())), + }) +} + +fn generate_table_struct(table_struct: &TableStruct) -> proc_macro2::TokenStream { + let TableStruct { + name, + fields, + db_alias, + } = table_struct; + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + + let fields = fields.iter().map( + |TableStructField { + name, + r#type, + db_alias, + }| { + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + quote::quote!(#db_alias #name : #r#type) + }, + ); + quote::quote!( + #[derive(ToSql, FromSql)] + #db_alias + struct #name { + #(#fields,)* + } + ) +} diff --git a/rasql-core/src/rust/type_gen.rs b/rasql-core/src/rust/type_gen.rs new file mode 100644 index 0000000..9942c7b --- /dev/null +++ b/rasql-core/src/rust/type_gen.rs @@ -0,0 +1,194 @@ +use thiserror::Error; + +use crate::rust::TableStructField; + +use super::TableStruct; + +pub trait TypeGenerator { + fn sql_datatype_to_rust_type( + &self, + datatype: &sqlparser::ast::DataType, + ) -> Result; + + fn generate_table_struct(&self, table_struct: &TableStruct) -> proc_macro2::TokenStream; +} + +#[derive(Debug, Error)] +#[error("Type generator does not support the following SQL datatype: {0}")] +pub struct UnsupportedDataType(pub sqlparser::ast::DataType); + +#[cfg(feature = "tokio-postgres")] +pub struct TokioPostgresGenerator { + pub use_rust_decimal: UseRustDecimal, + pub use_uuid: UseUuid, +} + +#[cfg(feature = "tokio-postgres")] +pub enum UseRustDecimal { + DontUse, + Version1, +} + +#[cfg(feature = "tokio-postgres")] +pub enum UseUuid { + DontUse, + Version0_8, + Version1, +} + +#[cfg(feature = "tokio-postgres")] +impl TokioPostgresGenerator { + pub(super) fn generate_execute( + client: &syn::Expr, + prepared_statement: &syn::Expr, + parameters: &[&syn::Expr], + ) -> proc_macro2::TokenStream { + quote::quote!(#client.execute(#prepared_statement, &[#(#parameters,)*]).await) + } +} + +#[cfg(feature = "tokio-postgres")] +impl TypeGenerator for TokioPostgresGenerator { + fn sql_datatype_to_rust_type( + &self, + datatype: &sqlparser::ast::DataType, + ) -> Result { + Ok(match datatype { + sqlparser::ast::DataType::Character(..) + | sqlparser::ast::DataType::Char(..) + | sqlparser::ast::DataType::CharacterVarying(..) + | sqlparser::ast::DataType::CharVarying(..) + | sqlparser::ast::DataType::Varchar(..) + | sqlparser::ast::DataType::Nvarchar(..) + | sqlparser::ast::DataType::Text + | sqlparser::ast::DataType::TinyText + | sqlparser::ast::DataType::MediumText + | sqlparser::ast::DataType::LongText + | sqlparser::ast::DataType::String(_) + | sqlparser::ast::DataType::FixedString(_) => { + syn::Type::Verbatim(quote::quote! {String}) + } + sqlparser::ast::DataType::Uuid + if matches!(self.use_uuid, UseUuid::Version0_8 | UseUuid::Version1) => + { + syn::Type::Verbatim(quote::quote! {uuid::Uuid}) + } + sqlparser::ast::DataType::Varbinary(_) + | sqlparser::ast::DataType::Blob(_) + | sqlparser::ast::DataType::TinyBlob + | sqlparser::ast::DataType::MediumBlob + | sqlparser::ast::DataType::LongBlob + | sqlparser::ast::DataType::Bytes(_) + | sqlparser::ast::DataType::Bytea + | sqlparser::ast::DataType::Binary(_) => syn::Type::Verbatim(quote::quote! {Vec}), + sqlparser::ast::DataType::Numeric(..) + | sqlparser::ast::DataType::Decimal(..) + | sqlparser::ast::DataType::Dec(..) + if matches!(self.use_rust_decimal, UseRustDecimal::Version1) => + { + syn::Type::Verbatim(quote::quote! {rust_decimal::Decimal}) + } + sqlparser::ast::DataType::Int2(_) | sqlparser::ast::DataType::SmallInt(_) => { + syn::Type::Verbatim(quote::quote! {i16}) + } + sqlparser::ast::DataType::UnsignedInt2(_) => syn::Type::Verbatim(quote::quote! {u16}), + sqlparser::ast::DataType::Int16 => todo!(), + sqlparser::ast::DataType::Int128 => todo!(), + sqlparser::ast::DataType::Int256 => todo!(), + sqlparser::ast::DataType::Int(_) + | sqlparser::ast::DataType::Int32 + | sqlparser::ast::DataType::Int4(_) + | sqlparser::ast::DataType::Integer(_) => syn::Type::Verbatim(quote::quote! {i32}), + sqlparser::ast::DataType::UnsignedInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt4(_) => todo!(), + sqlparser::ast::DataType::UnsignedInteger(_) => todo!(), + sqlparser::ast::DataType::UInt8 => todo!(), + sqlparser::ast::DataType::UInt16 => todo!(), + sqlparser::ast::DataType::UInt32 => todo!(), + sqlparser::ast::DataType::UInt64 => todo!(), + sqlparser::ast::DataType::UInt128 => todo!(), + sqlparser::ast::DataType::UInt256 => todo!(), + sqlparser::ast::DataType::Int8(_) + | sqlparser::ast::DataType::Int64 + | sqlparser::ast::DataType::BigInt(_) => syn::Type::Verbatim(quote::quote! {i64}), + sqlparser::ast::DataType::UnsignedBigInt(_) => todo!(), + sqlparser::ast::DataType::UnsignedInt8(_) => todo!(), + sqlparser::ast::DataType::Float(_) + | sqlparser::ast::DataType::Float4 + | sqlparser::ast::DataType::Real + | sqlparser::ast::DataType::Float32 => syn::Type::Verbatim(quote::quote! {f32}), + sqlparser::ast::DataType::Float64 + | sqlparser::ast::DataType::Float8 + | sqlparser::ast::DataType::Double(..) + | sqlparser::ast::DataType::DoublePrecision => syn::Type::Verbatim(quote::quote! {f64}), + sqlparser::ast::DataType::Bool => todo!(), + sqlparser::ast::DataType::Boolean => todo!(), + sqlparser::ast::DataType::Date => todo!(), + sqlparser::ast::DataType::Date32 => todo!(), + sqlparser::ast::DataType::Time(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Datetime(_) => todo!(), + sqlparser::ast::DataType::Datetime64(_, _) => todo!(), + sqlparser::ast::DataType::Timestamp(_, timezone_info) => todo!(), + sqlparser::ast::DataType::Interval => todo!(), + sqlparser::ast::DataType::JSON => todo!(), + sqlparser::ast::DataType::JSONB => todo!(), + sqlparser::ast::DataType::Regclass => todo!(), + sqlparser::ast::DataType::Bit(_) => todo!(), + sqlparser::ast::DataType::BitVarying(_) => todo!(), + sqlparser::ast::DataType::Custom(object_name, vec) => todo!(), + sqlparser::ast::DataType::Array(array_elem_type_def) => match array_elem_type_def { + sqlparser::ast::ArrayElemTypeDef::None => { + return Err(UnsupportedDataType(datatype.clone())) + } + sqlparser::ast::ArrayElemTypeDef::AngleBracket(data_type) + | sqlparser::ast::ArrayElemTypeDef::SquareBracket(data_type, _) + | sqlparser::ast::ArrayElemTypeDef::Parenthesis(data_type) => { + let inner_type = self.sql_datatype_to_rust_type(&datatype)?; + syn::Type::Verbatim(quote::quote! {Vec<#inner_type>}) + } + }, + sqlparser::ast::DataType::Map(data_type, data_type1) => todo!(), + sqlparser::ast::DataType::Tuple(vec) => todo!(), + sqlparser::ast::DataType::Nested(vec) => todo!(), + sqlparser::ast::DataType::Enum(vec, _) => todo!(), + sqlparser::ast::DataType::Set(vec) => todo!(), + sqlparser::ast::DataType::Struct(vec, struct_bracket_kind) => todo!(), + sqlparser::ast::DataType::Union(vec) => todo!(), + sqlparser::ast::DataType::Nullable(data_type) => todo!(), + sqlparser::ast::DataType::LowCardinality(data_type) => todo!(), + sqlparser::ast::DataType::Trigger => todo!(), + _ => return Err(UnsupportedDataType(datatype.clone())), + }) + } + + fn generate_table_struct(&self, table_struct: &TableStruct) -> proc_macro2::TokenStream { + let TableStruct { + name, + fields, + db_alias, + } = table_struct; + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + + let fields = fields.iter().map( + |TableStructField { + name, + r#type, + db_alias, + }| { + let db_alias = db_alias + .as_deref() + .map(|db_alias| quote::quote!(#[postgres(name = #db_alias)])); + quote::quote!(#db_alias #name : #r#type) + }, + ); + quote::quote!( + #[derive(ToSql, FromSql)] + #db_alias + struct #name { + #(#fields,)* + } + ) + } +} diff --git a/rasql-core/src/sql.rs b/rasql-core/src/sql.rs new file mode 100644 index 0000000..0ab00b3 --- /dev/null +++ b/rasql-core/src/sql.rs @@ -0,0 +1,193 @@ +use std::collections::HashMap; + +use sqlparser::ast::{ColumnDef, DataType, Ident, ObjectName, SchemaName, TableConstraint}; + +pub fn parse_sql_schema( + sql_statements: impl IntoIterator< + Item = impl TryInto, + >, +) { + let mut schemas = HashMap::new(); + for statement in sql_statements { + let statement: sqlparser::ast::Statement = statement.try_into().unwrap(); + match statement { + sqlparser::ast::Statement::CreateSchema { schema_name, .. } => { + schemas + .entry(schema_name.clone()) + .or_insert_with(|| Schema { + name: schema_name, + tables: Default::default(), + types: Default::default(), + }); + } + sqlparser::ast::Statement::CreateTable(sqlparser::ast::CreateTable { + name, + columns, + constraints, + .. + }) => { + let schema = schema_for_object(&mut schemas, &name); + schema.types.insert( + name.clone(), + UserDefinedType::Composite(CompositeType { + name: name.clone(), + fields: columns + .iter() + .map(|column| Field { + name: column.name.clone(), + r#type: column.data_type.clone(), + }) + .collect(), + }), + ); + schema.tables.insert( + name.clone(), + Table { + name, + columns, + constraints, + }, + ); + } + sqlparser::ast::Statement::AlterTable { + name, operations, .. + } => { + let schema = schema_for_object(&mut schemas, &name); + let Some(table) = schema.tables.get_mut(&name) else { + continue; + }; + for op in operations { + match op { + sqlparser::ast::AlterTableOperation::AddConstraint(table_constraint) => { + table.constraints.push(table_constraint); + } + _ => (), + } + } + } + sqlparser::ast::Statement::CreateType { + name, + representation: + sqlparser::ast::UserDefinedTypeRepresentation::Composite { attributes }, + } => { + let schema = schema_for_object(&mut schemas, &name); + schema.types.insert( + name.clone(), + UserDefinedType::Composite(CompositeType { + name, + fields: attributes + .into_iter() + .map(|attr| Field { + name: attr.name, + r#type: attr.data_type, + }) + .collect(), + }), + ); + } + sqlparser::ast::Statement::CreateType { + name, + representation: sqlparser::ast::UserDefinedTypeRepresentation::Enum { labels }, + } => { + let schema = schema_for_object(&mut schemas, &name); + schema.types.insert( + name.clone(), + UserDefinedType::Enum(EnumType { + name, + variants: labels, + }), + ); + } + _ => (), + } + } +} + +fn schema_for_object<'a>( + schemas: &'a mut HashMap, + object_name: &ObjectName, +) -> &'a mut Schema { + let schema = match object_name.0.as_slice() { + [_table_name] => schemas + .entry(SchemaName::Simple(ObjectName(vec![Ident::new("public")]))) + .or_insert_with(|| Schema { + name: SchemaName::Simple(ObjectName(vec![Ident::new("public")])), + tables: Default::default(), + types: Default::default(), + }), + [schema_name, _table_name] => schemas + .entry(SchemaName::Simple(ObjectName(vec![schema_name.clone()]))) + .or_insert_with(|| Schema { + name: SchemaName::Simple(ObjectName(vec![schema_name.clone()])), + tables: Default::default(), + types: Default::default(), + }), + [catalog_name, schema_name, _table_name] => schemas + .entry(SchemaName::Simple(ObjectName(vec![ + catalog_name.clone(), + schema_name.clone(), + ]))) + .or_insert_with(|| Schema { + name: SchemaName::Simple(ObjectName(vec![ + catalog_name.clone(), + schema_name.clone(), + ])), + tables: Default::default(), + types: Default::default(), + }), + _ => unreachable!(), + }; + schema +} + +pub struct Schema { + pub name: SchemaName, + pub tables: HashMap, + pub types: HashMap, +} + +pub struct Table { + pub name: ObjectName, + pub columns: Vec, + pub constraints: Vec, +} + +pub enum UserDefinedType { + Composite(CompositeType), + Enum(EnumType), + Domain(DomainType), +} + +pub struct CompositeType { + name: ObjectName, + fields: Vec, +} + +pub struct EnumType { + name: ObjectName, + variants: Vec, +} + +pub struct DomainType { + +} + +pub struct Field { + pub name: Ident, + pub r#type: DataType, +} + +pub enum Constraint { + PrimaryKey(PrimaryKeyConstraint), + ForeignKey(ForeignKeyConstraint), + Unique(UniqueConstraint), + Check(CheckConstraint), +} + +pub struct PrimaryKeyConstraint {} + +pub struct ForeignKeyConstraint {} + +pub struct UniqueConstraint {} + +pub struct CheckConstraint {} diff --git a/rasql-model/Cargo.toml b/rasql-model/Cargo.toml new file mode 100644 index 0000000..eb012fe --- /dev/null +++ b/rasql-model/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "rasql-model" +version = "0.1.0" +edition = "2024" + +[dependencies] +sqlparser = "0.54.0" diff --git a/rasql-model/src/lib.rs b/rasql-model/src/lib.rs new file mode 100644 index 0000000..7c47532 --- /dev/null +++ b/rasql-model/src/lib.rs @@ -0,0 +1,84 @@ +use std::{any::TypeId, borrow::Cow, num::NonZeroU32}; + +#[derive(Debug, Clone)] +pub enum PostgresDatatype { + Bigint, + Bigserial, + Bit(Option), + BitVarying(Option), + Boolean, + Box, + Bytea, + Character(Option), + CharacterVarying(Option), + Cidr, + Circle, + Date, + DoublePrecision, + Inet, + Integer, + Interval(Option,Option), + Json, + Jsonb, + Line, + Lseg, + Macaddr, + Macaddr8, + Money, + Numeric(Option), + Path, + PgLsn, + PgSnapshot, + Point, + Polygon, + Real, + SmallInt, + SmallSerial, + Serial, + Text, + Time(Option), + TimeTz(Option), + Timestamp(Option), + TimestampTz(Option), + TsQuery, + TsVector, + TxidSnapshot, + Uuid, + Xml, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IntervalFields { + Year, + Month, + Day, + Hour, + Minute, + Second, + YearToMonth, + DayToHour, + DayToMinute, + DayToSecond, + HourToMinute, + HourToSecond, + MinuteToSecond, +} + +#[derive(Debug, Clone, Copy)] +pub enum NumericConfig { + Precision(u16), + PrecisionAndScale(u16, i16), +} + +#[derive(Debug, Clone)] +pub struct Table { + pub columns: Cow<'static, [Column]>, +} + +#[derive(Debug, Clone)] +pub struct Column { + pub postgres_name: Cow<'static, str>, + pub postgres_datatype: PostgresDatatype, + pub field_name: Cow<'static, str>, + pub rust_type_id: fn() -> TypeId, +} diff --git a/rasql-query/Cargo.toml b/rasql-query/Cargo.toml index cbb57d3..bacd0e7 100644 --- a/rasql-query/Cargo.toml +++ b/rasql-query/Cargo.toml @@ -1,7 +1,9 @@ [package] name = "rasql-query" -version = "0.1.0" +version = "0.0.0" edition = "2021" +license = "MIT" +description = "Procedural macros for Rasql powered queries in Rust" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/rasql-traits/Cargo.toml b/rasql-traits/Cargo.toml index 8ffa3af..098215f 100644 --- a/rasql-traits/Cargo.toml +++ b/rasql-traits/Cargo.toml @@ -1,7 +1,9 @@ [package] name = "rasql-traits" -version = "0.1.0" +version = "0.0.0" edition = "2021" +license = "MIT" +description = "Trait definitions for Rasql generated database types" [dependencies] tokio-postgres = { version = "0.7.12", optional = true }