diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 66a1b75..0e3d3f1 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -11,4 +11,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.46" quote = "1.0.21" -syn = "1.0.101" + +[dependencies.syn] +version = "1" +features = ["full"] diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 56a160b..88ed5c9 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -1,8 +1,9 @@ use proc_macro::TokenStream; +use proc_macro2::Span; use quote::quote; use syn::{ - parse_macro_input, parse_quote, DataStruct, DeriveInput, FieldsNamed, FieldsUnnamed, Generics, - WhereClause, WherePredicate, + parse_macro_input, parse_quote, DataEnum, DataStruct, DeriveInput, FieldsNamed, FieldsUnnamed, + Generics, Variant, WhereClause, WherePredicate, }; #[proc_macro_derive(RustyValue)] @@ -13,7 +14,7 @@ pub fn derive_value(input: TokenStream) -> TokenStream { fn derive(input: DeriveInput) -> TokenStream { match &input.data { syn::Data::Struct(s) => derive_struct(&input, s), - syn::Data::Enum(_) => todo!(), + syn::Data::Enum(e) => derive_enum(&input, e), syn::Data::Union(_) => panic!("unions are currently unsupported"), } } @@ -45,7 +46,7 @@ fn derive_struct(input: &DeriveInput, struct_data: &DataStruct) -> TokenStream { Value::Struct(Struct{ name: #name.to_string(), - fields: StructFields::Named(values), + fields: Fields::Named(values), }) } } @@ -71,7 +72,7 @@ fn derive_struct(input: &DeriveInput, struct_data: &DataStruct) -> TokenStream { Value::Struct(Struct{ name: #name.to_string(), - fields: StructFields::Unnamed(values), + fields: Fields::Unnamed(values), }) } } @@ -80,13 +81,107 @@ fn derive_struct(input: &DeriveInput, struct_data: &DataStruct) -> TokenStream { syn::Fields::Unit => TokenStream::from(quote! { impl #impl_generics rusty_value::RustyValue for #ident #ty_generics #where_clause { fn into_rusty_value(self) -> rusty_value::Value { - Value::Unit(#name.to_string()) + use rusty_value::*; + Value::Struct(Struct{ + name: #name.to_string(), + fields: Fields::Unit, + }) } } }), } } +fn derive_enum(input: &DeriveInput, enum_data: &DataEnum) -> TokenStream { + let ident = &input.ident; + let (impl_generics, ty_generics, _) = input.generics.split_for_impl(); + let where_clause = add_rusty_bound(&input.generics); + let variant_matchers = enum_data + .variants + .iter() + .map(|v| create_enum_value_match(ident, v)) + .collect::>(); + + TokenStream::from(quote! { + impl #impl_generics rusty_value::RustyValue for #ident #ty_generics #where_clause { + fn into_rusty_value(self) -> rusty_value::Value { + let enum_val = match self { + #( #variant_matchers )* + }; + rusty_value::Value::Enum(enum_val) + } + } + }) +} + +fn create_enum_value_match(ident: &syn::Ident, variant: &Variant) -> proc_macro2::TokenStream { + let enum_name = ident.to_string(); + let variant_ident = &variant.ident; + let variant_name = variant_ident.to_string(); + + match &variant.fields { + syn::Fields::Named(FieldsNamed { named, .. }) => { + let field_idents = named.iter().map(|f| &f.ident).collect::>(); + let field_names = named + .iter() + .map(|f| f.ident.as_ref().unwrap().to_string()) + .collect::>(); + let field_count = named.len(); + + quote! { + #ident::#variant_ident { #( #field_idents, )* } => { + use rusty_value::*; + + let mut fields = std::collections::HashMap::with_capacity(#field_count); + #( + fields.insert(#field_names.to_string(), #field_idents.into_rusty_value()); + )* + Enum { + name: #enum_name.to_string(), + variant: #variant_name.to_string(), + fields: Fields::Named(fields) + } + } + } + } + syn::Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => { + let field_names = unnamed + .iter() + .enumerate() + .map(|(i, _)| syn::Ident::new(&format!("f{i}"), Span::call_site())) + .collect::>(); + let field_count = unnamed.len(); + + quote! { + #ident::#variant_ident ( #( #field_names, )* ) => { + use rusty_value::*; + + let mut fields = Vec::with_capacity(#field_count); + #( + fields.push(#field_names.into_rusty_value()); + )* + Enum { + name: #enum_name.to_string(), + variant: #variant_name.to_string(), + fields: Fields::Unnamed(fields) + } + } + } + } + syn::Fields::Unit => quote! { + #ident::#variant_ident => { + use rusty_value::*; + + Enum { + name: #enum_name.to_string(), + variant: #variant_name.to_string(), + fields: Fields::Unit + } + } + }, + } +} + fn add_rusty_bound(generics: &Generics) -> WhereClause { let trait_bound: proc_macro2::TokenStream = parse_quote!(rusty_value::RustyValue); diff --git a/src/value.rs b/src/value.rs index 92c43fa..379e64c 100644 --- a/src/value.rs +++ b/src/value.rs @@ -7,26 +7,26 @@ pub enum Value { Enum(Enum), Map(HashMap), List(Vec), - Unit(String), } #[derive(Clone, Debug, PartialEq)] pub struct Enum { pub name: String, pub variant: String, - pub value: Box, + pub fields: Fields, } #[derive(Clone, Debug, PartialEq)] pub struct Struct { pub name: String, - pub fields: StructFields, + pub fields: Fields, } #[derive(Clone, Debug, PartialEq)] -pub enum StructFields { +pub enum Fields { Named(HashMap), Unnamed(Vec), + Unit, } #[derive(Clone, Debug, PartialEq, PartialOrd)] diff --git a/tests/enums.rs b/tests/enums.rs new file mode 100644 index 0000000..97ab0ad --- /dev/null +++ b/tests/enums.rs @@ -0,0 +1,138 @@ +use rusty_value::{Fields, RustyValue, Value}; +use rusty_value_derive::*; + +#[allow(dead_code)] +#[derive(RustyValue)] +enum TestEnumNamed { + Named { foo: String }, + Named2 { foo: String, bar: u64 }, +} + +#[test] +fn it_handles_enums_with_named_fields() { + let enum_value = TestEnumNamed::Named2 { + foo: String::from("hello"), + bar: 12, + }; + let value = enum_value.into_rusty_value(); + dbg!(&value); + + if let Value::Enum(e) = value { + assert_eq!(&e.name, "TestEnumNamed"); + assert_eq!(&e.variant, "Named2"); + + if let Fields::Named(n) = e.fields { + assert_eq!(n.len(), 2); + } else { + panic!("Enum variant doesn't have named fields") + } + } else { + panic!("Value is not an enum") + } +} +#[allow(dead_code)] +#[derive(RustyValue)] +enum TestEnumUnnamed { + Unnamed1(String, u8), + Unnamed2(u8), +} + +#[test] +fn it_handles_enums_with_unamed_fields() { + let enum_value = TestEnumUnnamed::Unnamed1(String::from("hello"), 12); + let value = enum_value.into_rusty_value(); + dbg!(&value); + + if let Value::Enum(e) = value { + assert_eq!(&e.name, "TestEnumUnnamed"); + assert_eq!(&e.variant, "Unnamed1"); + + if let Fields::Unnamed(n) = e.fields { + assert_eq!(n.len(), 2); + } else { + panic!("Enum variant doesn't have unnamed fields") + } + } else { + panic!("Value is not an enum") + } +} + +#[allow(dead_code)] +#[derive(RustyValue)] +enum TestEnumUnit { + Unit1, + Unit2, +} + +#[test] +fn it_handles_unit_enums() { + let enum_val = TestEnumUnit::Unit1; + let value = enum_val.into_rusty_value(); + dbg!(&value); + + if let Value::Enum(e) = value { + assert_eq!(&e.name, "TestEnumUnit"); + assert_eq!(&e.variant, "Unit1"); + + if let Fields::Unit = e.fields { + assert!(true) + } else { + panic!("Enum is variant is not a unit") + } + } else { + panic!("Value is not an enum") + } +} + +#[derive(RustyValue)] +enum TestGeneric { + CloneVar(R), +} + +#[test] +fn it_handles_generic_enums() { + let enum_val = TestGeneric::CloneVar(String::from("test")); + let value = enum_val.into_rusty_value(); + dbg!(&value); + + if let Value::Enum(e) = value { + assert_eq!(&e.name, "TestGeneric"); + assert_eq!(&e.variant, "CloneVar"); + + if let Fields::Unnamed(u) = e.fields { + assert_eq!(u.len(), 1) + } else { + panic!("Enum is variant is not an unnamed enum") + } + } else { + panic!("Value is not an enum") + } +} + +#[allow(dead_code)] +#[derive(RustyValue)] +enum TestMixed { + CloneVar(R), + Unit, + Named { val: R, val2: u8 }, +} + +#[test] +fn it_handles_mixed_enums() { + let enum_val = TestMixed::::Unit; + let value = enum_val.into_rusty_value(); + dbg!(&value); + + if let Value::Enum(e) = value { + assert_eq!(&e.name, "TestMixed"); + assert_eq!(&e.variant, "Unit"); + + if let Fields::Unit = e.fields { + assert!(true) + } else { + panic!("Enum is variant is not a unit") + } + } else { + panic!("Value is not an enum") + } +} diff --git a/tests/structs.rs b/tests/structs.rs index 6c1fc61..a0ef7ac 100644 --- a/tests/structs.rs +++ b/tests/structs.rs @@ -1,4 +1,4 @@ -use rusty_value::{RustyValue, StructFields, Value}; +use rusty_value::{Fields, RustyValue, Value}; use rusty_value_derive::*; #[derive(RustyValue)] @@ -19,7 +19,7 @@ fn it_handles_named_fields() { if let Value::Struct(s) = value { assert_eq!(&s.name, "TestStructNamed"); - if let StructFields::Named(fields) = s.fields { + if let Fields::Named(fields) = s.fields { assert_eq!(fields.len(), 2); } else { panic!("Struct wasn't serialized as named struct") @@ -41,7 +41,7 @@ fn it_handles_unnamed_fields() { if let Value::Struct(s) = value { assert_eq!(&s.name, "TestStructUnnamed"); - if let StructFields::Unnamed(fields) = s.fields { + if let Fields::Unnamed(fields) = s.fields { assert_eq!(fields.len(), 2); } else { panic!("Struct wasn't serialized as unnamed struct") @@ -60,8 +60,13 @@ fn it_handles_unit_structs() { let value = test_struct.into_rusty_value(); dbg!(&value); - if let Value::Unit(s) = value { - assert_eq!(&s, "TestStructUnit"); + if let Value::Struct(s) = value { + assert_eq!(&s.name, "TestStructUnit"); + if let Fields::Unit = s.fields { + assert!(true); + } else { + panic!("Struct wasn't serialized as unit struct") + } } else { panic!("Struct wasn't serialized as struct"); } @@ -81,7 +86,7 @@ fn it_handles_generics() { if let Value::Struct(s) = value { assert_eq!(&s.name, "GenericStruct"); - if let StructFields::Named(fields) = s.fields { + if let Fields::Named(fields) = s.fields { assert_eq!(fields.len(), 1); } else { panic!("Struct wasn't serialized as named struct")