Using Huggingface with Rust

Cover image

Hello world! Today, we’re going to talk about Huggingface with Rust. We’re going to cover the following:

  • Downloading a repo from Huggingface
  • Tokenizing input and outputting tokens using Huggingface via Candle
  • Generating text from tokens
  • Using Candle in a web service

By the end of this article, we'll have a fully working web server for running our model on. Interested in the final code? Check it out here.

What is Huggingface?

Huggingface describes itself simply as:

The AI community, building the future.

Huggingface is both a company, as well as a platform that contributes heavily to the AI and NLP fields through open source and open science. A couple of examples of how they do this:

  • Offering a free platform for free to upload AI models, try other peoples’ AI models and gain a lot of insight generally about how LLMs work
  • They have an in-depth NLP course teaching you how to create and use transformers, datasets and tokenizers (although it’s using Python)
  • Giving back to the community by creating frameworks like Candle (Rust)

Huggingface is a huge driving force within the AI community. As you can see, their platform is a great way to be able to start using AI at home, without paying for anything (initially).

Using Huggingface

Getting started

To use Huggingface with your Rust application requires the hf_hub crate to be added to your Cargo.toml:

cargo add hf_hub

You'll also want the following dependencies which you can grab from this shell snippet:

cargo add anyhow candle-core candle-nn candle-transformers serde serde_json tokenizers -F serde/derive

Before we start initialising the model, let’s talk about models. Models typically have a weights map, which determine the strength of connections within your neural network. For example if we were to go to a HuggingFace repo (assuming we’re logged in and have been granted access), you might see a bunch of files regarding tensors (see below), a JSON file containing model weights as well as some other stuff.

But first before we continue, let’s have a quick look at what tensors actually are - as they’re quite important to know about!

From Wikipedia:

a tensor is an algebraic object that describes a multilinear relationship between sets of algebraic objects related to a vector space.

Practically speaking for us, this just means multidimensional arrays of numbers (an array with multiple indexes, if you will). We can manipulate tensors and shape them to get output. A real-world example of using a tensor might be a picture. A picture can have height, width and color - which we can model as a 3D tensor. If we feed this into a model by turning it into a data representation, the model can find similar photos by comparing values from its training data and output a relevant response.

To deserialize the weight map file, we’ll want to implement a custom deserialization function as the file can contain an unknown amount of keys and values. meaning that we should deserialize it to a serde_json::Value first. If you try to just deserialize it directly, will simply tell you that it got a sequence but expected a map (or something similar).

#[derive(Debug, Deserialize)]
struct Weightmaps {
    #[serde(deserialize_with = "deserialize_weight_map")]
    weight_map: HashSet<String>,
}

// Custom deserializer for the weight_map to directly extract values into a HashSet
fn deserialize_weight_map<'de, D>(deserializer: D) -> Result<HashSet<String>, D::Error>
where
    D: Deserializer<'de>,
{
    let map = serde_json::Value::deserialize(deserializer)?;
    match map {
        serde_json::Value::Object(obj) => Ok(obj
            .values()
            .filter_map(|v| v.as_str().map(ToString::to_string))
            .collect::<HashSet<String>>()),
        _ => Err(serde::de::Error::custom(
            "Expected an object for weight_map",
        )),
    }
}

Next, we’ll write a function to load tensors from a file in the “safetensors” format into our program. Safetensors is a data serialization format for storing tensors safely. Created by Hugging Face, it was made to replace the pickle format. In Python, you may have a PyTorch weight model that gets saved (or “pickled”) into a .bin file with the Python pickle utility. However, this is unsafe and said files may hold malicious code to be executed on un-pickling. Note that repo.get() returns Result<PathBuf>:

pub fn hub_load_safetensors(
    repo: &hf_hub::api::sync::ApiRepo,
    json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
    let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
    let json_file = std::fs::File::open(json_file)?;
    let json: Weightmaps = serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;

    let pathbufs: Vec<std::path::PathBuf> = json
        .weight_map
        .iter()
        .map(|f| repo.get(f).unwrap())
        .collect();

    Ok(pathbufs)
}

Next, we’ll add some code to our program for downloading a mistral-7b model with the current latest revision at the time of writing. Note that the repository download is quite large, clocking in at around 13.4gb! If you’re running this locally, you will want some space to store the model on. You can find the latest version of a repository by going to the repo, clicking on Files and Versions and then going to the commit history. The Mistral-7B-v0.1 repo can be found here.

To make this easier to follow, we’ll first write a function that initialises the ApiRepo for us to download files from. We can do this like so, by creating an ApiBuilder that takes our token.

