-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Description
Backprop RGB image losses to mesh shape?
I am trying to fit a source mesh (sphere) to a target mesh (cow) per the example from https://pytorch3d.org/tutorials/fit_textured_mesh. In particular, I would like to propagate losses taken over the rendered RGB images of the current and target mesh to the vertex positions of the current mesh being deformed.
I am able to achieve the desired results using a 50-50 weighting of L1 silhouette loss and L1 RGB loss taken over the rendered images, and the training progression across 200 iterations is shown here:
However, using only L1 RGB loss, the mesh doesn't converge to the desired shape as shown here:
I have tried using L2 RGB loss and changing the texture color, but still have this issue. Is it possible to use RGB image supervision without silhouettes and propagate to mesh shape? I have referred to the similar issue here but have confirmed that the rasterization settings do not have this problem.
This can be reproduced by running the following code, and replacing losses = {"rgb": {"weight": 0.5, "values": []}, "silhouette": {"weight": 0.5, "values": []}}
with losses = {"rgb": {"weight": 1.0, "values": []}, "silhouette": {"weight": 0.0, "values": []}}
import os
import cv2
import sys
sys.path.append(os.path.abspath(''))
import torch
import os
import torch
import matplotlib.pyplot as plt
from pytorch3d.utils import ico_sphere
import numpy as np
from pytorch3d.io import load_objs_as_meshes, save_obj
from plot_image_grid import image_grid
from pytorch3d.loss import (
chamfer_distance,
mesh_edge_loss,
mesh_laplacian_smoothing,
mesh_normal_consistency,
)
# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
OpenGLPerspectiveCameras,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
SoftSilhouetteShader,
SoftPhongShader,
TexturesVertex
)
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
DATA_DIR = "./data"
obj_filename = os.path.join(DATA_DIR, "cow_mesh/cow.obj")
mesh = load_objs_as_meshes([obj_filename], device=device)
white_tex = torch.ones_like(mesh.verts_packed())
white_tex = white_tex.unsqueeze(0)
white_tex = TexturesVertex(verts_features=white_tex.to(device))
mesh.textures = white_tex
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
mesh.offset_verts_(-center)
mesh.scale_verts_((1.0 / float(scale)));
num_views = 20
#num_views = 2
elev = torch.linspace(0, 360, num_views)
azim = torch.linspace(-180, 180, num_views)
#azim = torch.linspace(-180, -90, num_views)
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
camera = cameras[0]
raster_settings = RasterizationSettings(
image_size=128,
blur_radius=0.0,
faces_per_pixel=1,
perspective_correct=False
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=camera,
raster_settings=raster_settings
),
shader=SoftPhongShader(
device=device,
cameras=camera,
lights=lights
)
)
sigma = 1e-4
raster_settings_silhouette = RasterizationSettings(
image_size=128,
blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
faces_per_pixel=50,
perspective_correct=False
)
# Silhouette renderer
renderer_silhouette = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=camera,
raster_settings=raster_settings_silhouette
),
shader=SoftSilhouetteShader()
)
meshes = mesh.extend(num_views)
target_images = renderer(meshes, cameras=cameras, lights=lights)
target_rgb = [target_images[i, ..., :3] for i in range(num_views)]
target_cameras = [OpenGLPerspectiveCameras(device=device, R=R[None, i, ...],
T=T[None, i, ...]) for i in range(num_views)]
image_grid(target_images.cpu().numpy(), rows=4, cols=5, rgb=True)
plt.show()
plt.savefig('cows_rgb.png')
plt.clf()
# Show a visualization comparing the rendered predicted mesh to the ground truth
def visualize_pred(curr_image, ref_image, fname):
visualization = np.hstack((curr_image.detach().cpu().numpy(), ref_image.detach().cpu().numpy()))
cv2.imwrite('%s'%(fname), visualization*255)
# Plot losses as a function of optimization iteration
def plot_losses(losses):
fig = plt.figure(figsize=(13, 5))
ax = fig.gca()
for k, l in losses.items():
ax.plot(l['values'], label=k + " loss")
ax.legend(fontsize="16")
ax.set_xlabel("Iteration", fontsize="16")
ax.set_ylabel("Loss", fontsize="16")
ax.set_title("Loss vs iterations", fontsize="16")
# We initialize the source shape to be a sphere of radius 1.
src_mesh = ico_sphere(4, device)
white_tex = torch.ones_like(src_mesh.verts_packed())
white_tex = white_tex.unsqueeze(0)
white_tex = TexturesVertex(verts_features=white_tex.to(device))
src_mesh.textures = white_tex
# Rasterization settings for differentiable rendering, where the blur_radius
# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable
# Renderer for Image-based 3D Reasoning', ICCV 2019
sigma = 1e-4
raster_settings_soft = RasterizationSettings(
image_size=128,
blur_radius=np.log(1. / 1e-4 - 1.)*sigma,
faces_per_pixel=50,
perspective_correct=False
)
# Depth rasterizer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=camera,
raster_settings=raster_settings_soft
),
shader=SoftPhongShader(
device=device,
cameras=camera,
lights=lights
)
)
num_views_per_iteration = 5 # use 5 which works best
Niter = 350
plot_period = 50
losses = {"rgb": {"weight": 0.5, "values": []},
"silhouette": {"weight": 0.5, "values": []}}
verts_shape = src_mesh.verts_packed().shape
deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)
# The optimizer
optimizer = torch.optim.Adam([deform_verts], lr=0.01)
loop = range(Niter)
out_dir = os.path.join('out')
if not os.path.exists(out_dir):
os.mkdir(out_dir)
for i in loop:
# Initialize optimizer
optimizer.zero_grad()
new_src_mesh = src_mesh.offset_verts(deform_verts)
# Losses to smooth /regularize the mesh shape
loss = {k: torch.tensor(0.0, device=device) for k in losses}
for j in np.random.permutation(num_views).tolist()[:num_views_per_iteration]:
curr_image = renderer(new_src_mesh, cameras=target_cameras[j], lights=lights)[..., :3]
curr_sil = renderer_silhouette(new_src_mesh, cameras=target_cameras[j], lights=lights)[..., 3]
ref_image = renderer(meshes[j], cameras=target_cameras[j], lights=lights)[..., :3]
ref_sil = renderer_silhouette(meshes[j], cameras=target_cameras[j], lights=lights)[..., 3]
loss_rgb = (torch.abs(curr_image - ref_image)).mean() # L1 rgb loss
loss_silhouette = (torch.abs(curr_sil - ref_sil)).mean() # L1 silhouette loss
loss["silhouette"] += loss_silhouette / num_views_per_iteration
loss["rgb"] += loss_rgb / num_views_per_iteration
# Weighted sum of the losses
sum_loss = torch.tensor(0.0, device=device)
for k, l in loss.items():
sum_loss += l * losses[k]["weight"]
losses[k]["values"].append(float(l.detach().cpu()))
# Print the losses
#loop.set_description("total_loss = %.6f" % sum_loss)
print("iter: %d/%d, total_loss = %.6f" % (i,Niter,sum_loss), end='\r')
sys.stdout.flush()
# Plot mesh
if i % plot_period == 0:
#visualize_prediction(new_src_mesh, fname="%s/%05d.png" % (out_dir, i), silhouette=True, target_image=target_silhouette[1])
visualize_pred(curr_image[0], ref_image[0], fname="%s/%05d_depth.png"%(out_dir, i))
# Optimization step
sum_loss.backward()
optimizer.step()
#visualize_prediction(new_src_mesh, silhouette=True, target_image=target_silhouette[1], fname='preds.png')
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)
final_verts = final_verts * scale + center
final_obj = os.path.join('./', 'cow_sil_model.obj')
save_obj(final_obj, final_verts, final_faces)
plot_losses(losses)
The cow mesh data is available by running:
!mkdir -p data/cow_mesh
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png
Thanks in advance for your help!