How to use the ptflops.get_model_complexity_info function in ptflops

To help you get started, we’ve selected a few ptflops examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github xieydd / Pytorch-Single-Path-One-Shot-NAS / utils / count_flops.py View on Github external
def test_resnet50_ptflops():
   net = resnet50()
   flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True)
   print('Flops:  ' + flops)
   print('Params: ' + params)
github tristandb / EfficientDet-PyTorch / train.py View on Github external
elif parser.depth == 152:
        model = retinanet.resnet152(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.efficientdet:
        model = efficientdet.efficientdet(num_classes=dataset_train.num_classes(), pretrained=True, phi=parser.scaling_compound)
    else:
        raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152, or specify ')        

    use_gpu = True

    if use_gpu:
        model = model.cuda()
    
    model = torch.nn.DataParallel(model).cuda()
    
    if parser.print_model_complexity:
        flops, params = get_model_complexity_info(model, (3, img_size, img_size), as_strings=True, print_per_layer_stat=True)
        print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
        print('{:<30}  {:<8}'.format('Number of parameters: ', params))

    model.training = True

    optimizer = optim.SGD(model.parameters(), lr=4e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    loss_hist = collections.deque(maxlen=500)
    
    model.train()
    model.module.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))
github opencv / openvino_training_extensions / pytorch_toolkit / object_detection / tools / count_flops.py View on Github external
def main():
    args = parse_args()
    with torch.no_grad():
        model = init_detector(args.config)
        model.eval()
        input_res = model.cfg.data['test']['img_scale']
        flops, params = get_model_complexity_info(model, input_res,
                                                  as_strings=True,
                                                  print_per_layer_stat=True,
                                                  input_constructor=inp_fun)
        print('Computational complexity: ' + flops)
        print('Number of parameters: ', params)
github NoelShin / Deep-Learning-Bootcamp-with-PyTorch / classification / DenseNet / models.py View on Github external
return self.layer(x)


class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(x.shape[0], *self.shape)


if __name__ == '__main__':
    from ptflops import get_model_complexity_info
    densenet_bc = DenseNetBC(depth=100, growth_rate=12, n_classes=100, efficient=False)
    flops, params = get_model_complexity_info(densenet_bc, (3, 32, 32), as_strings=False, print_per_layer_stat=False)
    print("flops: {}, params: {}".format(flops, params))
github sovrasov / flops-counter.pytorch / sample.py View on Github external
parser.add_argument('--model', choices=list(pt_models.keys()),
                        type=str, default='resnet18')
    parser.add_argument('--result', type=str, default=None)
    args = parser.parse_args()

    if args.result is None:
        ost = sys.stdout
    else:
        ost = open(args.result, 'w')

    net = pt_models[args.model]()

    if torch.cuda.is_available():
        net.cuda(device=args.device)

    flops, params = get_model_complexity_info(net, (3, 224, 224),
                                              as_strings=True,
                                              print_per_layer_stat=True,
                                              ost=ost)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

ptflops

Flops counter for neural networks in pytorch framework

MIT
Latest version published 2 months ago

Package Health Score

76 / 100
Full package analysis

Popular ptflops functions