本篇主要记录深度学习过程中遇到的各种bug,error以及对应的解决方法。
WARNING:root:NaN or Inf found in input tensor.
这个报错一版搜索的时候出现的是网络训练的时候Loss变为Nan的解决办法,通用的有四种
- 降低学习率
- 加入正则化
- 进行梯度截断
- 检查数据
我遇到的这个问题的场景是在训练一个模型vlbert
,输入是数据包含图像和对应的图像描述,并且我的模型中包含正则化和梯度截断,因此大概率是数据的原因,下面是我进行debug的过程
固定shuffle的种子,检查数据是否出现为nan的情况,在pytorch中,使用如下程序进行对dataset的数据进行检查
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24import numpy as np
import torch
def check_tensor(self, vector, name=None):
try:
if isinstance(vector, torch.Tensor):
if not np.any(np.isnan(vector.cpu().detach().numpy())):
return True
print("[] is false".format(name))
return False
elif isinstance(vector, np.ndarray):
if not np.any(np.isnan(vector)):
return True
print("[] is false".format(name))
return False
elif isinstance(vector, list):
vector = np.asarray(vector, dtype=np.float32)
if not np.any(np.isnan(vector)):
return True
print("[] is false".format(name))
return False
except Exception as ex:
logger.error(traceback.format_exc())
logger.error("name is [{}] value is {}".format(name, vector))
return False加入了数据
check
部分,使用同样的参数配置进行再次训练,发现在上次同样的迭代次数的时候又出来了nan
的错误,但是这一个部分的输入数据并没有触发nan
错误检查模型输出数据是否为
nan
在模型数据输出部分,该部分数据还没有进行
loss
计算,只检查这部分数据是否为nan
,检查发现,在这个迭代次数出现nan错误的时候,模型的输出为nan
调低学习率
降低学习率,实际上这个时候的学习率是很低的$1e-7$,调低到$1e-10$,这个错误还是没有解决
确信以及肯定是数据的问题,但是就是没找到问题所在
- 在dataset的
__geiitem__
程序中记录当前的index
,并将这个数传递到后续的模型计算过程中,这样再出现错误的时候就能根据index
取到原始数据,再次同样参数运行程序,找到了出现问题的的这一组batch中的index是谁,因为我是使用的单机多卡训练的,因此为了保险起见,我把出现错误的前面四个batch的数据全部记录下来
- 在dataset的
check数据,这一组记录的数据进行check,设置为batch=1,发现里面有一个数据会触发nan的错误,检查这个原始数据,在原始数据中记录的是这个样本的图片的url以及对应的图片描述,为了以防万一,打开了图片链接查看,惊喜出现了 https://sf1-ttcdn-tos.pstatp.com/obj/image-pair-videos/image_text_cc_dataset_BM5B_10A2Pglobal001jpglargewid1250_157795348409_jrlVRNYLF 这是这个样本的链接,发现这是一张不正常的图片,查看这个图片记录的其他信息为宽度=0,高度=0,终于找到问题了,如果图片的宽度或者高度为0,那么后续的处理过程中,因为涉及一个计算相对坐标的过程,在这个过程中,原来的绝对坐标处以这个0,直接就导致了数据为nan的错误,把这个数据删除之后一切正常了,天啦撸,太难了
这个问题debug了一天多,因为出发nan的时候在一个epoch的中间位置,一开始没有相应的记录,很难找到出错的样本,实际上针对梯度爆炸或者梯度小时的问题,开始在模型设计的时候就把该加的都加了经过这个事情,应该学会一下几点
- 数据要清洗,针对图像数据,就要看图像是否正常,针对文本数据就要看文本是否为空
- dataset中最好记录index,能进行反向追溯