Classification

In this notebook we’ll demonstrate training a PyTorch resnet-50 classifier with Limbo data to determine whether an image contains a type 48 cylinder or not. We’ll use data from Campaign 17, which contains images of type 48G, 48X, and 48Y cylinders scattered among distractors in a synthetic outdoor industrial environment.

We will evaluate our model using synthetic data (a case we refer to as “train synthetic, test synthetic”) to confirm that it is learning, and using real-world images (a case we refer to as “train synthetic, test real”) to find out how well it would work in reality. Note that the latter is extremely challenging for models, and a ripe area for new research, which is why we created the Limbo data in the first place!

To begin, we’ll define a variable containing the path to the data:

[1]:
DATA_ROOT = "/mnt/mc1/limbo"

… if you’re running this notebook yourself, you’ll need to set DATA_ROOT to point to a directory containing Limbo campaign data that you’ve downloaded.

Every campaign in the Limbo data is split into numbered subdirectories, each of which contains one thousand images, which we will use to load one thousand synthetic images for training, one thousand synthetic images for (synthetic) testing, and all of the available real-world data for (real) testing:

[2]:
import os

import limbo.data

training_data = limbo.data.Dataset([os.path.join(DATA_ROOT, "campaign17", "0000")])
test_data = limbo.data.Dataset([os.path.join(DATA_ROOT, "campaign17", "0001")])
real_data = limbo.data.Dataset([os.path.join(DATA_ROOT, "ref")])

print(f"Loaded {len(training_data)} training images.")
print(f"Loaded {len(test_data)} test images.")
print(f"Loaded {len(real_data)} real images.")
Loaded 1000 training images.
Loaded 1000 test images.
Loaded 482 real images.

Tip

One thousand images for training is just enough to prove that our model works for purposes of pedagogy. In practice, we assume that you’ll train with more - most campaigns in the Limbo data contain 50000 images each. Note that you can pass more than one subdirectory to the Dataset initializer, or you can pass the top-level campaign directory to load everything.

Next, we’ll display some images to be sure we have the right datasets:

[3]:
from IPython.display import display, HTML

display(HTML("<h3>Training sample<h3>"))
display(training_data[0].image)

display(HTML("<h3>Test sample<h3>"))
display(test_data[0].image)

display(HTML("<h3>Real sample<h3>"))
display(real_data[0].image)

Training sample

../_images/user-guide_classification_5_1.png

Test sample

../_images/user-guide_classification_5_3.png

Real sample

../_images/user-guide_classification_5_5.png

… looking good!

To train our classifier to label each image as “48” or “not”, we’ll need to generate a corresponding set of labels. The samples in a Limbo dataset could contain dozens or even hundreds of annotations, so we’ll have to reduce that to a single label, which will be “1” for images that contain type 48 containers, and “0” for images that don’t.

To do so, we’ll scan the “categories” property (which returns a set containing every unique category that appears in a sample) for each sample to identify which ones include containers. Note that the following will probably take several minutes to execute, since it loads and parses a JSON file from disk for each sample … just be patient, and keep in mind that as you begin working with larger subsets of the data, you may wish to cache these kinds of intermediate results to disk to speed up your workflow:

[4]:
import numpy

targets = set(["48G", "48X", "48Y"])

training_labels = [1 if sample.categories & targets else 0 for sample in training_data]
test_labels = [1 if sample.categories & targets else 0 for sample in test_data]
real_labels = [1 if sample.categories & targets else 0 for sample in real_data]

Since there are three styles of type 48 container in the dataset, we use Python set intersection to test whether a sample contains any of the three.

It’s always a good idea to see how skewed your data is before you go too far, so let’s figure out how many images contain our objects of interest:

[5]:
print(f"{numpy.sum(training_labels) / len(training_labels):.1%} of training images have type 48 containers.")
print(f"{numpy.sum(test_labels) / len(test_labels):.1%} of test images have type 48 containers.")
print(f"{numpy.sum(real_labels) / len(real_labels):.1%} of real images have type 48 containers.")
60.6% of training images have type 48 containers.
60.1% of test images have type 48 containers.
19.3% of real images have type 48 containers.

An even mix of container and not-container images for training would be ideal, but ~60% containers isn’t too bad, so we’ll live with it for this demo. When conducting your own experiments, you will typically want to load more images than you need, downsampling based on the labels to achieve your desired balance between positive and negative examples.

