Confidence Levels: Exploring Probabilities in Explainable Large Language Models

In this blog post, we delve into the inner workings of large language models (LLMs) and explore a novel approach to enhancing their explainability. By visualizing the probabilities associated with each generated word, we aim to gain insights into the model’s confidence levels and understand how they might impact the overall quality of the generated text. This research-oriented exploration not only sheds light on the decision-making process of LLMs but also paves the way for more transparent and trustworthy natural language generation systems.
Data Science
Large Language Models
Author

Daniel Fat

Published

May 10, 2024

Before delving into the implementation details, let us first understand the significance of this exploration. Explainable AI has become a crucial area of research, as it addresses the need for transparency and interpretability in complex machine learning models, particularly in high-stakes applications.

By unveiling the model’s decision-making process, we can foster trust and enable more informed decisions about when and how to use these models effectively.

Importing Required Libraries and Loading the Model

In this initial step, we import the necessary libraries and modules to facilitate our exploration. This includes the termcolor library for enhancing the visual representation of our generated text, matplotlib.pyplot for plotting the probabilities, and the mlx_lm library for loading and generating text from a pre-trained language model.

Code
import os
from termcolor import colored
import matplotlib.pyplot as plt
from mlx_lm import load, generate

Hardware Overview

In this exploration, we leverage the following hardware configuration:

  • Model Name: MacBook Air
  • Model Identifier: Mac14,2
  • Model Number: Z15W001JQB/A
  • Chip: Apple M2
  • Total Number of Cores: 8 (4 performance and 4 efficiency)
  • Memory: 16 GB
Code
model, tokenizer = load("mlx-community/Meta-Llama-3-8B-Instruct-4bit")
Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 39199.10it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Formatting Input Prompts

To ensure that the language model correctly understands the context and generates appropriate responses, we define a function called promptify that formats the input prompts (both system and user prompts) in a specific way required by the model.

Code
def promptify(system_prompt, user_prompt):
    return f'''
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_prompt}
<|eot_id|><|start_header_id|>user<|end_header_id|>
{user_prompt}
<|eot_id|>
'''

Generating Text and Storing Probabilities

At the core of our exploration lies the process of generating text and storing the associated probabilities for each generated word. We define two empty lists, words and probs, to store the generated words and their corresponding probabilities, respectively. Additionally, we define a helper function add_word that appends each generated word and its probability to the respective lists.

Code
words = []
probs = []

def add_word(word, prob):
    words.append(word.replace(' ',''))
    probs.append(prob)

response = generate(
    model, 
    tokenizer, 
    prompt=promptify(
        system_prompt='You are an amazing assitant and you will help me with anything I need.',
        user_prompt='Why is the sky blue?',
    ), 
    temp=0.5,
    max_tokens=1000,
    verbose=True,
    formatter= lambda tok, prob: add_word(tok, prob)
)
==========
Prompt: 
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an amazing assitant and you will help me with anything I need.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Why is the sky blue?
<|eot_id|>


==========
Prompt: 26.254 tokens-per-sec
Generation: 17.484 tokens-per-sec

Printing Generated Text

To provide a preliminary view of the generated text before enhancing its visual representation, we simply print the generated words by joining them with a space character.

Code
print(' '.join(words))
What a great question ! The sky appears blue because of a phenomenon called Ray leigh scattering . It 's named after the British physicist Lord Ray leigh , who first described it in the late  19 th century .

 In simple terms , Ray leigh scattering occurs when shorter ( blue ) wavelengths of light are scattered more than longer ( red ) wavelengths by the tiny molecules of gases in the atmosphere , like nitrogen and oxygen . This scattering effect gives the sky its blue appearance .

 Imagine the sun 's light as a bunch of different colored balls bouncing around . The shorter , blue balls are more easily def lected by the tiny molecules in the air , so they scatter in all directions and reach our eyes from all parts of the sky . The longer , red balls , on the other hand , travel in more direct paths to our eyes , so they don 't scatter as much and don 't contribute as much to the blue color we see .

 This is why the sky typically appears blue during the daytime , especially in the direction of the sun . However , it 's worth noting that the color of the sky can change depending on various factors , such as the time of day , the amount of cloud cover , and the amount of dust and pollution in the atmosphere .

 I hope that helps you understand why the sky is blue !

Visualising Probabilities with Coloured Text

In this final step, we introduce a visual representation of the model’s confidence levels by colouring each generated word based on its associated probability. We define a color_map dictionary that maps probability ranges to different colours, and a get_color function that determines the appropriate colour for a given probability value.

Code
color_map = {
    'blue': 1.0,
    'green': 0.75,
    'yellow': 0.5,
    'red': 0.25,
    'white': 0.0,
}

def get_color(prob):
    for color, threshold in color_map.items():
        if prob >= threshold:
            return color


for i, word in enumerate(words):
    prob = probs[i]
    color = get_color(prob)
    print(colored(word, color), end='')
    print(f"({prob:.2f})", end=' ')
