You Do the Math: Fine Tuning Multimodal Models (CLIP) to Match Cartoon Images to Joke Captions

This tutorial shows how to fine tune multimodal models like CLIP to match images to text captions, using cartoons and their joke captions from The New Yorker caption contest.

  • Dave Berenbaum
  • September 12, 20248 min read
Hero Picture

“Six rounds. Nine lives. You do the math.” Image and caption from https://www.capcon.dev/ via The New Yorker

Multimodal models like CLIP have opened up new AI use cases by connecting complex objects like images to text descriptions that are easy to understand, generate, and parse. However, off-the-shelf models like CLIP may not be representative of the data typically seen in specific domains, in which case fine-tuning may be needed to adapt the model to that domain.

This post shows how to fine-tune the CLIP model on cartoon images from The New Yorker Magazine and joke captions for those cartoons. It is based on https://www.capcon.dev/, a dataset for various tasks associated with The New Yorker’s cartoon contest. One of the tasks is to take a cartoon image and predict the appropriate caption from a list of possible captions. Let’s see how we can fine-tune CLIP for this task.

Data

The data is hosted and publicly available at gs://datachain-demo/newyorker_caption_contest, and it has two parts:

  • images: A folder of JPEG files, each representing a cartoon image.
  • new_yorker_meta.parquet: A parquet file with metadata about the images, including multiple choices of captions for the image and the correct caption choice.

To work with this data, we will use the open-source library datachain, which helps to wrangle unstructured data like this into a more structured format (disclaimer: I helped develop datachain). All of the code used in this post is available in a Jupyter Notebook in GitHub, or you can run it in Colab.

To start, we read both the images and metadata from their sources and then join them according to filename (which is available as a column in the metadata):

from datachain import C, DataChain
from datachain.sql.functions import path

img_dc = DataChain.from_storage("gs://datachain-demo/newyorker_caption_contest/images", type="image", anon=True)
meta_dc = DataChain.from_parquet("gs://datachain-demo/newyorker_caption_contest/new_yorker_meta.parquet")
dc = img_dc.mutate(filename=path.name(C("file.path"))).merge(meta_dc, on="filename")

The code first creates a dataset img_dc from images in a directory, storing the essential information about each file, which we will use later to read the images. Then, it creates a dataset meta_dc from the parquet file of metadata. Finally, it merges these two based on the image filename. img_dc contains a column file.path with the full path to the file, and img_dc.mutate(filename=path.name(C("file.path"))) extracts only the last part of that path, which matches the contents of the filename column in meta_dc. The merged dc dataset has both the file info and metadata for each image.

We can view a sample of the data by filtering and collecting the data like this:

sample = dc.filter(C("file.path").endswith("/371.jpeg")).limit(1)
sample_results = list(sample.collect("file", "caption_choices", "label"))

This limits the data to the image ending in /371.jpeg and collects only the columns "file", "caption_choices", "label". The resulting output includes an ImageFile (see below), a list of possible captions, and a label for the letter choice of the correct caption. You may end up with slightly different results since there are multiple rows per image with different caption choices.

[(ImageFile(source='gs://datachain-demo', path='newyorker_caption_contest/images/371.jpeg', size=25555, version='1719848719616822', etag='CLaWgOCXhocDEAE=', is_latest=True, last_modified=datetime.datetime(2024, 7, 1, 15, 45, 19, 669000, tzinfo=datetime.timezone.utc), location=None, vtype=''),
  ["I feel like we've gotten a little soft, Lex.",
   "Hold on, the Senate Committee on Women's Health is getting out.",
   "I know a specialist, but he's in prison.",
   'Six rounds. Nine lives. You do the math.',
   'Growth has exceeded our projections.'],
  'D')]

We can get the image itself from the ImageFile object using its read() method:

example = sample_results[0]
example[0].read()

Example Image from https://www.capcon.dev/ via The New Yorker