Next, we need to wrap our Limbo dataset object to make it usable with PyTorch. We’ll create a simple adapter that loads Limbo training images on demand, resizes them to \(224\times224\), and caches them in memory to speed-up training since our data is relatively small:

[6]:
import torch.utils.data

class LimboToTorch(torch.utils.data.Dataset):
    def __init__(self, dataset, labels, transforms):
        self.dataset = dataset
        self.labels = numpy.expand_dims(labels, 1).astype(numpy.float32)
        self.transforms = transforms
        self.cache = {}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, key):
        if key not in self.cache:
            sample = self.dataset[key]
            image = sample.image
            if "C" in image.layers:
                image = (image.layers["C"].data * 255).astype(numpy.uint8)
            elif "Y" in image.layers:
                image = numpy.tile(image.layers["Y"].data * 255, (1, 1, 3)).astype(numpy.uint8)
            image = torchvision.transforms.functional.to_pil_image(image, mode="RGB")
            image = torchvision.transforms.functional.resize(image, (224, 224))
            self.cache[key] = image
        return self.transforms(self.cache[key]), self.labels[key]

Next, we’ll define a set of data augmentation transforms to apply to the images for training, and a much simpler set of transforms for evaluation:

[7]:
import torchvision

training_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomAffine(degrees=90, scale=(0.8, 1.2), shear=20, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, fill=(int(255*.485), int(255*.456), int(255*.406))),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

evaluation_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Now we can wrap our Limbo datasets for use with PyTorch:

[8]:
training_data = LimboToTorch(training_data, training_labels, training_transforms)
test_data = LimboToTorch(test_data, test_labels, evaluation_transforms)
real_data = LimboToTorch(real_data, real_labels, evaluation_transforms)

We will hold back 20% of our training data for validation:

[9]:
training_data, validation_data = torch.utils.data.random_split(training_data, [0.8, 0.2])

And create data loaders that can be used to iterate over the data in batches of 50 images:

[10]:
training_loader = torch.utils.data.DataLoader(training_data, batch_size=50, shuffle=True, num_workers=0)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=50, shuffle=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=False, num_workers=0)
real_loader = torch.utils.data.DataLoader(real_data, batch_size=50, shuffle=False, num_workers=0)

Our model will be based on a Resnet-50 model pre-trained on Imagenet. Since our problem only requires a single “48 or not” output, we will replace the model’s final, fully-connected output layer with one that produces a single sigmoid output. This output can be thought of as representing the model’s confidence that an image contains a type 48 container, and nicely coincides with the “1” and “0” labels we created earlier.

[11]:
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")

model.fc = torch.nn.Sequential(
    torch.nn.Linear(2048, 1),
    torch.nn.Sigmoid(),
)

We will train using nVidia GPU hardware; if you don’t have any, you can change the following to the string "cpu", but your training times will be very long:

[12]:
DEVICE = "cuda"

With DEVICE set, we can prepare our model for use on the hardware, and setup our loss function and optimizer:

[13]:
model = model.to(DEVICE)
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

Now we train! The following code will train and validate the model for 500 epochs using synthetic data, and evaluate it using synthetic and real test data. Various performance metrics are printed to stdout, and retained for later analysis.

Rant

You’ll notice that the following training code is extremely verbose. Machine learning researchers often tie themselves in knots trying to eliminate duplication in their training loops; we like tight loops as much as anyone, but don’t be that person - simplicity, clarity, and reproducibility are more important, even if they come at the expense of some repetition. DRY is good advice, but not if it obscures what you’re doing!

[14]:
import collections

import sklearn.metrics

metrics = collections.defaultdict(list)

