Classification
In this notebook we’ll demonstrate training a PyTorch resnet-50 classifier with Limbo data to determine whether an image contains a type 48 cylinder or not. We’ll use data from Campaign 17, which contains images of type 48G, 48X, and 48Y cylinders scattered among distractors in a synthetic outdoor industrial environment.
We will evaluate our model using synthetic data (a case we refer to as “train synthetic, test synthetic”) to confirm that it is learning, and using real-world images (a case we refer to as “train synthetic, test real”) to find out how well it would work in reality. Note that the latter is extremely challenging for models, and a ripe area for new research, which is why we created the Limbo data in the first place!
To begin, we’ll define a variable containing the path to the data:
[1]:
DATA_ROOT = "/mnt/mc1/limbo"
… if you’re running this notebook yourself, you’ll need to set DATA_ROOT to point to a directory containing Limbo campaign data that you’ve downloaded.
Every campaign in the Limbo data is split into numbered subdirectories, each of which contains one thousand images, which we will use to load one thousand synthetic images for training, one thousand synthetic images for (synthetic) testing, and all of the available real-world data for (real) testing:
[2]:
import os
import limbo.data
training_data = limbo.data.Dataset([os.path.join(DATA_ROOT, "campaign17", "0000")])
test_data = limbo.data.Dataset([os.path.join(DATA_ROOT, "campaign17", "0001")])
real_data = limbo.data.Dataset([os.path.join(DATA_ROOT, "ref")])
print(f"Loaded {len(training_data)} training images.")
print(f"Loaded {len(test_data)} test images.")
print(f"Loaded {len(real_data)} real images.")
Loaded 1000 training images.
Loaded 1000 test images.
Loaded 482 real images.
Tip
One thousand images for training is just enough to prove that our model works for purposes of pedagogy. In practice, we assume that you’ll train with more - most campaigns in the Limbo data contain 50000 images each. Note that you can pass more than one subdirectory to the Dataset initializer, or you can pass the top-level campaign directory to load everything.
Next, we’ll display some images to be sure we have the right datasets:
[3]:
from IPython.display import display, HTML
display(HTML("<h3>Training sample<h3>"))
display(training_data[0].image)
display(HTML("<h3>Test sample<h3>"))
display(test_data[0].image)
display(HTML("<h3>Real sample<h3>"))
display(real_data[0].image)
Training sample
Test sample
Real sample
… looking good!
To train our classifier to label each image as “48” or “not”, we’ll need to generate a corresponding set of labels. The samples in a Limbo dataset could contain dozens or even hundreds of annotations, so we’ll have to reduce that to a single label, which will be “1” for images that contain type 48 containers, and “0” for images that don’t.
To do so, we’ll scan the “categories” property (which returns a set containing every unique category that appears in a sample) for each sample to identify which ones include containers. Note that the following will probably take several minutes to execute, since it loads and parses a JSON file from disk for each sample … just be patient, and keep in mind that as you begin working with larger subsets of the data, you may wish to cache these kinds of intermediate results to disk to speed up your workflow:
[4]:
import numpy
targets = set(["48G", "48X", "48Y"])
training_labels = [1 if sample.categories & targets else 0 for sample in training_data]
test_labels = [1 if sample.categories & targets else 0 for sample in test_data]
real_labels = [1 if sample.categories & targets else 0 for sample in real_data]
Since there are three styles of type 48 container in the dataset, we use Python set intersection to test whether a sample contains any of the three.
It’s always a good idea to see how skewed your data is before you go too far, so let’s figure out how many images contain our objects of interest:
[5]:
print(f"{numpy.sum(training_labels) / len(training_labels):.1%} of training images have type 48 containers.")
print(f"{numpy.sum(test_labels) / len(test_labels):.1%} of test images have type 48 containers.")
print(f"{numpy.sum(real_labels) / len(real_labels):.1%} of real images have type 48 containers.")
60.6% of training images have type 48 containers.
60.1% of test images have type 48 containers.
19.3% of real images have type 48 containers.
An even mix of container and not-container images for training would be ideal, but ~60% containers isn’t too bad, so we’ll live with it for this demo. When conducting your own experiments, you will typically want to load more images than you need, downsampling based on the labels to achieve your desired balance between positive and negative examples.
Next, we need to wrap our Limbo dataset object to make it usable with PyTorch. We’ll create a simple adapter that loads Limbo training images on demand, resizes them to \(224\times224\), and caches them in memory to speed-up training since our data is relatively small:
[6]:
import torch.utils.data
class LimboToTorch(torch.utils.data.Dataset):
def __init__(self, dataset, labels, transforms):
self.dataset = dataset
self.labels = numpy.expand_dims(labels, 1).astype(numpy.float32)
self.transforms = transforms
self.cache = {}
def __len__(self):
return len(self.dataset)
def __getitem__(self, key):
if key not in self.cache:
sample = self.dataset[key]
image = sample.image
if "C" in image.layers:
image = (image.layers["C"].data * 255).astype(numpy.uint8)
elif "Y" in image.layers:
image = numpy.tile(image.layers["Y"].data * 255, (1, 1, 3)).astype(numpy.uint8)
image = torchvision.transforms.functional.to_pil_image(image, mode="RGB")
image = torchvision.transforms.functional.resize(image, (224, 224))
self.cache[key] = image
return self.transforms(self.cache[key]), self.labels[key]
Next, we’ll define a set of data augmentation transforms to apply to the images for training, and a much simpler set of transforms for evaluation:
[7]:
import torchvision
training_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomAffine(degrees=90, scale=(0.8, 1.2), shear=20, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, fill=(int(255*.485), int(255*.456), int(255*.406))),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
evaluation_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Now we can wrap our Limbo datasets for use with PyTorch:
[8]:
training_data = LimboToTorch(training_data, training_labels, training_transforms)
test_data = LimboToTorch(test_data, test_labels, evaluation_transforms)
real_data = LimboToTorch(real_data, real_labels, evaluation_transforms)
We will hold back 20% of our training data for validation:
[9]:
training_data, validation_data = torch.utils.data.random_split(training_data, [0.8, 0.2])
And create data loaders that can be used to iterate over the data in batches of 50 images:
[10]:
training_loader = torch.utils.data.DataLoader(training_data, batch_size=50, shuffle=True, num_workers=0)
validation_loader = torch.utils.data.DataLoader(validation_data, batch_size=50, shuffle=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=50, shuffle=False, num_workers=0)
real_loader = torch.utils.data.DataLoader(real_data, batch_size=50, shuffle=False, num_workers=0)
Our model will be based on a Resnet-50 model pre-trained on Imagenet. Since our problem only requires a single “48 or not” output, we will replace the model’s final, fully-connected output layer with one that produces a single sigmoid output. This output can be thought of as representing the model’s confidence that an image contains a type 48 container, and nicely coincides with the “1” and “0” labels we created earlier.
[11]:
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")
model.fc = torch.nn.Sequential(
torch.nn.Linear(2048, 1),
torch.nn.Sigmoid(),
)
We will train using nVidia GPU hardware; if you don’t have any, you can change the following to the string "cpu", but your training times will be very long:
[12]:
DEVICE = "cuda"
With DEVICE set, we can prepare our model for use on the hardware, and setup our loss function and optimizer:
[13]:
model = model.to(DEVICE)
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
Now we train! The following code will train and validate the model for 500 epochs using synthetic data, and evaluate it using synthetic and real test data. Various performance metrics are printed to stdout, and retained for later analysis.
Rant
You’ll notice that the following training code is extremely verbose. Machine learning researchers often tie themselves in knots trying to eliminate duplication in their training loops; we like tight loops as much as anyone, but don’t be that person - simplicity, clarity, and reproducibility are more important, even if they come at the expense of some repetition. DRY is good advice, but not if it obscures what you’re doing!
[14]:
import collections
import sklearn.metrics
metrics = collections.defaultdict(list)
epochs = numpy.arange(500)
for epoch in epochs:
if not epoch % 10 or epoch < 10:
print("*" * 80)
print(f"epoch: {epoch}")
# Train the model
losses = []
y_true = []
y_pred = []
model.train(True)
for x, y in training_loader:
x = x.to(DEVICE)
y = y.to(DEVICE)
y_hat = model(x)
loss = loss_function(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
y_true.append(y.cpu().detach())
y_pred.append(y_hat.cpu().detach())
losses = numpy.row_stack(losses)
y_true = numpy.row_stack(y_true)
y_pred = numpy.row_stack(y_pred)
metrics["training/loss"].append(numpy.mean(losses))
metrics["training/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
metrics["training/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
metrics["training/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))
# Validate the model
losses = []
y_true = []
y_pred = []
model.train(False)
for x, y in validation_loader:
x = x.to(DEVICE)
y = y.to(DEVICE)
y_hat = model(x)
loss = loss_function(y_hat, y)
losses.append(loss.item())
y_true.append(y.cpu().detach())
y_pred.append(y_hat.cpu().detach())
losses = numpy.row_stack(losses)
y_true = numpy.row_stack(y_true)
y_pred = numpy.row_stack(y_pred)
metrics["validation/loss"].append(numpy.mean(losses))
metrics["validation/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
metrics["validation/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
metrics["validation/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))
# Test the model on synthetic data
losses = []
y_true = []
y_pred = []
model.train(False)
for x, y in test_loader:
x = x.to(DEVICE)
y = y.to(DEVICE)
y_hat = model(x)
loss = loss_function(y_hat, y)
losses.append(loss.item())
y_true.append(y.cpu().detach())
y_pred.append(y_hat.cpu().detach())
losses = numpy.row_stack(losses)
y_true = numpy.row_stack(y_true)
y_pred = numpy.row_stack(y_pred)
metrics["test/loss"].append(numpy.mean(losses))
metrics["test/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
metrics["test/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
metrics["test/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))
# Test the model on real data
losses = []
y_true = []
y_pred = []
model.train(False)
for x, y in real_loader:
x = x.to(DEVICE)
y = y.to(DEVICE)
y_hat = model(x)
loss = loss_function(y_hat, y)
losses.append(loss.item())
y_true.append(y.cpu().detach())
y_pred.append(y_hat.cpu().detach())
losses = numpy.row_stack(losses)
y_true = numpy.row_stack(y_true)
y_pred = numpy.row_stack(y_pred)
metrics["real/loss"].append(numpy.mean(losses))
metrics["real/accuracy"].append(sklearn.metrics.accuracy_score(y_true, y_pred > 0.5))
metrics["real/precision"].append(sklearn.metrics.precision_score(y_true, y_pred > 0.5))
metrics["real/recall"].append(sklearn.metrics.recall_score(y_true, y_pred > 0.5))
if not epoch % 10 or epoch < 10:
print(f"{'':>16}{'loss':<10}{'accuracy':<10}{'precision':<10}{'recall':<10}")
print(f"{'training: ':>16}", end="")
print(f"{metrics['training/loss'][-1]:<10.3f}", end="")
print(f"{metrics['training/accuracy'][-1]:<10.3f}", end="")
print(f"{metrics['training/precision'][-1]:<10.3f}", end="")
print(f"{metrics['training/recall'][-1]:<10}")
print(f"{'validation: ':>16}", end="")
print(f"{metrics['validation/loss'][-1]:<10.3f}", end="")
print(f"{metrics['validation/accuracy'][-1]:<10.3f}", end="")
print(f"{metrics['validation/precision'][-1]:<10.3f}", end="")
print(f"{metrics['validation/recall'][-1]:<10}")
print(f"{'test: ':>16}", end="")
print(f"{metrics['test/loss'][-1]:<10.3f}", end="")
print(f"{metrics['test/accuracy'][-1]:<10.3f}", end="")
print(f"{metrics['test/precision'][-1]:<10.3f}", end="")
print(f"{metrics['test/recall'][-1]:<10}")
print(f"{'real: ':>16}", end="")
print(f"{metrics['real/loss'][-1]:<10.3f}", end="")
print(f"{metrics['real/accuracy'][-1]:<10.3f}", end="")
print(f"{metrics['real/precision'][-1]:<10.3f}", end="")
print(f"{metrics['real/recall'][-1]:<10}")
********************************************************************************
epoch: 0
loss accuracy precision recall
training: 0.692 0.520 0.589 0.7002053388090349
validation: 0.684 0.540 0.589 0.7478991596638656
test: 0.689 0.536 0.626 0.56738768718802
real: 0.727 0.297 0.147 0.5483870967741935
********************************************************************************
epoch: 1
loss accuracy precision recall
training: 0.686 0.571 0.607 0.8418891170431212
validation: 0.684 0.590 0.613 0.8403361344537815
test: 0.681 0.588 0.626 0.7803660565723793
real: 0.750 0.249 0.188 0.8709677419354839
********************************************************************************
epoch: 2
loss accuracy precision recall
training: 0.680 0.585 0.608 0.893223819301848
validation: 0.680 0.575 0.592 0.9159663865546218
test: 0.676 0.625 0.631 0.9068219633943427
real: 0.765 0.207 0.184 0.9032258064516129
********************************************************************************
epoch: 3
loss accuracy precision recall
training: 0.676 0.615 0.619 0.9548254620123203
validation: 0.679 0.600 0.604 0.9495798319327731
test: 0.673 0.626 0.627 0.9351081530782029
real: 0.778 0.203 0.188 0.946236559139785
********************************************************************************
epoch: 4
loss accuracy precision recall
training: 0.676 0.604 0.611 0.9589322381930184
validation: 0.676 0.605 0.604 0.9747899159663865
test: 0.670 0.628 0.624 0.9584026622296173
real: 0.790 0.189 0.187 0.956989247311828
********************************************************************************
epoch: 5
loss accuracy precision recall
training: 0.671 0.608 0.610 0.9835728952772074
validation: 0.668 0.605 0.601 1.0
test: 0.667 0.624 0.618 0.978369384359401
real: 0.802 0.193 0.192 0.989247311827957
********************************************************************************
epoch: 6
loss accuracy precision recall
training: 0.671 0.610 0.611 0.9917864476386037
validation: 0.674 0.600 0.600 0.9831932773109243
test: 0.665 0.621 0.616 0.9833610648918469
real: 0.810 0.193 0.193 1.0
********************************************************************************
epoch: 7
loss accuracy precision recall
training: 0.666 0.609 0.610 0.9917864476386037
validation: 0.666 0.605 0.601 1.0
test: 0.662 0.623 0.615 0.9966722129783694
real: 0.821 0.193 0.193 1.0
********************************************************************************
epoch: 8
loss accuracy precision recall
training: 0.666 0.611 0.611 0.9938398357289527
validation: 0.668 0.600 0.598 1.0
test: 0.660 0.625 0.616 0.9966722129783694
real: 0.826 0.193 0.193 1.0
********************************************************************************
epoch: 9
loss accuracy precision recall
training: 0.661 0.611 0.611 0.9958932238193019
validation: 0.669 0.605 0.601 1.0
test: 0.658 0.621 0.613 0.9983361064891847
real: 0.834 0.193 0.193 1.0
********************************************************************************
epoch: 10
loss accuracy precision recall
training: 0.662 0.615 0.613 0.997946611909651
validation: 0.667 0.605 0.601 1.0
test: 0.656 0.619 0.612 0.9983361064891847
real: 0.847 0.193 0.193 1.0
********************************************************************************
epoch: 20
loss accuracy precision recall
training: 0.647 0.610 0.610 1.0
validation: 0.652 0.600 0.598 1.0
test: 0.644 0.614 0.609 1.0
real: 0.907 0.193 0.193 1.0
********************************************************************************
epoch: 30
loss accuracy precision recall
training: 0.641 0.615 0.613 1.0
validation: 0.645 0.605 0.601 1.0
test: 0.636 0.620 0.613 1.0
real: 0.945 0.193 0.193 1.0
********************************************************************************
epoch: 40
loss accuracy precision recall
training: 0.632 0.626 0.620 0.997946611909651
validation: 0.635 0.620 0.611 0.9915966386554622
test: 0.627 0.629 0.619 0.9966722129783694
real: 0.968 0.193 0.193 1.0
********************************************************************************
epoch: 50
loss accuracy precision recall
training: 0.623 0.636 0.627 0.9938398357289527
validation: 0.633 0.635 0.624 0.9747899159663865
test: 0.620 0.641 0.627 0.9916805324459235
real: 0.990 0.193 0.193 1.0
********************************************************************************
epoch: 60
loss accuracy precision recall
training: 0.612 0.662 0.645 0.9897330595482546
validation: 0.625 0.635 0.624 0.9747899159663865
test: 0.612 0.640 0.628 0.9866888519134775
real: 0.998 0.193 0.193 1.0
********************************************************************************
epoch: 70
loss accuracy precision recall
training: 0.604 0.672 0.653 0.9835728952772074
validation: 0.625 0.635 0.625 0.9663865546218487
test: 0.604 0.653 0.637 0.9833610648918469
real: 1.011 0.193 0.193 1.0
********************************************************************************
epoch: 80
loss accuracy precision recall
training: 0.591 0.688 0.666 0.9753593429158111
validation: 0.615 0.640 0.630 0.957983193277311
test: 0.596 0.659 0.642 0.978369384359401
real: 1.032 0.193 0.193 1.0
********************************************************************************
epoch: 90
loss accuracy precision recall
training: 0.587 0.693 0.674 0.9568788501026694
validation: 0.609 0.665 0.648 0.957983193277311
test: 0.588 0.671 0.651 0.9750415973377704
real: 1.033 0.193 0.193 1.0
********************************************************************************
epoch: 100
loss accuracy precision recall
training: 0.575 0.696 0.677 0.9589322381930184
validation: 0.591 0.670 0.650 0.9663865546218487
test: 0.583 0.675 0.655 0.9683860232945092
real: 1.041 0.193 0.193 1.0
********************************************************************************
epoch: 110
loss accuracy precision recall
training: 0.562 0.726 0.701 0.9609856262833676
validation: 0.571 0.690 0.669 0.9495798319327731
test: 0.570 0.690 0.667 0.9650582362728786
real: 1.058 0.195 0.193 1.0
********************************************************************************
epoch: 120
loss accuracy precision recall
training: 0.557 0.719 0.703 0.9322381930184805
validation: 0.580 0.665 0.662 0.8907563025210085
test: 0.564 0.692 0.670 0.9600665557404326
real: 1.071 0.195 0.193 1.0
********************************************************************************
epoch: 130
loss accuracy precision recall
training: 0.551 0.720 0.706 0.9240246406570842
validation: 0.571 0.675 0.657 0.9495798319327731
test: 0.557 0.693 0.672 0.9550748752079867
real: 1.078 0.195 0.193 1.0
********************************************************************************
epoch: 140
loss accuracy precision recall
training: 0.538 0.723 0.712 0.9137577002053389
validation: 0.570 0.670 0.667 0.8907563025210085
test: 0.548 0.697 0.677 0.9484193011647255
real: 1.085 0.197 0.194 1.0
********************************************************************************
epoch: 150
loss accuracy precision recall
training: 0.531 0.743 0.731 0.9117043121149897
validation: 0.563 0.660 0.665 0.865546218487395
test: 0.542 0.702 0.681 0.9467554076539102
real: 1.080 0.203 0.194 0.989247311827957
********************************************************************************
epoch: 160
loss accuracy precision recall
training: 0.521 0.738 0.726 0.9137577002053389
validation: 0.557 0.685 0.677 0.8991596638655462
test: 0.533 0.712 0.692 0.940099833610649
real: 1.094 0.205 0.195 1.0
********************************************************************************
epoch: 170
loss accuracy precision recall
training: 0.517 0.746 0.738 0.9034907597535934
validation: 0.564 0.680 0.680 0.8739495798319328
test: 0.526 0.716 0.694 0.9450915141430949
real: 1.104 0.207 0.196 1.0
********************************************************************************
epoch: 180
loss accuracy precision recall
training: 0.507 0.748 0.744 0.891170431211499
validation: 0.546 0.685 0.692 0.8487394957983193
test: 0.517 0.727 0.709 0.9251247920133111
real: 1.102 0.212 0.197 1.0
********************************************************************************
epoch: 190
loss accuracy precision recall
training: 0.492 0.766 0.756 0.9096509240246407
validation: 0.547 0.690 0.691 0.865546218487395
test: 0.516 0.724 0.701 0.9417637271214643
real: 1.110 0.212 0.197 1.0
********************************************************************************
epoch: 200
loss accuracy precision recall
training: 0.480 0.771 0.769 0.893223819301848
validation: 0.550 0.705 0.708 0.8571428571428571
test: 0.510 0.730 0.708 0.9367720465890182
real: 1.122 0.216 0.197 1.0
********************************************************************************
epoch: 210
loss accuracy precision recall
training: 0.474 0.776 0.781 0.8788501026694046
validation: 0.535 0.670 0.680 0.8403361344537815
test: 0.502 0.737 0.716 0.9334442595673876
real: 1.133 0.218 0.197 0.989247311827957
********************************************************************************
epoch: 220
loss accuracy precision recall
training: 0.467 0.771 0.774 0.8809034907597536
validation: 0.541 0.705 0.711 0.8487394957983193
test: 0.496 0.743 0.724 0.9251247920133111
real: 1.146 0.216 0.196 0.989247311827957
********************************************************************************
epoch: 230
loss accuracy precision recall
training: 0.456 0.791 0.792 0.891170431211499
validation: 0.520 0.715 0.715 0.865546218487395
test: 0.487 0.748 0.729 0.9234608985024958
real: 1.150 0.222 0.199 1.0
********************************************************************************
epoch: 240
loss accuracy precision recall
training: 0.446 0.794 0.795 0.891170431211499
validation: 0.530 0.695 0.713 0.8151260504201681
test: 0.484 0.745 0.729 0.9151414309484193
real: 1.158 0.220 0.197 0.989247311827957
********************************************************************************
epoch: 250
loss accuracy precision recall
training: 0.441 0.786 0.797 0.8706365503080082
validation: 0.516 0.705 0.724 0.8151260504201681
test: 0.479 0.760 0.745 0.913477537437604
real: 1.157 0.228 0.199 0.989247311827957
********************************************************************************
epoch: 260
loss accuracy precision recall
training: 0.431 0.806 0.807 0.8952772073921971
validation: 0.517 0.715 0.721 0.8487394957983193
test: 0.474 0.752 0.738 0.9118136439267887
real: 1.179 0.226 0.198 0.989247311827957
********************************************************************************
epoch: 270
loss accuracy precision recall
training: 0.437 0.806 0.810 0.891170431211499
validation: 0.512 0.720 0.730 0.8403361344537815
test: 0.474 0.757 0.740 0.9184692179700499
real: 1.166 0.226 0.197 0.978494623655914
********************************************************************************
epoch: 280
loss accuracy precision recall
training: 0.425 0.801 0.812 0.8767967145790554
validation: 0.508 0.710 0.729 0.8151260504201681
test: 0.466 0.757 0.746 0.9034941763727121
real: 1.167 0.239 0.200 0.978494623655914
********************************************************************************
epoch: 290
loss accuracy precision recall
training: 0.420 0.805 0.812 0.8850102669404517
validation: 0.532 0.685 0.703 0.8151260504201681
test: 0.466 0.759 0.747 0.9068219633943427
real: 1.179 0.239 0.200 0.978494623655914
********************************************************************************
epoch: 300
loss accuracy precision recall
training: 0.414 0.809 0.821 0.8767967145790554
validation: 0.522 0.715 0.742 0.7983193277310925
test: 0.460 0.767 0.755 0.9068219633943427
real: 1.176 0.245 0.201 0.978494623655914
********************************************************************************
epoch: 310
loss accuracy precision recall
training: 0.398 0.821 0.831 0.8870636550308009
validation: 0.497 0.720 0.733 0.8319327731092437
test: 0.456 0.764 0.755 0.9001663893510815
real: 1.186 0.245 0.200 0.967741935483871
********************************************************************************
epoch: 320
loss accuracy precision recall
training: 0.405 0.818 0.822 0.893223819301848
validation: 0.493 0.705 0.734 0.7899159663865546
test: 0.453 0.766 0.758 0.8968386023294509
real: 1.186 0.253 0.201 0.967741935483871
********************************************************************************
epoch: 330
loss accuracy precision recall
training: 0.396 0.807 0.825 0.8685831622176592
validation: 0.507 0.730 0.744 0.8319327731092437
test: 0.450 0.765 0.759 0.891846921797005
real: 1.202 0.251 0.201 0.967741935483871
********************************************************************************
epoch: 340
loss accuracy precision recall
training: 0.372 0.839 0.848 0.8952772073921971
validation: 0.506 0.730 0.756 0.8067226890756303
test: 0.448 0.766 0.761 0.8901830282861897
real: 1.200 0.259 0.203 0.967741935483871
********************************************************************************
epoch: 350
loss accuracy precision recall
training: 0.376 0.824 0.841 0.8767967145790554
validation: 0.533 0.715 0.738 0.8067226890756303
test: 0.444 0.770 0.766 0.8885191347753744
real: 1.208 0.263 0.204 0.967741935483871
********************************************************************************
epoch: 360
loss accuracy precision recall
training: 0.364 0.849 0.859 0.8993839835728953
validation: 0.523 0.705 0.746 0.7647058823529411
test: 0.443 0.777 0.773 0.8901830282861897
real: 1.213 0.266 0.204 0.967741935483871
********************************************************************************
epoch: 370
loss accuracy precision recall
training: 0.354 0.848 0.861 0.893223819301848
validation: 0.503 0.710 0.744 0.7815126050420168
test: 0.439 0.773 0.777 0.8735440931780366
real: 1.223 0.276 0.206 0.967741935483871
********************************************************************************
epoch: 380
loss accuracy precision recall
training: 0.340 0.855 0.864 0.9034907597535934
validation: 0.503 0.750 0.776 0.8151260504201681
test: 0.439 0.777 0.776 0.8851913477537438
real: 1.237 0.278 0.206 0.956989247311828
********************************************************************************
epoch: 390
loss accuracy precision recall
training: 0.340 0.856 0.866 0.9034907597535934
validation: 0.509 0.710 0.756 0.7563025210084033
test: 0.438 0.777 0.776 0.8851913477537438
real: 1.248 0.284 0.207 0.956989247311828
********************************************************************************
epoch: 400
loss accuracy precision recall
training: 0.330 0.873 0.883 0.9117043121149897
validation: 0.505 0.720 0.760 0.773109243697479
test: 0.434 0.775 0.776 0.8785357737104825
real: 1.251 0.293 0.208 0.946236559139785
********************************************************************************
epoch: 410
loss accuracy precision recall
training: 0.325 0.875 0.885 0.9137577002053389
validation: 0.520 0.715 0.728 0.8319327731092437
test: 0.435 0.775 0.780 0.8718801996672213
real: 1.260 0.288 0.207 0.946236559139785
********************************************************************************
epoch: 420
loss accuracy precision recall
training: 0.308 0.876 0.883 0.917864476386037
validation: 0.495 0.750 0.767 0.8319327731092437
test: 0.432 0.771 0.778 0.8668885191347754
real: 1.277 0.284 0.206 0.946236559139785
********************************************************************************
epoch: 430
loss accuracy precision recall
training: 0.318 0.849 0.867 0.8870636550308009
validation: 0.501 0.735 0.754 0.8235294117647058
test: 0.431 0.778 0.780 0.8785357737104825
real: 1.292 0.290 0.206 0.9354838709677419
********************************************************************************
epoch: 440
loss accuracy precision recall
training: 0.291 0.891 0.913 0.9075975359342916
validation: 0.501 0.740 0.760 0.8235294117647058
test: 0.436 0.774 0.776 0.8768718801996672
real: 1.329 0.293 0.206 0.9354838709677419
********************************************************************************
epoch: 450
loss accuracy precision recall
training: 0.309 0.855 0.869 0.8973305954825462
validation: 0.521 0.695 0.734 0.7647058823529411
test: 0.432 0.774 0.781 0.8668885191347754
real: 1.334 0.295 0.207 0.9354838709677419
********************************************************************************
epoch: 460
loss accuracy precision recall
training: 0.287 0.889 0.904 0.9137577002053389
validation: 0.540 0.725 0.750 0.8067226890756303
test: 0.429 0.771 0.784 0.8552412645590682
real: 1.324 0.297 0.206 0.9247311827956989
********************************************************************************
epoch: 470
loss accuracy precision recall
training: 0.288 0.874 0.894 0.8993839835728953
validation: 0.503 0.745 0.770 0.8151260504201681
test: 0.428 0.781 0.788 0.870216306156406
real: 1.311 0.303 0.207 0.9247311827956989
********************************************************************************
epoch: 480
loss accuracy precision recall
training: 0.260 0.892 0.900 0.9260780287474333
validation: 0.494 0.720 0.744 0.8067226890756303
test: 0.432 0.777 0.785 0.8668885191347754
real: 1.367 0.305 0.208 0.9247311827956989
********************************************************************************
epoch: 490
loss accuracy precision recall
training: 0.260 0.902 0.913 0.9281314168377823
validation: 0.528 0.745 0.770 0.8151260504201681
test: 0.428 0.777 0.792 0.8535773710482529
real: 1.377 0.315 0.210 0.9247311827956989
Let’s plot the results, so we can look for any broad trends …
[15]:
import toyplot
canvas = toyplot.Canvas(width=600, height=600)
axes = canvas.cartesian(grid=(2, 2, 0), xlabel="Epoch", ylabel="Loss")
training = axes.plot(epochs, metrics["training/loss"])
validation = axes.plot(epochs, metrics["validation/loss"])
test = axes.plot(epochs, metrics["test/loss"])
real = axes.plot(epochs, metrics["real/loss"])
axes = canvas.cartesian(grid=(2, 2, 1), xlabel="Epoch", ylabel="Accuracy")
training = axes.plot(epochs, metrics["training/accuracy"])
validation = axes.plot(epochs, metrics["validation/accuracy"])
test = axes.plot(epochs, metrics["test/accuracy"])
real = axes.plot(epochs, metrics["real/accuracy"])
axes = canvas.cartesian(grid=(2, 2, 2), xlabel="Epoch", ylabel="Precision")
training = axes.plot(epochs, metrics["training/precision"])
validation = axes.plot(epochs, metrics["validation/precision"])
test = axes.plot(epochs, metrics["test/precision"])
real = axes.plot(epochs, metrics["real/precision"])
axes = canvas.cartesian(grid=(2, 2, 3), xlabel="Epoch", ylabel="Recall")
training = axes.plot(epochs, metrics["training/recall"])
validation = axes.plot(epochs, metrics["validation/recall"])
test = axes.plot(epochs, metrics["test/recall"])
real = axes.plot(epochs, metrics["real/recall"])
title = "Training {} validation {} test {} and real {} metrics".format(
training.markers[0],
validation.markers[0],
test.markers[0],
real.markers[0],
)
canvas.text(300, 15, text=title, style={"font-size": "16px", "font-weight":"bold"});
Looking at the training, validation, and test results (using synthetic data), we see patterns that are typical of a successful model:
Training loss steadily decreases, while validation and test loss appear to be nearing a minimum. Under normal circumstances, we would stop training at this point to avoid overfitting.
Test accuracy steadily increases, while precision and recall eventually reach a good balance.
Overall, we can have confidence that the train synthetic, test synthetic use case is working the way we expect.
For the results using real images, there are interesting patterns that we commonly find with the train synthetic, test real use-case:
The real loss function steadily increases (the opposite of what we would like) - this highlights the fact that the synthetic and real images are drawn from different distributions, even though the synthetic data is often mistaken for real data by humans.
In spite of this, and after an initial low-precision, high-recall configuration, we see gradual balancing of precision and recall, leading to steady improvements in accuracy. Although the accuracy is relatively low (around 32%), is is still better than the expected accuracy for a random classifier (around 19%; the proportion of real images that contain type-48 containers).
Furthermore, some of the poor precision can be traced to our decision to train only on type 48 containers; since the real data contains both 30B and type 48 containers, it’s likely that the model is confusing 30B containers for type 48 containers, something we can address by training using examples of each.
Even though our validation loss seems to be levelling out by epoch 500, the metrics for real-world data are still improving as of epoch 500, suggesting that further training could be useful.