Started rasql-core

This commit is contained in:
ZacJW 2025-01-27 00:16:15 +00:00
parent 336aba3e0e
commit 93154dc706
6 changed files with 646 additions and 0 deletions

View file

@ -5,4 +5,14 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
tokio-postgres = ["rasql-traits/tokio-postgres", "dep:tokio-postgres"]
[dependencies]
sqlparser = "0.54.0"
rasql-traits = { version = "0.1.0", path = "../rasql-traits" }
quote = "1.0.35"
proc-macro2 = "1.0.93"
syn = { version = "2.0.96", features = ["full"] }
tokio-postgres = { version = "0.7.12", optional = true }
convert_case = "0.7.1"

View file

@ -0,0 +1,2 @@
pub mod sql;
pub mod rust;

View file

@ -0,0 +1,145 @@
pub trait AsyncClientCodeGenerator<Client: rasql_traits::r#async::Client> {
/// Create a token stream for usage of `client` to prepare `statement_str` for
/// later execution, evaluating to a value of type
/// `Result<Client::PreparedStatement, Client::PrepareError>`
///
/// - `client` is an expr of type `&Client`
/// - `statement_str` is an expr of type `&str`
fn generate_prepare_statement(
client: &syn::Expr,
statement_str: &syn::Expr,
) -> proc_macro2::TokenStream;
/// Create a token stream for usage of `client` to execute the `prepared_statement`
/// with the provided `parameters`, evaluating to a value of type
/// `Result<Client::Rows, Client::QueryError>`.
///
/// - `client` is an expr of type `&Client`
/// - `prepared_statement` is an expr of type `&Client::PreparedStatement`
/// - exprs in `parameters` are references of types that can be assumed to be compatible with the client
fn generate_query_many_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream;
/// Create a token stream for usage of `client` to execute the `prepared_statement`
/// with the provided `parameters`, evaluating to a value of type
/// `Result<Option<Client::Row>, Client::QueryError>`.
///
/// - `client` is an expr of type `&Client`
/// - `prepared_statement` is an expr of type `&Client::PreparedStatement`
/// - exprs in `parameters` are references of types that can be assumed to be compatible with the client
fn generate_query_one_or_none_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream;
/// Create a token stream for usage of `row` to read a column, evaluating to a value of type
/// `Result<T, Client::RowReadColumnError>` where `T` is any type compatible with the database client.
///
/// - `row` is an expr of type `&Client::Row`
/// - `column_name` is an expr of type `&str`
fn generate_row_read_column(
row: &syn::Expr,
column_name: &syn::Expr,
) -> proc_macro2::TokenStream;
/// Create a token stream for usage of `client` to execute the `prepared_statement`
/// with the provided `parameters`, evaluating to a value of type
/// `Result<Client::InsertOutcome, Client::InsertError>`.
///
/// - `client` is an expr of type `&Client`
/// - `prepared_statement` is an expr of type `&Client::PreparedStatement`
/// - exprs in `parameters` are references of types that can be assumed to be compatible with the client
fn generate_insert_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream;
/// Create a token stream for usage of `client` to execute the `prepared_statement`
/// with the provided `parameters`, evaluating to a value of type
/// `Result<Client::UpdateOutcome, Client::UpdateError>`.
///
/// - `client` is an expr of type `&Client`
/// - `prepared_statement` is an expr of type `&Client::PreparedStatement`
/// - exprs in `parameters` are references of types that can be assumed to be compatible with the client
fn generate_update_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream;
/// Create a token stream for usage of `client` to execute the `prepared_statement`
/// with the provided `parameters`, evaluating to a value of type
/// `Result<Client::DeleteOutcome, Client::DeleteError>`.
///
/// - `client` is an expr of type `&Client`
/// - `prepared_statement` is an expr of type `&Client::PreparedStatement`
/// - exprs in `parameters` are references of types that can be assumed to be compatible with the client
fn generate_delete_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream;
}
#[cfg(feature = "tokio-postgres")]
impl AsyncClientCodeGenerator<tokio_postgres::Client> for super::type_gen::TokioPostgresGenerator {
fn generate_prepare_statement(
client: &syn::Expr,
statement_str: &syn::Expr,
) -> proc_macro2::TokenStream {
quote::quote!(#client.prepare(#statement_str).await)
}
fn generate_query_many_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream {
quote::quote!(#client.query(#prepared_statement, &[#(#parameters,)*]).await)
}
fn generate_query_one_or_none_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream {
quote::quote!(#client.query_opt(#prepared_statement, &[#(#parameters,)*]).await)
}
fn generate_row_read_column(
row: &syn::Expr,
column_name: &syn::Expr,
) -> proc_macro2::TokenStream {
quote::quote!(#row.try_get(#column_name))
}
fn generate_insert_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream {
Self::generate_execute(client, prepared_statement, parameters)
}
fn generate_update_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream {
Self::generate_execute(client, prepared_statement, parameters)
}
fn generate_delete_with_statement(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream {
Self::generate_execute(client, prepared_statement, parameters)
}
}

257
rasql-core/src/rust/mod.rs Normal file
View file

@ -0,0 +1,257 @@
pub mod type_gen;
pub mod client_gen;
use std::collections::HashMap;
use client_gen::AsyncClientCodeGenerator;
use convert_case::Casing;
use type_gen::TypeGenerator;
pub struct TableStruct {
pub name: syn::Ident,
pub fields: Vec<TableStructField>,
pub db_alias: Option<String>,
}
pub struct TableStructField {
pub name: syn::Ident,
pub r#type: syn::Type,
pub db_alias: Option<String>,
}
pub struct GeneratedTableStruct(pub proc_macro2::TokenStream);
pub struct TableStructImpls {
pub base_table_impl: proc_macro2::TokenStream,
pub table_with_pk_impl: Option<proc_macro2::TokenStream>,
}
fn sql_ident_to_type_name(ident: &sqlparser::ast::Ident) -> syn::Ident {
let mut ident = ident.value.to_case(convert_case::Case::Pascal);
if ident.chars().next().unwrap().is_ascii_digit() {
ident.insert(0, '_');
}
syn::Ident::new(&ident, proc_macro2::Span::call_site())
}
fn sql_ident_to_field_name(ident: &sqlparser::ast::Ident) -> syn::Ident {
let mut ident = ident.value.to_case(convert_case::Case::Snake);
if ident.chars().next().unwrap().is_ascii_digit() {
ident.insert(0, '_');
}
syn::Ident::new(&ident, proc_macro2::Span::call_site())
}
#[inline]
fn sql_ident_to_module_name(ident: &sqlparser::ast::Ident) -> syn::Ident {
sql_ident_to_field_name(ident)
}
fn sql_datatype_to_rust_type(datatype: &sqlparser::ast::DataType) -> syn::Type {
match datatype {
sqlparser::ast::DataType::Character(..)
| sqlparser::ast::DataType::Char(..)
| sqlparser::ast::DataType::CharacterVarying(..)
| sqlparser::ast::DataType::CharVarying(..)
| sqlparser::ast::DataType::Varchar(..)
| sqlparser::ast::DataType::Nvarchar(..)
| sqlparser::ast::DataType::Text
| sqlparser::ast::DataType::TinyText
| sqlparser::ast::DataType::MediumText
| sqlparser::ast::DataType::LongText
| sqlparser::ast::DataType::String(_)
| sqlparser::ast::DataType::FixedString(_) => syn::Type::Verbatim(quote::quote! {String}),
sqlparser::ast::DataType::Uuid => todo!(),
sqlparser::ast::DataType::CharacterLargeObject(_) => todo!(),
sqlparser::ast::DataType::CharLargeObject(_) => todo!(),
sqlparser::ast::DataType::Clob(_) => todo!(),
sqlparser::ast::DataType::Binary(_) => todo!(),
sqlparser::ast::DataType::Varbinary(_) => todo!(),
sqlparser::ast::DataType::Blob(_) => todo!(),
sqlparser::ast::DataType::TinyBlob => todo!(),
sqlparser::ast::DataType::MediumBlob => todo!(),
sqlparser::ast::DataType::LongBlob => todo!(),
sqlparser::ast::DataType::Bytes(_) => todo!(),
sqlparser::ast::DataType::Numeric(exact_number_info) => todo!(),
sqlparser::ast::DataType::Decimal(exact_number_info) => todo!(),
sqlparser::ast::DataType::BigNumeric(exact_number_info) => todo!(),
sqlparser::ast::DataType::BigDecimal(exact_number_info) => todo!(),
sqlparser::ast::DataType::Dec(exact_number_info) => todo!(),
sqlparser::ast::DataType::Float(_) => todo!(),
sqlparser::ast::DataType::TinyInt(_) => todo!(),
sqlparser::ast::DataType::UnsignedTinyInt(_) => todo!(),
sqlparser::ast::DataType::Int2(_) => todo!(),
sqlparser::ast::DataType::UnsignedInt2(_) => todo!(),
sqlparser::ast::DataType::SmallInt(_) => todo!(),
sqlparser::ast::DataType::UnsignedSmallInt(_) => todo!(),
sqlparser::ast::DataType::MediumInt(_) => todo!(),
sqlparser::ast::DataType::UnsignedMediumInt(_) => todo!(),
sqlparser::ast::DataType::Int(_) => todo!(),
sqlparser::ast::DataType::Int16 => todo!(),
sqlparser::ast::DataType::Int128 => todo!(),
sqlparser::ast::DataType::Int256 => todo!(),
sqlparser::ast::DataType::Int32
| sqlparser::ast::DataType::Int4(_)
| sqlparser::ast::DataType::Integer(_) => syn::Type::Verbatim(quote::quote! {i32}),
sqlparser::ast::DataType::UnsignedInt(_) => todo!(),
sqlparser::ast::DataType::UnsignedInt4(_) => todo!(),
sqlparser::ast::DataType::UnsignedInteger(_) => todo!(),
sqlparser::ast::DataType::UInt8 => todo!(),
sqlparser::ast::DataType::UInt16 => todo!(),
sqlparser::ast::DataType::UInt32 => todo!(),
sqlparser::ast::DataType::UInt64 => todo!(),
sqlparser::ast::DataType::UInt128 => todo!(),
sqlparser::ast::DataType::UInt256 => todo!(),
sqlparser::ast::DataType::Int8(_)
| sqlparser::ast::DataType::Int64
| sqlparser::ast::DataType::BigInt(_) => syn::Type::Verbatim(quote::quote! {i64}),
sqlparser::ast::DataType::UnsignedBigInt(_) => todo!(),
sqlparser::ast::DataType::UnsignedInt8(_) => todo!(),
sqlparser::ast::DataType::Float4
| sqlparser::ast::DataType::Real
| sqlparser::ast::DataType::Float32 => syn::Type::Verbatim(quote::quote! {f32}),
sqlparser::ast::DataType::Float64
| sqlparser::ast::DataType::Float8
| sqlparser::ast::DataType::Double(..)
| sqlparser::ast::DataType::DoublePrecision => syn::Type::Verbatim(quote::quote! {f64}),
sqlparser::ast::DataType::Bool => todo!(),
sqlparser::ast::DataType::Boolean => todo!(),
sqlparser::ast::DataType::Date => todo!(),
sqlparser::ast::DataType::Date32 => todo!(),
sqlparser::ast::DataType::Time(_, timezone_info) => todo!(),
sqlparser::ast::DataType::Datetime(_) => todo!(),
sqlparser::ast::DataType::Datetime64(_, _) => todo!(),
sqlparser::ast::DataType::Timestamp(_, timezone_info) => todo!(),
sqlparser::ast::DataType::Interval => todo!(),
sqlparser::ast::DataType::JSON => todo!(),
sqlparser::ast::DataType::JSONB => todo!(),
sqlparser::ast::DataType::Regclass => todo!(),
sqlparser::ast::DataType::Bytea => todo!(),
sqlparser::ast::DataType::Bit(_) => todo!(),
sqlparser::ast::DataType::BitVarying(_) => todo!(),
sqlparser::ast::DataType::Custom(object_name, vec) => todo!(),
sqlparser::ast::DataType::Array(array_elem_type_def) => match array_elem_type_def {
sqlparser::ast::ArrayElemTypeDef::None => unimplemented!(),
sqlparser::ast::ArrayElemTypeDef::AngleBracket(data_type)
| sqlparser::ast::ArrayElemTypeDef::SquareBracket(data_type, _)
| sqlparser::ast::ArrayElemTypeDef::Parenthesis(data_type) => todo!(),
},
sqlparser::ast::DataType::Map(data_type, data_type1) => todo!(),
sqlparser::ast::DataType::Tuple(vec) => todo!(),
sqlparser::ast::DataType::Nested(vec) => todo!(),
sqlparser::ast::DataType::Enum(vec, _) => todo!(),
sqlparser::ast::DataType::Set(vec) => todo!(),
sqlparser::ast::DataType::Struct(vec, struct_bracket_kind) => todo!(),
sqlparser::ast::DataType::Union(vec) => todo!(),
sqlparser::ast::DataType::Nullable(data_type) => todo!(),
sqlparser::ast::DataType::LowCardinality(data_type) => todo!(),
sqlparser::ast::DataType::Unspecified => todo!(),
sqlparser::ast::DataType::Trigger => todo!(),
sqlparser::ast::DataType::AnyType => todo!(),
}
}
fn generate_table_struct_and_impls<
Traits: rasql_traits::DbTraits,
TypeGen: TypeGenerator<Traits>,
Client: rasql_traits::r#async::Client<Traits = Traits>,
ClientGen: AsyncClientCodeGenerator<Client>,
>(
table: &crate::sql::Table,
module_config: Option<&ModuleCodeGenConfig>,
) -> (GeneratedTableStruct, TableStructImpls) {
let name = sql_ident_to_type_name(table.name.0.last().unwrap());
let default_struct_config = StructCodeGenConfig {
field_configs: HashMap::new(),
deny_extra_fields: false,
};
let struct_config = module_config
.and_then(|config| config.struct_configs.get(&name))
.unwrap_or(&default_struct_config);
let fields = table
.columns
.iter()
.map(|column| {
let default_field_config = StructFieldCodeGenConfig {
rename: None,
override_type: None,
attrs: vec![],
id_promote_mode: IdPromoteMode::None,
};
let field_config = struct_config
.field_configs
.get(&column.name)
.unwrap_or(&default_field_config);
let (name, db_alias) = match &field_config.rename {
Some(rename) => (rename.clone(), Some(column.name.value.clone())),
None => {
let name = sql_ident_to_field_name(&column.name);
if name.to_string() == column.name.value {
(name, None)
} else {
(name, Some(column.name.value.clone()))
}
}
};
let r#type = match (&field_config.override_type, field_config.id_promote_mode) {
(Some(r#type), _) => r#type.clone(),
(None, IdPromoteMode::None) => column.data_type,
(None, IdPromoteMode::TrustedId) => todo!(),
(None, IdPromoteMode::Id) => todo!(),
};
TableStructField {
name,
r#type,
db_alias,
}
})
.collect();
let table_struct = TableStruct {
name,
fields,
db_alias: todo!(),
};
(
GeneratedTableStruct(TypeGen::generate_table_struct(&table_struct)),
TableStructImpls {
base_table_impl: todo!(),
table_with_pk_impl: todo!(),
},
)
}
pub struct CodeGenConfig {
pub module_configs: HashMap<syn::Ident, ModuleCodeGenConfig>,
}
pub struct ModuleCodeGenConfig {
pub use_statements: Vec<syn::ItemUse>,
pub struct_configs: HashMap<syn::Ident, StructCodeGenConfig>,
}
pub struct StructCodeGenConfig {
pub field_configs: HashMap<sqlparser::ast::Ident, StructFieldCodeGenConfig>,
pub deny_extra_fields: bool,
}
#[derive(Default)]
pub struct StructFieldCodeGenConfig {
rename: Option<syn::Ident>,
override_type: Option<syn::Type>,
attrs: Vec<syn::Attribute>,
id_promote_mode: IdPromoteMode,
}
#[derive(Clone, Copy, Default)]
pub enum IdPromoteMode {
#[default]
None,
TrustedId,
Id,
}

View file

@ -0,0 +1,63 @@
use crate::rust::TableStructField;
use super::TableStruct;
pub trait TypeGenerator<Traits: rasql_traits::DbTraits> {
fn sql_datatype_to_rust_type(datatype: &sqlparser::ast::DataType) -> syn::Type;
fn generate_table_struct(table_struct: &TableStruct) -> proc_macro2::TokenStream;
}
#[cfg(feature = "tokio-postgres")]
pub struct TokioPostgresGenerator;
#[cfg(feature = "tokio-postgres")]
impl TokioPostgresGenerator {
pub(super) fn generate_execute(
client: &syn::Expr,
prepared_statement: &syn::Expr,
parameters: &[&syn::Expr],
) -> proc_macro2::TokenStream {
quote::quote!(#client.execute(#prepared_statement, &[#(#parameters,)*]).await)
}
}
#[cfg(feature = "tokio-postgres")]
impl TypeGenerator<rasql_traits::PostgresTypesTraits> for TokioPostgresGenerator {
fn sql_datatype_to_rust_type(datatype: &sqlparser::ast::DataType) -> syn::Type {
todo!()
}
fn generate_table_struct(table_struct: &TableStruct) -> proc_macro2::TokenStream {
let TableStruct {
name,
fields,
db_alias,
} = table_struct;
let db_alias = db_alias
.as_deref()
.map(|db_alias| quote::quote!(#[postgres(name = #db_alias)]));
let fields = fields.iter().map(
|TableStructField {
name,
r#type,
db_alias,
}| {
let db_alias = db_alias
.as_deref()
.map(|db_alias| quote::quote!(#[postgres(name = #db_alias)]));
quote::quote!(#db_alias #name : #r#type)
},
);
quote::quote!(
#[derive(ToSql, FromSql)]
#db_alias
struct #name {
#(#fields,)*
}
)
}
}

169
rasql-core/src/sql.rs Normal file
View file

@ -0,0 +1,169 @@
use std::collections::HashMap;
use sqlparser::ast::{ColumnDef, DataType, Ident, ObjectName, SchemaName, TableConstraint};
pub fn parse_sql_schema(
sql_statements: impl IntoIterator<
Item = impl TryInto<sqlparser::ast::Statement, Error = impl std::fmt::Debug>,
>,
) {
let mut schemas = HashMap::new();
for statement in sql_statements {
let statement: sqlparser::ast::Statement = statement.try_into().unwrap();
match statement {
sqlparser::ast::Statement::CreateSchema { schema_name, .. } => {
schemas
.entry(schema_name.clone())
.or_insert_with(|| Schema {
name: schema_name,
tables: Default::default(),
types: Default::default(),
});
}
sqlparser::ast::Statement::CreateTable(sqlparser::ast::CreateTable {
name,
columns,
constraints,
..
}) => {
let schema = schema_for_object(&mut schemas, &name);
schema.types.insert(
name.clone(),
Type::Composite {
name: name.clone(),
fields: columns
.iter()
.map(|column| Field {
name: column.name.clone(),
r#type: column.data_type.clone(),
})
.collect(),
},
);
schema.tables.insert(
name.clone(),
Table {
name,
columns,
constraints,
},
);
}
sqlparser::ast::Statement::AlterTable {
name, operations, ..
} => {
let schema = schema_for_object(&mut schemas, &name);
let Some(table) = schema.tables.get_mut(&name) else {
continue;
};
for op in operations {
match op {
sqlparser::ast::AlterTableOperation::AddConstraint(table_constraint) => {
table.constraints.push(table_constraint);
}
_ => (),
}
}
}
sqlparser::ast::Statement::CreateType {
name,
representation:
sqlparser::ast::UserDefinedTypeRepresentation::Composite { attributes },
} => {
let schema = schema_for_object(&mut schemas, &name);
schema.types.insert(
name.clone(),
Type::Composite {
name,
fields: attributes
.into_iter()
.map(|attr| Field {
name: attr.name,
r#type: attr.data_type,
})
.collect(),
},
);
}
sqlparser::ast::Statement::CreateType {
name,
representation: sqlparser::ast::UserDefinedTypeRepresentation::Enum { labels },
} => {
let schema = schema_for_object(&mut schemas, &name);
schema.types.insert(
name.clone(),
Type::Enum {
name,
variants: labels,
},
);
}
_ => (),
}
}
}
fn schema_for_object<'a>(
schemas: &'a mut HashMap<SchemaName, Schema>,
object_name: &ObjectName,
) -> &'a mut Schema {
let schema = match object_name.0.as_slice() {
[_table_name] => schemas
.entry(SchemaName::Simple(ObjectName(vec![Ident::new("public")])))
.or_insert_with(|| Schema {
name: SchemaName::Simple(ObjectName(vec![Ident::new("public")])),
tables: Default::default(),
types: Default::default(),
}),
[schema_name, _table_name] => schemas
.entry(SchemaName::Simple(ObjectName(vec![schema_name.clone()])))
.or_insert_with(|| Schema {
name: SchemaName::Simple(ObjectName(vec![schema_name.clone()])),
tables: Default::default(),
types: Default::default(),
}),
[catalog_name, schema_name, _table_name] => schemas
.entry(SchemaName::Simple(ObjectName(vec![
catalog_name.clone(),
schema_name.clone(),
])))
.or_insert_with(|| Schema {
name: SchemaName::Simple(ObjectName(vec![
catalog_name.clone(),
schema_name.clone(),
])),
tables: Default::default(),
types: Default::default(),
}),
_ => unreachable!(),
};
schema
}
pub struct Schema {
pub name: SchemaName,
pub tables: HashMap<ObjectName, Table>,
pub types: HashMap<ObjectName, Type>,
}
pub struct Table {
pub name: ObjectName,
pub columns: Vec<ColumnDef>,
pub constraints: Vec<TableConstraint>,
}
pub enum Type {
Composite {
name: ObjectName,
fields: Vec<Field>,
},
Enum {
name: ObjectName,
variants: Vec<Ident>,
},
}
pub struct Field {
pub name: Ident,
pub r#type: DataType,
}