epochs = numpy.arange(500)
for epoch in epochs:
    if not epoch % 10 or epoch < 10:
        print("*" * 80)
        print(f"epoch: {epoch}")

    # Train the model
    losses = []
    y_true = []
    y_pred = []
    model.train(True)

    for x, y in training_loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        y_hat = model(x)
        loss = loss_function(y_hat, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        y_true.append(y.cpu().detach())
        y_pred.append(y_hat.cpu().detach())

    losses = numpy.row_stack(losses)
    y_true = numpy.row_stack(y_true)
    y_pred = numpy.row_stack(y_pred)

    metrics["training/loss"].append(numpy.mean(losses))
    metrics["training/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
    metrics["training/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
    metrics["training/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))

    # Validate the model
    losses = []
    y_true = []
    y_pred = []
    model.train(False)

    for x, y in validation_loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        y_hat = model(x)
        loss = loss_function(y_hat, y)
        losses.append(loss.item())
        y_true.append(y.cpu().detach())
        y_pred.append(y_hat.cpu().detach())

    losses = numpy.row_stack(losses)
    y_true = numpy.row_stack(y_true)
    y_pred = numpy.row_stack(y_pred)

    metrics["validation/loss"].append(numpy.mean(losses))
    metrics["validation/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
    metrics["validation/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
    metrics["validation/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))

    # Test the model on synthetic data
    losses = []
    y_true = []
    y_pred = []
    model.train(False)

    for x, y in test_loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        y_hat = model(x)
        loss = loss_function(y_hat, y)
        losses.append(loss.item())
        y_true.append(y.cpu().detach())
        y_pred.append(y_hat.cpu().detach())

    losses = numpy.row_stack(losses)
    y_true = numpy.row_stack(y_true)
    y_pred = numpy.row_stack(y_pred)

    metrics["test/loss"].append(numpy.mean(losses))
    metrics["test/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
    metrics["test/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
    metrics["test/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))

    # Test the model on real data
    losses = []
    y_true = []
    y_pred = []
    model.train(False)

    for x, y in real_loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        y_hat = model(x)
        loss = loss_function(y_hat, y)
        losses.append(loss.item())
        y_true.append(y.cpu().detach())
        y_pred.append(y_hat.cpu().detach())

    losses = numpy.row_stack(losses)
    y_true = numpy.row_stack(y_true)
    y_pred = numpy.row_stack(y_pred)

    metrics["real/loss"].append(numpy.mean(losses))
    metrics["real/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
    metrics["real/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
    metrics["real/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))


    if not epoch % 10 or epoch < 10:
        print(f"{'':>16}{'loss':<10}{'accuracy':<10}{'precision':<10}{'recall':<10}")

        print(f"{'training: ':>16}", end="")
        print(f"{metrics['training/loss'][-1]:<10.3f}", end="")
        print(f"{metrics['training/accuracy'][-1]:<10.3f}", end="")
        print(f"{metrics['training/precision'][-1]:<10.3f}", end="")
        print(f"{metrics['training/recall'][-1]:<10}")

        print(f"{'validation: ':>16}", end="")
        print(f"{metrics['validation/loss'][-1]:<10.3f}", end="")
        print(f"{metrics['validation/accuracy'][-1]:<10.3f}", end="")
        print(f"{metrics['validation/precision'][-1]:<10.3f}", end="")
        print(f"{metrics['validation/recall'][-1]:<10}")

        print(f"{'test: ':>16}", end="")
        print(f"{metrics['test/loss'][-1]:<10.3f}", end="")
        print(f"{metrics['test/accuracy'][-1]:<10.3f}", end="")
        print(f"{metrics['test/precision'][-1]:<10.3f}", end="")
        print(f"{metrics['test/recall'][-1]:<10}")

        print(f"{'real: ':>16}", end="")
        print(f"{metrics['real/loss'][-1]:<10.3f}", end="")
        print(f"{metrics['real/accuracy'][-1]:<10.3f}", end="")
        print(f"{metrics['real/precision'][-1]:<10.3f}", end="")
        print(f"{metrics['real/recall'][-1]:<10}")
********************************************************************************
epoch: 0
                loss      accuracy  precision recall
      training: 0.692     0.520     0.589     0.7002053388090349
    validation: 0.684     0.540     0.589     0.7478991596638656
          test: 0.689     0.536     0.626     0.56738768718802
          real: 0.727     0.297     0.147     0.5483870967741935
********************************************************************************
epoch: 1
                loss      accuracy  precision recall
      training: 0.686     0.571     0.607     0.8418891170431212
    validation: 0.684     0.590     0.613     0.8403361344537815
          test: 0.681     0.588     0.626     0.7803660565723793
          real: 0.750     0.249     0.188     0.8709677419354839
********************************************************************************
epoch: 2
                loss      accuracy  precision recall
      training: 0.680     0.585     0.608     0.893223819301848
    validation: 0.680     0.575     0.592     0.9159663865546218
          test: 0.676     0.625     0.631     0.9068219633943427
          real: 0.765     0.207     0.184     0.9032258064516129
********************************************************************************
epoch: 3
                loss      accuracy  precision recall
      training: 0.676     0.615     0.619     0.9548254620123203
    validation: 0.679     0.600     0.604     0.9495798319327731
          test: 0.673     0.626     0.627     0.9351081530782029
          real: 0.778     0.203     0.188     0.946236559139785
********************************************************************************
epoch: 4
                loss      accuracy  precision recall
      training: 0.676     0.604     0.611     0.9589322381930184
    validation: 0.676     0.605     0.604     0.9747899159663865
          test: 0.670     0.628     0.624     0.9584026622296173
          real: 0.790     0.189     0.187     0.956989247311828
********************************************************************************
epoch: 5
                loss      accuracy  precision recall
      training: 0.671     0.608     0.610     0.9835728952772074
    validation: 0.668     0.605     0.601     1.0
          test: 0.667     0.624     0.618     0.978369384359401
          real: 0.802     0.193     0.192     0.989247311827957
********************************************************************************
epoch: 6
                loss      accuracy  precision recall
      training: 0.671     0.610     0.611     0.9917864476386037
    validation: 0.674     0.600     0.600     0.9831932773109243
          test: 0.665     0.621     0.616     0.9833610648918469
          real: 0.810     0.193     0.193     1.0
********************************************************************************
epoch: 7
                loss      accuracy  precision recall
      training: 0.666     0.609     0.610     0.9917864476386037
    validation: 0.666     0.605     0.601     1.0
          test: 0.662     0.623     0.615     0.9966722129783694
          real: 0.821     0.193     0.193     1.0
********************************************************************************
epoch: 8
                loss      accuracy  precision recall
      training: 0.666     0.611     0.611     0.9938398357289527
    validation: 0.668     0.600     0.598     1.0
          test: 0.660     0.625     0.616     0.9966722129783694
          real: 0.826     0.193     0.193     1.0
********************************************************************************
epoch: 9
                loss      accuracy  precision recall
      training: 0.661     0.611     0.611     0.9958932238193019
    validation: 0.669     0.605     0.601     1.0
          test: 0.658     0.621     0.613     0.9983361064891847
          real: 0.834     0.193     0.193     1.0
********************************************************************************
epoch: 10
                loss      accuracy  precision recall
      training: 0.662     0.615     0.613     0.997946611909651
    validation: 0.667     0.605     0.601     1.0
          test: 0.656     0.619     0.612     0.9983361064891847
          real: 0.847     0.193     0.193     1.0
********************************************************************************
epoch: 20
                loss      accuracy  precision recall
      training: 0.647     0.610     0.610     1.0
    validation: 0.652     0.600     0.598     1.0
          test: 0.644     0.614     0.609     1.0
          real: 0.907     0.193     0.193     1.0
********************************************************************************
epoch: 30
                loss      accuracy  precision recall
      training: 0.641     0.615     0.613     1.0
    validation: 0.645     0.605     0.601     1.0
          test: 0.636     0.620     0.613     1.0
          real: 0.945     0.193     0.193     1.0
********************************************************************************
epoch: 40
                loss      accuracy  precision recall
      training: 0.632     0.626     0.620     0.997946611909651
    validation: 0.635     0.620     0.611     0.9915966386554622
          test: 0.627     0.629     0.619     0.9966722129783694
          real: 0.968     0.193     0.193     1.0
********************************************************************************
epoch: 50
                loss      accuracy  precision recall
      training: 0.623     0.636     0.627     0.9938398357289527
    validation: 0.633     0.635     0.624     0.9747899159663865
          test: 0.620     0.641     0.627     0.9916805324459235
          real: 0.990     0.193     0.193     1.0
********************************************************************************
epoch: 60
                loss      accuracy  precision recall
      training: 0.612     0.662     0.645     0.9897330595482546
    validation: 0.625     0.635     0.624     0.9747899159663865
          test: 0.612     0.640     0.628     0.9866888519134775
          real: 0.998     0.193     0.193     1.0
********************************************************************************
epoch: 70
                loss      accuracy  precision recall
      training: 0.604     0.672     0.653     0.9835728952772074
    validation: 0.625     0.635     0.625     0.9663865546218487
          test: 0.604     0.653     0.637     0.9833610648918469
          real: 1.011     0.193     0.193     1.0
********************************************************************************
epoch: 80
                loss      accuracy  precision recall
      training: 0.591     0.688     0.666     0.9753593429158111
    validation: 0.615     0.640     0.630     0.957983193277311
          test: 0.596     0.659     0.642     0.978369384359401
          real: 1.032     0.193     0.193     1.0
********************************************************************************
epoch: 90
                loss      accuracy  precision recall
      training: 0.587     0.693     0.674     0.9568788501026694
    validation: 0.609     0.665     0.648     0.957983193277311
          test: 0.588     0.671     0.651     0.9750415973377704
          real: 1.033     0.193     0.193     1.0
********************************************************************************
epoch: 100
                loss      accuracy  precision recall
      training: 0.575     0.696     0.677     0.9589322381930184
    validation: 0.591     0.670     0.650     0.9663865546218487
          test: 0.583     0.675     0.655     0.9683860232945092
          real: 1.041     0.193     0.193     1.0
********************************************************************************
epoch: 110
                loss      accuracy  precision recall
      training: 0.562     0.726     0.701     0.9609856262833676
    validation: 0.571     0.690     0.669     0.9495798319327731
          test: 0.570     0.690     0.667     0.9650582362728786
          real: 1.058     0.195     0.193     1.0
********************************************************************************
epoch: 120
                loss      accuracy  precision recall
      training: 0.557     0.719     0.703     0.9322381930184805
    validation: 0.580     0.665     0.662     0.8907563025210085
          test: 0.564     0.692     0.670     0.9600665557404326
          real: 1.071     0.195     0.193     1.0
********************************************************************************
epoch: 130
                loss      accuracy  precision recall
      training: 0.551     0.720     0.706     0.9240246406570842
    validation: 0.571     0.675     0.657     0.9495798319327731
          test: 0.557     0.693     0.672     0.9550748752079867
          real: 1.078     0.195     0.193     1.0
********************************************************************************
epoch: 140
                loss      accuracy  precision recall
      training: 0.538     0.723     0.712     0.9137577002053389
    validation: 0.570     0.670     0.667     0.8907563025210085
          test: 0.548     0.697     0.677     0.9484193011647255
          real: 1.085     0.197     0.194     1.0
********************************************************************************
epoch: 150
                loss      accuracy  precision recall
      training: 0.531     0.743     0.731     0.9117043121149897
    validation: 0.563     0.660     0.665     0.865546218487395
          test: 0.542     0.702     0.681     0.9467554076539102
          real: 1.080     0.203     0.194     0.989247311827957
********************************************************************************
epoch: 160
                loss      accuracy  precision recall
      training: 0.521     0.738     0.726     0.9137577002053389
    validation: 0.557     0.685     0.677     0.8991596638655462
          test: 0.533     0.712     0.692     0.940099833610649
          real: 1.094     0.205     0.195     1.0
********************************************************************************
epoch: 170
                loss      accuracy  precision recall
      training: 0.517     0.746     0.738     0.9034907597535934
    validation: 0.564     0.680     0.680     0.8739495798319328
          test: 0.526     0.716     0.694     0.9450915141430949
          real: 1.104     0.207     0.196     1.0
********************************************************************************
epoch: 180
                loss      accuracy  precision recall
      training: 0.507     0.748     0.744     0.891170431211499
    validation: 0.546     0.685     0.692     0.8487394957983193
          test: 0.517     0.727     0.709     0.9251247920133111
          real: 1.102     0.212     0.197     1.0
********************************************************************************
epoch: 190
                loss      accuracy  precision recall
      training: 0.492     0.766     0.756     0.9096509240246407
    validation: 0.547     0.690     0.691     0.865546218487395
          test: 0.516     0.724     0.701     0.9417637271214643
          real: 1.110     0.212     0.197     1.0
********************************************************************************
epoch: 200
                loss      accuracy  precision recall
      training: 0.480     0.771     0.769     0.893223819301848
    validation: 0.550     0.705     0.708     0.8571428571428571
          test: 0.510     0.730     0.708     0.9367720465890182
          real: 1.122     0.216     0.197     1.0
********************************************************************************
epoch: 210
                loss      accuracy  precision recall
      training: 0.474     0.776     0.781     0.8788501026694046
    validation: 0.535     0.670     0.680     0.8403361344537815
          test: 0.502     0.737     0.716     0.9334442595673876
          real: 1.133     0.218     0.197     0.989247311827957
********************************************************************************
epoch: 220
                loss      accuracy  precision recall
      training: 0.467     0.771     0.774     0.8809034907597536
    validation: 0.541     0.705     0.711     0.8487394957983193
          test: 0.496     0.743     0.724     0.9251247920133111
          real: 1.146     0.216     0.196     0.989247311827957
********************************************************************************
epoch: 230
                loss      accuracy  precision recall
      training: 0.456     0.791     0.792     0.891170431211499
    validation: 0.520     0.715     0.715     0.865546218487395
          test: 0.487     0.748     0.729     0.9234608985024958
          real: 1.150     0.222     0.199     1.0
********************************************************************************
epoch: 240
                loss      accuracy  precision recall
      training: 0.446     0.794     0.795     0.891170431211499
    validation: 0.530     0.695     0.713     0.8151260504201681
          test: 0.484     0.745     0.729     0.9151414309484193
          real: 1.158     0.220     0.197     0.989247311827957
********************************************************************************
epoch: 250
                loss      accuracy  precision recall
      training: 0.441     0.786     0.797     0.8706365503080082
    validation: 0.516     0.705     0.724     0.8151260504201681
          test: 0.479     0.760     0.745     0.913477537437604
          real: 1.157     0.228     0.199     0.989247311827957
********************************************************************************
epoch: 260
                loss      accuracy  precision recall
      training: 0.431     0.806     0.807     0.8952772073921971
    validation: 0.517     0.715     0.721     0.8487394957983193
          test: 0.474     0.752     0.738     0.9118136439267887
          real: 1.179     0.226     0.198     0.989247311827957
********************************************************************************
epoch: 270
                loss      accuracy  precision recall
      training: 0.437     0.806     0.810     0.891170431211499
    validation: 0.512     0.720     0.730     0.8403361344537815
          test: 0.474     0.757     0.740     0.9184692179700499
          real: 1.166     0.226     0.197     0.978494623655914
********************************************************************************
epoch: 280
                loss      accuracy  precision recall
      training: 0.425     0.801     0.812     0.8767967145790554
    validation: 0.508     0.710     0.729     0.8151260504201681
          test: 0.466     0.757     0.746     0.9034941763727121
          real: 1.167     0.239     0.200     0.978494623655914
********************************************************************************
epoch: 290
                loss      accuracy  precision recall
      training: 0.420     0.805     0.812     0.8850102669404517
    validation: 0.532     0.685     0.703     0.8151260504201681
          test: 0.466     0.759     0.747     0.9068219633943427
          real: 1.179     0.239     0.200     0.978494623655914
********************************************************************************
epoch: 300
                loss      accuracy  precision recall
      training: 0.414     0.809     0.821     0.8767967145790554
    validation: 0.522     0.715     0.742     0.7983193277310925
          test: 0.460     0.767     0.755     0.9068219633943427
          real: 1.176     0.245     0.201     0.978494623655914
********************************************************************************
epoch: 310
                loss      accuracy  precision recall
      training: 0.398     0.821     0.831     0.8870636550308009
    validation: 0.497     0.720     0.733     0.8319327731092437
          test: 0.456     0.764     0.755     0.9001663893510815
          real: 1.186     0.245     0.200     0.967741935483871
********************************************************************************
epoch: 320
                loss      accuracy  precision recall
      training: 0.405     0.818     0.822     0.893223819301848
    validation: 0.493     0.705     0.734     0.7899159663865546
          test: 0.453     0.766     0.758     0.8968386023294509
          real: 1.186     0.253     0.201     0.967741935483871
********************************************************************************
epoch: 330
                loss      accuracy  precision recall
      training: 0.396     0.807     0.825     0.8685831622176592
    validation: 0.507     0.730     0.744     0.8319327731092437
          test: 0.450     0.765     0.759     0.891846921797005
          real: 1.202     0.251     0.201     0.967741935483871
********************************************************************************
epoch: 340
                loss      accuracy  precision recall
      training: 0.372     0.839     0.848     0.8952772073921971
    validation: 0.506     0.730     0.756     0.8067226890756303
          test: 0.448     0.766     0.761     0.8901830282861897
          real: 1.200     0.259     0.203     0.967741935483871
********************************************************************************
epoch: 350
                loss      accuracy  precision recall
      training: 0.376     0.824     0.841     0.8767967145790554
    validation: 0.533     0.715     0.738     0.8067226890756303
          test: 0.444     0.770     0.766     0.8885191347753744
          real: 1.208     0.263     0.204     0.967741935483871
********************************************************************************
epoch: 360
                loss      accuracy  precision recall
      training: 0.364     0.849     0.859     0.8993839835728953
    validation: 0.523     0.705     0.746     0.7647058823529411
          test: 0.443     0.777     0.773     0.8901830282861897
          real: 1.213     0.266     0.204     0.967741935483871
********************************************************************************
epoch: 370
                loss      accuracy  precision recall
      training: 0.354     0.848     0.861     0.893223819301848
    validation: 0.503     0.710     0.744     0.7815126050420168
          test: 0.439     0.773     0.777     0.8735440931780366
          real: 1.223     0.276     0.206     0.967741935483871
********************************************************************************
epoch: 380
                loss      accuracy  precision recall
      training: 0.340     0.855     0.864     0.9034907597535934
    validation: 0.503     0.750     0.776     0.8151260504201681
          test: 0.439     0.777     0.776     0.8851913477537438
          real: 1.237     0.278     0.206     0.956989247311828
********************************************************************************
epoch: 390
                loss      accuracy  precision recall
      training: 0.340     0.856     0.866     0.9034907597535934
    validation: 0.509     0.710     0.756     0.7563025210084033
          test: 0.438     0.777     0.776     0.8851913477537438
          real: 1.248     0.284     0.207     0.956989247311828
********************************************************************************
epoch: 400
                loss      accuracy  precision recall
      training: 0.330     0.873     0.883     0.9117043121149897
    validation: 0.505     0.720     0.760     0.773109243697479
          test: 0.434     0.775     0.776     0.8785357737104825
          real: 1.251     0.293     0.208     0.946236559139785
********************************************************************************
epoch: 410
                loss      accuracy  precision recall
      training: 0.325     0.875     0.885     0.9137577002053389
    validation: 0.520     0.715     0.728     0.8319327731092437
          test: 0.435     0.775     0.780     0.8718801996672213
          real: 1.260     0.288     0.207     0.946236559139785
********************************************************************************
epoch: 420
                loss      accuracy  precision recall
      training: 0.308     0.876     0.883     0.917864476386037
    validation: 0.495     0.750     0.767     0.8319327731092437
          test: 0.432     0.771     0.778     0.8668885191347754
          real: 1.277     0.284     0.206     0.946236559139785
********************************************************************************
epoch: 430
                loss      accuracy  precision recall
      training: 0.318     0.849     0.867     0.8870636550308009
    validation: 0.501     0.735     0.754     0.8235294117647058
          test: 0.431     0.778     0.780     0.8785357737104825
          real: 1.292     0.290     0.206     0.9354838709677419
********************************************************************************
epoch: 440
                loss      accuracy  precision recall
      training: 0.291     0.891     0.913     0.9075975359342916
    validation: 0.501     0.740     0.760     0.8235294117647058
          test: 0.436     0.774     0.776     0.8768718801996672
          real: 1.329     0.293     0.206     0.9354838709677419
********************************************************************************
epoch: 450
                loss      accuracy  precision recall
      training: 0.309     0.855     0.869     0.8973305954825462
    validation: 0.521     0.695     0.734     0.7647058823529411
          test: 0.432     0.774     0.781     0.8668885191347754
          real: 1.334     0.295     0.207     0.9354838709677419
********************************************************************************
epoch: 460
                loss      accuracy  precision recall
      training: 0.287     0.889     0.904     0.9137577002053389
    validation: 0.540     0.725     0.750     0.8067226890756303
          test: 0.429     0.771     0.784     0.8552412645590682
          real: 1.324     0.297     0.206     0.9247311827956989
********************************************************************************
epoch: 470
                loss      accuracy  precision recall
      training: 0.288     0.874     0.894     0.8993839835728953
    validation: 0.503     0.745     0.770     0.8151260504201681
          test: 0.428     0.781     0.788     0.870216306156406
          real: 1.311     0.303     0.207     0.9247311827956989
********************************************************************************
epoch: 480
                loss      accuracy  precision recall
      training: 0.260     0.892     0.900     0.9260780287474333
    validation: 0.494     0.720     0.744     0.8067226890756303
          test: 0.432     0.777     0.785     0.8668885191347754
          real: 1.367     0.305     0.208     0.9247311827956989
********************************************************************************
epoch: 490
                loss      accuracy  precision recall
      training: 0.260     0.902     0.913     0.9281314168377823
    validation: 0.528     0.745     0.770     0.8151260504201681
          test: 0.428     0.777     0.792     0.8535773710482529
          real: 1.377     0.315     0.210     0.9247311827956989

Let’s plot the results, so we can look for any broad trends …

[15]:
import toyplot

canvas = toyplot.Canvas(width=600, height=600)
axes = canvas.cartesian(grid=(2, 2, 0), xlabel="Epoch", ylabel="Loss")
training = axes.plot(epochs, metrics["training/loss"])
validation = axes.plot(epochs, metrics["validation/loss"])
test = axes.plot(epochs, metrics["test/loss"])
real = axes.plot(epochs, metrics["real/loss"])

axes = canvas.cartesian(grid=(2, 2, 1), xlabel="Epoch", ylabel="Accuracy")
training = axes.plot(epochs, metrics["training/accuracy"])
validation = axes.plot(epochs, metrics["validation/accuracy"])
test = axes.plot(epochs, metrics["test/accuracy"])
real = axes.plot(epochs, metrics["real/accuracy"])

axes = canvas.cartesian(grid=(2, 2, 2), xlabel="Epoch", ylabel="Precision")
training = axes.plot(epochs, metrics["training/precision"])
validation = axes.plot(epochs, metrics["validation/precision"])
test = axes.plot(epochs, metrics["test/precision"])
real = axes.plot(epochs, metrics["real/precision"])

axes = canvas.cartesian(grid=(2, 2, 3), xlabel="Epoch", ylabel="Recall")
training = axes.plot(epochs, metrics["training/recall"])
validation = axes.plot(epochs, metrics["validation/recall"])
test = axes.plot(epochs, metrics["test/recall"])
real = axes.plot(epochs, metrics["real/recall"])

title = "Training {} validation {} test {} and real {} metrics".format(
    training.markers[0],
    validation.markers[0],
    test.markers[0],
    real.markers[0],
)
canvas.text(300, 15, text=title, style={"font-size": "16px", "font-weight":"bold"});
0100200300400500Epoch0.00.51.01.5Loss0100200300400500Epoch0.20.50.81.0Accuracy0100200300400500Epoch0.00.30.60.9Precision0100200300400500Epoch0.60.70.80.91.0RecallTraining validation test and real metrics

Looking at the training, validation, and test results (using synthetic data), we see patterns that are typical of a successful model:

  • Training loss steadily decreases, while validation and test loss appear to be nearing a minimum. Under normal circumstances, we would stop training at this point to avoid overfitting.

  • Test accuracy steadily increases, while precision and recall eventually reach a good balance.

  • Overall, we can have confidence that the train synthetic, test synthetic use case is working the way we expect.

For the results using real images, there are interesting patterns that we commonly find with the train synthetic, test real use-case:

  • The real loss function steadily increases (the opposite of what we would like) - this highlights the fact that the synthetic and real images are drawn from different distributions, even though the synthetic data is often mistaken for real data by humans.

  • In spite of this, and after an initial low-precision, high-recall configuration, we see gradual balancing of precision and recall, leading to steady improvements in accuracy. Although the accuracy is relatively low (around 32%), is is still better than the expected accuracy for a random classifier (around 19%; the proportion of real images that contain type-48 containers).

  • Furthermore, some of the poor precision can be traced to our decision to train only on type 48 containers; since the real data contains both 30B and type 48 containers, it’s likely that the model is confusing 30B containers for type 48 containers, something we can address by training using examples of each.

  • Even though our validation loss seems to be levelling out by epoch 500, the metrics for real-world data are still improving as of epoch 500, suggesting that further training could be useful.