burn project modern lstm
Burn Book: modern-lstm Example
The Burn Book modern-lstm example is a clean, modular implementation of a recurrent neural network using an LSTM (Long Short-Term Memory) architecture. It is built using the Burn deep learning framework in Rust. This example helps demonstrate how to build production-quality, well-structured deep learning models using Rust's powerful type system and modern syntax.
Purpose and Context
The goal of this example is to provide a foundation for training LSTM-based models on sequential data. This could include tasks like:
- Time-series forecasting
- Character-level language modeling
- Sequence classification
Rather than focusing on a specific application, the example emphasizes architectural clarity, separation of concerns, and idiomatic use of the Burn framework.
Model Architecture
The Model struct defined in model.rs contains two key components:
- LSTM Layer: This handles the sequence modeling, learning temporal dependencies across time steps. Input is expected as a 3D tensor with shape [batch_size, sequence_length, input_size].
- Linear Layer: After processing the sequence, the LSTM output is passed through a dense layer to produce the final output—usually logits or predicted values.
The model is defined generically over the backend (B: Backend), allowing it to be used with both CPU and GPU environments by simply changing the backend configuration.
Code Structure
The project is organized cleanly:
- model.rs: Contains the model definition and configuration logic.
- main.rs: Handles training setup, inference logic, and CLI argument parsing.
The ModelConfig struct uses the #[derive(Config)] macro, enabling users to configure model parameters like input_size, hidden_size, and output_size via config files or environment variables. This makes the model flexible and easy to experiment with.
Forward Pass
The forward method of the model performs the computation pipeline:
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
let output = self.lstm.forward(input).output;
self.linear.forward(output)
}
Here’s what happens step by step:
- The input sequence is passed through the LSTM, producing hidden state outputs.
- The LSTM output is then passed through the linear layer.
- The result is typically a 2D tensor with shape [batch * sequence_length, output_size].
Training Integration
Although the model logic is self-contained in model.rs, training is managed in main.rs. This includes:
- Data loading and preprocessing
- Loss function definition
- Optimizer setup
- Training loop using Burn's LearnerBuilder
Burn provides a clean API to abstract repetitive deep learning routines, while still allowing for customization and performance tuning.
Summary
The modern-lstm project in Burn Book is a thoughtfully designed example for building LSTM-based models in Rust. It demonstrates:
- Clear separation between model logic, configuration, and training
- Idiomatic use of Rust’s type system for safe and efficient kbd
- Backend-agnostic architecture using Burn’s trait abstractions
more details of the example
The following is about the Burn Book example: modern-lstm
This example demonstrates a clean, idiomatic LSTM model implementation using the Burn framework in Rust. The kbd is modular, easy to follow, and designed for sequence prediction tasks like language modeling or time series forecasting.
Project Location
burn-book/examples/modern-lstm/
Model Code: model.rs
Imports
use burn::nn;
use burn::nn::{Linear, Lstm, LstmConfig};
use burn::tensor::{backend::Backend, Tensor};
use burn::{module::Module, config::Config};
- Linear – Fully connected layer for output projection
- Lstm – Long Short-Term Memory recurrent layer
- Backend – Trait for hardware abstraction (e.g., CPU, GPU)
- Module – Core trait for neural network modules in Burn
- Config – Used to define hyperparameters in a clean, reusable way
Model Structure
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
lstm: Lstm<B>,
linear: Linear<B>,
}
- #[derive(Module)] – Auto-implements Burn’s neural network interfaces.
- lstm – An LSTM layer that handles sequence processing.
- linear – Projects the LSTM output to the desired output size.
odel Configuration
#[derive(Config)]
pub struct ModelConfig {
pub input_size: usize,
pub hidden_size: usize,
pub output_size: usize,
}
- input_size – Number of features per timestep.
- hidden_size – Size of the LSTM hidden state.
- output_size – Final output dimension (e.g., class count or regression target size).
Mel Initialization
impl ModelConfig {
pub fn init<B: Backend>(&self) -> Model<B> {
let lstm = LstmConfig::new(self.input_size, self.hidden_size).init();
let linear = Linear::new(self.hidden_size, self.output_size);
Model { lstm, linear }
}
}
This function builds a model using the config parameters:
- LstmConfig::new(...).init() – Initializes the recurrent LSTM layer.
- Linear::new(...) – Final layer mapping hidden state to output.
Forward Pass
impl<B: Backend> Model<B> {
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
let output = self.lstm.forward(input).output;
self.linear.forward(output)
}
}
Expected Tensor Shapes:
- input: [batch_size, sequence_length, input_size]
- output: [batch_size * sequence_length, output_size] (usually flattened)
Steps:
- self.lstm.forward processes the sequence input.
- .output extracts the hidden states from the LSTM.
- self.linear.forward maps the hidden state to output logits or predictions.
Summary Table
Component | Description |
Model | Combines LSTM and Linear layers for sequence modeling |
ModelConfig | Defines the model’s hyperparameters like input and output sizes |
Lstm | Processes time-series or sequential data across timesteps |
Linear | Maps LSTM outputs to predictions (e.g., class logits or regression) |