Skip to content

My first Rust macro.

Background

I am new to Rust programming, and currently working on my first large rust project. As part of my research we are building a data streaming framework(will be released as OSS after the paper is published). In this system there are different types of messages that can be passed between different componenets of the system.

For example:-

Message
pub enum FromWorkerToCoord {
    Done(String),
    ReconfAck((String, Box<ReconfMsg>)),
    PollTPResponse((String, RecvThrougputMsg)),
    Handshake(String),
}

For each such message we had to write a Codec. A codec can encode an object to bytes that can be sent over wire, as well as decode bytes recieved over the wire to an object. We are using bincode for this purpose.

A Codec implementation for the above struct looks something like this:-

Codec
#[derive(Clone)]
pub struct FromWorkerToCoordCodec {
    config: bincode::config::Configuration,
    length_codec: tokio_util::codec::LengthDelimitedCodec,
}

impl FromWorkerToCoordCodec {
    pub fn new() -> Self {
        FromWorkerToCoordCodec {
            config: bincode::config::standard(),
            length_codec: tokio_util::codec::LengthDelimitedCodec::builder()
                .length_field_length(4)
                .max_frame_length(u32::MAX as usize)
                .new_codec(),
        }
    }
}

impl tokio_util::codec::Encoder<FromWorkerToCoord> for FromWorkerToCoordCodec {
    type Error = std::io::Error;

    fn encode(
        &mut self,
        item: #msg_name,
        dst: &mut bytes::BytesMut,
    ) -> Result<(), Self::Error> {
        ...
        }
}
impl tokio_util::codec::Decoder for FromWorkerToCoordCodec {
    type Item = FromWorkerToCoord;
    type Error = std::io::Error;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FromWorkerToCoord>, Self::Error> {
      ...
    }
}

Motivation

We had bunch of such messages and needed a similar Codec for all of them. And there is no difference in logic for different Codecs, they all use bincode internally, and the only difference is the Message that they are encoding/decoding. Hence there is lot of code duplication. I had heared about macros in rust that allows us to prevent this.

Macros primer

A macro is a metaprogramming tool, that allows us to write code that generates code. The macro expansion happens at compile time so there is no runtime overhead.

A very good introduction to Rust macro can be found here. The summary of the above blog is as follows:- There are two types of macros:-

  1. Declarative:- Replaces the macro invocation by the code generated by the given marco. It internally uses pattern matching based logic to generate the code.
  2. Procedural:- Takes in an AST and generates a new AST. Since it has access to the entire AST it can use arbitrary logic to generate code.
Type Usecase Examples
declarative Pattern matching for repetitve code vec!, println!
procedural-attribute Modify functions or structs using attributes #[tokio::main]
procedural-derive Automatic implementation of traits #[derive(Clone)]
procedural-function Generate new functions using some parameters

Implementation

Since I had to implement Encoder and Decoder trait I went with derive_macro.

  1. Created a new crate using cargo new --lib codec-derive
  2. Instructed the compiler that this library implements a procedural macro by adding the following lines in Cargo.toml
    [lib]
    proc-macro = true
    Created a new crate using `cargo new --lib codec-derive`
    
    Now lets see the code. Click on the + symbol to read the explanation.
Entrypoint
use proc_macro::{self, TokenStream};
use syn::parse_macro_input; // (1)
use quote::{format_ident, quote};  // (2)

#[proc_macro_derive(Codec)]
pub fn codec_derive(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as syn::DeriveInput);  // (3)

    let name: &syn::Ident = &input.ident;
    let codec_name: syn::Ident = format_ident!("{}Codec", name);  // (4)

    let generics = &input.generics;
    // (5)
    if generics.params.is_empty() {
        gen_without_generic(name, codec_name)
    } else {
        gen_with_generic(name, codec_name)
    }
}
  1. syn parses the TokenStream into an AST.
  2. quote converts AST into a TokenStream.
  3. Using syn library first we convert the input token stream into an AST.
  4. We then generate the name of the output struct as {InputStructName}Codec. So the name of the generated Codec struct for FromWorkerToCoord will be FromWorkerToCoordCodec.
  5. Codegen logic.

Some of our messages were generic, for example:-

Generic Message
1
2
3
4
pub enum FromPeerToPeer<S> {
    Handshake(String),
    State((StateBatch<S>, ReconfId)),
}

To deal with this I had to create seperate logic for messages with a generic parameter and without generic parameter. There has to be a better way to do this, I will refactor this when I learn that.

Now lets see the main code generation code.

Codegen without generics
fn gen_without_generic(msg_name: &syn::Ident, codec_name: syn::Ident) -> TokenStream {
    #[derive(Clone)]
    pub struct #codec_name {   // (2)
        config: bincode::config::Configuration,
        length_codec: tokio_util::codec::LengthDelimitedCodec,
    }

    impl #codec_name {
        pub fn new() -> Self {
            #codec_name {
                config: bincode::config::standard(),
                length_codec: tokio_util::codec::LengthDelimitedCodec::builder()
                    .length_field_length(4)
                    .max_frame_length(u32::MAX as usize)
                    .new_codec(),
            }
        }
    }

    impl tokio_util::codec::Encoder<#msg_name> for #codec_name {
        type Error = std::io::Error;

        fn encode(
            &mut self,
            item: #msg_name,    // (3)
            dst: &mut bytes::BytesMut,
        ) -> Result<(), Self::Error> {
            let encoded_data = bincode::encode_to_vec(&item, self.config)
                .map_err(|_| std::io::Error::new(
                    std::io::ErrorKind::InvalidData,
                    "Failed to encode data",
                ))?;

            self.length_codec
                .encode(bytes::Bytes::from(encoded_data), dst)
                .map_err(|_| std::io::Error::new(
                    std::io::ErrorKind::InvalidData,
                    "Couldn't encode length-delimited data",
                ))?;

                Ok(())
            }
    }
    impl tokio_util::codec::Decoder for #codec_name {
        type Item = #msg_name;
        type Error = std::io::Error;

        fn decode(&mut self, src: &mut BytesMut) -> Result<Option<#msg_name>, Self::Error> {
            let frame = match self.length_codec.decode(src).map_err(|_| {
                std::io::Error::new(
                    ErrorKind::InvalidData,
                    "Couldn't decode length-delimited data",
                )
            })? {
                Some(frame) => frame,
                None => return Ok(None), // Not enough data yet
            };

            let (message, _) = bincode::decode_from_slice(&frame, self.config).map_err(|_| {
                std::io::Error::new(
                    ErrorKind::InvalidData,
                    "Couldn't decode message from bitcode",
                )
            })?;

            Ok(Some(message))
        }
    }
}

  1. Remember that we need to return the code as a TokenStream. quote provides a very handy macro that takes in a rust looking template and produces a valid Rust TokenStream.

  2. quote will replace this with the codec name.

  3. quote will replace this with the message name.

Thats it!!!.

Usage

Now that we have implemented our derive_macro we can use it as follows:-

#[derive(Debug, PartialEq, Eq, Clone, Encode, Decode, Codec)]
pub enum FromCoordToWorker {
    Done,
    PollTPRequest(u64),
    Reconf(Box<ReconfMsg>),
    Exit,
}
Now when we create a new message we just need to add #[derive(Codec)] and we will have a codec generated for that message. Note that it is required that message also implements Encode and Decode in order to derive a Codec.

Comments