Skip to content

Optimize tensor.slice() #1381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,46 @@ export class Tensor {
// Precompute strides
const stride = this.stride();

for (let i = 0; i < newBufferSize; ++i) {
let originalIndex = 0;
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
const size = newDims[j];
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
num = Math.floor(num / size);
// Detect if the slice is contiguous
let isContiguous = true;
for (let i = 1; i < newDims.length; ++i) {
if (newOffsets[i][0] !== 0 || newOffsets[i][1] !== this.dims[i]) {
isContiguous = false;
break;
}
data[i] = this_data[originalIndex];
}

if (isContiguous) {
// Perform bulk copy for contiguous slices to improve performance
const start = newOffsets[0][0] * stride[0];
const end = newOffsets[0][1] * stride[0];

if (ArrayBuffer.isView(this_data)) {
// If this.data is a TypedArray, use subarray
// @ts-ignore
data.set(this_data.subarray(start, end));
} else if (Array.isArray(this_data)) {
// If this.data is a plain array, use slice
const slicedData = this_data.slice(start, end);
for (let i = 0; i < slicedData.length; ++i) {
data[i] = slicedData[i];
}
} else {
throw new Error("Unsupported data type for slicing");
}
} else {
// Fallback to manual copying for non-contiguous slices
for (let i = 0; i < newBufferSize; ++i) {
let originalIndex = 0;
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
const size = newDims[j];
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
num = Math.floor(num / size);
}
data[i] = this_data[originalIndex];
}
}

return new Tensor(this.type, data, newTensorDims);
}

Expand Down
62 changes: 59 additions & 3 deletions tests/utils/tensor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ describe("Tensor operations", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(1);
const target = new Tensor("float32", [3, 4], [2]);

compare(t2, target);
});

it("should return a range of rows", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([1, 3]);
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);

compare(t2, target);
});

Expand All @@ -78,9 +76,67 @@ describe("Tensor operations", () => {
[4, 7],
);
const t2 = t1.slice([1, -1], [1, -1]);

const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);
compare(t2, target);
});

it("should return the whole tensor when all indices are null/unset", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice();
compare(t2, t1);
});

it("should return the whole dimension when index is null", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(null);
compare(t2, t1);
});

it("should slice from index to end when [start, null] is used", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([1, null]);
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);
compare(t2, target);
});

it("should slice from beginning to index when [null, end] is used", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([null, 2]);
const target = new Tensor("float32", [1, 2, 3, 4], [2, 2]);
compare(t2, target);
});

it("should handle [null, null] as full slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([null, null]);
compare(t2, t1);
});

it("should select a single element when a number is used in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(2, 1);
const target = new Tensor("float32", [6], []);
compare(t2, target);
});

it("should select a single row when a number is used in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(0);
const target = new Tensor("float32", [1, 2], [2]);
compare(t2, target);
});

it("should select a single column when a number is used in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(null, 1);
const target = new Tensor("float32", [2, 4, 6], [3]);
compare(t2, target);
});

it("should handle negative indices in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(-1);
const target = new Tensor("float32", [5, 6], [2]);
compare(t2, target);
});
});
Expand Down