blob: 53bfdba636e614b188a55b8c69fcd2b5fd766803 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Derives a 9P wire format encoding for a struct by recursively calling
//! `WireFormat::encode` or `WireFormat::decode` on the fields of the struct.
//! This is only intended to be used from within the `p9` crate.
#![recursion_limit = "256"]
extern crate proc_macro;
extern crate proc_macro2;
#[macro_use]
extern crate quote;
#[cfg_attr(test, macro_use)]
extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::Tokens;
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Fields, Ident};
/// The function that derives the actual implementation.
#[proc_macro_derive(P9WireFormat)]
pub fn p9_wire_format(input: TokenStream) -> TokenStream {
p9_wire_format_inner(syn::parse(input).unwrap()).into()
}
fn p9_wire_format_inner(input: DeriveInput) -> Tokens {
if !input.generics.params.is_empty() {
return quote! {
compile_error!("derive(P9WireFormat) does not support generic parameters");
};
}
let var = quote!(self);
let ident = input.ident;
let name = quote!(#ident);
let call_site = Span::call_site();
let import = quote_spanned!(call_site=> use protocol::WireFormat; );
let wire_format = quote_spanned!(call_site=> WireFormat);
let byte_size_impl = byte_size_sum(&input.data, &wire_format, &var);
let encode_impl = encode_wire_format(&input.data, &wire_format, &var);
let decode_impl = decode_wire_format(&input.data, &wire_format, &name);
let scope = Ident::from(format!("wire_format_{}", ident).to_lowercase());
quote! {
mod #scope {
extern crate std;
use self::std::io;
use self::std::result::Result::Ok;
use super::#name;
#import
impl #wire_format for #name {
fn byte_size(&#var) -> u32 {
#byte_size_impl
}
fn encode<W: io::Write>(&#var, _writer: &mut W) -> io::Result<()> {
#encode_impl
}
fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
#decode_impl
}
}
}
}
}
// Generate code that recursively calls byte_size on every field in the struct.
fn byte_size_sum(data: &Data, wire_format: &Tokens, var: &Tokens) -> Tokens {
let def_site = Span::def_site();
let call_site = Span::call_site();
if let Data::Struct(ref data) = *data {
if let Fields::Named(ref fields) = data.fields {
let fields = fields.named.iter().map(|f| {
let name = f.ident;
let access = quote_spanned!(call_site=> #var.#name);
let span = f.span().resolved_at(def_site);
quote_spanned!(span=> #wire_format::byte_size(&#access) )
});
quote! {
0 #(+ #fields)*
}
} else {
unimplemented!();
}
} else {
unimplemented!();
}
}
// Generate code that recursively calls encode on every field in the struct.
fn encode_wire_format(data: &Data, wire_format: &Tokens, var: &Tokens) -> Tokens {
let def_site = Span::def_site();
let call_site = Span::call_site();
if let Data::Struct(ref data) = *data {
if let Fields::Named(ref fields) = data.fields {
let fields = fields.named.iter().map(|f| {
let name = f.ident;
let access = quote_spanned!(call_site=> #var.#name);
let span = f.span().resolved_at(def_site);
quote_spanned!(span=> #wire_format::encode(&#access, _writer)?; )
});
quote! {
#(#fields)*
Ok(())
}
} else {
unimplemented!();
}
} else {
unimplemented!();
}
}
// Generate code that recursively calls decode on every field in the struct.
fn decode_wire_format(data: &Data, wire_format: &Tokens, name: &Tokens) -> Tokens {
let def_site = Span::def_site();
let call_site = Span::call_site();
if let Data::Struct(ref data) = *data {
if let Fields::Named(ref fields) = data.fields {
let values = fields.named.iter().map(|f| {
let name = f.ident;
let access = quote_spanned!(call_site=> #name);
let span = f.span().resolved_at(def_site);
quote_spanned!(span=> let #access = #wire_format::decode(_reader)?; )
});
let members = fields.named.iter().map(|f| {
let name = f.ident;
quote_spanned!(call_site=>
#name: #name,
)
});
quote! {
#(#values)*
Ok(#name {
#(#members)*
})
}
} else {
unimplemented!();
}
} else {
unimplemented!();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn byte_size() {
let input: DeriveInput = parse_quote! {
struct Item {
ident: u32,
with_underscores: String,
other: u8,
}
};
let var = quote!(self);
let wire_format = quote!(WireFormat);
let expected = quote! {
0
+ WireFormat::byte_size(&self.ident)
+ WireFormat::byte_size(&self.with_underscores)
+ WireFormat::byte_size(&self.other)
};
assert_eq!(byte_size_sum(&input.data, &wire_format, &var), expected);
}
#[test]
fn encode() {
let input: DeriveInput = parse_quote! {
struct Item {
ident: u32,
with_underscores: String,
other: u8,
}
};
let var = quote!(self);
let wire_format = quote!(WireFormat);
let expected = quote! {
WireFormat::encode(&self.ident, _writer)?;
WireFormat::encode(&self.with_underscores, _writer)?;
WireFormat::encode(&self.other, _writer)?;
Ok(())
};
assert_eq!(
encode_wire_format(&input.data, &wire_format, &var),
expected
);
}
#[test]
fn decode() {
let input: DeriveInput = parse_quote! {
struct Item {
ident: u32,
with_underscores: String,
other: u8,
}
};
let name = quote!(Item);
let wire_format = quote!(WireFormat);
let expected = quote! {
let ident = WireFormat::decode(_reader)?;
let with_underscores = WireFormat::decode(_reader)?;
let other = WireFormat::decode(_reader)?;
Ok(Item {
ident: ident,
with_underscores: with_underscores,
other: other,
})
};
assert_eq!(
decode_wire_format(&input.data, &wire_format, &name),
expected
);
}
#[test]
fn end_to_end() {
let input: DeriveInput = parse_quote! {
struct Niijima_先輩 {
a: u8,
b: u16,
c: u32,
d: u64,
e: String,
f: Vec<String>,
g: Nested,
}
};
let expected = quote! {
mod wire_format_niijima_先輩 {
extern crate std;
use self::std::io;
use self::std::result::Result::Ok;
use super::Niijima_先輩;
use protocol::WireFormat;
impl WireFormat for Niijima_先輩 {
fn byte_size(&self) -> u32 {
0
+ WireFormat::byte_size(&self.a)
+ WireFormat::byte_size(&self.b)
+ WireFormat::byte_size(&self.c)
+ WireFormat::byte_size(&self.d)
+ WireFormat::byte_size(&self.e)
+ WireFormat::byte_size(&self.f)
+ WireFormat::byte_size(&self.g)
}
fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
WireFormat::encode(&self.a, _writer)?;
WireFormat::encode(&self.b, _writer)?;
WireFormat::encode(&self.c, _writer)?;
WireFormat::encode(&self.d, _writer)?;
WireFormat::encode(&self.e, _writer)?;
WireFormat::encode(&self.f, _writer)?;
WireFormat::encode(&self.g, _writer)?;
Ok(())
}
fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
let a = WireFormat::decode(_reader)?;
let b = WireFormat::decode(_reader)?;
let c = WireFormat::decode(_reader)?;
let d = WireFormat::decode(_reader)?;
let e = WireFormat::decode(_reader)?;
let f = WireFormat::decode(_reader)?;
let g = WireFormat::decode(_reader)?;
Ok(Niijima_先輩 {
a: a,
b: b,
c: c,
d: d,
e: e,
f: f,
g: g,
})
}
}
}
};
assert_eq!(p9_wire_format_inner(input), expected);
}
}