Re-export context_deserialize_derive inside context_deserialize (#7852)

Re-export `context_deserialize_derive` inside of `context_deserialize` so they are both available from the same interface, which matches how popular crates (like `serde`) handle this.

This also nests both crates inside a new `context_deserialize` directory which will make it easier to eventually spin out into a different repo (if/when) we decide to do that (plus I prefer it aesthetically).
This commit is contained in:
Mac L
2025-08-12 15:16:19 +10:00
committed by GitHub
parent 918121e313
commit 152f2bb2e4
21 changed files with 56 additions and 39 deletions

View File

@@ -0,0 +1,16 @@
[package]
name = "context_deserialize_derive"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[dependencies]
quote = { workspace = true }
syn = { workspace = true }
[dev-dependencies]
context_deserialize = { path = "../context_deserialize" }
serde = { workspace = true }
serde_json = "1.0"

View File

@@ -0,0 +1,118 @@
extern crate proc_macro;
extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse_macro_input, AttributeArgs, DeriveInput, GenericParam, LifetimeDef, Meta, NestedMeta,
WhereClause,
};
#[proc_macro_attribute]
pub fn context_deserialize(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as AttributeArgs);
let input = parse_macro_input!(item as DeriveInput);
let ident = &input.ident;
let mut ctx_types = Vec::new();
let mut explicit_where: Option<WhereClause> = None;
for meta in args {
match meta {
NestedMeta::Meta(Meta::Path(p)) => {
ctx_types.push(p);
}
NestedMeta::Meta(Meta::NameValue(nv)) if nv.path.is_ident("bound") => {
if let syn::Lit::Str(lit_str) = &nv.lit {
let where_string = format!("where {}", lit_str.value());
match syn::parse_str::<WhereClause>(&where_string) {
Ok(where_clause) => {
explicit_where = Some(where_clause);
}
Err(err) => {
return syn::Error::new_spanned(
lit_str,
format!("Invalid where clause '{}': {}", lit_str.value(), err),
)
.to_compile_error()
.into();
}
}
} else {
return syn::Error::new_spanned(
&nv,
"Expected a string literal for `bound` value",
)
.to_compile_error()
.into();
}
}
_ => {
return syn::Error::new_spanned(
&meta,
"Expected paths or `bound = \"...\"` in #[context_deserialize(...)]",
)
.to_compile_error()
.into();
}
}
}
if ctx_types.is_empty() {
return quote! {
compile_error!("Usage: #[context_deserialize(Type1, Type2, ..., bound = \"...\")]");
}
.into();
}
let original_generics = input.generics.clone();
// Clone and clean generics for impl use (remove default params)
let mut impl_generics = input.generics.clone();
for param in impl_generics.params.iter_mut() {
if let GenericParam::Type(ty) = param {
ty.eq_token = None;
ty.default = None;
}
}
// Ensure 'de lifetime exists in impl generics
let has_de = impl_generics
.lifetimes()
.any(|LifetimeDef { lifetime, .. }| lifetime.ident == "de");
if !has_de {
impl_generics.params.insert(0, syn::parse_quote! { 'de });
}
let (_, ty_generics, _) = original_generics.split_for_impl();
let (impl_gens, _, _) = impl_generics.split_for_impl();
// Generate: no `'de` applied to the type name
let mut impls = quote! {};
for ctx in ctx_types {
impls.extend(quote! {
impl #impl_gens context_deserialize::ContextDeserialize<'de, #ctx>
for #ident #ty_generics
#explicit_where
{
fn context_deserialize<D>(
deserializer: D,
_context: #ctx,
) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
<Self as serde::Deserialize>::deserialize(deserializer)
}
}
});
}
quote! {
#input
#impls
}
.into()
}

View File

@@ -0,0 +1,93 @@
use context_deserialize::{context_deserialize, ContextDeserialize};
use serde::{Deserialize, Serialize};
#[test]
fn test_context_deserialize_derive() {
type TestContext = ();
#[context_deserialize(TestContext)]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Test {
field: String,
}
let test = Test {
field: "test".to_string(),
};
let serialized = serde_json::to_string(&test).unwrap();
let deserialized =
Test::context_deserialize(&mut serde_json::Deserializer::from_str(&serialized), ())
.unwrap();
assert_eq!(test, deserialized);
}
#[test]
fn test_context_deserialize_derive_multiple_types() {
#[allow(dead_code)]
struct TestContext1(u64);
#[allow(dead_code)]
struct TestContext2(String);
// This will derive:
// - ContextDeserialize<TestContext1> for Test
// - ContextDeserialize<TestContext2> for Test
// by just leveraging the Deserialize impl
#[context_deserialize(TestContext1, TestContext2)]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Test {
field: String,
}
let test = Test {
field: "test".to_string(),
};
let serialized = serde_json::to_string(&test).unwrap();
let deserialized = Test::context_deserialize(
&mut serde_json::Deserializer::from_str(&serialized),
TestContext1(1),
)
.unwrap();
assert_eq!(test, deserialized);
let deserialized = Test::context_deserialize(
&mut serde_json::Deserializer::from_str(&serialized),
TestContext2("2".to_string()),
)
.unwrap();
assert_eq!(test, deserialized);
}
#[test]
fn test_context_deserialize_derive_bound() {
use std::fmt::Debug;
struct TestContext;
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Inner {
value: u64,
}
#[context_deserialize(
TestContext,
bound = "T: Serialize + for<'a> Deserialize<'a> + Debug + PartialEq"
)]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Wrapper<T> {
inner: T,
}
let val = Wrapper {
inner: Inner { value: 42 },
};
let serialized = serde_json::to_string(&val).unwrap();
let deserialized = Wrapper::<Inner>::context_deserialize(
&mut serde_json::Deserializer::from_str(&serialized),
TestContext,
)
.unwrap();
assert_eq!(val, deserialized);
}