Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Completed cumulativeSum #103

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/cumulativeSum.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateCumulativeSumParams} from './lib/validate-input.js';

/**
* Computes the cumulative sum of the input tensor along the specified axis.
* @param {Tensor} input
* @param {number} axis
* @param {MLCumulativeSumOptions} options
* @return {Tensor}
*/
export function cumulativeSum(input, axis, {exclusive = false, reverse = false} = {}) {
validateCumulativeSumParams(...arguments);
const inputShape = input.shape;
const outputShape = inputShape;
const output = new Tensor(outputShape);
const elementCountAlongAxis = inputShape[axis];
const totalElements = sizeOfShape(outputShape);
const inputElementStart = reverse ? elementCountAlongAxis - 1 : 0;
const inputElementStep = reverse ? -1 : 1;

const cumulativeSums = new Array(elementCountAlongAxis).fill(0);

for (let outputIndex = 0; outputIndex < totalElements; outputIndex++) {
const location = output.locationFromIndex(outputIndex);

if (location[axis] !== inputElementStart) continue;

for (let i = 0; i < elementCountAlongAxis; ++i) {
const index = inputElementStart + i * inputElementStep;
location[axis] = index;
const inputValue = input.getValueByLocation(location);
cumulativeSums[i] = (i === 0 ? 0 : cumulativeSums[i - 1]) + inputValue;
const outputValue = exclusive ? (i === 0 ? 0 : cumulativeSums[i - 1]) : cumulativeSums[i];
output.setValueByLocation(location, outputValue);
}
}

return output;
}
9 changes: 9 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,15 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) {
}
}

export function validateCumulativeSumParams(input, axis) {
if (axis !== undefined) {
const rank = input.rank;
if (!Number.isInteger(axis) || axis < 0 || axis >= rank) {
throw new Error(`The axis ${axis} must be an unsigned integer in the interval [0, ${rank}).`);
}
}
}

export function validateTriangularParams(input, {diagonal = 0} = {}) {
const inputRank = input.rank;
if (inputRank < 2) {
Expand Down
125 changes: 125 additions & 0 deletions test/cumulativeSum_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
'use strict';

import {cumulativeSum} from '../src/cumulativeSum.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test cumulativeSum', function() {
function testCumulativeSum(input, axis, options={}, expected) {
const tensor = new Tensor(input.shape, input.data);
const outputTensor = cumulativeSum(tensor, axis, options);
console.log('outputTensor', outputTensor);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('test cumulativeSum 1d', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis = 0;
const options = {exclusive: 0, reverse: 0};
const expected = {
shape: [5],
data: [
1, 3, 6, 10, 15,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 1d exclusive', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis = 0;
const options = {exclusive: 1, reverse: 0};
const expected = {
shape: [5],
data: [
0, 1, 3, 6, 10,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 1d reverse', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis = 0;
const options = {exclusive: 0, reverse: 1};
const expected = {
shape: [5],
data: [
15, 14, 12, 9, 5,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 1d reverse exclusive', function() {
const input = {
shape: [5],
data: [
1, 2, 3, 4, 5,
],
};
const axis = 0;
const options = {exclusive: 1, reverse: 1};
const expected = {
shape: [5],
data: [
14, 12, 9, 5, 0,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 2d', function() {
const input = {
shape: [2, 3],
data: [
1, 2, 3, 4, 5, 6,
],
};
const axis = 0;
const options = {exclusive: 0, reverse: 0};
const expected = {
shape: [2, 3],
data: [
1, 2, 3, 5, 7, 9,
],
};
testCumulativeSum(input, axis, options, expected);
});

it('test cumulativeSum 2d axis=1', function() {
const input = {
shape: [2, 3],
data: [
1, 2, 3, 4, 5, 6,
],
};
const axis = 1;
const options = {exclusive: 0, reverse: 0};
const expected = {
shape: [2, 3],
data: [
1, 3, 6, 4, 9, 15,
],
};
testCumulativeSum(input, axis, options, expected);
});
});


Loading