What(0.25) a(1.00) great(0.98) question(1.00) !(0.33) The(0.90) sky(0.88) appears(0.96) blue(1.00) because(0.98) of(1.00) a(0.95) phenomenon(0.94) called(1.00) Ray(0.88) leigh(1.00) scattering(1.00) .(0.36) It(0.42) 's(0.99) named(0.53) after(1.00) the(0.99) British(1.00) physicist(1.00) Lord(0.99) Ray(1.00) leigh(1.00) ,(1.00) who(1.00) first(0.98) described(0.78) it(1.00) in(1.00) the(1.00) late(0.96) (1.00) 19(0.98) th(1.00) century(1.00) .

(0.98) In(0.47) simple(0.82) terms(1.00) ,(1.00) Ray(0.49) leigh(1.00) scattering(1.00) occurs(0.17) when(1.00) shorter(0.52) ((0.97) blue(1.00) )(1.00) wavelengths(0.97) of(1.00) light(1.00) are(0.97) scattered(1.00) more(0.69) than(0.95) longer(1.00) ((1.00) red(1.00) )(1.00) wavelengths(1.00) by(0.98) the(0.97) tiny(0.98) molecules(1.00) of(1.00) gases(1.00) in(1.00) the(1.00) atmosphere(0.78) ,(0.89) like(0.37) nitrogen(1.00) and(0.97) oxygen(1.00) .(1.00) This(0.99) scattering(0.96) effect(0.99) gives(0.78) the(0.98) sky(1.00) its(1.00) blue(1.00) appearance(0.36) .

(0.95) Imagine(0.65) the(0.31) sun(0.52) 's(0.94) light(0.87) as(0.69) a(1.00) bunch(0.48) of(1.00) different(0.96) colored(0.85) balls(0.67) bouncing(0.37) around(0.69) .(0.44) The(0.86) shorter(0.76) ,(0.45) blue(1.00) balls(0.84) are(0.79) more(0.57) easily(0.73) def(0.89) lected(1.00) by(0.97) the(1.00) tiny(0.88) molecules(0.62) in(0.72) the(1.00) air(0.98) ,(1.00) so(0.63) they(0.99) scatter(0.70) in(0.96) all(0.99) directions(1.00) and(0.72) reach(0.99) our(0.92) eyes(1.00) from(0.99) all(0.98) parts(0.99) of(1.00) the(1.00) sky(1.00) .(0.99) The(0.98) longer(1.00) ,(1.00) red(1.00) balls(0.71) ,(0.95) on(0.98) the(1.00) other(1.00) hand(1.00) ,(1.00) travel(0.39) in(0.97) more(0.42) direct(0.90) paths(0.82) to(0.53) our(1.00) eyes(1.00) ,(0.95) so(0.62) they(0.88) don(0.84) 't(1.00) scatter(0.86) as(1.00) much(1.00) and(0.94) don(0.51) 't(1.00) contribute(0.19) as(0.96) much(1.00) to(1.00) the(1.00) blue(0.96) color(1.00) we(0.67) see(1.00) .

(0.99) This(0.54) is(0.27) why(0.96) the(0.99) sky(1.00) typically(0.54) appears(1.00) blue(0.99) during(1.00) the(1.00) daytime(0.99) ,(1.00) especially(0.92) in(0.71) the(0.98) direction(0.99) of(0.64) the(1.00) sun(1.00) .(1.00) However(0.45) ,(1.00) it(0.10) 's(0.99) worth(0.74) noting(1.00) that(1.00) the(0.94) color(0.68) of(0.99) the(1.00) sky(1.00) can(1.00) change(0.93) depending(0.87) on(1.00) various(0.67) factors(0.73) ,(0.51) such(0.94) as(1.00) the(0.52) time(0.99) of(1.00) day(1.00) ,(1.00) the(0.67) amount(0.98) of(1.00) cloud(0.69) cover(1.00) ,(1.00) and(1.00) the(0.85) amount(0.85) of(1.00) dust(0.63) and(0.92) pollution(0.77) in(1.00) the(1.00) atmosphere(0.82) .

(0.81) I(0.96) hope(1.00) that(0.97) helps(1.00) you(0.34) understand(0.98) why(0.98) the(1.00) sky(1.00) is(1.00) blue(1.00) !(1.00) 

colours available in notebook or terminal only

By running this notebook, we generate text using the loaded language model and visualise the probabilities associated with each generated word. The colours provide a visual representation of the model’s confidence levels, with higher probabilities represented by shades of blue and green, and lower probabilities represented by shades of yellow, red, and white.

This approach not only enhances the explainability of the language model’s decision-making process but also allows for a deeper understanding of how the model’s confidence levels might impact the overall quality and coherence of the generated text. By unveiling these probabilities, we pave the way for more transparent and trustworthy natural language generation systems, enabling researchers and practitioners to make informed decisions about when and how to employ these powerful models effectively.