Skip to content

Commit

Permalink
Fix RuntimeError when initialising from scratch
Browse files Browse the repository at this point in the history
When initialising from scratch, `requires_grad` is passed and then `normal_` is called as below
```
w = torch.empty( ... , requires_grad=self.requires_grad)
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
```

Causing the following issue:
```
/usr/local/lib/python3.7/dist-packages/dall_e/utils.py in __attrs_post_init__(self)
     22                 size = (self.n_out, self.n_in, self.kw, self.kw)
     23                 w = torch.empty(size=size, dtype=torch.float32, device=self.device, requires_grad = self.requires_grad)
---> 24                 w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
     25 

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
```

The above change fixes it.
  • Loading branch information
yashbonde authored Mar 3, 2021
1 parent 3381ae9 commit 91b8f62
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions dall_e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ class Conv2d(nn.Module):

def __attrs_post_init__(self) -> None:
super().__init__()

w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
device=self.device, requires_grad=self.requires_grad)
size = (self.n_out, self.n_in, self.kw, self.kw)
w = torch.empty(size=size, dtype=torch.float32, device=self.device)
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))

# move requires_grad after filling values using normal_
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
w.requires_grad = self.requires_grad

b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
requires_grad=self.requires_grad)
Expand Down

0 comments on commit 91b8f62

Please sign in to comment.