基于 ResNet18 架构使用 deformable convolution的车道线检测

下面是一个基于关键点的车道线检测网络的 PyTorch 代码示例,其中使用了 deformable convolution。该代码示例基于 ResNet18 架构,可以根据实际情况进行修改。

首先,需要导入必要的库和模块:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter

from torchvision.models.resnet import resnet18

然后,定义一个基于 ResNet18 架构的车道线检测网络模型:

class LaneDetectionNet(nn.Module):

    def __init__(self, num_classes=1, deformable_groups=2):
        super(LaneDetectionNet, self).__init__()

        # load ResNet18
        self.resnet = resnet18(pretrained=True)

        # replace the first conv layer
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # add deformable convolutions
        self.resnet.layer1[0].conv1 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)
        self.resnet.layer1[0].conv2 = DeformConv2d(64, 64, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)

        self.resnet.layer2[0].conv1 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)
        self.resnet.layer2[0].conv2 = DeformConv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)

        self.resnet.layer3[0].conv1 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)
        self.resnet.layer3[0].conv2 = DeformConv2d(256, 256, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)

        self.resnet.layer4[0].conv1 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)
        self.resnet.layer4[0].conv2 = DeformConv2d(512, 512, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False, deformable_groups=deformable_groups)

        # add the output layers
        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.resnet(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

其中,DeformConv2d 是一个 deformable convolution 的实现类。其代码如下:

class DeformConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, deformable_groups=1):
        super(DeformConv2d, self).__init__()

        self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size[0] * kernel_size[1] * deformable_groups, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)

        self.weight = Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1]))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.deformable_groups = deformable_groups

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        offset = self.offset_conv(x)
        output = deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups)
        return output

最后,定义一个 deformable convolution 的实现函数 deform_conv2d,代码如下:

def deform_conv2d(input, offset, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1):
    # get shapes and parameters
    batch_size, in_channels, in_h, in_w = input.size()
    out_channels, _, kernel_h, kernel_w = weight.size()
    stride_h, stride_w = _pair(stride)
    pad_h, pad_w = _pair(padding)
    dilation_h, dilation_w = _pair(dilation)
    input_padded = F.pad(input, (pad_w, pad_w, pad_h, pad_h))

    # calculate output shape
    out_h = (in_h + 2*pad_h - dilation_h*(kernel_h-1) - 1) // stride_h + 1
    out_w = (in_w + 2*pad_w - dilation_w*(kernel_w-1) - 1) // stride_w + 1

    # unfold input and offset
    offset = offset.view(batch_size, deformable_groups, 2 * kernel_h * kernel_w, out_h, out_w)
    input_unfolded = F.unfold(input_padded, (kernel_h, kernel_w), dilation=dilation, stride=stride)

    # calculate output
    output = torch.zeros(batch_size, out_channels, out_h, out_w).to(input.device)
    weight = weight.view(1, out_channels, in_channels // groups, kernel_h, kernel_w).repeat(batch_size, 1, 1, 1, 1)
    for h in range(out_h):
        for w in range(out_w):
            input_region = input_unfolded[:, :, h, w].view(batch_size, -1, 1, 1)
            offset_region = offset[:, :, :, h, w]
            weight_region = weight
            output_region = F.conv2d(input_region, weight_region, bias=None, stride=1, padding=0, dilation=1, groups=deformable_groups)
            output_region = deformable_conv2d_compute(output_region, offset_region)
            output[:, :, h, w] = output_region.squeeze()

    if bias is not None:
        output += bias.view(1, -1, 1, 1)

    return output

其中,deformable_conv2d_compute 函数是 deformable convolution 的计算函数。它的代码如下:

def deformable_conv2d_compute(input, offset):
    # get shapes and parameters
    batch_size, out_channels, out_h, out_w = input.size()
    in_channels = offset.size(1) // 2

    # sample input according to offset
    grid_h = torch.linspace(-1, 1, out_h).view(1, 1, out_h, 1).to(input.device)
    grid_w = torch.linspace(-1, 1, out_w).view(1, 1, 1, out_w).to(input.device)
    offset_h = offset[:, :in_channels, :, :]
    offset_w = offset[:, in_channels:, :, :]
    sample_h = torch.add(grid_h, offset_h)
    sample_w = torch.add(grid_w, offset_w)
    sample_h = sample_h.clamp(-1, 1)
    sample_w = sample_w.clamp(-1, 1)
    sample_h = ((sample_h + 1) / 2) * (out_h - 1)
    sample_w = ((sample_w + 1) / 2) * (out_w - 1)
    sample_h_floor = sample_h.floor().long()
    sample_w_floor = sample_w.floor().long()
    sample_h_ceil = sample_h.ceil().long()
    sample_w_ceil = sample_w.ceil().long()
    sample_h_floor = sample_h_floor.clamp(0, out_h - 1)
    sample_w_floor = sample_w_floor.clamp(0, out_w - 1)
    sample_h_ceil = sample_h_ceil.clamp(0, out_h - 1)
    sample_w_ceil = sample_w_ceil.clamp(0, out_w - 1)

    # gather input values according to sampled indices
    input_flat = input.view(batch_size, in_channels, out_h * out_w)
    index_base = torch.arange(0, batch_size, device=input.device).view(batch_size, 1, 1) * out_h * out_w
    index_base = index_base.expand(batch_size, in_channels, out_h * out_w)
    index_offset = torch.arange(0, out_h * out_w, device=input.device).view(1, 1, -1)
    index_offset = index_offset.expand(batch_size, in_channels, out_h * out_w)
    indices_a = (sample_h_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)
    indices_b = (sample_w_floor + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)
    indices_c = (sample_h_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)
    indices_d = (sample_w_ceil + index_base + index_offset).view(batch_size, in_channels * out_h * out_w)
    value_a = input_flat.gather(2, indices_a.unsqueeze(1).repeat(1, out_channels, 1))
    value_b = input_flat.gather(2, indices_b.unsqueeze(1).repeat(1, out_channels, 1))
    value_c = input_flat.gather(2, indices_c.unsqueeze(1).repeat(1, out_channels, 1))
    value_d = input_flat.gather(2, indices_d.unsqueeze(1).repeat(1, out_channels, 1))

    # calculate interpolation weights and output
    w_a = ((sample_w_ceil - sample_w) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)
    w_b = ((sample_w - sample_w_floor) * (sample_h_ceil - sample_h)).view(batch_size, 1, out_h, out_w)
    w_c = ((sample_w_ceil - sample_w) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)
    w_d = ((sample_w - sample_w_floor) * (sample_h - sample_h_floor)).view(batch_size, 1, out_h, out_w)
    output = w_a * value_a + w_b * value_b + w_c * value_c + w_d * value_d

    return output

最后,可以使用以下代码进行网络的测试:

net = LaneDetectionNet(num_classes=1, deformable_groups=2)  # create the network
input = torch.randn(1, 3, 100, 100)  # create a random input tensor
output = net(input)  # feed it through the network
print(output.shape)  # print the output shape

输出的结果应该为 (1, 1, 1, 1)。这说明网络已经成功地将 100*100 的像素图压缩成了一个标量。可以根据实际情况进行调整和优化,来达到更好的性能。