Finetuning a Large Language Model to Write Emails

May 8, 2024

As a member of UBC Biztech’s partnerships team, we take pride in ensuring our partners are treated with the utmost care. An essential part of maintaining the quality of our outreach is making sure that briefing/confirmation emails sent to all partners involved in events are as detailed and personalized as possible.

This is usually done in two phases:

  • In phase 1, one would usually gather relevant information for the email, create a rough outline of the email’s structure, and then fill out the gaps with details from the information gathered.
  • In phase 2, a greater focus is made on personalizing the message to each partner, making modifications wherever necessary (for example, if a partner is a UBC alumnus, we’d usually add something about “returning to your old stomping grounds” or “coming back to home territory”)

The first phase usually takes the longest. I’ve found this the case because you’re starting from 0 to 1, trying to figure out a message format that works and is concise for partners to skim through easily.

Thus, I wanted to make the process of creating this first phase draft easier and faster. Sounds like a job for LLMs if you ask me.

Why can’t you use something like ChatGPT instead?

You definitely can. With very simple prompting, a few example emails, and clear instructions, you could absolutely do this in ChatGPT.

However, I wanted a solution that:

  • Avoids putting sensitive information like partners’ names or emails in remote locations (like OpenAI servers)
  • Doesn’t require few-shot learning or pre-prompting
  • Can receive inputs in a short-form format (instead of giving instructions as sentences, I want to be lazy and just write short phrases of the email’s content)

For example, a prompt-completion pair would look something like this:

prompt: "[ORIGINAL] None [CONTEXT] to: [email protected], about: attend HelloHacks as mentor, details: application successful, would love to invite, date is May 1-2 2025, hellohacks is 2 day hackathon for beginners, appreciate your expertise"  
completion: "Hi John,  
  
Thanks for reaching out! We are excited to have you join us as a mentor at  
HelloHacks. Your application was successful and we would love for you to   
attend on May 1-2 2025.  
  
HelloHacks is UBC's premiere hackathon, aimed at creating an inclusive   
and supportive environment for students with little to no coding   
experience. We hope that your expertise will help guide our teams to   
success.  
  
Please let us know if you have any questions or would like further   
information about the event. We look forward to hearing back from you.  
  
Best regards,  
Garreth Lee   
UBC Biztech Partnerships Lead"  

From this completion, we can quickly start adding personalized touches and other relevant information, allowing for greater outreach productivity.

Data Collection & Preprocessing

What data should we use for finetuning?

The dataset I decided to use is all my emails sent out to partners from May 2023 to April 2024. Exporting email data was relatively simple using Google Takeout, ultimately resulting in .mbox files corresponding to each message category (Archived, Drafts, Inbox, Sent, and Starred).

I wrote a simple reader that reads a .mbox file and returns a list of email objects.

# Adapted from: https://stackoverflow.com/questions/59681461/read-a-big-mbox-file-with-python  
  
import email  
import time  
from email.policy import default  
  
class MboxReader:  
      
    def __init__(self, path):  
        self.file_obj = open(path, "rb")  
        assert self.file_obj.readline().startswith(b"From")  
        self.file_obj = open(path, "rb")  
          
    def __iter__(self):  
        return iter(self.__next__())  
          
    def __next__(self):  
        lines = []  
        while True:  
            line = self.file_obj.readline()  
            # If we have a message stored and we run into a new message  
            if line == b'' or line.startswith(b'From '):  
                yield email.message_from_bytes(b''.join(lines), policy=default)  
                lines = []  
                if line == b"":  
                    break  
            lines.append(line)         
      
    def __enter__(self):  
        return self  
      
    def __exit__(self, exc_type, exc_value, exc_traceback):  
        self.file_obj.close()  
  
  
def get_emails(path = 'takeout/Mail/Inbox.mbox'):  
    with MboxReader(path) as reader:  
        for item in reader:  
            if item:  
                yield item  
  
emails = get_emails()

Representing the Data

A .mbox file contains multiple objects or entries, where each entry can be considered a “thread” of messages — an email, the reply to that email, the reply to that reply, and so on.

Each message in the thread also includes multiple content types, corresponding to plaintext, HTML, or alternative formats. Thus, we also need to remove duplicated messages that show up in different formats, since the underlying content is the same and would be redundant.

for msg_obj in email_thread:  
  seen_plaintext = False  
  for part in msg_obj.walk():  
      if part.get_content_type() == "text/plain":  
          process_part_plaintext(part)  
          seen_plaintext = True  
      # We don't want duplicate emails in plaintext & html form  
      elif part.get_content_type() == "text/html" and not seen_plaintext:  
          process_part_html(part)

