-
[Python Torch] 얼렁뚱땅 load_state_dict 에러 잡기얼렁뚱땅 에러고치기 2022. 3. 1. 14:36
많은 분들이 학습 시킬 때, 특정 epoch마다 혹은 특정 iter마다 모델을 저장하는데
저장한 모델을 다시 불러올 때가 있죠....
그 때 종종 에러가 뜨는거에요 .. 슬프게 ... ㅜ
저는 CcGAN(https://github.com/UBCDingXin/improved_CcGAN)이라는 것을 학습시키고
학습된 모델을 불러오는 것을 하려고 했는데 에러가 떴습니다 ..ㅎㅎ
gener_path = os.path.join(check_path, name) net_generator = CcGAN_SAGAN_Generator(dim_z=256, dim_embed=64) check_ge = torch.load(gener_path) net_generator.load_state_dict(check_ge['netG_state_dict'])
여기서 CcGAN_SAGAN_Generator는 아래 코드와 같습니다
제가 코드리뷰를 하려고 올린게 아니라, 위에 혹시 CcGAN_SAGAN_Generator가 무엇을 의미하는지 궁금하신 분들이 있으실까봐 올린겁니다
무튼 제가 위에서 정의한 net_generator는 모델을 불러온 것입니다.
class CcGAN_SAGAN_Generator(nn.Module): """Generator.""" def __init__(self, dim_z, dim_embed=128, nc=3, gene_ch=64): super(CcGAN_SAGAN_Generator, self).__init__() self.dim_z = dim_z self.gene_ch = gene_ch self.snlinear0 = snlinear(in_features=dim_z, out_features=gene_ch*16*4*4) self.block1 = GenBlock(gene_ch*16, gene_ch*16, dim_embed) self.block2 = GenBlock(gene_ch*16, gene_ch*8, dim_embed) self.block3 = GenBlock(gene_ch*8, gene_ch*4, dim_embed) self.block4 = GenBlock(gene_ch*4, gene_ch*2, dim_embed) self.self_attn = Self_Attn(gene_ch*2) self.block5 = GenBlock(gene_ch*2, gene_ch*2, dim_embed) self.block6 = GenBlock(gene_ch*2, gene_ch, dim_embed) self.bn = nn.BatchNorm2d(gene_ch, eps=1e-5, momentum=0.0001, affine=True) self.relu = nn.ReLU(inplace=True) self.snconv2d1 = snconv2d(in_channels=gene_ch, out_channels=nc, kernel_size=3, stride=1, padding=1) self.tanh = nn.Tanh() # Weight init self.apply(init_weights) def forward(self, z, labels): # n x dim_z out = self.snlinear0(z) # 4*4 out = out.view(-1, self.gene_ch*16, 4, 4) # 4 x 4 out = self.block1(out, labels) # 8 x 8 out = self.block2(out, labels) # 16 x 16 out = self.block3(out, labels) # 32 x 32 out = self.block4(out, labels) # 64 x 64 out = self.self_attn(out) # 64 x 64 out = self.block5(out, labels) # 128 x 128 out = self.block6(out, labels) # 256 x 256 out = self.bn(out) out = self.relu(out) out = self.snconv2d1(out) out = self.tanh(out) return out
그리고 torch.load를 이용해서 정해진 Path에 있는 모델을 업로드하여 check_ge에 넣어주고
먼저 그 전에
모델.load_state_dict(학습된모델['state_dict'])이 의미하는 바는
정의한 모델에 내가 예전에 학습시켜뒀던 가중치 등을 다시 입력(?)하겠다 입니다.
따라서, 내가 학습했을 때의 환경을 그대로 다시 재현해 주는 것입니다.
그런데 저는 net_generator.load_state_dict(check_ge['netG_state_dict'])을 했는데 아래와 같이 에러가 뜹니다
에러의 의미를 해석하면,
내가 net_generator에 정의한 모델의 파라미터 명과 내가 불러온 파라미터 명의 이름이 다르다는 것입니다.
자세히 아래를 보면 모델(net_generator)이 원하는 key는 'snlinear0.weight_orig' 등등 인데
내가 학습시켰던 모델(check_ge)가 가지고 있는 key는 'module.snlinear0.weight_orig' 등등으로 되어있었습니다.
RuntimeError: Error(s) in loading state_dict for CcGAN_SAGAN_Generator: Missing key(s) in state_dict: "snlinear0.weight_orig", "snlinear0.weight", "snlinear0.weight_u", "snlinear0.bias", "snlinear0.weight_orig", "snlinear0.weight_u", "snlinear0.weight_v", "block1.cond_bn1.bn.running_mean", "block1.cond_bn1.bn.running_var", "block1.cond_bn1.embed_gamma.weight", "block1.cond_bn1.embed_beta.weight", "block1.snconv2d1.weight_orig", "block1.snconv2d1.weight", "block1.snconv2d1.weight_u", "block1.snconv2d1.bias", "block1.snconv2d1.weight_orig", "block1.snconv2d1.weight_u", "block1.snconv2d1.weight_v", "block1.cond_bn2.bn.running_mean", "block1.cond_bn2.bn.running_var", "block1.cond_bn2.embed_gamma.weight", "block1.cond_bn2.embed_beta.weight", "block1.snconv2d2.weight_orig", "block1.snconv2d2.weight", "block1.snconv2d2.weight_u", "block1.snconv2d2.bias", "block1.snconv2d2.weight_orig", "block1.snconv2d2.weight_u", "block1.snconv2d2.weight_v", "block1.snconv2d0.weight_orig", "block1.snconv2d0.weight", "block1.snconv2d0.weight_u", "block1.snconv2d0.bias", "block1.snconv2d0.weight_orig", "block1.snconv2d0.weight_u", "block1.snconv2d0.weight_v", "block2.cond_bn1.bn.running_mean", "block2.cond_bn1.bn.running_var", "block2.cond_bn1.embed_gamma.weight", "block2.cond_bn1.embed_beta.weight", "block2.snconv2d1.weight_orig", "block2.snconv2d1.weight", "block2.snconv2d1.weight_u", "block2.snconv2d1.bias", "block2.snconv2d1.weight_orig", "block2.snconv2d1.weight_u", "block2.snconv2d1.weight_v", "block2.cond_bn2.bn.running_mean", "block2.cond_bn2.bn.running_var", "block2.cond_bn2.embed_gamma.weight", "block2.cond_bn2.embed_beta.weight", "block2.snconv2d2.weight_orig", "block2.snconv2d2.weight", "block2.snconv2d2.weight_u", "block2.snconv2d2.bias", "block2.snconv2d2.weight_orig", "block2.snconv2d2.weight_u", "block2.snconv2d2.weight_v", "block2.snconv2d0.weight_orig", "block2.snconv2d0.weight", "block2.snconv2d0.weight_u", "block2.snconv2d0.bias", "block2.snconv2d0.weight_orig", "block2.snconv2d0.weight_u", "block2.snconv2d0.weight_v", "block3.cond_bn1.bn.running_mean", "block3.cond_bn1.bn.running_var", "block3.cond_bn1.embed_gamma.weight", "block3.cond_bn1.embed_beta.weight", "block3.snconv2d1.weight_orig", "block3.snconv2d1.weight", "block3.snconv2d1.weight_u", "block3.snconv2d1.bias", "block3.snconv2d1.weight_orig", "block3.snconv2d1.weight_u", "block3.snconv2d1.weight_v", "block3.cond_bn2.bn.running_mean", "block3.cond_bn2.bn.running_var", "block3.cond_bn2.embed_gamma.weight", "block3.cond_bn2.embed_beta.weight", "block3.snconv2d2.weight_orig", "block3.snconv2d2.weight", "block3.snconv2d2.weight_u", "block3.snconv2d2.bias", "block3.snconv2d2.weight_orig", "block3.snconv2d2.weight_u", "block3.snconv2d2.weight_v", "block3.snconv2d0.weight_orig", "block3.snconv2d0.weight", "block3.snconv2d0.weight_u", "block3.snconv2d0.bias", "block3.snconv2d0.weight_orig", "block3.snconv2d0.weight_u", "block3.snconv2d0.weight_v", "block4.cond_bn1.bn.running_mean", "block4.cond_bn1.bn.running_var", "block4.cond_bn1.embed_gamma.weight", "block4.cond_bn1.embed_beta.weight", "block4.snconv2d1.weight_orig", "block4.snconv2d1.weight", "block4.snconv2d1.weight_u", "block4.snconv2d1.bias", "block4.snconv2d1.weight_orig", "block4.snconv2d1.weight_u", "block4.snconv2d1.weight_v", "block4.cond_bn2.bn.running_mean", "block4.cond_bn2.bn.running_var", "block4.cond_bn2.embed_gamma.weight", "block4.cond_bn2.embed_beta.weight", "block4.snconv2d2.weight_orig", "block4.snconv2d2.weight", "block4.snconv2d2.weight_u", "block4.snconv2d2.bias", "block4.snconv2d2.weight_orig", "block4.snconv2d2.weight_u", "block4.snconv2d2.weight_v", "block4.snconv2d0.weight_orig", "block4.snconv2d0.weight", "block4.snconv2d0.weight_u", "block4.snconv2d0.bias", "block4.snconv2d0.weight_orig", "block4.snconv2d0.weight_u", "block4.snconv2d0.weight_v", "self_attn.sigma", "self_attn.snconv1x1_theta.weight_orig", "self_attn.snconv1x1_theta.weight", "self_attn.snconv1x1_theta.weight_u", "self_attn.snconv1x1_theta.bias", "self_attn.snconv1x1_theta.weight_orig", "self_attn.snconv1x1_theta.weight_u", "self_attn.snconv1x1_theta.weight_v", "self_attn.snconv1x1_phi.weight_orig", "self_attn.snconv1x1_phi.weight", "self_attn.snconv1x1_phi.weight_u", "self_attn.snconv1x1_phi.bias", "self_attn.snconv1x1_phi.weight_orig", "self_attn.snconv1x1_phi.weight_u", "self_attn.snconv1x1_phi.weight_v", "self_attn.snconv1x1_g.weight_orig", "self_attn.snconv1x1_g.weight", "self_attn.snconv1x1_g.weight_u", "self_attn.snconv1x1_g.bias", "self_attn.snconv1x1_g.weight_orig", "self_attn.snconv1x1_g.weight_u", "self_attn.snconv1x1_g.weight_v", "self_attn.snconv1x1_attn.weight_orig", "self_attn.snconv1x1_attn.weight", "self_attn.snconv1x1_attn.weight_u", "self_attn.snconv1x1_attn.bias", "self_attn.snconv1x1_attn.weight_orig", "self_attn.snconv1x1_attn.weight_u", "self_attn.snconv1x1_attn.weight_v", "block5.cond_bn1.bn.running_mean", "block5.cond_bn1.bn.running_var", "block5.cond_bn1.embed_gamma.weight", "block5.cond_bn1.embed_beta.weight", "block5.snconv2d1.weight_orig", "block5.snconv2d1.weight", "block5.snconv2d1.weight_u", "block5.snconv2d1.bias", "block5.snconv2d1.weight_orig", "block5.snconv2d1.weight_u", "block5.snconv2d1.weight_v", "block5.cond_bn2.bn.running_mean", "block5.cond_bn2.bn.running_var", "block5.cond_bn2.embed_gamma.weight", "block5.cond_bn2.embed_beta.weight", "block5.snconv2d2.weight_orig", "block5.snconv2d2.weight", "block5.snconv2d2.weight_u", "block5.snconv2d2.bias", "block5.snconv2d2.weight_orig", "block5.snconv2d2.weight_u", "block5.snconv2d2.weight_v", "block5.snconv2d0.weight_orig", "block5.snconv2d0.weight", "block5.snconv2d0.weight_u", "block5.snconv2d0.bias", "block5.snconv2d0.weight_orig", "block5.snconv2d0.weight_u", "block5.snconv2d0.weight_v", "block6.cond_bn1.bn.running_mean", "block6.cond_bn1.bn.running_var", "block6.cond_bn1.embed_gamma.weight", "block6.cond_bn1.embed_beta.weight", "block6.snconv2d1.weight_orig", "block6.snconv2d1.weight", "block6.snconv2d1.weight_u", "block6.snconv2d1.bias", "block6.snconv2d1.weight_orig", "block6.snconv2d1.weight_u", "block6.snconv2d1.weight_v", "block6.cond_bn2.bn.running_mean", "block6.cond_bn2.bn.running_var", "block6.cond_bn2.embed_gamma.weight", "block6.cond_bn2.embed_beta.weight", "block6.snconv2d2.weight_orig", "block6.snconv2d2.weight", "block6.snconv2d2.weight_u", "block6.snconv2d2.bias", "block6.snconv2d2.weight_orig", "block6.snconv2d2.weight_u", "block6.snconv2d2.weight_v", "block6.snconv2d0.weight_orig", "block6.snconv2d0.weight", "block6.snconv2d0.weight_u", "block6.snconv2d0.bias", "block6.snconv2d0.weight_orig", "block6.snconv2d0.weight_u", "block6.snconv2d0.weight_v", "bn.weight", "bn.bias", "bn.running_mean", "bn.running_var", "snconv2d1.weight_orig", "snconv2d1.weight", "snconv2d1.weight_u", "snconv2d1.bias", "snconv2d1.weight_orig", "snconv2d1.weight_u", "snconv2d1.weight_v". Unexpected key(s) in state_dict: "module.snlinear0.bias", "module.snlinear0.weight_orig", "module.snlinear0.weight_u", "module.snlinear0.weight_v", "module.block1.cond_bn1.bn.running_mean", "module.block1.cond_bn1.bn.running_var", "module.block1.cond_bn1.bn.num_batches_tracked", "module.block1.cond_bn1.embed_gamma.weight", "module.block1.cond_bn1.embed_beta.weight", "module.block1.snconv2d1.bias", "module.block1.snconv2d1.weight_orig", "module.block1.snconv2d1.weight_u", "module.block1.snconv2d1.weight_v", "module.block1.cond_bn2.bn.running_mean", "module.block1.cond_bn2.bn.running_var", "module.block1.cond_bn2.bn.num_batches_tracked", "module.block1.cond_bn2.embed_gamma.weight", "module.block1.cond_bn2.embed_beta.weight", "module.block1.snconv2d2.bias", "module.block1.snconv2d2.weight_orig", "module.block1.snconv2d2.weight_u", "module.block1.snconv2d2.weight_v", "module.block1.snconv2d0.bias", "module.block1.snconv2d0.weight_orig", "module.block1.snconv2d0.weight_u", "module.block1.snconv2d0.weight_v", "module.block2.cond_bn1.bn.running_mean", "module.block2.cond_bn1.bn.running_var", "module.block2.cond_bn1.bn.num_batches_tracked", "module.block2.cond_bn1.embed_gamma.weight", "module.block2.cond_bn1.embed_beta.weight", "module.block2.snconv2d1.bias", "module.block2.snconv2d1.weight_orig", "module.block2.snconv2d1.weight_u", "module.block2.snconv2d1.weight_v", "module.block2.cond_bn2.bn.running_mean", "module.block2.cond_bn2.bn.running_var", "module.block2.cond_bn2.bn.num_batches_tracked", "module.block2.cond_bn2.embed_gamma.weight", "module.block2.cond_bn2.embed_beta.weight", "module.block2.snconv2d2.bias", "module.block2.snconv2d2.weight_orig", "module.block2.snconv2d2.weight_u", "module.block2.snconv2d2.weight_v", "module.block2.snconv2d0.bias", "module.block2.snconv2d0.weight_orig", "module.block2.snconv2d0.weight_u", "module.block2.snconv2d0.weight_v", "module.block3.cond_bn1.bn.running_mean", "module.block3.cond_bn1.bn.running_var", "module.block3.cond_bn1.bn.num_batches_tracked", "module.block3.cond_bn1.embed_gamma.weight", "module.block3.cond_bn1.embed_beta.weight", "module.block3.snconv2d1.bias", "module.block3.snconv2d1.weight_orig", "module.block3.snconv2d1.weight_u", "module.block3.snconv2d1.weight_v", "module.block3.cond_bn2.bn.running_mean", "module.block3.cond_bn2.bn.running_var", "module.block3.cond_bn2.bn.num_batches_tracked", "module.block3.cond_bn2.embed_gamma.weight", "module.block3.cond_bn2.embed_beta.weight", "module.block3.snconv2d2.bias", "module.block3.snconv2d2.weight_orig", "module.block3.snconv2d2.weight_u", "module.block3.snconv2d2.weight_v", "module.block3.snconv2d0.bias", "module.block3.snconv2d0.weight_orig", "module.block3.snconv2d0.weight_u", "module.block3.snconv2d0.weight_v", "module.block4.cond_bn1.bn.running_mean", "module.block4.cond_bn1.bn.running_var", "module.block4.cond_bn1.bn.num_batches_tracked", "module.block4.cond_bn1.embed_gamma.weight", "module.block4.cond_bn1.embed_beta.weight", "module.block4.snconv2d1.bias", "module.block4.snconv2d1.weight_orig", "module.block4.snconv2d1.weight_u", "module.block4.snconv2d1.weight_v", "module.block4.cond_bn2.bn.running_mean", "module.block4.cond_bn2.bn.running_var", "module.block4.cond_bn2.bn.num_batches_tracked", "module.block4.cond_bn2.embed_gamma.weight", "module.block4.cond_bn2.embed_beta.weight", "module.block4.snconv2d2.bias", "module.block4.snconv2d2.weight_orig", "module.block4.snconv2d2.weight_u", "module.block4.snconv2d2.weight_v", "module.block4.snconv2d0.bias", "module.block4.snconv2d0.weight_orig", "module.block4.snconv2d0.weight_u", "module.block4.snconv2d0.weight_v", "module.self_attn.sigma", "module.self_attn.snconv1x1_theta.bias", "module.self_attn.snconv1x1_theta.weight_orig", "module.self_attn.snconv1x1_theta.weight_u", "module.self_attn.snconv1x1_theta.weight_v", "module.self_attn.snconv1x1_phi.bias", "module.self_attn.snconv1x1_phi.weight_orig", "module.self_attn.snconv1x1_phi.weight_u", "module.self_attn.snconv1x1_phi.weight_v", "module.self_attn.snconv1x1_g.bias", "module.self_attn.snconv1x1_g.weight_orig", "module.self_attn.snconv1x1_g.weight_u", "module.self_attn.snconv1x1_g.weight_v", "module.self_attn.snconv1x1_attn.bias", "module.self_attn.snconv1x1_attn.weight_orig", "module.self_attn.snconv1x1_attn.weight_u", "module.self_attn.snconv1x1_attn.weight_v", "module.block5.cond_bn1.bn.running_mean", "module.block5.cond_bn1.bn.running_var", "module.block5.cond_bn1.bn.num_batches_tracked", "module.block5.cond_bn1.embed_gamma.weight", "module.block5.cond_bn1.embed_beta.weight", "module.block5.snconv2d1.bias", "module.block5.snconv2d1.weight_orig", "module.block5.snconv2d1.weight_u", "module.block5.snconv2d1.weight_v", "module.block5.cond_bn2.bn.running_mean", "module.block5.cond_bn2.bn.running_var", "module.block5.cond_bn2.bn.num_batches_tracked", "module.block5.cond_bn2.embed_gamma.weight", "module.block5.cond_bn2.embed_beta.weight", "module.block5.snconv2d2.bias", "module.block5.snconv2d2.weight_orig", "module.block5.snconv2d2.weight_u", "module.block5.snconv2d2.weight_v", "module.block5.snconv2d0.bias", "module.block5.snconv2d0.weight_orig", "module.block5.snconv2d0.weight_u", "module.block5.snconv2d0.weight_v", "module.block6.cond_bn1.bn.running_mean", "module.block6.cond_bn1.bn.running_var", "module.block6.cond_bn1.bn.num_batches_tracked", "module.block6.cond_bn1.embed_gamma.weight", "module.block6.cond_bn1.embed_beta.weight", "module.block6.snconv2d1.bias", "module.block6.snconv2d1.weight_orig", "module.block6.snconv2d1.weight_u", "module.block6.snconv2d1.weight_v", "module.block6.cond_bn2.bn.running_mean", "module.block6.cond_bn2.bn.running_var", "module.block6.cond_bn2.bn.num_batches_tracked", "module.block6.cond_bn2.embed_gamma.weight", "module.block6.cond_bn2.embed_beta.weight", "module.block6.snconv2d2.bias", "module.block6.snconv2d2.weight_orig", "module.block6.snconv2d2.weight_u", "module.block6.snconv2d2.weight_v", "module.block6.snconv2d0.bias", "module.block6.snconv2d0.weight_orig", "module.block6.snconv2d0.weight_u", "module.block6.snconv2d0.weight_v", "module.bn.weight", "module.bn.bias", "module.bn.running_mean", "module.bn.running_var", "module.bn.num_batches_tracked", "module.snconv2d1.bias", "module.snconv2d1.weight_orig", "module.snconv2d1.weight_u", "module.snconv2d1.weight_v".
따라서, key들의 이름만 맞춰주면 문제는 해결이 됩니다 야호 !
자세히 보면 왜 인지는 저도 잘 모르겠지만 'module.'이 쓸데없이 더 붙어있는 것을 확인 했습니다
따라서, 아래와 같이 for 문과 list를 이용해서 변경해주었습니다.
간단하게만 코드 설명을 하면 keys 라는 변수에 내가 학습했던 모델의 key들을 list 형태로 저장하고
리스트를 하나하나 돌면서, 'module.'라는 것을 없애기 위해 ''으로 대체하였습니다.
그러고, 마지막은 key명을 바꾸어 주는 것입니다.
keys = list(check_ge['netG_state_dict'].keys()) for i in range(len(keys)) : temp = keys[i].replace('module.','') check_ge['netG_state_dict'][temp] = check_ge['netG_state_dict'].pop(keys[i])
그러고 딱 실행을 시키면 !
보이시나요 All keys mathced successfully
하나도 빠짐없이 모든 키가 쑉쑉 자기 자리 잘 찾아간 것을 확인 가능합니다.
In [9]: net_generator.load_state_dict(check_ge['netG_state_dict']) Out[9]: <All keys matched successfully>
저도 제가 에러가 뜨니까 여기저기 열심히 검색해서 에러를 잡으려고 했는데
model.load_state_dict(checkpoint, strict=False)
이러한 방법으로 에러를 잡는것도 보았습니다.
근데 이것은 조금 조심해야 하는데 이렇게 하면 에러가 뜨지는 않습니다.
하지만, 키가 맞지 않는 것은 업로드 되지 않고, 키가 맞는 것만 업로드 하겠다 라는 의미입니다
즉 부분적으로만 가중치를 업로드 하는 것이기 때문에, 내가 맨처음에 학습하고 내가 학습한거 업로드하고자 했던
그 의미가 약간 깨질 수도 있으니 주의해서 사용해주세요
특히 한두개의 가중치 층이면 모두에게 비밀로 하고 그냥 쉿 하고 하면 되는데
저 같은 경우에는 모든 것에 다 module. 이라는 것이 들어갔기 때문에
그냥 strict=False했으면 아무것도 업로드 안될 뻔 했습니다 ㅜ
그러니까 상황에 맞게 잘 사용하시길 바래용 그럼 안뇽
'얼렁뚱땅 에러고치기' 카테고리의 다른 글
[word 수식 비활성화] 저널 형식으로 바꾸다 화나서 쓰는 글 (0) 2022.04.22 [python Torch] Fine-tuning을 위해 모델 재정의하기 (0) 2022.04.16 [python Torch] 얼렁뚱땅 torch multigpu 쓰는 방법 (0) 2022.03.28 [Python Torch] RuntimeError: Found dtype Double but expected Float 에러 잡기 (0) 2022.02.06