In this sample, we have a cartoon of a mouse pointing a gun at the cat, with the correct caption being option D, which reads, “Six rounds. Nine lives. You do the math.”

Applying a base CLIP model

We can apply CLIP to this data to predict the likelihood of each caption. This is similar to the base architecture of CLIP, which uses contrastive learning to take an image and discern the most likely caption from a batch of text captions (and vice versa). During training, CLIP gets batches of image-text pairs as input, with each image mapped to its text caption. For each batch, CLIP calculates the cosine similarity of each image to each text in the batch, so that it has not only the similarities of the matches but also of every mismatched image-text pair (see the image below). It then treats this as a classification problem where the match is considered the correct label and the mismatches are considered the incorrect labels. During inference, this can be used as a zero-shot predictor by feeding in an image and batch of captions, for which CLIP will return the probability of each caption. For a deeper dive into CLIP, see the original OpenAI post about it, or Chip Huyen has a nice summary of how it works here.

Overview A Image from https://openai.com/index/clip/ Overview B Image from https://openai.com/index/clip/

For the cartoon dataset, we can feed in our sample image and the caption choices to get back the probability that each is a correct match. Here’s how that looks in code:

import clip
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = example[0].read()
image = preprocess(image).unsqueeze(0).to(device)
text = clip.tokenize(example[1]).to(device)
logits_per_image, logits_per_text = model(image, text)
logits_per_image.softmax(dim=1)[0]

First, we load the ViT-B/32 pre-trained model and image preprocessor onto the device. Then, we transform the image into the expected tensor input and tokenize the text captions to do the same. Next, we run the model on those transformed inputs to get the logit similarity scores of the image to each text, and finally run a softmax function to get a relative probability for each text caption.

The output shows that CLIP can already confidently predict the correct caption for this example, since caption D (the fourth caption) has a probability of 0.9844 (if you try this yourself, you may have different caption choices in your example, which may lead to different results):

tensor([0.0047, 0.0013, 0.0029, 0.9844, 0.0067], grad_fn=<SelectBackward0>)

Creating a training dataset

Now that we know how to apply CLIP to predict captions, we can build a training dataset for fine-tuning the model. Let’s get the similarity scores for a random 10 images (you can increase to a larger size but we will keep it small here to make it easy to quickly follow along on a laptop CPU). Here’s the code to do that:

from datachain.torch import clip_similarity_scores

train_dc = dc.shuffle().limit(10).save("newyorker_caption_contest_train")
train_dc = train_dc.map(
    func=lambda img_file, txt: clip_similarity_scores(img_file.read(), txt, model, preprocess, clip.tokenize, prob=True)[0],
    params=["file", "caption_choices"],
    output={"scores": list[float]}
)

First, we shuffle and save 10 images from the dataset. Then, we use the map() method to to apply a function to each record and save the result as a new column. We use the utility function clip_similarity_scores, which performs the steps from the previous section in one line to get the caption probabilities. The input to the map() function is defined by params=["file", "caption_choices"], and the output column is defined by output={"scores": list[float]}.

For training, we also need the ground truth of the correct captions, so we again use map() to calculate the index of the correct caption for each record, along with the CLIP probability of that caption so we can see how well baseline CLIP is performing:

import string

def label_ind(label):
    return string.ascii_uppercase.index(label)

def label_prob(scores, label_ind):
    return scores[label_ind]

train_dc = (
    train_dc.map(label_ind, params=["label"], output={"label_ind": int})
    .map(label_prob, params=["scores", "label_ind"], output={"label_prob": float})
)
train_dc = train_dc.save()

We can run train_dc.avg("label_prob") to get the average probability of the correct caption for the training sample. The average will depend on the random samples in your training dataset, but you should see a much lower value than the sample image above, so it seems the other images are not so easy for baseline CLIP to correctly predict.

Fine-tuning