Between plaintext and HTML, there weren’t any major differences in the regex functions used. However, the plaintext threads had multiple encodings and I had to hack together a solution that tried to ‘guess’ the proper encoding.

Using Regex, I was able to separate each thread into individual objects, each containing a sender and a message.

MARKDOWN_SYMBOLS_REGEX = re.compile("**")  
NON_ASCII_REGEX = re.compile('[^x00-x7F]+')  
PREVIOUS_REPLY_REGEX = re.compile("^>+ *", flags = re.MULTILINE)  
URL_REGEX = re.compile("<?(https?://(?:www.|(?!www))[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9].[^s]{2,}|www.[a-zA-Z0-9][a-zA-Z0-9-]+[a-zA-Z0-9].[^s]{2,}|https?://(?:www.|(?!www))[a-zA-Z0-9]+.[^s]{2,}|www.[a-zA-Z0-9]+.[^s]{2,})>?")  
RANDOM_NEWLINE_REGEX = re.compile("(S)n(S)")  
STRIP_NEWLINE_REGEX = re.compile("s{3,}")    
  
...  
  
thread_str = PREVIOUS_REPLY_REGEX.sub("", thread_str)  
thread_str = URL_REGEX.sub("[URL]", thread_str)  
thread_str = RANDOM_NEWLINE_REGEX.sub("g<1> g<2>", thread_str)  
thread_str = NON_ASCII_REGEX.sub("", thread_str)  
thread_str = MARKDOWN_SYMBOLS_REGEX.sub("", thread_str)  
thread_str = STRIP_NEWLINE_REGEX.sub("nn", thread_str)  
  
individual_email_list = separate_emails(thread_str)

The preprocessing resulted in an EmailThread object containing one or more EmailMessage objects. An example is seen below:

>> thread  
<EmailThread from Alice <alice@gmail.com> to Bob <bob@gmail.com>>  
>> thread.messages  
[<EmailMessage from Alice <alice@gmail.com>: Hi Bob...>  
 <EmailMessage from Bob: Hello, Thank you for ...>]  
>> message = thread.messages[1]  
>> message.sender  
"Bob"  
>> message.message  
"Hello, Thank you for saying hi, Alice."  

The reason I went through all the trouble of formatting the emails as a chain of replies, instead of individual and independent messages, is so that the model can learn to reply to an existing message, instead of only synthesizing new emails from scratch.

We want the fine-tuned model to learn two things:

  • How do I reply to an email when given the source email + information on the content?
  • How do I send an email from scratch when given content information?

To address these questions, we’ll collect all email-reply pairs where I was the replier AND all emails where I was the first sender. This is where the EmailThread data structure shines since we can just iterate over consecutive pairs to find email-reply pairs that fit our criteria.

import pandas as pd  
  
# 'original' is the source email  
# 'generation' is the reply email  
data = pd.DataFrame({col:[] for col in ["original", "generation"]})  
  
for thread in email_threads:  
    if len(thread) > 0:  
        if "garreth" in thread[0].sender.lower():  
            data = pd.concat([data, pd.DataFrame(dict(original = [None], generation = [thread[0].message]))], ignore_index = True)  
        for original, reply in zip(thread[1:], thread[2:]):  
            if "garreth" in reply.sender.lower() and "garreth" not in original.sender.lower():  
                data = pd.concat([data, pd.DataFrame(dict(original = [original.message], generation = [reply.message]))], ignore_index = True)

Data Deduplication

Unfortunately, a lot of the data is redundant. If I have an email thread and someone replies to me, instead of modifying the existing thread, it creates a new thread with that additional reply.

Due to the overlapping nature of email threads, there needs to be deduplication of the data so we don’t finetune the model using redundant data.

The deduplication pipeline will have several steps:

  1. Exact dedup (exact string comparison, keeping first among duplicates)
  2. Jaccard Similarity dedup (we can get away with this since most duplicates have very little difference)
import re  
  
# List to store clusters of indices  
clusters = []  
SIMILARITY_THRESHOLD = 0.65  
  
# Regular expression pattern to match signatures - avoid false positives  
signature_regex = re.compile("Best,(.*)Partnerships Director", flags=re.DOTALL)  
  
def jaccard_similarity(list1, list2):  
    """Calculate the Jaccard similarity between two lists."""  
    intersection_cardinality = len(set(list1).intersection(list2))  
    union_cardinality = len(set(list1).union(list2))  
    if union_cardinality == 0:  
        return -1  
    return intersection_cardinality / float(union_cardinality)  
  