fn get_repo(token: String) -> Result<ApiRepo> {
    let api = ApiBuilder::new().with_token(Some(token)).build()?;

    let model_id = "mistralai/Mistral-7B-v0.1".to_string();

   api.repo(Repo::with_revision(
        model_id,
        RepoType::Model,
        "26bca36bde8333b5d7f72e9ed20ccda6a618af24".to_string(),
    ))
}

Next, we’ll need to set up our tokenizer. This function will use repo.get() to download the tokenizer file (returning a path) and then we use Tokenizer::from_file to create it:

fn get_tokenizer(repo: &ApiRepo) -> Result<Tokenizer> {
    let tokenizer_filename = repo.get("tokenizer.json")?;

    Ok(Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?)
}

Finally, we’ll create the model itself:

fn initialise_model(token: String) -> Result<AppState> {
    let repo = get_repo(token)?;
    let tokenizer = get_tokenizer(&repo)?;
    let device = Device::Cpu;
    let filenames = hub_load_safetensors(&repo, "model.safetensors.index.json")?;

    // note that here, we'd set this to true if we were using Flash Attention
    // Flash Attention requires the CUDA feature flag to be enabled, but speeds
    // up inference
    let config = Config::config_7b_v0_1(false);

    let model = {
        let dtype = DType::F32;
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
        Mistral::new(&config, vb)?
    };

    Ok((model, device, tokenizer).into())
}

Creating a Token Output Stream

So we’ve downloaded our model, loaded the weight map files, loaded them all in and created our model. But how do we use it?

Internally, when you feed an input to a model, it encodes the input into tokens. Tokens can represent words, characters, data - anything. The main idea is that by converting the data to tokens, it allows a model to parse the data more easily. Most of the more popular pretrained models will often have billions, if not trillions of tokens in training data that allow it to produce sophisticated answers. The training data gives words meaning and allows the model to produce an answer according to the input by comparing the input to its training data.

To get started, we will create a stream for encoding our tokens and outputting a stream of tokens:

pub struct TokenOutputStream {
    tokenizer: tokenizers::Tokenizer,
    tokens: Vec<u32>,
    prev_index: usize,
    current_index: usize,
}

impl TokenOutputStream {
    pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
        Self {
            tokenizer,
            tokens: Vec::new(),
            prev_index: 0,
            current_index: 0,
        }
    }
}

Next, we’ll implement some methods to do the following:

  • Decode tokens into UTF-8 strings
  • Advance the internal index and return a string if there is any text
impl TokenOutputStream {
    fn decode(&self, tokens: &[u32]) -> String {
        match self.tokenizer.decode(tokens, true) {
            Ok(str) => Ok(str),
            Err(err) => candle_core::bail!("cannot decode: {err}"),
        }
    }

    pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
        // if there's nothing there, return an empty string
        let prev_text = if self.tokens.is_empty() {
            String::new()
        } else {
        // otherwise, use the previous decode method to decode tokens to a String
            let tokens = &self.tokens[self.prev_index..self.current_index];
            self.decode(tokens)?
        };

        // add this token to the current list of tokens already processed
        self.tokens.push(token);
        let text = self.decode(&self.tokens[self.prev_index..])?;
        if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
            let text = text.split_at(prev_text.len());
            self.prev_index = self.current_index;
            self.current_index = self.tokens.len();
            Ok(Some(text.1.to_string()))
        } else {
            Ok(None)
        }
    }
}

We’ll also need some auxiliary methods on this struct which we’ll call later on when generating our response text. Some comments have been added below to show what these are used for:

impl TokenOutputStream {
    // resets self-state so as not to del
    fn clear(&mut self) {
        self.tokens.clear();
        self.prev_index = 0;
        self.current_index = 0;
    }

    // get access to inner tokenizer as a reference
    pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
        &self.tokenizer
    }

    // uses get_vocab() to get a hashmap of tokens to indexes
    // then grabs the index value with .get()
    pub fn get_token(&self, token_s: &str) -> Option<u32> {
        self.tokenizer.get_vocab(true).get(token_s).copied()
    }
}

Generating text from tokens

As the final piece of the puzzle, we’ll create a struct for generating text from our tokens. This is arguably the most important part. There is some terminology here that you may want to be acquainted here before we continue, as otherwise you may be confused:

  • temp (or temperature) - lets the model know how accurate we want it to be. The lower the temperature, the less likely the model will hallucinate.
  • top_k - This argument allows the model to only sample from the “top K tokens”, then samples based on probability. A lower K value makes the model more predictable and consistent.
  • top_p - This allows the model to choose from a subset of tokens whose combined probability reaches or exceeds a threshold p. This allows you to create diverse responses while still having relevant context.
  • repeat_penalty - We can decide how much we want to punish repetitive or redundant output here.
  • repeat_last_n - This allows us to choose the size of the context window that we want for our repeat penalty. Note that a token can be anywhere between 1 to 4 words, so don’t make your context window size too large!

