AwesomesErrors

本篇主要记录深度学习过程中遇到的各种bug,error以及对应的解决方法。

WARNING:root:NaN or Inf found in input tensor.

这个报错一版搜索的时候出现的是网络训练的时候Loss变为Nan的解决办法,通用的有四种

  1. 降低学习率
  2. 加入正则化
  3. 进行梯度截断
  4. 检查数据

我遇到的这个问题的场景是在训练一个模型vlbert,输入是数据包含图像和对应的图像描述,并且我的模型中包含正则化和梯度截断,因此大概率是数据的原因,下面是我进行debug的过程

  1. 固定shuffle的种子,检查数据是否出现为nan的情况,在pytorch中,使用如下程序进行对dataset的数据进行检查

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    import numpy as np
    import torch
    def check_tensor(self, vector, name=None):
    try:
    if isinstance(vector, torch.Tensor):
    if not np.any(np.isnan(vector.cpu().detach().numpy())):
    return True
    print("[] is false".format(name))
    return False
    elif isinstance(vector, np.ndarray):
    if not np.any(np.isnan(vector)):
    return True
    print("[] is false".format(name))
    return False
    elif isinstance(vector, list):
    vector = np.asarray(vector, dtype=np.float32)
    if not np.any(np.isnan(vector)):
    return True
    print("[] is false".format(name))
    return False
    except Exception as ex:
    logger.error(traceback.format_exc())
    logger.error("name is [{}] value is {}".format(name, vector))
    return False

    加入了数据check部分,使用同样的参数配置进行再次训练,发现在上次同样的迭代次数的时候又出来了nan的错误,但是这一个部分的输入数据并没有触发nan错误

  2. 检查模型输出数据是否为nan

    在模型数据输出部分,该部分数据还没有进行loss计算,只检查这部分数据是否为nan,检查发现,在这个迭代次数出现nan错误的时候,模型的输出为nan

  3. 调低学习率

    降低学习率,实际上这个时候的学习率是很低的$1e-7$,调低到$1e-10$,这个错误还是没有解决

  4. 确信以及肯定是数据的问题,但是就是没找到问题所在

    1. 在dataset的__geiitem__程序中记录当前的index,并将这个数传递到后续的模型计算过程中,这样再出现错误的时候就能根据index取到原始数据,再次同样参数运行程序,找到了出现问题的的这一组batch中的index是谁,因为我是使用的单机多卡训练的,因此为了保险起见,我把出现错误的前面四个batch的数据全部记录下来
  5. check数据,这一组记录的数据进行check,设置为batch=1,发现里面有一个数据会触发nan的错误,检查这个原始数据,在原始数据中记录的是这个样本的图片的url以及对应的图片描述,为了以防万一,打开了图片链接查看,惊喜出现了 https://sf1-ttcdn-tos.pstatp.com/obj/image-pair-videos/image_text_cc_dataset_BM5B_10A2Pglobal001jpglargewid1250_157795348409_jrlVRNYLF 这是这个样本的链接,发现这是一张不正常的图片,查看这个图片记录的其他信息为宽度=0,高度=0,终于找到问题了,如果图片的宽度或者高度为0,那么后续的处理过程中,因为涉及一个计算相对坐标的过程,在这个过程中,原来的绝对坐标处以这个0,直接就导致了数据为nan的错误,把这个数据删除之后一切正常了,天啦撸,太难了

  6. 这个问题debug了一天多,因为出发nan的时候在一个epoch的中间位置,一开始没有相应的记录,很难找到出错的样本,实际上针对梯度爆炸或者梯度小时的问题,开始在模型设计的时候就把该加的都加了经过这个事情,应该学会一下几点

    1. 数据要清洗,针对图像数据,就要看图像是否正常,针对文本数据就要看文本是否为空
    2. dataset中最好记录index,能进行反向追溯
赏杯咖啡!