Skip to content

Commit

Permalink
replace bvh-distance-queries with PyTorch3D & Kaolin
Browse files Browse the repository at this point in the history
  • Loading branch information
YuliangXiu committed Mar 7, 2022
1 parent 928340e commit 2f926c9
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 165 deletions.
21 changes: 10 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,7 @@
<p align="center">

<h1 align="center">ICON: Implicit Clothed humans Obtained from Normals</h1>
<div align="center">
<a href="https://paperswithcode.com/sota/3d-human-reconstruction-on-cape?p=icon-implicit-clothed-humans-obtained-from"><img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/icon-implicit-clothed-humans-obtained-from/3d-human-reconstruction-on-cape"></a><br><br>
</div>

<a href="">
<img src="./assets/teaser.jpeg" alt="Logo" width="100%">
</a>

<p align="center">
arXiv, December 2021.
<br />
<a href="https://ps.is.tuebingen.mpg.de/person/yxiu"><strong>Yuliang Xiu</strong></a>
·
<a href="https://ps.is.tuebingen.mpg.de/person/jyang"><strong>Jinlong Yang</strong></a>
Expand All @@ -22,10 +12,19 @@
·
<a href="https://ps.is.tuebingen.mpg.de/person/black"><strong>Michael J. Black</strong></a>
</p>
<h2 align="center">CVPR 2022</h2>
<div align="center">
</div>

<a href="">
<img src="./assets/teaser.jpeg" alt="Logo" width="100%">
</a>

<p align="center">
<br>
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a>
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a><br><br>
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a>
<a href="https://paperswithcode.com/sota/3d-human-reconstruction-on-cape?p=icon-implicit-clothed-humans-obtained-from"><img src="https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/icon-implicit-clothed-humans-obtained-from/3d-human-reconstruction-on-cape"></a><br></br>
<a href='https://arxiv.org/abs/2112.09127'>
<img src='https://img.shields.io/badge/Paper-PDF-green?style=flat&logo=arXiv&logoColor=green' alt='Paper PDF'>
</a>
Expand Down
3 changes: 2 additions & 1 deletion apps/ICON.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,8 @@ def test_single(self, batch):
proj_matrix=None)

verts_pr, faces_pr = self.reconEngine.export_mesh(sdf)
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)
if self.clean_mesh_flag:
verts_pr, faces_pr = clean_mesh(verts_pr, faces_pr)

# convert from GT to SDF
verts_pr -= (self.resolutions[-1] - 1) / 2.0
Expand Down
2 changes: 1 addition & 1 deletion apps/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def tensor2variable(tensor, device):
cfg.merge_from_file('../lib/pymaf/configs/pymaf_config.yaml')

cfg_show_list = [
'test_gpus', [args.gpu_device], 'mcube_res', 512, 'clean_mesh', False
'test_gpus', [args.gpu_device], 'mcube_res', 256, 'clean_mesh', True
]

cfg.merge_from_list(cfg_show_list)
Expand Down
9 changes: 1 addition & 8 deletions colab.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# bvh-distance-queries only support cuda 11.0
# yet cuda 11.1 is the default version for colab
cd /etc/alternatives/
unlink cuda
ln -s /usr/local/cuda-11.0 cuda
cd /content

# conda installation
Expand All @@ -19,7 +14,5 @@ conda env create -f environment.yaml
conda init bash
source ~/.bashrc
source activate icon

# install packages for colab
pip install ipykernel ipywidgets --user --no-warn-script-location
pip install -r requirements.txt --use-deprecated=legacy-resolver

7 changes: 2 additions & 5 deletions docs/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,15 @@ conda config --env --set always_yes true
conda update -n base -c defaults conda -y

# Note:
# bvh-distance-queries only support cuda <= 11.0
# Thus, you need to setup suitable "cuda toolkit" firstly
# https://developer.nvidia.com/cuda-11.0-download-archive
# For google colab, please refer to ICON/colab.sh
ln -s {directory of cuda-11.0} /usr/local/cuda

# create conda env and install required libs (~20min)

