-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Completed cumulativeSum #103
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mei1127 !
I left some comments. Please take a look, thanks!
src/lib/validate-input.js
Outdated
export function validateCumulativeSumParams(input, axis) { | ||
if (axis !== undefined) { | ||
const rank = input.rank; | ||
if (!Number.isInteger(axis) || axis < -rank || axis >= rank) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The axis is of unsigned long
, so its range is in [0, rank)
. Please modify this check and delete the tests with negative axis, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we generally want WebNN to be more explicit, resolving these kinds of user-facing API conveniences (like negative numbers and output shape rounding) into lower-level concrete values, like the actual axis and specific output shape.
test/cumulativeSum_test.js
Outdated
1, 2, 3, 4, 5, | ||
], | ||
}; | ||
const axis=0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const axis=0; | |
const axis = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/cumulativeSum.js
Outdated
* @param {MLCumulativeSumOptions} options | ||
* @return {Tensor} | ||
*/ | ||
export function cumulativeSum(input, axis, {exclusive = 0, reverse = 0} = {}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
export function cumulativeSum(input, axis, {exclusive = 0, reverse = 0} = {}) { | |
export function cumulativeSum(input, axis, {exclusive = false, reverse = false} = {}) { |
dictionary MLCumulativeSumOptions : MLOperatorOptions {
boolean exclusive = false;
boolean reversed = false;
};
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks May for adding this. Some thoughts 💭🧠.
src/cumulativeSum.js
Outdated
export function cumulativeSum(input, axis, {exclusive = false, reverse = false} = {}) { | ||
validateCumulativeSumParams(...arguments); | ||
const inputShape = input.shape; | ||
const outputShape = [...inputShape]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const outputShape = [...inputShape]; | |
const outputShape = inputShape; |
Do we need the [... ]
, since outputShape will always equal inputShape anyway (no modification)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/cumulativeSum.js
Outdated
const totalElements = sizeOfShape(outputShape); | ||
|
||
for (let outputIndex = 0; outputIndex < totalElements; outputIndex++) { | ||
const loc = output.locationFromIndex(outputIndex); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
request loc
-> location
(whole words policy helps readability for others). Same for inputLocation
and outputLocation
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/cumulativeSum.js
Outdated
const step = reverse ? -1 : 1; | ||
const end = reverse ? -1 : numElementsAlongAxis; | ||
|
||
for (let i = start; reverse ? i > end : i < end; i += step) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are guaranteed this loop will end (given step and end are selected accordingly above, which they are), this...
reverse ? i > end : i < end
...could just be...
i != end
Alternately (and probably clearer) we could make i
always be relative to the output offset and use <
, computing the input coordinate using the step.
for (let i = 0; i < elementCountAlongAxis; ++i) {
inputLocation[axis] = inputElementStart + i * inputElementStep;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/lib/validate-input.js
Outdated
throw new Error(`The axis ${axis} should be an unsigned integer.`); | ||
} | ||
if (axis >= rank) { | ||
throw new Error(`The axis ${axis} should be in the interval [0, ${rank}).`); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, could just combine these two.
if (!Number.isInteger(axis) || axis < 0 || axis >= rank) {
throw new Error(`The axis ${axis} must be an unsigned integer in the interval [0, ${rank}).`);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/cumulativeSum.js
Outdated
|
||
const start = reverse ? numElementsAlongAxis - 1 : 0; | ||
const step = reverse ? -1 : 1; | ||
const end = reverse ? -1 : numElementsAlongAxis; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three values never change and can be computed once outside the loop.
src/cumulativeSum.js
Outdated
|
||
const totalElements = sizeOfShape(outputShape); | ||
|
||
for (let outputIndex = 0; outputIndex < totalElements; outputIndex++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears this nested for
loop repeats the same computation for the same output axial sliver, overwriting the same output elements multiple times. I recommend skipping it if it's already computed in an earlier loop. e.g.
const elementStart = reverse ? elementCountAlongAxis - 1 : 0;
const elementStep = reverse ? -1 : 1;
for (let outputIndex = 0; outputIndex < totalElements; ++outputIndex) {
let location = output.locationFromIndex(outputIndex);
if (location[axis] > 0) {
continue; // No need to compute this axis again, since it was already done.
}
let cumulativeSumValue = 0;
// Compute the accumulated sum along this entire axis.
for (let i = 0; i < elementCountAlongAxis; ++i) {
location[axis] = elementStart + i * elementStep;
const inputValue = input.getValueByLocation(inputLocation);
const oldCumulativeSumValue = cumulativeSumValue;
cumulativeSumValue += inputValue;
const outputValue = exclusive ? oldCumulativeSumValue : cumulativeSumValue;
output.setValueByLocation(outputLocation, outputValue);
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still seeing the doubly nested loop such that the inner loop redundantly overwrites the previous loop's output (which will be identical each time for a given summation sliver). Notice the continue
statement above, since we only need to compute the sum (inner loop) along the axis once for each distinct sliver (outer loop). Also, I since realized that the input and the output locations are always identical, meaning we can fold them into a single location
variable. Updated above accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two more requests, and then it should be good. Thanks.
src/cumulativeSum.js
Outdated
|
||
const totalElements = sizeOfShape(outputShape); | ||
|
||
for (let outputIndex = 0; outputIndex < totalElements; outputIndex++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still seeing the doubly nested loop such that the inner loop redundantly overwrites the previous loop's output (which will be identical each time for a given summation sliver). Notice the continue
statement above, since we only need to compute the sum (inner loop) along the axis once for each distinct sliver (outer loop). Also, I since realized that the input and the output locations are always identical, meaning we can fold them into a single location
variable. Updated above accordingly.
src/cumulativeSum.js
Outdated
const outputLocation = [...location]; | ||
for (let i = 0; i < elementCountAlongAxis; ++i) { | ||
const idx = inputElementStart + i * inputElementStep; | ||
inputLocation[axis]=idx; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputLocation[axis]=idx; | |
inputLocation[axis] = index; |
Request whole words and spacing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your advice!
@BruceDai @huningxin @fdwr @miaobin PTAL, thanks!