主页 > 手机  > 

PyTorchLightning-LightningModule训练逻辑(training_step)异常处理t

PyTorchLightning-LightningModule训练逻辑(training_step)异常处理t

欢迎关注我的CSDN: spike.blog.csdn.net/ 本文地址: spike.blog.csdn.net/article/details/133673820

在使用 LightningModule 框架训练模型时,因数据导致的训练错误,严重影响训练稳定性,因此需要使用 try-except 及时捕获错误。即 当错误发生时,在 training_step 异常返回 None,同时,on_before_zero_grad 也需要进行异常处理,处理 training_step 的异常返回 None。

同样的,validation_step 也可以这样处理。

源码如下:

class MyObject(pl.LightningModule): def __init__(self, config, args): # ... def training_step_wrapper(self, batch, batch_idx, log_interval=10): # train key process def training_step(self, batch, batch_idx, log_interval=10): """ typically, each step costs 50 seconds 参考: github /Lightning-AI/lightning/pull/3566 """ try: res = self.training_step_wrapper(batch, batch_idx, log_interval) return res except Exception as e: logger.info(f"[CL] training_step, exception: {e}") return None def on_before_zero_grad(self, *args, **kwargs): try: self.ema.update(self.model) except Exception as e: # 支持 training_step return None logger.info(f"[CL] on_before_zero_grad, exception: {e}") return def validation_step_wrapper(self, batch, batch_idx): # val key process def validation_step(self, batch, batch_idx): try: self.validation_step_wrapper(batch, batch_idx) except Exception as e: logger.info(f"[CL] validation_step, exception: {e}") return

常见错误如下

数组越界:

index 0 is out of bounds for dimension 0 with size 0

字典错误字段:

num_res = int(np_example["seq_length"]) KeyError: 'seq_length'

计算输入数值为空:

V, _, W = torch.linalg.svd(C)

free()异常:

free(): invalid next size (fast)

munmap_chunk() 空指针:

munmap_chunk(): invalid pointer

标签:

PyTorchLightning-LightningModule训练逻辑(training_step)异常处理t由讯客互联手机栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“PyTorchLightning-LightningModule训练逻辑(training_step)异常处理t