"""
神经网络代码,根据一组线性回归的点,预测线性模型
"""

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# 定义简单的线性模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.weight = nn.Parameter(torch.rand((1,), requires_grad=True))
        self.bias = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, x):
        return x * self.weight + self.bias

# 实例化模型
model = SimpleModel()

# 生成训练数据 我们将生成一些服从线性函数 y = 2 * x + 3 的随机点,并添加一些噪声。
torch.manual_seed(42)  # 为了结果可复现
x_train = torch.linspace(0, 10, 100).unsqueeze(1)  # 100个点,形状为(100, 1)
y_train = 2 * x_train + 3 + torch.randn(x_train.size())  # 加上噪声

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 真实的参数
print(f'Right weight: 2')
print(f'Right bias: 3')

# 打印训练前的参数
print(f'Init weight: {model.weight.item()}')
print(f'Init bias: {model.bias.item()}')

# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):
    model.train()  # 设置模型为训练模式

    # 前向传播
    outputs = model(x_train)
    loss = criterion(outputs, y_train)

    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 打印训练后的参数
print(f'Trained weight: {model.weight.item()}')
print(f'Trained bias: {model.bias.item()}')

# 使用训练好的模型进行预测
model.eval()  # 设置模型为评估模式
with torch.no_grad():  # 在评估时不计算梯度
    predicted = model(x_train)

# 可视化训练数据和预测结果
plt.scatter(x_train.numpy(), y_train.numpy(), label='Original data')
plt.plot(x_train.numpy(), predicted.numpy(), color='red', label='Fitted line')
plt.legend()
plt.show()