Skip to content

Commit 5043d15

Browse files
ynonaolgafacebook-github-bot
authored andcommitted
avoid CPU/GPU sync in sample_farthest_points
Summary: Optimizing sample_farthest_poinst by reducing CPU/GPU sync: 1. replacing iterative randint for starting indexes for 1 function call, if length is constant 2. Avoid sync in fetching maxumum of sample points, if we sample the same amount 3. Initializing 1 tensor for samples and indixes compare https://fburl.com/mlhub/7wk0xi98 Before {F1980383703} after {F1980383707} Histogram match pretty closely {F1980464338} Reviewed By: bottler Differential Revision: D78731869 fbshipit-source-id: 060528ae7a1e0fbbd005d129c151eaf9405841de
1 parent e3d3a67 commit 5043d15

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

pytorch3d/csrc/sample_farthest_points/sample_farthest_points.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ at::Tensor FarthestPointSamplingCuda(
107107
const at::Tensor& points, // (N, P, 3)
108108
const at::Tensor& lengths, // (N,)
109109
const at::Tensor& K, // (N,)
110-
const at::Tensor& start_idxs) {
110+
const at::Tensor& start_idxs,
111+
const int64_t max_K_known = -1) {
111112
// Check inputs are on the same device
112113
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
113114
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
@@ -129,7 +130,12 @@ at::Tensor FarthestPointSamplingCuda(
129130

130131
const int64_t N = points.size(0);
131132
const int64_t P = points.size(1);
132-
const int64_t max_K = at::max(K).item<int64_t>();
133+
int64_t max_K;
134+
if (max_K_known > 0) {
135+
max_K = max_K_known;
136+
} else {
137+
max_K = at::max(K).item<int64_t>();
138+
}
133139

134140
// Initialize the output tensor with the sampled indices
135141
auto idxs = at::full({N, max_K}, -1, lengths.options());

pytorch3d/csrc/sample_farthest_points/sample_farthest_points.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda(
4343
const at::Tensor& points,
4444
const at::Tensor& lengths,
4545
const at::Tensor& K,
46-
const at::Tensor& start_idxs);
46+
const at::Tensor& start_idxs,
47+
const int64_t max_K_known = -1);
4748

4849
at::Tensor FarthestPointSamplingCpu(
4950
const at::Tensor& points,
@@ -56,14 +57,16 @@ at::Tensor FarthestPointSampling(
5657
const at::Tensor& points,
5758
const at::Tensor& lengths,
5859
const at::Tensor& K,
59-
const at::Tensor& start_idxs) {
60+
const at::Tensor& start_idxs,
61+
const int64_t max_K_known = -1) {
6062
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
6163
#ifdef WITH_CUDA
6264
CHECK_CUDA(points);
6365
CHECK_CUDA(lengths);
6466
CHECK_CUDA(K);
6567
CHECK_CUDA(start_idxs);
66-
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
68+
return FarthestPointSamplingCuda(
69+
points, lengths, K, start_idxs, max_K_known);
6770
#else
6871
AT_ERROR("Not compiled with GPU support.");
6972
#endif

pytorch3d/ops/sample_farthest_points.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def sample_farthest_points(
5555
N, P, D = points.shape
5656
device = points.device
5757

58+
constant_length = lengths is None
5859
# Validate inputs
5960
if lengths is None:
6061
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
@@ -65,7 +66,9 @@ def sample_farthest_points(
6566
raise ValueError("A value in lengths was too large.")
6667

6768
# TODO: support providing K as a ratio of the total number of points instead of as an int
69+
max_K = -1
6870
if isinstance(K, int):
71+
max_K = K
6972
K = torch.full((N,), K, dtype=torch.int64, device=device)
7073
elif isinstance(K, list):
7174
K = torch.tensor(K, dtype=torch.int64, device=device)
@@ -82,15 +85,17 @@ def sample_farthest_points(
8285
K = K.to(torch.int64)
8386

8487
# Generate the starting indices for sampling
85-
start_idxs = torch.zeros_like(lengths)
8688
if random_start_point:
87-
for n in range(N):
88-
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
89-
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
89+
if constant_length:
90+
start_idxs = torch.randint(high=P, size=(N,), device=device)
91+
else:
92+
start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64)
93+
else:
94+
start_idxs = torch.zeros_like(lengths)
9095

9196
with torch.no_grad():
9297
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
93-
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
98+
idx = _C.sample_farthest_points(points, lengths, K, start_idxs, max_K)
9499
sampled_points = masked_gather(points, idx)
95100

96101
return sampled_points, idx

0 commit comments

Comments
 (0)