Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_parameters(model, bias=False):
import torch.nn as nn
modules_skipped = (
nn.ReLU,
nn.MaxPool2d,
nn.Dropout2d,
nn.Sequential,
torchfcn.models.FCN32s,
torchfcn.models.FCN16s,
torchfcn.models.FCN8s,
)
for m in model.modules():
if isinstance(m, nn.Conv2d):
if bias:
yield m.bias
else:
yield m.weight
elif isinstance(m, nn.ConvTranspose2d):
# weight is frozen because it is just a bilinear upsampling
if bias:
assert m.bias is None
elif isinstance(m, modules_skipped):
continue
else:
raise ValueError('Unexpected module: %s' % str(m))
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
model_file = args.model_file
root = osp.expanduser('~/data/datasets')
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False,
num_workers=4, pin_memory=True)
n_class = len(val_loader.dataset.class_names)
if osp.basename(model_file).startswith('fcn32s'):
model = torchfcn.models.FCN32s(n_class=21)
elif osp.basename(model_file).startswith('fcn16s'):
model = torchfcn.models.FCN16s(n_class=21)
elif osp.basename(model_file).startswith('fcn8s'):
if osp.basename(model_file).startswith('fcn8s-atonce'):
model = torchfcn.models.FCN8sAtOnce(n_class=21)
else:
model = torchfcn.models.FCN8s(n_class=21)
else:
raise ValueError
if torch.cuda.is_available():
model = model.cuda()
print('==> Loading %s model file: %s' %
(model.__class__.__name__, model_file))
model_data = torch.load(model_file)
try:
model.load_state_dict(model_data)
except Exception:
model.load_state_dict(model_data['model_state_dict'])
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN16s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
fcn32s = torchfcn.models.FCN32s()
state_dict = torch.load(args.pretrained_model)
try:
fcn32s.load_state_dict(state_dict)
except RuntimeError:
fcn32s.load_state_dict(state_dict['model_state_dict'])
model.copy_params_from_fcn32s(fcn32s)
if cuda:
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN8s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
fcn16s = torchfcn.models.FCN16s()
state_dict = torch.load(args.pretrained_model)
try:
fcn16s.load_state_dict(state_dict)
except RuntimeError:
fcn16s.load_state_dict(state_dict['model_state_dict'])
model.copy_params_from_fcn16s(fcn16s)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': args.lr * 2, 'weight_decay': 0},