freeradiantbunny.org

freeradiantbunny.org/blog

burn project modern lstm

Burn Book: modern-lstm Example

The Burn Book modern-lstm example is a clean, modular implementation of a recurrent neural network using an LSTM (Long Short-Term Memory) architecture. It is built using the Burn deep learning framework in Rust. This example helps demonstrate how to build production-quality, well-structured deep learning models using Rust's powerful type system and modern syntax.

Purpose and Context

The goal of this example is to provide a foundation for training LSTM-based models on sequential data. This could include tasks like:

Rather than focusing on a specific application, the example emphasizes architectural clarity, separation of concerns, and idiomatic use of the Burn framework.

Model Architecture

The Model struct defined in model.rs contains two key components:

  1. LSTM Layer: This handles the sequence modeling, learning temporal dependencies across time steps. Input is expected as a 3D tensor with shape [batch_size, sequence_length, input_size].
  2. Linear Layer: After processing the sequence, the LSTM output is passed through a dense layer to produce the final output—usually logits or predicted values.

The model is defined generically over the backend (B: Backend), allowing it to be used with both CPU and GPU environments by simply changing the backend configuration.

Code Structure

The project is organized cleanly:

The ModelConfig struct uses the #[derive(Config)] macro, enabling users to configure model parameters like input_size, hidden_size, and output_size via config files or environment variables. This makes the model flexible and easy to experiment with.

Forward Pass

The forward method of the model performs the computation pipeline:

pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
	let output = self.lstm.forward(input).output;
	self.linear.forward(output)
	}

Here’s what happens step by step:

  1. The input sequence is passed through the LSTM, producing hidden state outputs.
  2. The LSTM output is then passed through the linear layer.
  3. The result is typically a 2D tensor with shape [batch * sequence_length, output_size].

Training Integration

Although the model logic is self-contained in model.rs, training is managed in main.rs. This includes:

Burn provides a clean API to abstract repetitive deep learning routines, while still allowing for customization and performance tuning.

Summary

The modern-lstm project in Burn Book is a thoughtfully designed example for building LSTM-based models in Rust. It demonstrates:

more details of the example

The following is about the Burn Book example: modern-lstm

This example demonstrates a clean, idiomatic LSTM model implementation using the Burn framework in Rust. The kbd is modular, easy to follow, and designed for sequence prediction tasks like language modeling or time series forecasting.

Project Location

burn-book/examples/modern-lstm/

Model Code: model.rs

Imports

use burn::nn;
	use burn::nn::{Linear, Lstm, LstmConfig};
	use burn::tensor::{backend::Backend, Tensor};
	use burn::{module::Module, config::Config};
    

Model Structure

#[derive(Module, Debug)]
	pub struct Model<B: Backend> {
	lstm: Lstm<B>,
	linear: Linear<B>,
	}

odel Configuration

#[derive(Config)]
	pub struct ModelConfig {
	pub input_size: usize,
	pub hidden_size: usize,
	pub output_size: usize,
	}

Mel Initialization

impl ModelConfig {
	pub fn init<B: Backend>(&self) -> Model<B> {
	let lstm = LstmConfig::new(self.input_size, self.hidden_size).init();
	let linear = Linear::new(self.hidden_size, self.output_size);
	Model { lstm, linear }
	}
	}

This function builds a model using the config parameters:

Forward Pass

impl<B: Backend> Model<B> {
	pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
	let output = self.lstm.forward(input).output;
	self.linear.forward(output)
	}
	}

Expected Tensor Shapes:

Steps:

  1. self.lstm.forward processes the sequence input.
  2. .output extracts the hidden states from the LSTM.
  3. self.linear.forward maps the hidden state to output logits or predictions.

Summary Table

Component Description
Model Combines LSTM and Linear layers for sequence modeling
ModelConfig Defines the model’s hyperparameters like input and output sizes
Lstm Processes time-series or sequential data across timesteps
Linear Maps LSTM outputs to predictions (e.g., class logits or regression)