diff --git a/dall_e/utils.py b/dall_e/utils.py index cdb1cad..0f6b2fc 100644 --- a/dall_e/utils.py +++ b/dall_e/utils.py @@ -19,10 +19,13 @@ class Conv2d(nn.Module): def __attrs_post_init__(self) -> None: super().__init__() - - w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32, - device=self.device, requires_grad=self.requires_grad) + size = (self.n_out, self.n_in, self.kw, self.kw) + w = torch.empty(size=size, dtype=torch.float32, device=self.device) w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) + + # move requires_grad after filling values using normal_ + # RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. + w.requires_grad = self.requires_grad b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device, requires_grad=self.requires_grad)