# Loop through pairs of generations to find similar signatures  
for i, gen1 in enumerate(data_dedup["generation"]):  
    for j, gen2 in enumerate(data_dedup["generation"]):  
        if i != j and all(i not in cluster and j not in cluster for cluster in clusters):  
            # Here, we strip away the signature   
            # which is always the same and might become  
            # a false signal for similarity  
            sgen1 = signature_regex.sub("", gen1)  
            sgen2 = signature_regex.sub("", gen2)  
            if jaccard_similarity(sgen1, sgen2) >= SIMILARITY_THRESHOLD:  
                for cluster in clusters:  
                    if i in cluster:  
                        cluster.add(j)  
                        break  
                    elif j in cluster:  
                        cluster.add(i)  
                        break  
                else:  
                    new_cluster = {i, j}  
                    clusters.append(new_cluster)  
  
unique_indices = [min(cluster) for cluster in clusters]  
final_data = data_dedup.iloc[unique_indices].reset_index(drop=True)

We end up with a deduplicated dataset of email-reply pairs

At this point, the only thing left to do is to reformat the “original” column into a new “prompt” column, which contains instructions and context that when fed into an LLM would generate the text in the “generation” column.

In other words, if I copied an example from the “prompt” column directly to the fine-tuned LLM, it should generate something similar to what is in the “generation” column.

Formatting the Prompt

There are two parts to the instruction: the original email and the context to create the reply. I’ve decided on the below format for its simplicity:

[ORIGINAL]  
{If it exists, the original email that is being replied to, otherwise None}  
[CONTEXT]  
{the information being conveyed in the generated email}

An example is as follows:

[ORIGINAL]   
None   
[CONTEXT]   
to: john@company.com,   
about: attend HelloHacks as mentor,   
details: application successful, would love to invite,  
         date is May 1-2 2025, hellohacks is 2 day hackathon for beginners,   
         appreciate your expertise

For the [CONTEXT], I’ve decided to stick with the “to”, “about”, and “details” structure since it encompasses the most basic details required to compose an email (the recipient, subject, and message).

I hand-labeled around a hundred examples and used this as the fine-tuning dataset. With more data, we would get better results, but I wanted to test the feasibility of this volume of data.

Finally, I converted the dataset into jsonl, where it is then loaded as a HuggingFace Dataset. We are now ready to start fine-tuning!

Fine-Tuning the Model

A word on pretrained and fine-tuned models

Pretrained models are trained (usually) to predict the next word based on a given sequence of words. That means that if you were to give it an instruction, it would probably not do a good job of answering it, since its objective is to ‘complete’ the sentence (or in this case your instruction). However, these pretrained models can be fine-tuned to answer instructions (ChatGPT is one of these), commonly referred to as ‘instruction-tuned models’.

We can further fine-tune an instruction-tuned model, typically to replicate a style of output based on some instruction (which is what we are aiming for), or learn industry-specific jargon, for example, 10-K and 10-Q financial reports in business.

Some of the instruction-tuned models are very large (several billions of parameters) which means that it most likely won’t fit in your local machine’s memory. Google Colab is super useful for situations like this, where you can connect to a machine with a dedicated GPU and/or high RAM enough to load these models into memory.

The HuggingFace Ecosystem

The fine-tuning process, which by itself is complicated and involves a lot of tweakable parameters that can affect the quality of the output, is massively simplified with HuggingFace’s ecosystem of libraries and tools. These tools abstract away the complexities of setting up a fine-tuning pipeline from scratch and make it possible for anyone to fine-tune these models. As someone with limited experience working with fine-tuning models, it was really easy to get started as there were tons of resources online that helped me learn as I went.

Several notable libraries:

  • peft: A library filled with methods for Parameter-Efficient Fine-Tuning (PEFT), a series of techniques that enables fine-tuning models more efficiently by only tweaking a small number of extra model parameters instead of all of the original parameters.
  • trl: Used to apply fine-tuning methods to models with ease. PEFT is also well-integrated here, creating a smoother developer experience from the cross-library support.
  • bitsandbytes: Enables quantization of large language models, allowing us to use these models with a fraction of the memory required whilst maintaining performance. This greatly boosts inference speed due to the lower memory load per parameter.

Picking a Model

I decided to go with Mistral’s 2nd version of its 7B instruct model due to its relatively high performance and moderate size. When loading a model in HuggingFace, you typically load a tokenizer specific to the model as well. Here, I also define a bitsandbytes config that tells the model to load in 4-bit, instead of 8-bit precision.

import torch  
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig  
  
model_id = "mistralai/Mistral-7B-Instruct-v0.2"  
  