To start with generating tokens from text, we can declare our text generator below, like so.

struct TextGeneration {
    model: Mistral,
    device: Device,
    tokenizer: TokenOutputStream,
    logits_processor: LogitsProcessor,
    repeat_penalty: f32,
    repeat_last_n: usize,
}

impl TextGeneration {
    #[allow(clippy::too_many_arguments)]
    fn new(
        model: Mistral,
        tokenizer: Tokenizer,
        seed: u64,
        temp: Option<f64>,
        top_p: Option<f64>,
        _top_k: Option<usize>,
        repeat_penalty: f32,
        repeat_last_n: usize,
        device: &Device,
    ) -> Self {
        let logits_processor = LogitsProcessor::new(seed, temp, top_p);

        Self {
            model,
            tokenizer: TokenOutputStream::new(tokenizer),
            logits_processor,
            repeat_penalty,
            repeat_last_n,
            device: device.clone(),
        }
    }
}

Next, we need to write a run function to actually run our simple pipeline. This will do the following:

  • Turn a prompt into encoded tokens
  • Ensures that the </s> token exists (a sort of “end of tokens” marker)
  • Loops over the sample length, creates a tensor, applies a repeat penalty and attempts to get the next token

To start with, let’s set up our function - we’ll start by clearing the tokenizer of any leftover tokens from previous prompts, then take the prompt and encode it. We then turn the tokens into a vector so that we can process it later on:

impl TextGeneration {
    fn run(mut self, prompt: String, sample_len: usize) -> Result<String, Box<dyn Error>> {
        // clear the tokenizer of any previous input here
        self.tokenizer.clear();
        let mut tokens = self.tokenizer
            .tokenizer()
            .encode(prompt, true)
            .unwrap()
            .get_ids()
            .to_vec();

        println!("Got tokens!");
        // .. more code here!
    }
}

Next, we need to check whether or not the tokenizer has a </s> token - this signals the end of the output. If this isn’t present, the model might try to run infinitely! Definitely not good. Let’s add it in:

let eos_token = match self.tokenizer.get_token("</s>") {
    Some(token) => token,
    None => panic!("cannot find the </s> token"),
};

The next step is to go enumerate through from a range of 0 to sample_len, get the correct token from the start position of where we want to process the tokens from. We then create a new Tensor and unsqueeze it. This adds an additional dimension to the tensor and allows tensor multiplication.

let mut string = String::new();

for index in 0..sample_len {
    let context_size = if index > 0 { 1 } else { tokens.len() };
    let start_pos = tokens.len().saturating_sub(context_size);
    let ctxt = &tokens[start_pos..];
    let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).unwrap();

    // more code to come here
}

We then perform a forward pass to get the value of the output layer from the input data. In a neural network, this would mean traversing through all of the nodes from first to last and doing a calculation based on the output. This allows us to then compute the logits (outputs of a neural network pre-activation function).

let logits = self.model.forward(&input, start_pos).unwrap();

Next, we squeeze the logits. This decreases the number of dimensions in a tensor and allows us to grab the value we want. However, if the repeat penalty is over 1.0, we need to make sure to apply it! (Note here that we haven’t accounted for repeat penalty values under 1.0).

let logits = logits.squeeze(0).unwrap().squeeze(0).unwrap().to_dtype(DType::F32).unwrap();
let logits = if self.repeat_penalty == 1.0 {
    logits
} else {
    let start_at = tokens.len().saturating_sub(self.repeat_last_n);
    candle_transformers::utils
        ::apply_repeat_penalty(&logits, self.repeat_penalty, &tokens[start_at..])
        .unwrap()
};

We then get our next token by sampling the logits, pushing it to our tokens list and checking if it’s the end-of-input token: if it is, break the loop immediately. If not, we add our converted String to the final output! Then we return the output.

fn run(mut self, prompt: String, sample_len: usize) -> Result<String, Box<dyn Error>> {
    // .. previous code here
    for index in 0..sample_len {
    let next_token = self.logits_processor.sample(&logits).unwrap();
    tokens.push(next_token);

    if next_token == eos_token {
        break;
    }

    if let Some(t) = self.tokenizer.next_token(next_token).unwrap() {
        println!("Found a token!");
        string.push_str(&t);
    }
}

string
}

This is quite a long function - if you get stuck, you can find the final code here:

