from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from torch.ao.nn.quantized import FloatFunctional
from torchvision.models.resnet import conv1x1, conv3x3
class BottleneckQuantizeable(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
int = 4
expansion:
def __init__(
self,
int,
inplanes: int,
planes: int = 1,
stride: = None,
downsample: Optional[nn.Module] int = 1,
groups: int = 64,
base_width: int = 1,
dilation: = None,
norm_layer: Optional[Callable[..., nn.Module]] -> None:
) super().__init__()
if norm_layer is None:
= nn.BatchNorm2d
norm_layer = int(planes * (base_width / 64.0)) * groups
width # Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.ff_add = FloatFunctional()
def forward(self, x: Tensor) -> Tensor:
= x
identity
= self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out
= self.conv2(out)
out = self.bn2(out)
out = self.relu2(out)
out
= self.conv3(out)
out = self.bn3(out)
out
if self.downsample is not None:
= self.downsample(x)
identity
= self.ff_add.add(out, identity)
out = self.relu3(out)
out
return out
Model quantization is a powerful tool to reduce memory and compute by using lower precision data types. Although I’m familiar with the concept and a few high-level APIs, I hadn’t yet implemented it in PyTorch and wanted a deep dive for something non-trivial. After reading about PyTorch’s different quantization modes (eager vs graph fx) and methods (dynamic, post-training static, static aware-training), I decide to start with eager mode quantization aware training (QAT). At the time of writing this, eager mode quantization was a more mature feature and this seemed like a natural place to start.
In my research for a project, I came across multiple discussions online requesting either help or a tutorial for quantizing the backbone of an object detection model (faster R-CNN in this case). As far as I could tell there was nothing available so this was the perfect excuse.
So in this post, I will go through that process of QAT while quantizing the backbone of faster R-CNN and subsequently include analysis on the benefits.
ResNet and Feature Pyramid Network
This assumes some familiarity with the R-CNN architecture, but to refresh, the feature extraction backbone consists of two components; the resnet and the feature pyramid network. The FPN combines output from consecutive layers of the resnet (via upsampling) which allows it to extract semantic information at higher resolutions. These two components of the backbone can be quantized while the rest of the network still uses floating point precision.
From an implementation standpoint, there is a utility class IntermediateLayerGetter for extracting each layer output (no fully connected) from the resnet. And another convenience class for the FPN which takes the layer ouputs as input. Combining these two is BackboneWithFPN which is mostly just a thin wrapper around both.
As we are doing eager mode static quantization, we’ll need to prepare the model before we can train and subsequently quantize it.
Model Preparation
The first step in preparing the network for quantization is to create a modified bottleneck block. This isn’t obvious until you try to quantize the ResNet without it. You will get an error .. out += identity .. Could not run 'aten::add.out' ..
which means that PyTorch isn’t able to quantize the skip connection using the +=
operator in eager mode. This discussion on the pytorch forums was helpful for describing the error as well as how to fix it. The modified bottleneck block just uses FloatFunctional which has a quantized addition operator. I’m using ResNet 101 here but for much smaller networks you would want to modify the basic block. Also, the original bottleneck class reuses the ReLU layer which won’t work when fusing. Finding this blog post about quantizing ResNet was helpful for realizing and avoiding that pitfall.
Now that we have a quantizeable bottleneck, we can simply reference it when generating the ResNet. Even though the float functional operator was added, we can still load pretrained imagenet weights since the trainable submodules didn’t change. Note that the number of classes for the ResNet don’t matter here because we will extract intermediate layers and ignore the final fully connected layer.
from torchvision.models.resnet import ResNet, ResNet50_Weights, ResNet101_Weights
def resnet_101():
= ResNet(block=BottleneckQuantizeable, layers=[3, 4, 23, 3])
resnet =True))
resnet.load_state_dict(ResNet101_Weights.DEFAULT.get_state_dict(progressreturn resnet
= resnet_101() resnet
The next step is to pass the resnet to the IntermediateLayerGetter
. In addition to the resnet we created, this class also requires a dictionary of the layer names (to know what to extract). It returns an OrderedDict
of those layer outputs. Here’s an example using a toy image.
import torch
from torchvision.models._utils import IntermediateLayerGetter
= [1, 2, 3, 4] # get all 4 layers
returned_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} # {'layer1': 0, 'layer2': 1, ...}
return_layers
= IntermediateLayerGetter(resnet, return_layers=return_layers)
resnet_layers
= resnet_layers(torch.rand(1, 3, 200, 200)) # e.g. 200 x 200 image with 3 channels
out for k, v in out.items()] [(k, v.shape)
[('0', torch.Size([1, 256, 50, 50])),
('1', torch.Size([1, 512, 25, 25])),
('2', torch.Size([1, 1024, 13, 13])),
('3', torch.Size([1, 2048, 7, 7]))]
Again, the output of the resnet layers will be fed to the feature pyramid network. Before we can do that, the FPN also needs to be modified as it uses a +
addition operator. Note that there is also a functional F.interpolate
but that doesn’t actually need to be replaced. However, it does need to be referenced differently as importing torch.nn.functional as F
causes a namespace issue later with torchvision
.
from collections import OrderedDict
from typing import List, Dict
# importing as 'F' causes namespace collision with torchvision and QAT fails later
# import torch.nn.functional as F
import torch
from torchvision.ops.misc import Conv2dNormActivation
from torchvision.utils import _log_api_usage_once
from torchvision.ops.feature_pyramid_network import ExtraFPNBlock
class FeaturePyramidNetworkQuantizeable(nn.Module):
"""
Module that adds a FPN from on top of a set of feature maps. This is based on
`"Feature Pyramid Network for Object Detection" <https://arxiv.org/abs/1612.03144>`_.
The feature maps are currently supposed to be in increasing depth
order.
The input to the model is expected to be an OrderedDict[Tensor], containing
the feature maps on top of which the FPN will be added.
Args:
in_channels_list (list[int]): number of channels for each feature map that
is passed to the module
out_channels (int): number of channels of the FPN representation
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
be performed. It is expected to take the fpn features, the original
features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
Examples::
>>> m = torchvision.ops.FeaturePyramidNetwork([10, 20, 30], 5)
>>> # get some dummy data
>>> x = OrderedDict()
>>> x['feat0'] = torch.rand(1, 10, 64, 64)
>>> x['feat2'] = torch.rand(1, 20, 16, 16)
>>> x['feat3'] = torch.rand(1, 30, 8, 8)
>>> # compute the FPN on top of x
>>> output = m(x)
>>> print([(k, v.shape) for k, v in output.items()])
>>> # returns
>>> [('feat0', torch.Size([1, 5, 64, 64])),
>>> ('feat2', torch.Size([1, 5, 16, 16])),
>>> ('feat3', torch.Size([1, 5, 8, 8]))]
"""
= 2
_version
def __init__(
self,
int],
in_channels_list: List[int,
out_channels: = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]]
):super().__init__()
self)
_log_api_usage_once(self.inner_blocks = nn.ModuleList()
self.layer_blocks = nn.ModuleList()
for in_channels in in_channels_list:
if in_channels == 0:
raise ValueError("in_channels=0 is currently not supported")
= Conv2dNormActivation(
inner_block_module =1, padding=0, norm_layer=norm_layer, activation_layer=None
in_channels, out_channels, kernel_size
)= Conv2dNormActivation(
layer_block_module =3, norm_layer=norm_layer, activation_layer=None
out_channels, out_channels, kernel_size
)self.inner_blocks.append(inner_block_module)
self.layer_blocks.append(layer_block_module)
# initialize parameters now to avoid modifying the initialization of top_blocks
for m in self.modules():
if isinstance(m, nn.Conv2d):
=1)
nn.init.kaiming_uniform_(m.weight, aif m.bias is not None:
0)
nn.init.constant_(m.bias,
if extra_blocks is not None:
if not isinstance(extra_blocks, ExtraFPNBlock):
raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}")
self.extra_blocks = extra_blocks
self.ff_add = FloatFunctional()
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):= local_metadata.get("version", None)
version
if version is None or version < 2:
= len(self.inner_blocks)
num_blocks for block in ["inner_blocks", "layer_blocks"]:
for i in range(num_blocks):
for type in ["weight", "bias"]:
= f"{prefix}{block}.{i}.{type}"
old_key = f"{prefix}{block}.{i}.0.{type}"
new_key if old_key in state_dict:
= state_dict.pop(old_key)
state_dict[new_key]
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
"""
= len(self.inner_blocks)
num_blocks if idx < 0:
+= num_blocks
idx = x
out for i, module in enumerate(self.inner_blocks):
if i == idx:
= module(x)
out return out
def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
"""
= len(self.layer_blocks)
num_blocks if idx < 0:
+= num_blocks
idx = x
out for i, module in enumerate(self.layer_blocks):
if i == idx:
= module(x)
out return out
def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Computes the FPN for a set of feature maps.
Args:
x (OrderedDict[Tensor]): feature maps for each feature level.
Returns:
results (OrderedDict[Tensor]): feature maps after FPN layers.
They are ordered from the highest resolution first.
"""
# unpack OrderedDict into two lists for easier handling
= list(x.keys())
names = list(x.values())
x
= self.get_result_from_inner_blocks(x[-1], -1)
last_inner = []
results self.get_result_from_layer_blocks(last_inner, -1))
results.append(
for idx in range(len(x) - 2, -1, -1):
= self.get_result_from_inner_blocks(x[idx], idx)
inner_lateral = inner_lateral.shape[-2:]
feat_shape = torch.nn.functional.interpolate(last_inner, size=feat_shape, mode="nearest")
inner_top_down # last_inner = inner_lateral + inner_top_down
= self.ff_add.add(inner_lateral, inner_top_down)
last_inner 0, self.get_result_from_layer_blocks(last_inner, idx))
results.insert(
if self.extra_blocks is not None:
= self.extra_blocks(results, x, names)
results, names
# make it back an OrderedDict
= OrderedDict([(k, v) for k, v in zip(names, results)])
out
return out
As you can see from the signature of the modified FPN, it also needs to know input dimensions of each layer from the resnet. There are several ways to get this but one way is to simply get the number of features in the final module of each layer.
# from backbone_utils.py
# https://github.com/pytorch/vision/blob/main/torchvision/models/detection/backbone_utils.py#L145
# in_channels_stage2 = res101_layers.inplanes // 8
# in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
= []
in_channels_list for k1, m1 in resnet.named_children():
if 'layer' in k1:
-1].bn3.num_features))
in_channels_list.append((m1[
in_channels_list
[256, 512, 1024, 2048]
Next step is to create a modified BackboneWithFPN
that uses our FeaturePyramidNetworkQuantizeable
. Here we’ll also make sure that the inputs are quantized and the outputs subsequently dequantized so that they can be fed to the rest of the R-CNN.
One important note is that regular BatchNorm2d
is the default normalization layer which is used instead of FrozenBatchNorm2d
. Frozen batch norm is the recommended layer because batches are generally too small for good estimates of mean and variance statistics but that module isn’t quantizeable. So using regular batch norm could be unstable and less performant if those layers aren’t frozen before training.
from torchvision.ops.feature_pyramid_network import LastLevelMaxPool
from torchvision.models._utils import IntermediateLayerGetter
from torch.ao.quantization import QuantStub, DeQuantStub
class BackboneWithFPNQuantizeable(nn.Module):
def __init__(
self,
backbone: nn.Module,str, str],
return_layers: Dict[int],
in_channels_list: List[int,
out_channels: = None,
extra_blocks: Optional[ExtraFPNBlock] = None,
norm_layer: Optional[Callable[..., nn.Module]] -> None:
) super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
if extra_blocks is None:
= LastLevelMaxPool()
extra_blocks
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetworkQuantizeable(
=in_channels_list,
in_channels_list=out_channels,
out_channels=extra_blocks,
extra_blocks=norm_layer
norm_layer
)self.out_channels = out_channels
def forward(self, x: Tensor) -> Dict[str, Tensor]:
= self.quant(x)
x = self.body(x)
x = self.fpn(x)
x for k, v in x.items():
= self.dequant(v)
x[k] return x
Now we can create the modified backbone with FPN. Once created, there should be quant/dequant stubs visible in the network like so
BackboneWithFPNQuantizeable(
(quant): QuantStub()
(dequant): DeQuantStub()
(body): IntermediateLayerGetter(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BottleneckQuantizeable(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
...
# resnet = resnet_101()
# returned_layers = [1, 2, 3, 4]
# return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
# in_channels_list = []
# for k1, m1 in resnet.named_children():
# if 'layer' in k1:
# in_channels_list.append((m1[-1].bn3.num_features))
= BackboneWithFPNQuantizeable(
bbfpn =resnet,
backbone=return_layers,
return_layers=in_channels_list,
in_channels_list=256,
out_channels=None,
extra_blocks=None,
norm_layer
)# bbfpn
The last step is to plug in the modified backbone with FPN when creating the Faster R-CNN. Note that the number of classes is set to 2 (object or background) which is specific to the dataset used.
from torchvision.models.detection.faster_rcnn import FasterRCNN
= FasterRCNN(bbfpn, num_classes=2) quant_rcnn
Layer Fusion and Quantization Config
Before training and subsequently converting the model, we can fuse specific sequences of modules in the backbone. Fusing compresses the model making it smaller and run faster by combining modules like Conv2d-BatchNorm2d-ReLU
and Conv2d-BatchNorm2d
. After fusing you should see new fused modules in the network like ConvReLU2d
as well as Identity
where previous modules were.
FasterRCNN(
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
(backbone): BackboneWithFPNQuantizeable(
(body): IntermediateLayerGetter(
(conv1): ConvReLU2d(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
(1): ReLU(inplace=True)
)
(bn1): Identity()
(relu): Identity()
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BottleneckQuantizeable(
(conv1): ConvReLU2d(
(0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU(inplace=True)
)
(bn1): Identity()
(relu1): Identity()
...
from torch.ao.quantization import fuse_modules
eval()
quant_rcnn.# fuse stem
'conv1', 'bn1', 'relu']], inplace=True)
fuse_modules(quant_rcnn.backbone.body, [[# fuse blocks
for k1, m1 in quant_rcnn.backbone.body.named_children():
if "layer" in k1: # in sequential layer with blocks
for k2, m2 in m1.named_children():
"conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
fuse_modules(m2, [[for k3, m3 in m2.named_children():
if "downsample" in k3: # fuse downsample
"0", "1"]], inplace=True) fuse_modules(m3, [[
Before training, the quantization config needs to be set on the backbone only. And again, because the batches are so small, batch norm gets frozen (see this pytorch tutorial for another example). Last, I’ll freeze the stem and the first layer in the backbone since the pretrained imagenet weights were loaded. After preparation you should be able to see the observers in the network.
FasterRCNN(
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
(backbone): BackboneWithFPNQuantizeable(
(body): IntermediateLayerGetter(
(conv1): ConvReLU2d(
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)
(weight_fake_quant): FusedMovingAvgObsFakeQuantize(
fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
)
(activation_post_process): FusedMovingAvgObsFakeQuantize(
fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
(activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)
)
(bn1): Identity()
...
import re
import torch
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat
quant_rcnn.train()= get_default_qat_qconfig('fbgemm')
quant_rcnn.backbone.qconfig = prepare_qat(quant_rcnn, inplace=False)
quant_rcnn_prepared
= quant_rcnn_prepared.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
quant_rcnn_prepared
= False
quant_rcnn_prepared.backbone.body.conv1.weight.requires_grad for name, parameter in quant_rcnn_prepared.backbone.named_parameters():
if re.search(r".layer1", name):
= False parameter.requires_grad
Dataset, Training, and Conversion
I’ll be using the PennFudan dataset from the Torchvision object detection finetuning tutorial for QAT. Most of the code below is borrowed from that tutorial with slight modifications and no segmentation.
import os
import torch
from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F # careful namespace 'F'
from torchvision.transforms import v2 as T
class PennFudanDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# load all image files, sorting them to
# ensure that they are aligned
self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
# load images and masks
= os.path.join(self.root, "PNGImages", self.imgs[idx])
img_path = os.path.join(self.root, "PedMasks", self.masks[idx])
mask_path = read_image(img_path)
img = read_image(mask_path)
mask # instances are encoded as different colors
= torch.unique(mask)
obj_ids # first id is the background, so remove it
= obj_ids[1:]
obj_ids = len(obj_ids)
num_objs
# split the color-encoded mask into a set
# of binary masks
= (mask == obj_ids[:, None, None]).to(dtype=torch.uint8)
masks
# get bounding box coordinates for each mask
= masks_to_boxes(masks)
boxes
# there is only one class
= torch.ones((num_objs,), dtype=torch.int64)
labels
= idx
image_id = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
area # suppose all instances are not crowd
= torch.zeros((num_objs,), dtype=torch.int64)
iscrowd
# Wrap sample and targets into torchvision tv_tensors:
= tv_tensors.Image(img)
img
= {}
target "boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
target["masks"] = tv_tensors.Mask(masks)
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
target[
if self.transforms is not None:
= self.transforms(img, target)
img, target
return img, target
def __len__(self):
return len(self.imgs)
def get_transform(train):
= []
transforms if train:
0.5))
transforms.append(T.RandomHorizontalFlip(float, scale=True))
transforms.append(T.ToDtype(torch.
transforms.append(T.ToPureTensor())return T.Compose(transforms)
%%capture
"wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py")
os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py")
os.system(
!wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip
!unzip PennFudanPed.zip -d ./
import utils
from engine import train_one_epoch, evaluate
# train on the GPU or on the CPU, if a GPU is not available
= torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device
# our dataset has two classes only - background and person
= 2
num_classes # use our dataset and defined transformations
= PennFudanDataset('PennFudanPed', get_transform(train=True))
dataset = PennFudanDataset('PennFudanPed', get_transform(train=False))
dataset_test
# split the dataset in train and test set
= torch.randperm(len(dataset)).tolist()
indices = torch.utils.data.Subset(dataset, indices[:-50])
dataset = torch.utils.data.Subset(dataset_test, indices[-50:])
dataset_test
# define training and validation data loaders
= torch.utils.data.DataLoader(
data_loader
dataset,=2,
batch_size=True,
shuffle=1,
num_workers=utils.collate_fn
collate_fn
)
= torch.utils.data.DataLoader(
data_loader_test
dataset_test,=1,
batch_size=False,
shuffle=1,
num_workers=utils.collate_fn
collate_fn )
# move model to the right device
quant_rcnn_prepared.to(device)
# construct an optimizer
= [p for p in quant_rcnn_prepared.parameters() if p.requires_grad]
params = torch.optim.SGD(
optimizer
params,=0.005,
lr=0.9,
momentum=0.0005
weight_decay
)
# and a learning rate scheduler
= torch.optim.lr_scheduler.StepLR(
lr_scheduler
optimizer,=3,
step_size=0.1
gamma
)
# let's train it for 10 epochs
= 10
num_epochs
for epoch in range(num_epochs):
# train for one epoch, printing every 20 iterations
=20)
train_one_epoch(quant_rcnn_prepared, optimizer, data_loader, device, epoch, print_freq# update the learning rate
lr_scheduler.step()# evaluate on the test dataset
=device) evaluate(quant_rcnn_prepared, data_loader_test, device
Epoch: [0] [ 0/60] eta: 0:00:49 lr: 0.000090 loss: 1.5364 (1.5364) loss_classifier: 0.8895 (0.8895) loss_box_reg: 0.0007 (0.0007) loss_objectness: 0.6395 (0.6395) loss_rpn_box_reg: 0.0069 (0.0069) time: 0.8299 data: 0.1936 max mem: 6367
Epoch: [0] [20/60] eta: 0:00:25 lr: 0.001783 loss: 0.6043 (0.7210) loss_classifier: 0.1105 (0.2343) loss_box_reg: 0.0665 (0.0648) loss_objectness: 0.3679 (0.3956) loss_rpn_box_reg: 0.0270 (0.0263) time: 0.6163 data: 0.0064 max mem: 6367
Epoch: [0] [40/60] eta: 0:00:12 lr: 0.003476 loss: 0.3219 (0.5298) loss_classifier: 0.0957 (0.1841) loss_box_reg: 0.1185 (0.0904) loss_objectness: 0.0582 (0.2321) loss_rpn_box_reg: 0.0194 (0.0233) time: 0.6057 data: 0.0056 max mem: 6367
Epoch: [0] [59/60] eta: 0:00:00 lr: 0.005000 loss: 0.2767 (0.4805) loss_classifier: 0.0943 (0.1721) loss_box_reg: 0.1098 (0.0993) loss_objectness: 0.0459 (0.1849) loss_rpn_box_reg: 0.0160 (0.0242) time: 0.6120 data: 0.0057 max mem: 6367
Epoch: [0] Total time: 0:00:37 (0.6183 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:23 model_time: 0.2870 (0.2870) evaluator_time: 0.0120 (0.0120) time: 0.4753 data: 0.1745 max mem: 6367
Test: [49/50] eta: 0:00:00 model_time: 0.2425 (0.2465) evaluator_time: 0.0037 (0.0049) time: 0.2517 data: 0.0030 max mem: 6367
Test: Total time: 0:00:13 (0.2624 s / it)
Averaged stats: model_time: 0.2425 (0.2465) evaluator_time: 0.0037 (0.0049)
Accumulating evaluation results...
DONE (t=0.04s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.111
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.432
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.013
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.002
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.119
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.074
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.238
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.299
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.013
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.322
Epoch: [1] [ 0/60] eta: 0:00:50 lr: 0.005000 loss: 0.2406 (0.2406) loss_classifier: 0.0861 (0.0861) loss_box_reg: 0.1205 (0.1205) loss_objectness: 0.0287 (0.0287) loss_rpn_box_reg: 0.0053 (0.0053) time: 0.8447 data: 0.1997 max mem: 6367
Epoch: [1] [20/60] eta: 0:00:25 lr: 0.005000 loss: 0.2323 (0.2870) loss_classifier: 0.0756 (0.1012) loss_box_reg: 0.1083 (0.1261) loss_objectness: 0.0416 (0.0443) loss_rpn_box_reg: 0.0130 (0.0154) time: 0.6140 data: 0.0056 max mem: 6367
Epoch: [1] [40/60] eta: 0:00:12 lr: 0.005000 loss: 0.3195 (0.3082) loss_classifier: 0.1031 (0.1040) loss_box_reg: 0.1658 (0.1469) loss_objectness: 0.0288 (0.0396) loss_rpn_box_reg: 0.0207 (0.0177) time: 0.6341 data: 0.0056 max mem: 6983
Epoch: [1] [59/60] eta: 0:00:00 lr: 0.005000 loss: 0.2555 (0.2919) loss_classifier: 0.0822 (0.0958) loss_box_reg: 0.1283 (0.1443) loss_objectness: 0.0088 (0.0319) loss_rpn_box_reg: 0.0179 (0.0199) time: 0.6182 data: 0.0055 max mem: 6983
Epoch: [1] Total time: 0:00:37 (0.6288 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:24 model_time: 0.3046 (0.3046) evaluator_time: 0.0079 (0.0079) time: 0.4994 data: 0.1854 max mem: 6983
Test: [49/50] eta: 0:00:00 model_time: 0.2479 (0.2522) evaluator_time: 0.0022 (0.0028) time: 0.2579 data: 0.0033 max mem: 6983
Test: Total time: 0:00:13 (0.2663 s / it)
Averaged stats: model_time: 0.2479 (0.2522) evaluator_time: 0.0022 (0.0028)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.309
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.831
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.139
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.172
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.321
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.190
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.413
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.437
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.300
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.448
Epoch: [2] [ 0/60] eta: 0:00:51 lr: 0.005000 loss: 0.2004 (0.2004) loss_classifier: 0.0508 (0.0508) loss_box_reg: 0.1245 (0.1245) loss_objectness: 0.0078 (0.0078) loss_rpn_box_reg: 0.0173 (0.0173) time: 0.8501 data: 0.1844 max mem: 6983
Epoch: [2] [20/60] eta: 0:00:26 lr: 0.005000 loss: 0.2482 (0.2507) loss_classifier: 0.0578 (0.0676) loss_box_reg: 0.1521 (0.1533) loss_objectness: 0.0069 (0.0085) loss_rpn_box_reg: 0.0191 (0.0213) time: 0.6400 data: 0.0056 max mem: 6983
Epoch: [2] [40/60] eta: 0:00:12 lr: 0.005000 loss: 0.1892 (0.2265) loss_classifier: 0.0588 (0.0633) loss_box_reg: 0.1038 (0.1351) loss_objectness: 0.0061 (0.0092) loss_rpn_box_reg: 0.0143 (0.0189) time: 0.6334 data: 0.0054 max mem: 6983
Epoch: [2] [59/60] eta: 0:00:00 lr: 0.005000 loss: 0.1695 (0.2197) loss_classifier: 0.0545 (0.0631) loss_box_reg: 0.0872 (0.1279) loss_objectness: 0.0054 (0.0102) loss_rpn_box_reg: 0.0138 (0.0186) time: 0.6422 data: 0.0060 max mem: 6983
Epoch: [2] Total time: 0:00:38 (0.6448 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:23 model_time: 0.2871 (0.2871) evaluator_time: 0.0039 (0.0039) time: 0.4719 data: 0.1796 max mem: 6983
Test: [49/50] eta: 0:00:00 model_time: 0.2536 (0.2540) evaluator_time: 0.0017 (0.0021) time: 0.2615 data: 0.0031 max mem: 6983
Test: Total time: 0:00:13 (0.2672 s / it)
Averaged stats: model_time: 0.2536 (0.2540) evaluator_time: 0.0017 (0.0021)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.482
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.949
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.386
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.324
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.496
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.272
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.564
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.572
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.412
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.584
Epoch: [3] [ 0/60] eta: 0:00:53 lr: 0.000500 loss: 0.1430 (0.1430) loss_classifier: 0.0370 (0.0370) loss_box_reg: 0.0875 (0.0875) loss_objectness: 0.0089 (0.0089) loss_rpn_box_reg: 0.0096 (0.0096) time: 0.8995 data: 0.2071 max mem: 6983
Epoch: [3] [20/60] eta: 0:00:26 lr: 0.000500 loss: 0.1757 (0.1837) loss_classifier: 0.0432 (0.0492) loss_box_reg: 0.1194 (0.1130) loss_objectness: 0.0066 (0.0068) loss_rpn_box_reg: 0.0141 (0.0147) time: 0.6443 data: 0.0057 max mem: 6983
Epoch: [3] [40/60] eta: 0:00:13 lr: 0.000500 loss: 0.1488 (0.1730) loss_classifier: 0.0427 (0.0477) loss_box_reg: 0.0932 (0.1054) loss_objectness: 0.0046 (0.0062) loss_rpn_box_reg: 0.0106 (0.0136) time: 0.6576 data: 0.0057 max mem: 6983
Epoch: [3] [59/60] eta: 0:00:00 lr: 0.000500 loss: 0.1673 (0.1781) loss_classifier: 0.0469 (0.0508) loss_box_reg: 0.0972 (0.1080) loss_objectness: 0.0045 (0.0059) loss_rpn_box_reg: 0.0096 (0.0134) time: 0.6530 data: 0.0053 max mem: 6983
Epoch: [3] Total time: 0:00:39 (0.6593 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:23 model_time: 0.2835 (0.2835) evaluator_time: 0.0039 (0.0039) time: 0.4670 data: 0.1781 max mem: 6983
Test: [49/50] eta: 0:00:00 model_time: 0.2530 (0.2555) evaluator_time: 0.0013 (0.0018) time: 0.2578 data: 0.0032 max mem: 6983
Test: Total time: 0:00:13 (0.2685 s / it)
Averaged stats: model_time: 0.2530 (0.2555) evaluator_time: 0.0013 (0.0018)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.493
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.942
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.419
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.362
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.504
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.285
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.571
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.574
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.425
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.585
Epoch: [4] [ 0/60] eta: 0:00:54 lr: 0.000500 loss: 0.1610 (0.1610) loss_classifier: 0.0388 (0.0388) loss_box_reg: 0.1063 (0.1063) loss_objectness: 0.0052 (0.0052) loss_rpn_box_reg: 0.0107 (0.0107) time: 0.9067 data: 0.2171 max mem: 6983
Epoch: [4] [20/60] eta: 0:00:27 lr: 0.000500 loss: 0.1496 (0.1722) loss_classifier: 0.0469 (0.0516) loss_box_reg: 0.0883 (0.1041) loss_objectness: 0.0025 (0.0046) loss_rpn_box_reg: 0.0117 (0.0120) time: 0.6716 data: 0.0062 max mem: 6983
Epoch: [4] [40/60] eta: 0:00:13 lr: 0.000500 loss: 0.1522 (0.1744) loss_classifier: 0.0459 (0.0516) loss_box_reg: 0.1082 (0.1082) loss_objectness: 0.0038 (0.0043) loss_rpn_box_reg: 0.0068 (0.0103) time: 0.6789 data: 0.0058 max mem: 6983
Epoch: [4] [59/60] eta: 0:00:00 lr: 0.000500 loss: 0.1792 (0.1787) loss_classifier: 0.0513 (0.0527) loss_box_reg: 0.1076 (0.1115) loss_objectness: 0.0029 (0.0041) loss_rpn_box_reg: 0.0100 (0.0105) time: 0.6734 data: 0.0055 max mem: 6983
Epoch: [4] Total time: 0:00:40 (0.6817 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:23 model_time: 0.2894 (0.2894) evaluator_time: 0.0040 (0.0040) time: 0.4672 data: 0.1722 max mem: 6983
Test: [49/50] eta: 0:00:00 model_time: 0.2520 (0.2579) evaluator_time: 0.0014 (0.0018) time: 0.2610 data: 0.0030 max mem: 6983
Test: Total time: 0:00:13 (0.2706 s / it)
Averaged stats: model_time: 0.2520 (0.2579) evaluator_time: 0.0014 (0.0018)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.528
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.952
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.546
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.542
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.293
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.605
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.615
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.463
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.626
Epoch: [5] [ 0/60] eta: 0:00:52 lr: 0.000500 loss: 0.1348 (0.1348) loss_classifier: 0.0358 (0.0358) loss_box_reg: 0.0862 (0.0862) loss_objectness: 0.0018 (0.0018) loss_rpn_box_reg: 0.0110 (0.0110) time: 0.8712 data: 0.1820 max mem: 6983
Epoch: [5] [20/60] eta: 0:00:27 lr: 0.000500 loss: 0.1626 (0.1758) loss_classifier: 0.0444 (0.0485) loss_box_reg: 0.0996 (0.1133) loss_objectness: 0.0029 (0.0039) loss_rpn_box_reg: 0.0098 (0.0101) time: 0.6826 data: 0.0058 max mem: 6983
Epoch: [5] [40/60] eta: 0:00:13 lr: 0.000500 loss: 0.1343 (0.1651) loss_classifier: 0.0395 (0.0468) loss_box_reg: 0.0837 (0.1052) loss_objectness: 0.0023 (0.0037) loss_rpn_box_reg: 0.0089 (0.0095) time: 0.6910 data: 0.0059 max mem: 7287
Epoch: [5] [59/60] eta: 0:00:00 lr: 0.000500 loss: 0.1644 (0.1699) loss_classifier: 0.0499 (0.0484) loss_box_reg: 0.0954 (0.1076) loss_objectness: 0.0026 (0.0038) loss_rpn_box_reg: 0.0098 (0.0100) time: 0.6861 data: 0.0062 max mem: 7287
Epoch: [5] Total time: 0:00:41 (0.6929 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:24 model_time: 0.3069 (0.3069) evaluator_time: 0.0041 (0.0041) time: 0.4943 data: 0.1818 max mem: 7287
Test: [49/50] eta: 0:00:00 model_time: 0.2673 (0.2650) evaluator_time: 0.0014 (0.0018) time: 0.2721 data: 0.0030 max mem: 7287
Test: Total time: 0:00:13 (0.2780 s / it)
Averaged stats: model_time: 0.2673 (0.2650) evaluator_time: 0.0014 (0.0018)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.560
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.957
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.602
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.451
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.570
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.311
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.644
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.647
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.550
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.655
Epoch: [6] [ 0/60] eta: 0:00:55 lr: 0.000050 loss: 0.1775 (0.1775) loss_classifier: 0.0405 (0.0405) loss_box_reg: 0.1216 (0.1216) loss_objectness: 0.0029 (0.0029) loss_rpn_box_reg: 0.0125 (0.0125) time: 0.9273 data: 0.2030 max mem: 7287
Epoch: [6] [20/60] eta: 0:00:28 lr: 0.000050 loss: 0.1606 (0.1758) loss_classifier: 0.0500 (0.0530) loss_box_reg: 0.0945 (0.1087) loss_objectness: 0.0032 (0.0039) loss_rpn_box_reg: 0.0090 (0.0102) time: 0.6964 data: 0.0055 max mem: 7287
Epoch: [6] [40/60] eta: 0:00:14 lr: 0.000050 loss: 0.1137 (0.1597) loss_classifier: 0.0357 (0.0480) loss_box_reg: 0.0660 (0.0997) loss_objectness: 0.0018 (0.0034) loss_rpn_box_reg: 0.0057 (0.0086) time: 0.6934 data: 0.0057 max mem: 7287
Epoch: [6] [59/60] eta: 0:00:00 lr: 0.000050 loss: 0.1633 (0.1669) loss_classifier: 0.0410 (0.0493) loss_box_reg: 0.0954 (0.1050) loss_objectness: 0.0024 (0.0035) loss_rpn_box_reg: 0.0081 (0.0091) time: 0.6901 data: 0.0058 max mem: 7287
Epoch: [6] Total time: 0:00:41 (0.6998 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:24 model_time: 0.3050 (0.3050) evaluator_time: 0.0035 (0.0035) time: 0.4857 data: 0.1758 max mem: 7287
Test: [49/50] eta: 0:00:00 model_time: 0.2615 (0.2624) evaluator_time: 0.0013 (0.0018) time: 0.2673 data: 0.0031 max mem: 7287
Test: Total time: 0:00:13 (0.2757 s / it)
Averaged stats: model_time: 0.2615 (0.2624) evaluator_time: 0.0013 (0.0018)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.537
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.948
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.534
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.438
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.547
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.294
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.628
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.642
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.550
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.649
Epoch: [7] [ 0/60] eta: 0:00:55 lr: 0.000050 loss: 0.0931 (0.0931) loss_classifier: 0.0252 (0.0252) loss_box_reg: 0.0566 (0.0566) loss_objectness: 0.0020 (0.0020) loss_rpn_box_reg: 0.0093 (0.0093) time: 0.9269 data: 0.2053 max mem: 7287
Epoch: [7] [20/60] eta: 0:00:28 lr: 0.000050 loss: 0.1182 (0.1405) loss_classifier: 0.0339 (0.0406) loss_box_reg: 0.0713 (0.0877) loss_objectness: 0.0017 (0.0043) loss_rpn_box_reg: 0.0049 (0.0079) time: 0.6926 data: 0.0066 max mem: 7462
Epoch: [7] [40/60] eta: 0:00:14 lr: 0.000050 loss: 0.1324 (0.1550) loss_classifier: 0.0367 (0.0454) loss_box_reg: 0.0773 (0.0970) loss_objectness: 0.0020 (0.0038) loss_rpn_box_reg: 0.0087 (0.0088) time: 0.6966 data: 0.0055 max mem: 7462
Epoch: [7] [59/60] eta: 0:00:00 lr: 0.000050 loss: 0.1809 (0.1661) loss_classifier: 0.0493 (0.0488) loss_box_reg: 0.1123 (0.1040) loss_objectness: 0.0038 (0.0042) loss_rpn_box_reg: 0.0103 (0.0092) time: 0.6901 data: 0.0055 max mem: 7462
Epoch: [7] Total time: 0:00:42 (0.7001 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:24 model_time: 0.2994 (0.2994) evaluator_time: 0.0039 (0.0039) time: 0.4838 data: 0.1787 max mem: 7462
Test: [49/50] eta: 0:00:00 model_time: 0.2554 (0.2593) evaluator_time: 0.0013 (0.0017) time: 0.2629 data: 0.0029 max mem: 7462
Test: Total time: 0:00:13 (0.2721 s / it)
Averaged stats: model_time: 0.2554 (0.2593) evaluator_time: 0.0013 (0.0017)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.530
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.951
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.491
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.451
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.540
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.283
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.616
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.618
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.562
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.623
Epoch: [8] [ 0/60] eta: 0:00:58 lr: 0.000050 loss: 0.3004 (0.3004) loss_classifier: 0.0767 (0.0767) loss_box_reg: 0.2012 (0.2012) loss_objectness: 0.0051 (0.0051) loss_rpn_box_reg: 0.0173 (0.0173) time: 0.9718 data: 0.2255 max mem: 7462
Epoch: [8] [20/60] eta: 0:00:28 lr: 0.000050 loss: 0.1537 (0.1833) loss_classifier: 0.0469 (0.0556) loss_box_reg: 0.0927 (0.1127) loss_objectness: 0.0018 (0.0034) loss_rpn_box_reg: 0.0098 (0.0116) time: 0.7003 data: 0.0058 max mem: 7462
Epoch: [8] [40/60] eta: 0:00:14 lr: 0.000050 loss: 0.1291 (0.1668) loss_classifier: 0.0407 (0.0491) loss_box_reg: 0.0915 (0.1050) loss_objectness: 0.0029 (0.0032) loss_rpn_box_reg: 0.0055 (0.0094) time: 0.6929 data: 0.0056 max mem: 7462
Epoch: [8] [59/60] eta: 0:00:00 lr: 0.000050 loss: 0.1634 (0.1668) loss_classifier: 0.0514 (0.0498) loss_box_reg: 0.1034 (0.1045) loss_objectness: 0.0024 (0.0032) loss_rpn_box_reg: 0.0083 (0.0094) time: 0.6999 data: 0.0056 max mem: 7462
Epoch: [8] Total time: 0:00:42 (0.7050 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:24 model_time: 0.3030 (0.3030) evaluator_time: 0.0037 (0.0037) time: 0.4906 data: 0.1825 max mem: 7462
Test: [49/50] eta: 0:00:00 model_time: 0.2565 (0.2625) evaluator_time: 0.0012 (0.0017) time: 0.2623 data: 0.0030 max mem: 7462
Test: Total time: 0:00:13 (0.2753 s / it)
Averaged stats: model_time: 0.2565 (0.2625) evaluator_time: 0.0012 (0.0017)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.534
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.951
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.526
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.293
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.550
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.289
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.618
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.619
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.632
Epoch: [9] [ 0/60] eta: 0:00:57 lr: 0.000005 loss: 0.2058 (0.2058) loss_classifier: 0.0578 (0.0578) loss_box_reg: 0.1311 (0.1311) loss_objectness: 0.0029 (0.0029) loss_rpn_box_reg: 0.0140 (0.0140) time: 0.9631 data: 0.2252 max mem: 7462
Epoch: [9] [20/60] eta: 0:00:28 lr: 0.000005 loss: 0.1616 (0.1713) loss_classifier: 0.0448 (0.0518) loss_box_reg: 0.0901 (0.1049) loss_objectness: 0.0042 (0.0042) loss_rpn_box_reg: 0.0101 (0.0105) time: 0.7072 data: 0.0055 max mem: 7462
Epoch: [9] [40/60] eta: 0:00:14 lr: 0.000005 loss: 0.1449 (0.1669) loss_classifier: 0.0415 (0.0500) loss_box_reg: 0.1006 (0.1033) loss_objectness: 0.0021 (0.0038) loss_rpn_box_reg: 0.0069 (0.0098) time: 0.7147 data: 0.0056 max mem: 7462
Epoch: [9] [59/60] eta: 0:00:00 lr: 0.000005 loss: 0.1712 (0.1657) loss_classifier: 0.0517 (0.0498) loss_box_reg: 0.0998 (0.1030) loss_objectness: 0.0019 (0.0034) loss_rpn_box_reg: 0.0084 (0.0094) time: 0.7149 data: 0.0065 max mem: 7462
Epoch: [9] Total time: 0:00:43 (0.7191 s / it)
creating index...
index created!
Test: [ 0/50] eta: 0:00:24 model_time: 0.3057 (0.3057) evaluator_time: 0.0038 (0.0038) time: 0.4838 data: 0.1729 max mem: 7462
Test: [49/50] eta: 0:00:00 model_time: 0.2585 (0.2639) evaluator_time: 0.0013 (0.0018) time: 0.2670 data: 0.0030 max mem: 7462
Test: Total time: 0:00:13 (0.2766 s / it)
Averaged stats: model_time: 0.2585 (0.2639) evaluator_time: 0.0013 (0.0018)
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.535
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.950
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.550
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.346
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.550
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.297
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.624
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.634
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.475
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.646
Now to convert and save the model. Make sure to put the model on CPU before conversion to avoid any errors. After conversion you should see quantized modules like QuantizedConvReLU2d
.
FasterRCNN(
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
(backbone): BackboneWithFPNQuantizeable(
(quant): Quantize(scale=tensor([0.0375]), zero_point=tensor([57]), dtype=torch.quint8)
(dequant): DeQuantize()
(body): IntermediateLayerGetter(
(conv1): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.2553767263889313, zero_point=0, padding=(3, 3))
(bn1): Identity()
(relu): Identity()
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BottleneckQuantizeable(
(conv1): QuantizedConvReLU2d(64, 64, kernel_size=(1, 1), stride=(1, 1), scale=0.12423195689916611, zero_point=0)
(bn1): Identity()
(relu1): Identity()
...
from torch.ao.quantization import convert
eval()
quant_rcnn_prepared.'cpu'))
quant_rcnn_prepared.to(torch.device(
= convert(quant_rcnn_prepared, inplace=False)
quant_rcnn_converted
= "/content/quant_model.pth"
quant_model_path torch.save(quant_rcnn_converted.state_dict(), quant_model_path)
For comparison I’ll generate the same network without any modifications made for quantization (including fusion). Then we can compare model sizes and latency. Note that this is just comparing latency on the CPU, if the float model was on GPU it could be significantly faster depending upon the hardware.
%%capture
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models.resnet import Bottleneck
= resnet_101()
resnet_bb = FasterRCNN(
rcnn
BackboneWithFPN(=resnet_bb,
backbone=return_layers,
return_layers=in_channels_list,
in_channels_list=256,
out_channels=None,
extra_blocks=None,
norm_layer
),=2
num_classes
)
eval()
rcnn.'cpu'))
rcnn.to(torch.device(= "/content/float_model.pth"
model_path torch.save(rcnn.state_dict(), model_path)
print(f'size of quantized model: {round(os.path.getsize("/content/quant_model.pth") / 1e6)} MB')
print(f'size of float model: {round(os.path.getsize("/content/float_model.pth") / 1e6)} MB')
size of quantized model: 105 MB
size of float model: 242 MB
from time import perf_counter
'cpu'))
quant_rcnn_converted.to(torch.device(# just grab one test image/batch
= next(iter(data_loader_test))
images, targets = list(img.to(torch.device('cpu')) for img in images)
images = 10
n
= perf_counter()
start for _ in range(n):
= quant_rcnn_converted(images)
__ print(f"quant model avg time: {(perf_counter() - start) / n:.2f}")
= perf_counter()
start for _ in range(n):
= rcnn(images)
__ print(f"float model avg time (cpu): {(perf_counter() - start) / n:.2f}")
quant model avg time: 1.42
float model avg time (cpu): 2.20
I believe a fully quantized model would be even smaller and faster by comparison. In this case, while we did quantize the backbone for the R-CNN, it only accounted for roughly 75% of model parameters. So a significant number of float operations still occur after the quantized backbone.
= sum(p.numel() for p in rcnn.parameters() if p.requires_grad)
num_model_params = sum(p.numel() for p in rcnn.backbone.parameters() if p.requires_grad)
num_backbone_params
print(f"total number of parameters in model: {num_model_params}")
print(f"total number of parameters in backbone: {num_backbone_params}")
print(f"ratio of quantized parameters: {num_backbone_params / num_model_params:.2f}")
total number of parameters in model: 60344409
total number of parameters in backbone: 45844544
ratio of quantized parameters: 0.76
We can also profile each model to see where each spends the most time during a forward pass.
from torch.profiler import profile, record_function, ProfilerActivity
with profile(activities=[ProfilerActivity.CPU], record_shapes=False) as prof:
with record_function("model_inference"):
quant_rcnn_converted(images)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
with profile(activities=[ProfilerActivity.CPU], record_shapes=False) as prof:
with record_function("model_inference"):
rcnn(images)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
model_inference 2.15% 29.904ms 100.00% 1.388s 1.388s 1
quantized::conv2d_relu 26.61% 369.449ms 26.96% 374.209ms 5.585ms 67
quantized::conv2d 23.33% 323.889ms 23.44% 325.335ms 7.230ms 45
torchvision::roi_align 17.54% 243.538ms 20.78% 288.475ms 72.119ms 4
aten::conv2d 0.00% 64.000us 13.45% 186.652ms 12.443ms 15
aten::convolution 0.03% 400.000us 13.44% 186.588ms 12.439ms 15
aten::_convolution 0.02% 280.000us 13.41% 186.188ms 12.413ms 15
aten::mkldnn_convolution 13.25% 183.875ms 13.29% 184.541ms 12.303ms 15
aten::linear 0.01% 115.000us 5.54% 76.905ms 19.226ms 4
aten::addmm 5.50% 76.300ms 5.52% 76.641ms 19.160ms 4
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 1.388s
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg # of Calls
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
model_inference 3.80% 79.209ms 100.00% 2.086s 2.086s 1
aten::conv2d 0.04% 759.000us 66.00% 1.377s 10.841ms 127
aten::convolution 0.14% 3.015ms 65.96% 1.376s 10.835ms 127
aten::_convolution 0.11% 2.234ms 65.82% 1.373s 10.811ms 127
aten::mkldnn_convolution 65.44% 1.365s 65.71% 1.371s 10.793ms 127
torchvision::roi_align 14.10% 294.166ms 14.37% 299.835ms 59.967ms 5
aten::batch_norm 0.02% 462.000us 3.95% 82.490ms 793.173us 104
aten::_batch_norm_impl_index 0.05% 948.000us 3.93% 82.028ms 788.731us 104
aten::native_batch_norm 3.73% 77.820ms 3.87% 80.832ms 777.231us 104
aten::linear 0.00% 41.000us 3.54% 73.933ms 18.483ms 4
--------------------------------- ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.086s
The following loads the saved quantized model. It’s important that the same process of fusing, preparing, and converting be done before loading weights since quantization significantly alters the network. For sake of completeness, we can look at a prediction from the partially quantized R-CNN.
%%capture
= FasterRCNN(
quant_model_loaded
BackboneWithFPNQuantizeable(=resnet_101(),
backbone=return_layers,
return_layers=in_channels_list,
in_channels_list=256,
out_channels=None,
extra_blocks=None
norm_layer
),=2
num_classes
)
eval()
quant_model_loaded.'conv1', 'bn1', 'relu']], inplace=True)
fuse_modules(quant_model_loaded.backbone.body, [[for k1, m1 in quant_model_loaded.backbone.body.named_children():
if "layer" in k1: # in sequential layer with blocks
for k2, m2 in m1.named_children():
"conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
fuse_modules(m2, [[for k3, m3 in m2.named_children():
if "downsample" in k3: # fuse downsample
"0", "1"]], inplace=True)
fuse_modules(m3, [[
quant_model_loaded.train()= torch.quantization.get_default_qconfig('fbgemm')
quant_model_loaded.backbone.qconfig =True)
torch.quantization.prepare_qat(quant_model_loaded, inplace=True)
torch.quantization.convert(quant_model_loaded, inplace
eval()
quant_model_loaded.=torch.device('cpu'))) quant_model_loaded.load_state_dict(torch.load(quant_model_path, map_location
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
= read_image("PennFudanPed/PNGImages/FudanPed00022.png") # 7, 22
image = get_transform(train=False)
eval_transform
with torch.no_grad():
= eval_transform(image)
x # convert RGBA -> RGB and move to device
= x[:3, ...].to(torch.device('cpu'))
x = quant_model_loaded([x, ])
predictions = predictions[0]
pred
= 0.50
threshold = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
image = [f"pedestrian: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"]) if score > threshold]
pred_labels = pred["boxes"].long()[pred["scores"] > threshold]
pred_boxes
= draw_bounding_boxes(image, pred_boxes, pred_labels, colors="red")
output_image
# masks = (pred["masks"] > 0.7).squeeze(1)
# output_image = draw_segmentation_masks(output_image, masks, alpha=0.5, colors="blue")
=(12, 12))
plt.figure(figsize1, 2, 0)) plt.imshow(output_image.permute(
Although this post looks polished and the process linear, early attemps were anything but. I ran into all sorts of issues during the model preparation phase, which appears to be a major downside of eager mode quantization. In another post, I plan to use FX graph mode quantization to compare the two methods.