To fine-tune CLIP, we need to create a train() function to loop over the training data and update the model:

def train(loader, model, optimizer, epochs=5):
    if device == "cuda":
        model = model.float()
    loss_func = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for images, texts, labels in loader:
            optimizer.zero_grad()
            batch_loss = 0
            for image, text, label in zip(images, texts, labels):
                image = image.to(device).unsqueeze(0)
                text = text.to(device)
                label = label.to(device).unsqueeze(0)

                logits_per_image, logits_per_text = model(image, text)

                batch_loss += loss_func(logits_per_image, label)
            batch_loss.backward()
            optimizer.step()
            batch_loss = batch_loss.item()
            total_loss += batch_loss
        print(f"loss for epoch {epoch}: {total_loss}")

For each pairing of image to text captions, the function calculates the logit similarity scores, uses the correct label index to apply the loss function, and performs a backward pass to update the model.

This is very similar to how base CLIP works, except for one difference. Base CLIP expects each batch to contain image-text pairs, where each image has a single corresponding text, and CLIP must get incorrect texts from the other samples in the batch for contrastive learning (see the image above). With the cartoon dataset, each image already has not only the corresponding correct text caption, but also multiple incorrect text captions. Therefore, instead of relying on the other samples in the batch for contrastive learning, the function above relies only on the text captions choices provided for that image.

To feed the training data into this function, we need to generate a PyTorch dataset and data loader and pass the loader to the train() function along with an optimizer:

from torch.utils.data import DataLoader

ds = train_dc.select("file", "caption_choices", "label_ind").to_pytorch(
    transform=preprocess,
    tokenizer=clip.tokenize,
)

loader = DataLoader(ds, batch_size=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train(loader, model, optimizer)

The code above selects the columns needed for training ("file", "caption_choices", "label_ind"), and then calls to_pytorch() with the CLIP preprocessor and tokenizer, which returns a PyTorch IterableDataset with the preprocessed image tensors, tokenized text, and label indices. Next, the code creates a PyTorch DataLoader and optimizer and passes them to train() to start training.

Since we are using a tiny dataset, we can quickly see the model fit to the sample and the loss decreases dramatically:

loss for epoch 0: 5.243085099384018
loss for epoch 1: 6.937912189641793e-05
loss for epoch 2: 0.0006402461804100312
loss for epoch 3: 0.0009484810252615716
loss for epoch 4: 0.00019728825191123178

This should set off alarm bells about overfitting, but for this exercise, it’s useful to see that train() is doing what we expect: learning the correct captions from the training dataset. We can confirm by calculating the predicted probability of the correct caption for each image in the training data using the fine-tuned model:

train_dc = train_dc.map(
    func=lambda img_file, txt: clip_similarity_scores(img_file.read(), txt, model, preprocess, clip.tokenize, prob=True)[0],
    params=["file", "caption_choices"],
    output={"scores_fine_tune": list[float]}
)

train_dc = train_dc.map(label_prob, params=["scores_fine_tune", "label_ind"], output={"label_prob_fine_tune": float})

The above code is the same that was used above to calculate the probability before fine-tuning. Running train_dc.avg("label_prob_fine_tune") outputs an average predicted probability >0.99, so it looks like the fine-tuning worked as expected.

This is an artificial example but hopefully gives you an idea of how to fine-tune CLIP. To solve the task of predicting the correct caption in a more robust way, you would want to take a much larger sample and evaluate against a held-out sample of images and texts that hadn’t been seen during training. When trying that, you may find that CLIP does not perform so well at generalizing to the caption prediction problem, which should not be too surprising since CLIP was built to understand the contents of images rather than understanding jokes. CLIP relies on a relatively simple text encoder, and it may be worth exploring different text encoders for that task. This goes beyond fine tuning and this post, but now that you know how to train CLIP, you can try out this idea or bring your own ideas for how to adapt CLIP to your multimodal use cases.

Back to blog