-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Completed cumulativeSum #103
base: main
Are you sure you want to change the base?
Changes from 2 commits
353be4c
6d5a48d
f0a4273
94c0a24
3f64303
493f182
aab0675
571f3de
552e2fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,46 @@ | ||||||
'use strict'; | ||||||
|
||||||
import {Tensor, sizeOfShape} from './lib/tensor.js'; | ||||||
import {validateCumulativeSumParams} from './lib/validate-input.js'; | ||||||
|
||||||
/** | ||||||
* Computes the cumulative sum of the input tensor along the specified axis. | ||||||
* @param {Tensor} input | ||||||
* @param {number} axis | ||||||
* @param {MLCumulativeSumOptions} options | ||||||
* @return {Tensor} | ||||||
*/ | ||||||
export function cumulativeSum(input, axis, {exclusive = 0, reverse = 0} = {}) { | ||||||
validateCumulativeSumParams(...arguments); | ||||||
const inputShape = input.shape; | ||||||
const outputShape = [...inputShape]; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Do we need the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
const output = new Tensor(outputShape); | ||||||
const numElementsAlongAxis = inputShape[axis]; | ||||||
|
||||||
const totalElements = sizeOfShape(outputShape); | ||||||
|
||||||
for (let outputIndex = 0; outputIndex < totalElements; outputIndex++) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears this nested
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still seeing the doubly nested loop such that the inner loop redundantly overwrites the previous loop's output (which will be identical each time for a given summation sliver). Notice the |
||||||
const loc = output.locationFromIndex(outputIndex); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. request There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
let cumulativeSumValue = 0; | ||||||
|
||||||
const start = reverse ? numElementsAlongAxis - 1 : 0; | ||||||
const step = reverse ? -1 : 1; | ||||||
const end = reverse ? -1 : numElementsAlongAxis; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These three values never change and can be computed once outside the loop. |
||||||
|
||||||
for (let i = start; reverse ? i > end : i < end; i += step) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are guaranteed this loop will end (given step and end are selected accordingly above, which they are), this...
...could just be...
Alternately (and probably clearer) we could make
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
const inputLoc = [...loc]; | ||||||
inputLoc[axis] = exclusive ? (reverse ? i + 1 : i - 1) : i; | ||||||
|
||||||
if (!exclusive || (exclusive && inputLoc[axis] >= 0 && | ||||||
inputLoc[axis] < numElementsAlongAxis)) { | ||||||
cumulativeSumValue += input.getValueByLocation(inputLoc); | ||||||
} | ||||||
|
||||||
const outputLoc = [...loc]; | ||||||
outputLoc[axis] = i; | ||||||
output.setValueByLocation(outputLoc, cumulativeSumValue); | ||||||
} | ||||||
} | ||||||
|
||||||
return output; | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -611,6 +611,18 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) { | |
} | ||
} | ||
|
||
export function validateCumulativeSumParams(input, axis) { | ||
if (axis !== undefined) { | ||
const rank = input.rank; | ||
if (!Number.isInteger(axis) || axis < -rank || axis >= rank) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The axis is of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we generally want WebNN to be more explicit, resolving these kinds of user-facing API conveniences (like negative numbers and output shape rounding) into lower-level concrete values, like the actual axis and specific output shape. |
||
throw new Error(`The axis ${axis} should be in the range [-rank(input), rank(input)-1].`); | ||
} | ||
if (axis >= rank) { | ||
throw new Error(`The axis ${axis} should be in the interval [0, ${rank}).`); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, could just combine these two.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
} | ||
|
||
export function validateTriangularParams(input, {diagonal = 0} = {}) { | ||
const inputRank = input.rank; | ||
if (inputRank < 2) { | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,143 @@ | ||||||
'use strict'; | ||||||
|
||||||
import {cumulativeSum} from '../src/cumulativeSum.js'; | ||||||
import {Tensor} from '../src/lib/tensor.js'; | ||||||
import * as utils from './utils.js'; | ||||||
|
||||||
describe('test cumulativeSum', function() { | ||||||
function testCumulativeSum(input, axis, options={}, expected) { | ||||||
const tensor = new Tensor(input.shape, input.data); | ||||||
const outputTensor = cumulativeSum(tensor, axis, options); | ||||||
console.log('outputTensor', outputTensor); | ||||||
utils.checkShape(outputTensor, expected.shape); | ||||||
utils.checkValue(outputTensor, expected.data); | ||||||
} | ||||||
|
||||||
it('test cumulativeSum 1d', function() { | ||||||
const input = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
1, 2, 3, 4, 5, | ||||||
], | ||||||
}; | ||||||
const axis=0; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
const options = {exclusive: 0, reverse: 0}; | ||||||
const expected = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
1, 3, 6, 10, 15, | ||||||
], | ||||||
}; | ||||||
testCumulativeSum(input, axis, options, expected); | ||||||
}); | ||||||
|
||||||
it('test cumulativeSum 1d exclusive', function() { | ||||||
const input = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
1, 2, 3, 4, 5, | ||||||
], | ||||||
}; | ||||||
const axis=0; | ||||||
const options = {exclusive: 1, reverse: 0}; | ||||||
const expected = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
0, 1, 3, 6, 10, | ||||||
], | ||||||
}; | ||||||
testCumulativeSum(input, axis, options, expected); | ||||||
}); | ||||||
|
||||||
it('test cumulativeSum 1d reverse', function() { | ||||||
const input = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
1, 2, 3, 4, 5, | ||||||
], | ||||||
}; | ||||||
const axis=0; | ||||||
const options = {exclusive: 0, reverse: 1}; | ||||||
const expected = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
15, 14, 12, 9, 5, | ||||||
], | ||||||
}; | ||||||
testCumulativeSum(input, axis, options, expected); | ||||||
}); | ||||||
|
||||||
it('test cumulativeSum 1d reverse exclusive', function() { | ||||||
const input = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
1, 2, 3, 4, 5, | ||||||
], | ||||||
}; | ||||||
const axis=0; | ||||||
const options = {exclusive: 1, reverse: 1}; | ||||||
const expected = { | ||||||
shape: [5], | ||||||
data: [ | ||||||
14, 12, 9, 5, 0, | ||||||
], | ||||||
}; | ||||||
testCumulativeSum(input, axis, options, expected); | ||||||
}); | ||||||
|
||||||
it('test cumulativeSum 2d', function() { | ||||||
const input = { | ||||||
shape: [2, 3], | ||||||
data: [ | ||||||
1, 2, 3, 4, 5, 6, | ||||||
], | ||||||
}; | ||||||
const axis=0; | ||||||
const options = {exclusive: 0, reverse: 0}; | ||||||
const expected = { | ||||||
shape: [2, 3], | ||||||
data: [ | ||||||
1, 2, 3, 5, 7, 9, | ||||||
], | ||||||
}; | ||||||
testCumulativeSum(input, axis, options, expected); | ||||||
}); | ||||||
|
||||||
it('test cumulativeSum 2d axis=1', function() { | ||||||
const input = { | ||||||
shape: [2, 3], | ||||||
data: [ | ||||||
1, 2, 3, 4, 5, 6, | ||||||
], | ||||||
}; | ||||||
const axis=1; | ||||||
const options = {exclusive: 0, reverse: 0}; | ||||||
const expected = { | ||||||
shape: [2, 3], | ||||||
data: [ | ||||||
1, 3, 6, 4, 9, 15, | ||||||
], | ||||||
}; | ||||||
testCumulativeSum(input, axis, options, expected); | ||||||
}); | ||||||
|
||||||
it('test cumulativeSum 2d negtive axis', function() { | ||||||
const input = { | ||||||
shape: [2, 3], | ||||||
data: [ | ||||||
1, 2, 3, 4, 5, 6, | ||||||
], | ||||||
}; | ||||||
const axis=1; | ||||||
const options = {exclusive: 0, reverse: 0}; | ||||||
const expected = { | ||||||
shape: [2, 3], | ||||||
data: [ | ||||||
1, 3, 6, 4, 9, 15, | ||||||
], | ||||||
}; | ||||||
testCumulativeSum(input, axis, options, expected); | ||||||
}); | ||||||
}); | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refer to https://chromium-review.googlesource.com/c/chromium/src/+/5845069/9/third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.idl#66
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done