impl TextGeneration {
    fn run(mut self, prompt: String, sample_len: usize) -> Result<String, Box<dyn Error>> {
        self.tokenizer.clear();
        let mut tokens = self
            .tokenizer
            .tokenizer()
            .encode(prompt, true)
            .unwrap()
            .get_ids()
            .to_vec();

        println!("Got tokens!");

        let eos_token = match self.tokenizer.get_token("</s>") {
            Some(token) => token,
            None => panic!("cannot find the </s> token"),
        };

        let mut string = String::new();

        for index in 0..sample_len {
            let context_size = if index > 0 { 1 } else { tokens.len() };
            let start_pos = tokens.len().saturating_sub(context_size);
            let ctxt = &tokens[start_pos..];
            let input = Tensor::new(ctxt, &self.device).unwrap().unsqueeze(0).unwrap();
            let logits = self.model.forward(&input, start_pos).unwrap();
            let logits = logits.squeeze(0).unwrap().squeeze(0).unwrap().to_dtype(DType::F32).unwrap();
            let logits = if self.repeat_penalty == 1. {
                logits
            } else {
                let start_at = tokens.len().saturating_sub(self.repeat_last_n);
                candle_transformers::utils::apply_repeat_penalty(
                    &logits,
                    self.repeat_penalty,
                    &tokens[start_at..],
                ).unwrap()
            };

            let next_token = self.logits_processor.sample(&logits).unwrap();
            tokens.push(next_token);

            if next_token == eos_token {
                break;
            }
           if let Some(t) = self.tokenizer.next_token(next_token).unwrap() {
               println!("Found a token!");
               string.push_str(&t);
           }
         };

        Ok(string)
    }
}

Using Candle with a web service

Using the previous functions we’ve made, we can create a web service using axum and tokio with an endpoint to run our prompt! No local work required.

To install Axum and Tokio you can use this shell snippet:

cargo add axum tokio -F tokio/macros,tokio/rt-multi-thread

To get started, we’ll create an AppState that implements Clone and some helper methods to shorten down code in application-related functions:

#[derive(Clone)]
pub struct AppState {
    model: Mistral,
    device: Device,
    tokenizer: Tokenizer
}

impl From<(Mistral, Device, Tokenizer)> for AppState {
    fn from(e: (Mistral, Device, Tokenizer)) -> Self {
        Self { model: e.0, device: e.1, tokenizer: e.2 }
    }
}

impl From<AppState> for TextGeneration {
    fn from(e: AppState) -> Self {
    Self::new(
        e.model,
        e.tokenizer,
        299792458, // seed RNG
        Some(0.), // temperature
        None, // top_p - Nucleus sampling probability stuff
        None, // Only sample along the top K samples
        1.1, // repeat penalty
        64, // context size to consider for the repeat penalty
        &e.device,
    )
    }
}

Next, we’ll create a Prompt struct that can be deserialized from JSON, then an Axum function endpoint:

#[derive(Deserialize)]
pub struct Prompt {
    prompt: String
}

async fn run_pipeline(
    State(state): State<AppState>,
    Json(Prompt{prompt}): Json<Prompt>
) -> impl IntoResponse {
    let textgen = TextGeneration::from(state);
    textgen.run(prompt, 5).unwrap()
}

Then we add the endpoint to our main function and it’s done:

#[tokio::main]
async fn main() -> Result<()> {
    let Ok(api_token) = std::env::var("HF_TOKEN") else {
        return Err(anyhow::anyhow!("Error getting HF_TOKEN env var"))
    };
    let state = initialise_model(api_token)?;

    let router = Router::new()
        .route("/", get(hello_world))
        .route("/prompt", post(run_pipeline))
        .with_state(state);

    let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:8000").await.unwrap();

    axum::serve(tcp_listener, router).await.unwrap();

    Ok(())
}

Performance considerations

Of course, when you try to run this locally you may notice that CPU performance is somewhat sub-optimal compared to using a high-performance GPU. Generally speaking, while optimal performance can be attained on a GPU, there are times where this is infeasible.

GPU usage is gated by the use of the “CUDA” - Compute Unified Device Architecture - toolkit (and is a feature flag of candle_core). However, this would be difficult to use on the web unless you can have CUDA preinstalled. Nevertheless, it unlocks performance enhancing features outside of just your GPU. For example, Flash Attention, a technique that improves inference speeds with better parallelism and work partitioning. You can find the Arvix paper here.

This is even more noticeable when you’re running in debug mode. To alleviate this, you can run in release mode using cargo shuttle run --release (or without Shuttle: cargo run --release). This will run the release profile of your application, which is much more optimised.

Finishing up

Thanks for reading! Huggingface makes it much easier to deploy your own and other peoples’ models for absolutely free (disregarding electricity usage). Taking advantage of AI is becoming easier than ever!

This blog post is powered by shuttle - The Rust-native, open source, cloud development platform. If you have any questions, or want to provide feedback, join our Discord server!
Share article
rocket

Build the Future of Backend Development with us

Join the movement and help revolutionize the world of backend development. Together, we can create the future!