Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def main():
args = parse_args()
use_cuda = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
random.seed(1337)
torch.manual_seed(1337)
if not use_cuda:
print("warning, the experiments would take ages to run on cpu")
hyperparams = vars(args)
active_set, test_set = get_datasets(hyperparams['initial_pool'])
heuristic = get_heuristic(hyperparams['heuristic'],
hyperparams['shuffle_prop'])
criterion = CrossEntropyLoss()
model = vgg16(pretrained=False, num_classes=10)
weights = load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth')
weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
model.load_state_dict(weights, strict=False)
# change dropout layer to MCDropout
model = patch_module(model)
if use_cuda:
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=hyperparams["lr"], momentum=0.9)
# Wraps the model into a usable API.
model = ModelWrapper(model, criterion)