8000 Load model in pytorch c++ · Issue #406 · qubvel-org/segmentation_models.pytorch · GitHub
[go: up one dir, main page]

Skip to content
Load model in pytorch c++ #406
@fselka

Description

@fselka

Hi! I'm trying to load the models in C++ too, but I'm having the following error:

terminate called after throwing an instance of 'torch::jit::ErrorReport' what():
aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor):
Expected at most 12 arguments but found 13 positional arguments.

_ Serialized File "code/torch/torch/nn/modules/conv.py", line 8
def forward(self: torch.torch.nn.modules.conv.Conv2d,
input: Tensor) -> Tensor:
input0 = torch.convolution(input, self.weight, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1, False, False, True, True)
~~~~~~~~~~~~~~~~~~ <--- HERE
return input0

I used the model trained in cars example (jupyternotebook)

ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['car']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

Then I saved it using torch.jit.save as following :

input_img =  torch.rand(1, 3, 256, 256).to(DEVICE)
best_model = torch.load('./best_model.pth')
best_model.eval()
traced_script_module = torch.jit.trace(best_model, input_img)  
traced_script_module.save('./best_model.pt')

When I load the model in c++ using torch::jit::load(model_path.string()); it result in an error.
I'm using Pytorch 1.71

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0