Skip to main content

Segmentation

Installation

git clone https://github.com/NathanUA/U-2-Net.git

cd U-2-Net/saved_models
mkdir u2net
mkdir u2netp

Put models into these folders:

Dependencies:

pip3 install scikit-image
pip3 install torch
pip3 install torchvision

Usage

net = U2NET(3,1)
net.load_state_dict("saved_models/u2net/u2net.pth")
net.cuda()
net.eval()

inputs = torch.from_numpy(img)
inputs = inputs.type(torch.FloatTensor)
inputs = Variable(inputs.cuda())

d1, _, _, _, _, _, _ = net(inputs)

predict = d1[:,0,:,:]
predict = normalize(predict)
predict = predict.squeeze()
predict = predict.cpu().data.numpy()

Source