# 4bit integer config for memory efficiency  
bnb_config = BitsAndBytesConfig(load_in_4bit = True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)  
  
model = AutoModelForCausalLM.from_pretrained(model_id, device_map = "auto", torch_dtype = torch.bfloat16, quantization_config = bnb_config)  
tokenizer = AutoTokenizer.from_pretrained(model_id)

Fine-Tuning Config Initialization

Next, I defined the fine-tuning hyperparameters using previously defined hyperparameters for a similar task (causal language modeling). Philip Schmid from HuggingFace has a good blog post that I heavily adapted my hyperparameters from:

from peft import LoraConfig  
from transformers import TrainingArguments  
from trl import SFTTrainer  
  
peft_config = LoraConfig(  
    lora_alpha=128,  
    lora_dropout=0.05,  
    r=64,  
    bias="none",  
    task_type="CAUSAL_LM"  
  
)  
  
args = TrainingArguments(  
    output_dir="biztech_email_mistral_7b_instruct_v02", # directory to save and repository id  
    num_train_epochs=3,                     # number of training epochs (since dataset is small)  
    per_device_train_batch_size=3,          # batch size per device during training  
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass  
    gradient_checkpointing=True,            # use gradient checkpointing to save memory  
    optim="adamw_torch_fused",              # use fused adamw optimizer  
    logging_steps=10,                       # log every 10 steps  
    save_strategy="epoch",                  # save checkpoint every epoch  
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper  
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper  
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper  
    lr_scheduler_type="constant",           # use constant learning rate scheduler  
    push_to_hub=True,                       # push model to hub  
    report_to="tensorboard",                # report metrics to tensorboard  
)  
  
def formatting_prompts_func(example):  
    output_texts = []  
    for i in range(len(example['prompt'])):  
        text = f"### Instruction: {example['prompt'][i]}n ### Answer: {example['completion'][i]}"  
        output_texts.append(text)  
    return output_texts  
  
trainer = SFTTrainer(  
    model=model,  
    args=args,  
    train_dataset=train_dataset,  
    peft_config=peft_config,  
    formatting_func = formatting_prompts_func,  
    tokenizer=tokenizer,  
    dataset_kwargs={  
        "add_special_tokens": False,  # We template with special tokens  
        "append_concat_token": False, # No need to add additional separator token  
    }  
)  
  
trainer.train()

This training process creates a LoRA adapter, which is a more lightweight and efficient representation where we save the adapter weights — the result of PEFT — instead of the full model.

At this point, if I wanted to load the fine-tuned model, I’d have to instantiate the base model with the LoRA adapter on top of it, which is defined as a PeftModel.

from peft import PeftModel  
from transformers import AutoModelForCausalLM, AutoTokenizer  
  
# Load tokenizer with updated vocabulary after fine-tuning  
tokenizer = AutoTokenizer.from_pretrained("garrethlee/biztech_email_mistral_7b_instruct_v02")  
  
# Load base model  
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")  
  
# Since additional tokens have been added as a result of fine-tuning  
model.resize_token_embeddings(len(tokenizer))  
  
# Add LoRA adapters ontop of the base model  
finetuned_model = PeftModel.from_pretrained(model, "garrethlee/biztech_email_mistral_7b_instruct_v02")

However, in this form, the inference is quite slow, since the adapter is separate from the base model itself. To address this, I merged the adapters with the base model, which will now save it as a default model instead of a PeftModel.

from peft import AutoPeftModelForCausalLM  
  
# Load PEFT model on CPU  
model = AutoPeftModelForCausalLM.from_pretrained(  
    args.output_dir,  
    torch_dtype=torch.float16,  
    low_cpu_mem_usage=True,  
)  
  
# Merge LoRA and base model and save  
model.resize_token_embeddings(len(tokenizer))  
merged_model = model.merge_and_unload()  
merged_model.save_pretrained(output_dir,safe_serialization=True, max_shard_size="2GB", push_to_hub = True)

Now, we can load the merged model directly for inference. Using the pipeline object, we can quickly generate an output from a given instruction.

from transformers import pipeline  
  
finetuned_model_id = "garrethlee/biztech_email_mistral_7b_instruct_v02"  
bnb_config = BitsAndBytesConfig(load_in_4bit = True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)  
  
finetuned_model = AutoPeftModelForCausalLM.from_pretrained(finetuned_model_id, device_map = "auto", torch_dtype = torch.bfloat16, quantization_config = bnb_config)  
tokenizer = AutoTokenizer.from_pretrained(finetuned_model_id)  
  
