大蟒蛇python教程共享Python深度学习pyTorch权重衰减与L2范数正则化解析

Python深度学习pyTorch权重衰减与L2范数正则化解析

下面进行一个高维线性实验

假设我们的真实方程是:

Python深度学习pyTorch权重衰减与L2范数正则化解析

假设feature数200,训练样本和测试样本各20个

模拟数据集

  num_train,num_test = 10,10  num_features = 200  true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01  true_b = torch.tensor(0.5)  samples = torch.normal(0,1,(num_train+num_test,num_features))  noise = torch.normal(0,0.01,(num_train+num_test,1))  labels = samples.matmul(true_w) + true_b + noise  train_samples, train_labels= samples[:num_train],labels[:num_train]  test_samples, test_labels = samples[num_train:],labels[num_train:]  

定义带正则项的loss function

  def loss_function(predict,label,w,lambd):      loss = (predict - label) ** 2      loss = loss.mean() + lambd * (w**2).mean()      return loss  

画图的方法

  def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):      plt.figure(figsize=(3,3))      plt.xlabel(x_label)      plt.ylabel(y_label)      plt.semilogy(x_val,y_val)      if x2_val and y2_val:          plt.semilogy(x2_val,y2_val)          plt.legend(legend)      plt.show()  

拟合和画图

  def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):      w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=true)      b = torch.tensor(0.,requires_grad=true)      optimizer = torch.optim.adam([w,b],lr=0.05)      train_loss = []      test_loss = []      for epoch in range(num_epoch):          predict = train_samples.matmul(w) + b          epoch_train_loss = loss_function(predict,train_labels,w,lambd)          optimizer.zero_grad()          epoch_train_loss.backward()          optimizer.step()          test_predict = test_sapmles.matmul(w) + b          epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)          train_loss.append(epoch_train_loss.item())          test_loss.append(epoch_test_loss.item())      semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])  

Python深度学习pyTorch权重衰减与L2范数正则化解析
可以发现加了正则项的模型,在测试集上的loss确实下降了

以上就是python深度学习pytorch权重衰减与l2范数正则化解析的详细内容,更多关于python pytorch权重与l2范数正则化的资料请关注<计算机技术网(www.ctvol.com)!!>其它相关文章!

需要了解更多python教程分享Python深度学习pyTorch权重衰减与L2范数正则化解析,都可以关注python教程分享栏目—计算机技术网(www.ctvol.com)!

www.ctvol.com true Article 大蟒蛇python教程共享Python深度学习pyTorch权重衰减与L2范数正则化解析

本文来自网络收集,不代表计算机技术网立场,如涉及侵权请联系管理员删除。

ctvol管理联系方式QQ:251552304

本文章地址:https://www.ctvol.com/pythontutorial/837098.html

(0)
上一篇 2021年9月30日 下午6:53
下一篇 2021年9月30日 下午6:58

精彩推荐