diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 92c172a..56a160b 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,6 +1,9 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, DataStruct, DeriveInput, FieldsNamed, FieldsUnnamed}; +use syn::{ + parse_macro_input, parse_quote, DataStruct, DeriveInput, FieldsNamed, FieldsUnnamed, Generics, + WhereClause, WherePredicate, +}; #[proc_macro_derive(RustyValue)] pub fn derive_value(input: TokenStream) -> TokenStream { @@ -18,7 +21,8 @@ fn derive(input: DeriveInput) -> TokenStream { fn derive_struct(input: &DeriveInput, struct_data: &DataStruct) -> TokenStream { let ident = &input.ident; let name = ident.to_string(); - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let (impl_generics, ty_generics, _) = input.generics.split_for_impl(); + let where_clause = add_rusty_bound(&input.generics); match &struct_data.fields { syn::Fields::Named(FieldsNamed { named, .. }) => { @@ -82,3 +86,19 @@ fn derive_struct(input: &DeriveInput, struct_data: &DataStruct) -> TokenStream { }), } } + +fn add_rusty_bound(generics: &Generics) -> WhereClause { + let trait_bound: proc_macro2::TokenStream = parse_quote!(rusty_value::RustyValue); + + let new_predicates = generics.type_params().map::(|param| { + let param = ¶m.ident; + parse_quote!(#param : #trait_bound) + }); + + let mut generics = generics.clone(); + generics + .make_where_clause() + .predicates + .extend(new_predicates); + generics.where_clause.unwrap() +} diff --git a/tests/structs.rs b/tests/structs.rs index b0f6715..6c1fc61 100644 --- a/tests/structs.rs +++ b/tests/structs.rs @@ -66,3 +66,27 @@ fn it_handles_unit_structs() { panic!("Struct wasn't serialized as struct"); } } + +#[derive(RustyValue)] +struct GenericStruct { + field: T, +} + +#[test] +fn it_handles_generics() { + let test_struct = GenericStruct:: { field: 12 }; + let value = test_struct.into_rusty_value(); + dbg!(&value); + + if let Value::Struct(s) = value { + assert_eq!(&s.name, "GenericStruct"); + + if let StructFields::Named(fields) = s.fields { + assert_eq!(fields.len(), 1); + } else { + panic!("Struct wasn't serialized as named struct") + } + } else { + panic!("Struct wasn't serialized as struct"); + } +}