pipe = pipeline("text-generation", model=finetuned_model, tokenizer=tokenizer)  
  
instruction = "to: josh, about: coming to Produhacks as a mentor (ProduHacks, a process-centric hackathon focused on making the right decisions to develop a product that matters. Discover the importance of product planning, research, and development through a unique competition that sits between a case competition and a hackathon.), details: event will be March 23-24 2024, time commitment is 11 AM on the first day, 1-3 PM on the second, so grateful if come"  
prompt = f"### Instruction: {instruction}n ### Answer: "  
pipe(prompt,  
     max_new_tokens=256,  
     do_sample=False,  
     top_k = 50,  
     temperature = 0.1,  
     eos_token_id=pipe.tokenizer.eos_token_id,  
     pad_token_id=pipe.tokenizer.pad_token_id)  
# Hi Josh,nnI hope this message finds you well. I'm Garreth, the Partnerships Director for BizTech, the University of British Columbia's prominent business and technology organization.nnWe're excited to have you join us as a mentor for ProduHacks, an event aimed at creating an inclusive and supportive environment for students to develop their product ideas into prototypes. With your wealth of experience, we believe you'll be an invaluable source of guidance and insight for our participants...

We can see that the output is a great starting point for me to refine and personalize. Although the model sometimes generates incorrect information (hallucinates), this is expected due to the small size of the dataset. I experimented with 1 and 2 training epochs, but the model didn’t perform well enough (underfitting). When I tried 3 training epochs, the quality of the outputs improved, but some incorrect information started to appear (hallucinations).

In future experiments, I’ll likely adjust other hyperparameters to find a better balance between improving output quality and avoiding overfitting.

Local inference with Ollama

In its current form, inference was very slow due to the long load times and unreliable since we’d have to load the model on Google Colab every time we wanted to generate a completion.

In the final step of this project, I explored quantization and converting the models to GGUF, a binary format designed for fast loading and saving of models, which we’ll load to Ollama. This tool allows users to run LLMs locally without needing to connect to a hosted runtime.

GGUF Conversion

After downloading the fine-tuned model weights, I used llama.cpp to convert the HuggingFace model into a GGUF model.

python llama.cpp/convert.py biztech-email-mistral --outtype q8\_0 --outfile biztech-email-f16.gguf

Specifying the outtype as q8_0 results in an 8-bit quantization, which increases speed in exchange for a slight (and often unnoticeable) dip in the model’s quality compared to its original precision (float-16 or float-32).

llama.cpp also has a dedicated quantization script where you can specify a larger variety of precisions compared to the convert.py script, but I settled for the 8-bit quantization, which is quite tame compared to the other options that were available in the quantize script.

Loading the model into Ollama

Now that the model is in GGUF form, I loaded it into Ollama by specifying a Modelfile, which is a set of instructions and configurations for the model (similar to a Dockerfile)

I also specified the boilerplate that encapsulates the actual user input in a TEMPLATE which uses Jinja-style templating. This step was necessary since the model was fine-tuned with this specific instruction-answer format (you can see the code that specifies this in my previous post).

FROM ./biztech-email-mistral.gguf  
TEMPLATE "### Instructions: {{ .Prompt }} ### Answer: "

From here, I created the model…

ollama create biztech-mistral-q8 -f "Modelfile"

And it’s done! I now had a quantized, fine-tuned model ready for inference in my local machine.

Now what?

Ollama offers multiple interfaces to generate outputs. You can use it as an interactive shell, invoke a REST API via localhost, or even use a Python client that has a similar interface to OpenAI’s API library.

For my use case, invoking it from the interactive shell was simple enough, since all I had to do was run:

ollama run biztech-mistral-q8 

Caveats and Limitations

While this is an active part of my workflow, it doesn’t come without its faults. The model still hallucinates (quite a lot) and there are often completions where it would repeat a token that appeared quite often in the training dataset multiple times. There are also instances where the model would generate the reply FOR the reply (basically generating an entire thread of emails back and forth).

However, the bulk of the content itself is still present and almost entirely aligns with the user input. There’s just an extra step of removing the hallucinations (albeit a big one), which I’m totally fine with for my use case.

Thanks for reading through this series of posts! If you’ve stuck through and would like to check out my other work, I’m going to start posting on Medium more frequently.

In future projects, I plan to explore different levels of quantization and try and acquire larger datasets. I might even consolidate the mailboxes of all UBC Biztech members to create a larger dataset.


Ultimately, this was an interesting first stab at fine-tuning a language model for a downstream application. I learned many techniques, especially regarding inference, fine-tuning configurations, and their trade-offs.