cd ICON
conda env create -f environment.yaml
conda init bash
source ~/.bashrc
source activate icon
pip install -r requirements.txt --use-deprecated=legacy-resolver
```

For data generation and training
Expand Down
4 changes: 1 addition & 3 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,4 @@ dependencies:
- iopath
- nvidiacub
- pyembree
- pip
- pip:
- -r requirements.txt
- pip
1 change: 1 addition & 0 deletions fetch_hps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ rm -rf data && rm -f data.tar.gz

# PyMAF pre-trained model
source activate icon
pip install gdown --upgrade
gdown https://drive.google.com/drive/u/1/folders/1CkF79XRaZzdRlj6eJUt4W0nbTORv2t7O -O pretrained_model --folder
cd ..
echo "PyMAF done!"
Expand Down
9 changes: 2 additions & 7 deletions lib/common/render_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,14 @@ def face_vertices(vertices, faces):
:param faces: [batch size, number of faces, 3]
:return: [batch size, number of faces, 3, 3]
"""
assert (vertices.ndimension() == 3)
assert (faces.ndimension() == 3)
assert (vertices.shape[0] == faces.shape[0])
assert (vertices.shape[2] == 3)
assert (faces.shape[2] == 3)

bs, nv = vertices.shape[:2]
bs, nf = faces.shape[:2]
device = vertices.device
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) *
nv)[:, None, None]
vertices = vertices.reshape((bs * nv, 3))
# pytorch only supports long and byte tensors for indexing
vertices = vertices.reshape((bs * nv, vertices.shape[-1]))

return vertices[faces.long()]


Expand Down
17 changes: 9 additions & 8 deletions lib/common/seg3d_lossless.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import torch.nn.functional as F
import mcubes
from torchmcubes import marching_cubes
from kaolin.ops.conversions import voxelgrids_to_trianglemeshes
import logging

logging.getLogger("lightning").setLevel(logging.ERROR)
Expand Down Expand Up @@ -585,16 +585,17 @@ def export_mesh(self, occupancys):
final = occupancys[:-1, :-1, :-1].contiguous()

if final.shape[0] > 256:
# skimage marching cubes (0.2s for 256^3)
# occu_arr = final.detach().cpu().numpy() # non-smooth surface
occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
# for voxelgrid larger than 256^3, the required GPU memory will be > 9GB
# thus we use CPU marching_cube to avoid "CUDA out of memory"
occu_arr = final.detach().cpu().numpy() # non-smooth surface
# occu_arr = mcubes.smooth(final.detach().cpu().numpy()) # smooth surface
vertices, triangles = mcubes.marching_cubes(occu_arr, self.balance_value)
verts = torch.as_tensor(vertices[:,[2,1,0]])
faces = torch.as_tensor(triangles.astype(np.long), dtype=torch.long)[:,[0,2,1]]
else:
# torchmcubes (0.01s for 256^3, but CUDA memory explosion for 512^3)
vertices, triangles = marching_cubes(final, self.balance_value)
verts = torch.as_tensor(vertices)
faces = torch.as_tensor(triangles, dtype=torch.long)[:,[0,2,1]]
torch.cuda.empty_cache()
vertices, triangles = voxelgrids_to_trianglemeshes(final.unsqueeze(0))
verts = vertices[0][:,[2,1,0]].cpu()
faces = triangles[0][:,[0,2,1]].cpu()

return verts, faces
152 changes: 68 additions & 84 deletions lib/dataset/mesh_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from termcolor import colored
import os.path as osp
from scipy.spatial import cKDTree
import bvh_distance_queries
from kaolin.ops.mesh import check_sign
from kaolin.metrics.trianglemesh import point_to_mesh_distance

from pytorch3d.loss import (
mesh_edge_loss,
Expand All @@ -37,7 +38,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))
from pytorch3d.renderer.mesh import rasterize_meshes
from lib.common.render_utils import Pytorch3dRasterizer, face_vertices, batch_contains
from lib.common.render_utils import Pytorch3dRasterizer, face_vertices
from lib.pymaf.utils.imutils import uncrop
from pytorch3d.structures import Meshes

Expand Down Expand Up @@ -225,52 +226,76 @@ def get_visibility(xy, z, faces):
return vis_mask


def cal_sdf(verts, faces, points, edge=1.0):

n2t = lambda var: torch.from_numpy(var).float() if isinstance(
var, (np.ndarray, np.generic)) else var
[verts, faces, points] = [n2t(item) for item in [verts, faces, points]]

mesh_tree = cKDTree(verts)
pts_dist, pts_ind = mesh_tree.query(points, p=2)

pts_dist = pts_dist / torch.sqrt(torch.tensor(3 * (edge**2))) # p=2
# pts_dist = pts_dist / torch.tensor(3*edge) # p=1
mesh = trimesh.Trimesh(verts, faces, process=False)

pts_occ = mesh.contains(points)
pts_norm = 0.5 * (torch.as_tensor(mesh.vertex_normals[pts_ind] * np.array(
[-1.0, 1.0, -1.0])).float() + 1.0)
pts_sdf = (pts_dist * ((pts_occ - 0.5) / 0.5))[..., None].float()

return pts_sdf, pts_norm, pts_ind


def cal_sdf_batch(verts, faces, cmaps, points, edge=1.0):
def barycentric_coordinates_of_projection(points, vertices):
''' https://github.com/MPI-IS/mesh/blob/master/mesh/geometry/barycentric_coordinates_of_projection.py
'''
"""Given a point, gives projected coords of that point to a triangle
in barycentric coordinates.
See
**Heidrich**, Computing the Barycentric Coordinates of a Projected Point, JGT 05
at http://www.cs.ubc.ca/~heidrich/Papers/JGT.05.pdf
:param p: point to project. [B, 3]
:param v0: first vertex of triangles. [B, 3]
:returns: barycentric coordinates of ``p``'s projection in triangle defined by ``q``, ``u``, ``v``
vectorized so ``p``, ``q``, ``u``, ``v`` can all be ``3xN``
"""
#(p, q, u, v)
v0, v1, v2 = vertices[:,0], vertices[:,0], vertices[:,0]
p = points

q = v0
u = v1 - v0
v = v2 - v0
n = torch.cross(u, v)
s = torch.sum(n * n, dim=1)
# If the triangle edges are collinear, cross-product is zero,
# which makes "s" 0, which gives us divide by zero. So we
# make the arbitrary choice to set s to epsv (=numpy.spacing(1)),
# the closest thing to zero
s[s == 0] = 1e-6
oneOver4ASquared = 1.0 / s
w = p - q
b2 = torch.sum(torch.cross(u, w) * n, dim=1) * oneOver4ASquared
b1 = torch.sum(torch.cross(w, v) * n, dim=1) * oneOver4ASquared
weights = torch.stack((1 - b1 - b2, b1, b2), dim=-1)
# check barycenric weights
# p_n = v0*weights[:,0:1] + v1*weights[:,1:2] + v2*weights[:,2:3]
return weights


def cal_sdf_batch(verts, faces, cmaps, vis, points):

# verts [B, N_vert, 3]
# faces [B, N_face, 3]
# triangles [B, N_face, 3, 3]
# points [B, N_point, 3]
# cmaps [B, N_vert, 3]

func = bvh_distance_queries.PointToMeshResidual()
torch.cuda.synchronize()
norms = Meshes(verts, faces).verts_normals_padded()
normals = Meshes(verts, faces).verts_normals_padded()

triangles = face_vertices(verts, faces)
residues, normals, pts_cmap, pts_ind = func(
triangles.contiguous(), points.contiguous(),
face_vertices(norms, faces).contiguous(),
face_vertices(cmaps, faces).contiguous(), faces)
torch.cuda.synchronize()

pts_dist = torch.norm(residues, p=2, dim=2) / torch.sqrt(
torch.tensor(3 * (edge**2)))
pts_norm = normals * torch.tensor([-1.0, 1.0, -1.0]).type_as(normals)

# pts_sign = ((residues * normals).sum(dim=2) < 0) * 2.0 - 1.0
# if pts_signs.shape[1] != points.shape[1]:
pts_signs = (batch_contains(verts, faces, points)).type_as(verts)
# if pts_signs.shape[1] != points.shape[1]:
# pts_signs = ((winding_numbers(points, triangles).le(0.99)*2.0)-1.0).type_as(verts)
normals = face_vertices(normals, faces)
cmaps = face_vertices(cmaps, faces)
vis = face_vertices(vis, faces)

residues, pts_ind, _ = point_to_mesh_distance(points, triangles)
closest_triangles = torch.gather(triangles, 1, pts_ind[:,:,None,None].expand(-1,-1,3,3)).view(-1,3,3)
closest_normals = torch.gather(normals, 1, pts_ind[:,:,None,None].expand(-1,-1,3,3)).view(-1,3,3)
closest_cmaps = torch.gather(cmaps, 1, pts_ind[:,:,None,None].expand(-1,-1,3,3)).view(-1,3,3)
closest_vis = torch.gather(vis, 1, pts_ind[:,:,None,None].expand(-1,-1,3,1)).view(-1,3,1)

bary_weights = barycentric_coordinates_of_projection(points[0], closest_triangles)

pts_cmap = (closest_cmaps*bary_weights[:,:,None]).sum(1).unsqueeze(0)
pts_vis = (closest_vis*bary_weights[:,:,None]).sum(1).unsqueeze(0).ge(1e-1)
pts_norm = (closest_normals*bary_weights[:,:,None]).sum(1).unsqueeze(0) * torch.tensor([-1.0, 1.0, -1.0]).type_as(normals)
pts_dist = torch.sqrt(residues) / torch.sqrt(torch.tensor(3))

pts_signs = 2.0 * (check_sign(verts, faces[0], points).float() - 0.5)
pts_sdf = (pts_dist * pts_signs).unsqueeze(-1)

return pts_sdf, pts_norm, pts_cmap, pts_ind
return pts_sdf, pts_norm, pts_cmap, pts_vis


def orthogonal(points, calibrations, transforms=None):
Expand Down Expand Up @@ -702,47 +727,6 @@ def mesh_move(mesh_lst, step, scale=1.0):
return results


def scan2smpl(verts, dcs, smpl_cmap):
tree = cKDTree(dcs, leafsize=1)
dist, ind = tree.query(smpl_cmap, k=1)
return verts[ind, :], dcs[ind, :]


def fusion(ref_verts, part_verts, vis_lst, dc_lst):
"""fusion several partial verts with one reference verts
visibility of verts are known
ref_verts: [N, 3]
parts_verts: list([A,3], [B,3], ...)
vis_lst: list([N,1], [A,1], [B,1], ...)
dc_lst: list([N,1], [A,1], [B,1], ...)
Returns:
final_verts: [N, 3]
final_dcs: [N, 3]
"""

vis_lst = [vis.flatten().astype(dtype=np.bool) for vis in vis_lst]

deform_verts = ref_verts[np.invert(vis_lst[0]), :]
vis_part_verts = np.concatenate(
([verts[vis_lst[idx + 1], :] for idx, verts in enumerate(part_verts)]),
axis=0)
part_tree = cKDTree(vis_part_verts, leafsize=1)
_, ind = part_tree.query(deform_verts, k=10)
new_deform_verts = vis_part_verts[ind, :].mean(axis=1)

final_verts = ref_verts
final_verts[np.invert(vis_lst[0]), :] = new_deform_verts

final_dcs = dc_lst[0]
new_deform_dcs = np.concatenate(
[dc[vis_lst[idx + 1], :] for idx, dc in enumerate(dc_lst[1:])], axis=0)
final_dcs[np.invert(vis_lst[0]), :] = new_deform_dcs[ind, :].mean(axis=1)

return final_verts, final_dcs


class SMPLX():
def __init__(self):

Expand Down
20 changes: 10 additions & 10 deletions lib/net/HGPIFuNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,29 +279,29 @@ def query(self, features, points, calibs, transforms=None, regressor=None):

if self.prior_type == 'icon':

# smpl_verts [B, 10475, 3]
# smpl_faces [B, 20908, 3]
# smpl_verts [B, N_vert, 3]
# smpl_faces [B, N_face, 3]
# points [B, 3, N]
smpl_sdf, smpl_norm, smpl_cmap, smpl_ind = cal_sdf_batch(

smpl_sdf, smpl_norm, smpl_cmap, smpl_vis = cal_sdf_batch(
self.smpl_feat_dict['smpl_verts'],
self.smpl_feat_dict['smpl_faces'],
self.smpl_feat_dict['smpl_cmap'],
xyz.permute(0, 2, 1).contiguous(),
edge=1.0)

self.smpl_feat_dict['smpl_vis'],
xyz.permute(0, 2, 1).contiguous())
# smpl_sdf [B, N, 1]
# smpl_norm [B, N, 3]
# smpl_ind [B, N]
smpl_vis = torch.gather(self.smpl_feat_dict['smpl_vis'], 1,
smpl_ind.unsqueeze(2))
# smpl_cmap [B, N, 3]
# smpl_vis [B, N, 1]

# set ourlier point features as uniform values
smpl_outlier = torch.abs(smpl_sdf).ge(self.sdf_clip)
smpl_sdf[smpl_outlier] = torch.sign(smpl_sdf[smpl_outlier])

feat_lst = [smpl_sdf]
if 'cmap' in self.smpl_feats:
# smpl_cmap[smpl_outlier.repeat(1,1,3)] = smpl_sdf[smpl_outlier].repeat(1,1,3)
smpl_cmap[smpl_outlier.repeat(1,1,3)] = smpl_sdf[smpl_outlier].repeat(1,1,3)
feat_lst.append(smpl_cmap)
if 'norm' in self.smpl_feats:
feat_lst.append(smpl_norm)
Expand Down
Loading

0 comments on commit 2f926c9

Please sign in to comment.