-
At 4:56, when you run y_preds after setting right the indentation error of forward function, y_preds should return 10 predicted values. For some weird reason, I am getting a single value as output (in step 6 below). I have tried to rerun the whole thing multiple times right from the beginning. I also checked X_test and y_test, they both contain 10 values. Please see the entire code below. # 1. Creating dataset
Create known parameters
weight = 0.7
bias = 0.3
Create
start = 0
end = 1
step = 0.02
X = torch.arange(start,end,step).unsqueeze(dim=1)
y = weight * X + bias
X[:10],y[:10]
# 2. Create train/test split
train_split = int(0.8 * len(X))
X_train,y_train = X[:train_split],y[:train_split] # indexing
X_test,y_test = X[train_split:],y[train_split:]
len(X_train), len(y_train),len(X_test),len(y_test)
# 3. Defining the model
from torch import nn
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Parameter(torch.randn(1,
requires_grad=True,
dtype=torch.float))
self.bias = nn.Parameter(torch.randn(1,
requires_grad=True,
dtype=torch.float))
Forward method to define the computation in the model
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.weights * self.bias
# 4. Instantiating the model
Create a random seed
torch.manual_seed(42)
Create an instance of the model (this is a subclass)
model_0 = LinearRegressionModel()
# 5. Predicting the y-value
with torch.inference_mode():
y_preds = model_0(X_test)
y_preds
# 6. Wrong output
The output I get is tensor([0.0434]). Thanks a lot and appreciate any help on this. I am sure it is something silly, but for the love of PyTorch I am not able to figure it out. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I am not able to delete the question. But I found out the mistake. I had written the formula wrong within forward(). I fixed it now. Thanks. |
Beta Was this translation helpful? Give feedback.
I am not able to delete the question. But I found out the mistake. I had written the formula wrong within forward(). I fixed it now. Thanks.