20250303-代码笔记-classCVRPTester
- 互联网
- 2025-09-14 20:21:02

文章目录 前言一、class CVRPTester:__init__(self,env_params,model_params, tester_params)1.1函数解析1.2函数分析1.2.1加载预训练模型 1.2函数代码 二、class CVRPTester:run(self)函数解析函数代码 三、class CVRPTester:_test_one_batch(self, batch_size)函数解析函数代码 附录代码(全)
前言
学习代码CVRPTester.py,对代码的分析如下。
/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py
一、class CVRPTester:init(self,env_params,model_params, tester_params) 1.1函数解析
执行流程图链接
1.2函数分析 1.2.1加载预训练模型代码:
# Restore model_load = tester_params['model_load'] checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load) checkpoint = torch.load(checkpoint_fullname, map_location=device) self.model.load_state_dict(checkpoint['model_state_dict']) model_load: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的 epoch: model_load = tester_params['model_load'] checkpoint_fullname: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。 这会生成形如 /path/to/model/checkpoint-8100.pt 的文件路径。即需要输入参数path和epoch。 checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load) 加载模型: torch.load(checkpoint_fullname, map_location=device):从磁盘加载模型检查点(即 .pt 文件),并将其存储在 checkpoint 变量中。map_location=device 确保模型会被加载到正确的设备上(GPU 或 CPU)。self.model.load_state_dict(checkpoint['model_state_dict']):从加载的检查点中提取模型的状态字典,并将其加载到 self.model 中。 checkpoint = torch.load(checkpoint_fullname, map_location=device) self.model.load_state_dict(checkpoint['model_state_dict'])示例 假设 tester_params_regret[‘model_load’] 如下所示:
tester_params_regret = { 'model_load': { 'path': '../../pretrained/vrp100', 'epoch': 8100, }, # 其他参数... }然后 checkpoint_fullname 会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt,模型会从该路径加载。
1.2函数代码 def __init__(self, env_params, model_params, tester_params): # save arguments self.env_params = env_params self.model_params = model_params self.tester_params = tester_params # result folder, logger self.logger = getLogger(name='trainer') self.result_folder = get_result_folder() # cuda USE_CUDA = self.tester_params['use_cuda'] if USE_CUDA: cuda_device_num = self.tester_params['cuda_device_num'] torch.cuda.set_device(cuda_device_num) device = torch.device('cuda', cuda_device_num) torch.set_default_tensor_type('torch.cuda.FloatTensor') else: device = torch.device('cpu') torch.set_default_tensor_type('torch.FloatTensor') self.device = device # ENV and MODEL self.env = Env(**self.env_params) self.model = Model(**self.model_params) # Restore model_load = tester_params['model_load'] checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load) checkpoint = torch.load(checkpoint_fullname, map_location=device) self.model.load_state_dict(checkpoint['model_state_dict']) # utility self.time_estimator = TimeEstimator()
二、class CVRPTester:run(self) 函数解析
函数执行流程图链接
函数代码 def run(self): self.time_estimator.reset() score_AM = AverageMeter() aug_score_AM = AverageMeter() if self.tester_params['test_data_load']['enable']: self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device) test_num_episode = self.tester_params['test_episodes'] episode = 0 while episode < test_num_episode: remaining = test_num_episode - episode batch_size = min(self.tester_params['test_batch_size'], remaining) score, aug_score = self._test_one_batch(batch_size) score_AM.update(score, batch_size) aug_score_AM.update(aug_score, batch_size) episode += batch_size ############################ # Logs ############################ elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode) self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format( episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score)) all_done = (episode == test_num_episode) if all_done: self.logger.info(" *** Test Done *** ") self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg)) self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))三、class CVRPTester:_test_one_batch(self, batch_size) 函数解析
执行流程图链接
函数代码 def _test_one_batch(self, batch_size): # Augmentation ############################################### if self.tester_params['augmentation_enable']: aug_factor = self.tester_params['aug_factor'] else: aug_factor = 1 # Ready ############################################### self.model.eval() with torch.no_grad(): self.env.load_problems(batch_size, aug_factor) reset_state, _, _ = self.env.reset() self.model.pre_forward(reset_state) # POMO Rollout ############################################### state, reward, done = self.env.pre_step() while not done: selected, _ = self.model(state) # shape: (batch, pomo) state, reward, done = self.env.step(selected) # Return ############################################### aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size) # shape: (augmentation, batch, pomo) max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo # shape: (augmentation, batch) no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation # shape: (batch,) aug_score = -max_aug_pomo_reward.float().mean() # negative sign to make positive value return no_aug_score.item(), aug_score.item()附录 代码(全) import torch import os from logging import getLogger from CVRPEnv import CVRPEnv as Env from CVRPModel import CVRPModel as Model from utils.utils import * class CVRPTester: def __init__(self, env_params, model_params, tester_params): # save arguments self.env_params = env_params self.model_params = model_params self.tester_params = tester_params # result folder, logger self.logger = getLogger(name='trainer') self.result_folder = get_result_folder() # cuda USE_CUDA = self.tester_params['use_cuda'] if USE_CUDA: cuda_device_num = self.tester_params['cuda_device_num'] torch.cuda.set_device(cuda_device_num) device = torch.device('cuda', cuda_device_num) torch.set_default_tensor_type('torch.cuda.FloatTensor') else: device = torch.device('cpu') torch.set_default_tensor_type('torch.FloatTensor') self.device = device # ENV and MODEL self.env = Env(**self.env_params) self.model = Model(**self.model_params) # Restore model_load = tester_params['model_load'] checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load) checkpoint = torch.load(checkpoint_fullname, map_location=device) self.model.load_state_dict(checkpoint['model_state_dict']) # utility self.time_estimator = TimeEstimator() def run(self): self.time_estimator.reset() score_AM = AverageMeter() aug_score_AM = AverageMeter() if self.tester_params['test_data_load']['enable']: self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device) test_num_episode = self.tester_params['test_episodes'] episode = 0 while episode < test_num_episode: remaining = test_num_episode - episode batch_size = min(self.tester_params['test_batch_size'], remaining) score, aug_score = self._test_one_batch(batch_size) score_AM.update(score, batch_size) aug_score_AM.update(aug_score, batch_size) episode += batch_size ############################ # Logs ############################ elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode) self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format( episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score)) all_done = (episode == test_num_episode) if all_done: self.logger.info(" *** Test Done *** ") self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg)) self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg)) def _test_one_batch(self, batch_size): # Augmentation ############################################### if self.tester_params['augmentation_enable']: aug_factor = self.tester_params['aug_factor'] else: aug_factor = 1 # Ready ############################################### self.model.eval() with torch.no_grad(): self.env.load_problems(batch_size, aug_factor) reset_state, _, _ = self.env.reset() self.model.pre_forward(reset_state) # POMO Rollout ############################################### state, reward, done = self.env.pre_step() while not done: selected, _ = self.model(state) # shape: (batch, pomo) state, reward, done = self.env.step(selected) # Return ############################################### aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size) # shape: (augmentation, batch, pomo) max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo # shape: (augmentation, batch) no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation # shape: (batch,) aug_score = -max_aug_pomo_reward.float().mean() # negative sign to make positive value return no_aug_score.item(), aug_score.item()
20250303-代码笔记-classCVRPTester由讯客互联互联网栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“20250303-代码笔记-classCVRPTester”
上一篇
Ribbon实现原理