From f9d630943650aa059f3d2f7950fca1aa438270da Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Fri, 27 Sep 2024 21:48:57 -0400 Subject: [PATCH] Fix hardcoded shape in low_mem_dropout benchmark --- torchbenchmark/operators/low_mem_dropout/operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchbenchmark/operators/low_mem_dropout/operator.py b/torchbenchmark/operators/low_mem_dropout/operator.py index 82d50a6b0..d171eb8f4 100644 --- a/torchbenchmark/operators/low_mem_dropout/operator.py +++ b/torchbenchmark/operators/low_mem_dropout/operator.py @@ -38,7 +38,7 @@ def triton_dropout(self, p, x): n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda() + x_keep = (torch.rand(size=(n_elements,)) > p).to(torch.int32).cuda() def _inner(): return _triton_dropout[grid](