Pytorch与深度学习-02.回归预测
作者:Sun zi chao     发布时间:2021-06-25 20:40:17    阅读次数:642
线性回归是最容易理解的一种预测方式。线性方程Y = aX+b,大家都认识。回归预测就是知道一堆X和Y的值,计算出最接近真实a和b的两个值。这个一般运用在连续的,线性的数据预测上。

这里我将举一个线性回归的例子。线性就必须是连续的数据,Y=aX+b,是最好理解的线性方程,我们随机给出一组连续的X值(100个),然后把这些数据代入Y=aX+b方程里求出100个Y,a的值我取2.5,b的值我取随机数。这样,我就得到一组经过“抖动”后的数据。并用这两组数据X和Y去训练,求出最接近的a`和b`。

然后我再用求出的a`和b`和原始数据X套入公式,求出Y,然后画出Y和Y`。比较两组数据差异。


示例程序

			import numpy as np
			import torch
			import matplotlib.pyplot as plt

			def prepare_db():
				#生成100个点的X值
				train_X = np.linspace(-2*np.pi,2*np.pi,100)
				
				#根据train_db的值,生成相应的Y=aX+b值,并进行随机加减
				a=2.5
				train_Y = train_X*a+np.random.rand(100)
				
				train_X = train_X.reshape(-1,1)
				train_Y = train_Y.reshape(-1,1)
				
				return train_X,train_Y
			#=============================================
			class LinearRegression(torch.nn.Module):

				def __init__(self):
					super().__init__()
					self.linear = torch.nn.Linear(1,1)
					
				def forward(self,x):
					out = self.linear(x)
					return out
				
			#==============================================
			class Linear_Model():
				def __init__(self):
					self.learning_rate = 0.001
					self.epoches = 10000
					self.loss_function = torch.nn.MSELoss()
					self.create_model()
					
				def create_model(self):
					self.model = LinearRegression()
					self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
				
				def train(self, train_X,train_Y, model_save_path="model.pth"):
				 
					x = torch.tensor(train_X).float()
					y = torch.tensor(train_Y).float()
					
					for epoch in range(self.epoches):
						prediction = self.model(x)
						loss = self.loss_function(prediction, y)

						self.optimizer.zero_grad()
						loss.backward()
						self.optimizer.step()

						if epoch % 1000 == 0:
							print("epoch: {}, loss is: {}".format(epoch, loss.item()))
							
					torch.save(self.model.state_dict(), "linear.pth")

					
				def test(self,test_db,model_open_path='model.pth'):
					self.model.load_state_dict(torch.load(model_open_path))
					prediction = self.model(torch.tensor(test_db).float())
					return prediction.detach().numpy()

			#==================================================
			if __name__ == '__main__':
				
				train_X,train_Y = prepare_db()
				linear = Linear_Model()    
				linear.train(train_X,train_Y)
				ret = linear.test(train_X,"linear.pth")
				
				plt.plot(train_Y,'g')
				plt.plot(ret,'r')
				plt.show()
			


运行结果

绿色原始真实值,红色是预测值



桂ICP备11003301号-1 公安备案号:45040302000027 Copyright @ 2021- 2022 By Sun zi chao

阅读统计: 2.09W 文章数量: 76 运行天数: 464天 返回cmnsoft