ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Python Torch] 얼렁뚱땅 load_state_dict 에러 잡기
    얼렁뚱땅 에러고치기 2022. 3. 1. 14:36

    많은 분들이 학습 시킬 때, 특정 epoch마다 혹은 특정 iter마다 모델을 저장하는데

    저장한 모델을 다시 불러올 때가 있죠....

    그 때 종종 에러가 뜨는거에요 .. 슬프게 ... ㅜ

     

    저는 CcGAN(https://github.com/UBCDingXin/improved_CcGAN)이라는 것을 학습시키고

    학습된 모델을 불러오는 것을 하려고 했는데 에러가 떴습니다 ..ㅎㅎ

     

    gener_path = os.path.join(check_path, name)
    net_generator = CcGAN_SAGAN_Generator(dim_z=256, dim_embed=64)
    check_ge = torch.load(gener_path)
    net_generator.load_state_dict(check_ge['netG_state_dict'])

     

    여기서 CcGAN_SAGAN_Generator는 아래 코드와 같습니다

    제가 코드리뷰를 하려고 올린게 아니라, 위에 혹시 CcGAN_SAGAN_Generator가 무엇을 의미하는지 궁금하신 분들이 있으실까봐 올린겁니다

    무튼 제가 위에서 정의한 net_generator는 모델을 불러온 것입니다.

    class CcGAN_SAGAN_Generator(nn.Module):
        """Generator."""
    
        def __init__(self, dim_z, dim_embed=128, nc=3, gene_ch=64):
            super(CcGAN_SAGAN_Generator, self).__init__()
    
            self.dim_z = dim_z
            self.gene_ch = gene_ch
    
            self.snlinear0 = snlinear(in_features=dim_z, out_features=gene_ch*16*4*4)
            self.block1 = GenBlock(gene_ch*16, gene_ch*16, dim_embed)
            self.block2 = GenBlock(gene_ch*16, gene_ch*8, dim_embed)
            self.block3 = GenBlock(gene_ch*8, gene_ch*4, dim_embed)
            self.block4 = GenBlock(gene_ch*4, gene_ch*2, dim_embed)
            self.self_attn = Self_Attn(gene_ch*2)
            self.block5 = GenBlock(gene_ch*2, gene_ch*2, dim_embed)
            self.block6 = GenBlock(gene_ch*2, gene_ch, dim_embed)
            self.bn = nn.BatchNorm2d(gene_ch, eps=1e-5, momentum=0.0001, affine=True)
            self.relu = nn.ReLU(inplace=True)
            self.snconv2d1 = snconv2d(in_channels=gene_ch, out_channels=nc, kernel_size=3, stride=1, padding=1)
            self.tanh = nn.Tanh()
    
            # Weight init
            self.apply(init_weights)
    
        def forward(self, z, labels):
            # n x dim_z
            out = self.snlinear0(z)            # 4*4
            out = out.view(-1, self.gene_ch*16, 4, 4) # 4 x 4
            out = self.block1(out, labels)    # 8 x 8
            out = self.block2(out, labels)    # 16 x 16
            out = self.block3(out, labels)    # 32 x 32
            out = self.block4(out, labels)    # 64 x 64
            out = self.self_attn(out)         # 64 x 64
            out = self.block5(out, labels)    # 128 x 128
            out = self.block6(out, labels)    # 256 x 256
            out = self.bn(out)
            out = self.relu(out)
            out = self.snconv2d1(out)
            out = self.tanh(out)
            return out

    그리고 torch.load를 이용해서 정해진 Path에 있는 모델을 업로드하여 check_ge에 넣어주고

     

    먼저 그 전에

    모델.load_state_dict(학습된모델['state_dict'])이 의미하는 바는

    정의한 모델에 내가 예전에 학습시켜뒀던 가중치 등을 다시 입력(?)하겠다 입니다.

    따라서, 내가 학습했을 때의 환경을 그대로 다시 재현해 주는 것입니다.

     

    그런데 저는 net_generator.load_state_dict(check_ge['netG_state_dict'])을 했는데 아래와 같이 에러가 뜹니다

     

    에러의 의미를 해석하면,

    내가 net_generator에 정의한 모델의 파라미터 명과 내가 불러온 파라미터 명의 이름이 다르다는 것입니다.

     

    자세히 아래를 보면 모델(net_generator)이 원하는 key는 'snlinear0.weight_orig' 등등 인데

    내가 학습시켰던 모델(check_ge)가 가지고 있는 key는 'module.snlinear0.weight_orig' 등등으로 되어있었습니다.

    RuntimeError: Error(s) in loading state_dict for CcGAN_SAGAN_Generator:
    	Missing key(s) in state_dict: "snlinear0.weight_orig", "snlinear0.weight", "snlinear0.weight_u", "snlinear0.bias", "snlinear0.weight_orig", "snlinear0.weight_u", "snlinear0.weight_v", "block1.cond_bn1.bn.running_mean", "block1.cond_bn1.bn.running_var", "block1.cond_bn1.embed_gamma.weight", "block1.cond_bn1.embed_beta.weight", "block1.snconv2d1.weight_orig", "block1.snconv2d1.weight", "block1.snconv2d1.weight_u", "block1.snconv2d1.bias", "block1.snconv2d1.weight_orig", "block1.snconv2d1.weight_u", "block1.snconv2d1.weight_v", "block1.cond_bn2.bn.running_mean", "block1.cond_bn2.bn.running_var", "block1.cond_bn2.embed_gamma.weight", "block1.cond_bn2.embed_beta.weight", "block1.snconv2d2.weight_orig", "block1.snconv2d2.weight", "block1.snconv2d2.weight_u", "block1.snconv2d2.bias", "block1.snconv2d2.weight_orig", "block1.snconv2d2.weight_u", "block1.snconv2d2.weight_v", "block1.snconv2d0.weight_orig", "block1.snconv2d0.weight", "block1.snconv2d0.weight_u", "block1.snconv2d0.bias", "block1.snconv2d0.weight_orig", "block1.snconv2d0.weight_u", "block1.snconv2d0.weight_v", "block2.cond_bn1.bn.running_mean", "block2.cond_bn1.bn.running_var", "block2.cond_bn1.embed_gamma.weight", "block2.cond_bn1.embed_beta.weight", "block2.snconv2d1.weight_orig", "block2.snconv2d1.weight", "block2.snconv2d1.weight_u", "block2.snconv2d1.bias", "block2.snconv2d1.weight_orig", "block2.snconv2d1.weight_u", "block2.snconv2d1.weight_v", "block2.cond_bn2.bn.running_mean", "block2.cond_bn2.bn.running_var", "block2.cond_bn2.embed_gamma.weight", "block2.cond_bn2.embed_beta.weight", "block2.snconv2d2.weight_orig", "block2.snconv2d2.weight", "block2.snconv2d2.weight_u", "block2.snconv2d2.bias", "block2.snconv2d2.weight_orig", "block2.snconv2d2.weight_u", "block2.snconv2d2.weight_v", "block2.snconv2d0.weight_orig", "block2.snconv2d0.weight", "block2.snconv2d0.weight_u", "block2.snconv2d0.bias", "block2.snconv2d0.weight_orig", "block2.snconv2d0.weight_u", "block2.snconv2d0.weight_v", "block3.cond_bn1.bn.running_mean", "block3.cond_bn1.bn.running_var", "block3.cond_bn1.embed_gamma.weight", "block3.cond_bn1.embed_beta.weight", "block3.snconv2d1.weight_orig", "block3.snconv2d1.weight", "block3.snconv2d1.weight_u", "block3.snconv2d1.bias", "block3.snconv2d1.weight_orig", "block3.snconv2d1.weight_u", "block3.snconv2d1.weight_v", "block3.cond_bn2.bn.running_mean", "block3.cond_bn2.bn.running_var", "block3.cond_bn2.embed_gamma.weight", "block3.cond_bn2.embed_beta.weight", "block3.snconv2d2.weight_orig", "block3.snconv2d2.weight", "block3.snconv2d2.weight_u", "block3.snconv2d2.bias", "block3.snconv2d2.weight_orig", "block3.snconv2d2.weight_u", "block3.snconv2d2.weight_v", "block3.snconv2d0.weight_orig", "block3.snconv2d0.weight", "block3.snconv2d0.weight_u", "block3.snconv2d0.bias", "block3.snconv2d0.weight_orig", "block3.snconv2d0.weight_u", "block3.snconv2d0.weight_v", "block4.cond_bn1.bn.running_mean", "block4.cond_bn1.bn.running_var", "block4.cond_bn1.embed_gamma.weight", "block4.cond_bn1.embed_beta.weight", "block4.snconv2d1.weight_orig", "block4.snconv2d1.weight", "block4.snconv2d1.weight_u", "block4.snconv2d1.bias", "block4.snconv2d1.weight_orig", "block4.snconv2d1.weight_u", "block4.snconv2d1.weight_v", "block4.cond_bn2.bn.running_mean", "block4.cond_bn2.bn.running_var", "block4.cond_bn2.embed_gamma.weight", "block4.cond_bn2.embed_beta.weight", "block4.snconv2d2.weight_orig", "block4.snconv2d2.weight", "block4.snconv2d2.weight_u", "block4.snconv2d2.bias", "block4.snconv2d2.weight_orig", "block4.snconv2d2.weight_u", "block4.snconv2d2.weight_v", "block4.snconv2d0.weight_orig", "block4.snconv2d0.weight", "block4.snconv2d0.weight_u", "block4.snconv2d0.bias", "block4.snconv2d0.weight_orig", "block4.snconv2d0.weight_u", "block4.snconv2d0.weight_v", "self_attn.sigma", "self_attn.snconv1x1_theta.weight_orig", "self_attn.snconv1x1_theta.weight", "self_attn.snconv1x1_theta.weight_u", "self_attn.snconv1x1_theta.bias", "self_attn.snconv1x1_theta.weight_orig", "self_attn.snconv1x1_theta.weight_u", "self_attn.snconv1x1_theta.weight_v", "self_attn.snconv1x1_phi.weight_orig", "self_attn.snconv1x1_phi.weight", "self_attn.snconv1x1_phi.weight_u", "self_attn.snconv1x1_phi.bias", "self_attn.snconv1x1_phi.weight_orig", "self_attn.snconv1x1_phi.weight_u", "self_attn.snconv1x1_phi.weight_v", "self_attn.snconv1x1_g.weight_orig", "self_attn.snconv1x1_g.weight", "self_attn.snconv1x1_g.weight_u", "self_attn.snconv1x1_g.bias", "self_attn.snconv1x1_g.weight_orig", "self_attn.snconv1x1_g.weight_u", "self_attn.snconv1x1_g.weight_v", "self_attn.snconv1x1_attn.weight_orig", "self_attn.snconv1x1_attn.weight", "self_attn.snconv1x1_attn.weight_u", "self_attn.snconv1x1_attn.bias", "self_attn.snconv1x1_attn.weight_orig", "self_attn.snconv1x1_attn.weight_u", "self_attn.snconv1x1_attn.weight_v", "block5.cond_bn1.bn.running_mean", "block5.cond_bn1.bn.running_var", "block5.cond_bn1.embed_gamma.weight", "block5.cond_bn1.embed_beta.weight", "block5.snconv2d1.weight_orig", "block5.snconv2d1.weight", "block5.snconv2d1.weight_u", "block5.snconv2d1.bias", "block5.snconv2d1.weight_orig", "block5.snconv2d1.weight_u", "block5.snconv2d1.weight_v", "block5.cond_bn2.bn.running_mean", "block5.cond_bn2.bn.running_var", "block5.cond_bn2.embed_gamma.weight", "block5.cond_bn2.embed_beta.weight", "block5.snconv2d2.weight_orig", "block5.snconv2d2.weight", "block5.snconv2d2.weight_u", "block5.snconv2d2.bias", "block5.snconv2d2.weight_orig", "block5.snconv2d2.weight_u", "block5.snconv2d2.weight_v", "block5.snconv2d0.weight_orig", "block5.snconv2d0.weight", "block5.snconv2d0.weight_u", "block5.snconv2d0.bias", "block5.snconv2d0.weight_orig", "block5.snconv2d0.weight_u", "block5.snconv2d0.weight_v", "block6.cond_bn1.bn.running_mean", "block6.cond_bn1.bn.running_var", "block6.cond_bn1.embed_gamma.weight", "block6.cond_bn1.embed_beta.weight", "block6.snconv2d1.weight_orig", "block6.snconv2d1.weight", "block6.snconv2d1.weight_u", "block6.snconv2d1.bias", "block6.snconv2d1.weight_orig", "block6.snconv2d1.weight_u", "block6.snconv2d1.weight_v", "block6.cond_bn2.bn.running_mean", "block6.cond_bn2.bn.running_var", "block6.cond_bn2.embed_gamma.weight", "block6.cond_bn2.embed_beta.weight", "block6.snconv2d2.weight_orig", "block6.snconv2d2.weight", "block6.snconv2d2.weight_u", "block6.snconv2d2.bias", "block6.snconv2d2.weight_orig", "block6.snconv2d2.weight_u", "block6.snconv2d2.weight_v", "block6.snconv2d0.weight_orig", "block6.snconv2d0.weight", "block6.snconv2d0.weight_u", "block6.snconv2d0.bias", "block6.snconv2d0.weight_orig", "block6.snconv2d0.weight_u", "block6.snconv2d0.weight_v", "bn.weight", "bn.bias", "bn.running_mean", "bn.running_var", "snconv2d1.weight_orig", "snconv2d1.weight", "snconv2d1.weight_u", "snconv2d1.bias", "snconv2d1.weight_orig", "snconv2d1.weight_u", "snconv2d1.weight_v". 
    	Unexpected key(s) in state_dict: "module.snlinear0.bias", "module.snlinear0.weight_orig", "module.snlinear0.weight_u", "module.snlinear0.weight_v", "module.block1.cond_bn1.bn.running_mean", "module.block1.cond_bn1.bn.running_var", "module.block1.cond_bn1.bn.num_batches_tracked", "module.block1.cond_bn1.embed_gamma.weight", "module.block1.cond_bn1.embed_beta.weight", "module.block1.snconv2d1.bias", "module.block1.snconv2d1.weight_orig", "module.block1.snconv2d1.weight_u", "module.block1.snconv2d1.weight_v", "module.block1.cond_bn2.bn.running_mean", "module.block1.cond_bn2.bn.running_var", "module.block1.cond_bn2.bn.num_batches_tracked", "module.block1.cond_bn2.embed_gamma.weight", "module.block1.cond_bn2.embed_beta.weight", "module.block1.snconv2d2.bias", "module.block1.snconv2d2.weight_orig", "module.block1.snconv2d2.weight_u", "module.block1.snconv2d2.weight_v", "module.block1.snconv2d0.bias", "module.block1.snconv2d0.weight_orig", "module.block1.snconv2d0.weight_u", "module.block1.snconv2d0.weight_v", "module.block2.cond_bn1.bn.running_mean", "module.block2.cond_bn1.bn.running_var", "module.block2.cond_bn1.bn.num_batches_tracked", "module.block2.cond_bn1.embed_gamma.weight", "module.block2.cond_bn1.embed_beta.weight", "module.block2.snconv2d1.bias", "module.block2.snconv2d1.weight_orig", "module.block2.snconv2d1.weight_u", "module.block2.snconv2d1.weight_v", "module.block2.cond_bn2.bn.running_mean", "module.block2.cond_bn2.bn.running_var", "module.block2.cond_bn2.bn.num_batches_tracked", "module.block2.cond_bn2.embed_gamma.weight", "module.block2.cond_bn2.embed_beta.weight", "module.block2.snconv2d2.bias", "module.block2.snconv2d2.weight_orig", "module.block2.snconv2d2.weight_u", "module.block2.snconv2d2.weight_v", "module.block2.snconv2d0.bias", "module.block2.snconv2d0.weight_orig", "module.block2.snconv2d0.weight_u", "module.block2.snconv2d0.weight_v", "module.block3.cond_bn1.bn.running_mean", "module.block3.cond_bn1.bn.running_var", "module.block3.cond_bn1.bn.num_batches_tracked", "module.block3.cond_bn1.embed_gamma.weight", "module.block3.cond_bn1.embed_beta.weight", "module.block3.snconv2d1.bias", "module.block3.snconv2d1.weight_orig", "module.block3.snconv2d1.weight_u", "module.block3.snconv2d1.weight_v", "module.block3.cond_bn2.bn.running_mean", "module.block3.cond_bn2.bn.running_var", "module.block3.cond_bn2.bn.num_batches_tracked", "module.block3.cond_bn2.embed_gamma.weight", "module.block3.cond_bn2.embed_beta.weight", "module.block3.snconv2d2.bias", "module.block3.snconv2d2.weight_orig", "module.block3.snconv2d2.weight_u", "module.block3.snconv2d2.weight_v", "module.block3.snconv2d0.bias", "module.block3.snconv2d0.weight_orig", "module.block3.snconv2d0.weight_u", "module.block3.snconv2d0.weight_v", "module.block4.cond_bn1.bn.running_mean", "module.block4.cond_bn1.bn.running_var", "module.block4.cond_bn1.bn.num_batches_tracked", "module.block4.cond_bn1.embed_gamma.weight", "module.block4.cond_bn1.embed_beta.weight", "module.block4.snconv2d1.bias", "module.block4.snconv2d1.weight_orig", "module.block4.snconv2d1.weight_u", "module.block4.snconv2d1.weight_v", "module.block4.cond_bn2.bn.running_mean", "module.block4.cond_bn2.bn.running_var", "module.block4.cond_bn2.bn.num_batches_tracked", "module.block4.cond_bn2.embed_gamma.weight", "module.block4.cond_bn2.embed_beta.weight", "module.block4.snconv2d2.bias", "module.block4.snconv2d2.weight_orig", "module.block4.snconv2d2.weight_u", "module.block4.snconv2d2.weight_v", "module.block4.snconv2d0.bias", "module.block4.snconv2d0.weight_orig", "module.block4.snconv2d0.weight_u", "module.block4.snconv2d0.weight_v", "module.self_attn.sigma", "module.self_attn.snconv1x1_theta.bias", "module.self_attn.snconv1x1_theta.weight_orig", "module.self_attn.snconv1x1_theta.weight_u", "module.self_attn.snconv1x1_theta.weight_v", "module.self_attn.snconv1x1_phi.bias", "module.self_attn.snconv1x1_phi.weight_orig", "module.self_attn.snconv1x1_phi.weight_u", "module.self_attn.snconv1x1_phi.weight_v", "module.self_attn.snconv1x1_g.bias", "module.self_attn.snconv1x1_g.weight_orig", "module.self_attn.snconv1x1_g.weight_u", "module.self_attn.snconv1x1_g.weight_v", "module.self_attn.snconv1x1_attn.bias", "module.self_attn.snconv1x1_attn.weight_orig", "module.self_attn.snconv1x1_attn.weight_u", "module.self_attn.snconv1x1_attn.weight_v", "module.block5.cond_bn1.bn.running_mean", "module.block5.cond_bn1.bn.running_var", "module.block5.cond_bn1.bn.num_batches_tracked", "module.block5.cond_bn1.embed_gamma.weight", "module.block5.cond_bn1.embed_beta.weight", "module.block5.snconv2d1.bias", "module.block5.snconv2d1.weight_orig", "module.block5.snconv2d1.weight_u", "module.block5.snconv2d1.weight_v", "module.block5.cond_bn2.bn.running_mean", "module.block5.cond_bn2.bn.running_var", "module.block5.cond_bn2.bn.num_batches_tracked", "module.block5.cond_bn2.embed_gamma.weight", "module.block5.cond_bn2.embed_beta.weight", "module.block5.snconv2d2.bias", "module.block5.snconv2d2.weight_orig", "module.block5.snconv2d2.weight_u", "module.block5.snconv2d2.weight_v", "module.block5.snconv2d0.bias", "module.block5.snconv2d0.weight_orig", "module.block5.snconv2d0.weight_u", "module.block5.snconv2d0.weight_v", "module.block6.cond_bn1.bn.running_mean", "module.block6.cond_bn1.bn.running_var", "module.block6.cond_bn1.bn.num_batches_tracked", "module.block6.cond_bn1.embed_gamma.weight", "module.block6.cond_bn1.embed_beta.weight", "module.block6.snconv2d1.bias", "module.block6.snconv2d1.weight_orig", "module.block6.snconv2d1.weight_u", "module.block6.snconv2d1.weight_v", "module.block6.cond_bn2.bn.running_mean", "module.block6.cond_bn2.bn.running_var", "module.block6.cond_bn2.bn.num_batches_tracked", "module.block6.cond_bn2.embed_gamma.weight", "module.block6.cond_bn2.embed_beta.weight", "module.block6.snconv2d2.bias", "module.block6.snconv2d2.weight_orig", "module.block6.snconv2d2.weight_u", "module.block6.snconv2d2.weight_v", "module.block6.snconv2d0.bias", "module.block6.snconv2d0.weight_orig", "module.block6.snconv2d0.weight_u", "module.block6.snconv2d0.weight_v", "module.bn.weight", "module.bn.bias", "module.bn.running_mean", "module.bn.running_var", "module.bn.num_batches_tracked", "module.snconv2d1.bias", "module.snconv2d1.weight_orig", "module.snconv2d1.weight_u", "module.snconv2d1.weight_v".

     

    따라서, key들의 이름만 맞춰주면 문제는 해결이 됩니다 야호 ! 

    자세히 보면 왜 인지는 저도 잘 모르겠지만 'module.'이 쓸데없이 더 붙어있는 것을 확인 했습니다

     

    따라서, 아래와 같이 for 문과 list를 이용해서 변경해주었습니다.

     

    간단하게만 코드 설명을 하면 keys 라는 변수에 내가 학습했던 모델의 key들을 list 형태로 저장하고

    리스트를 하나하나 돌면서, 'module.'라는 것을 없애기 위해 ''으로 대체하였습니다.

    그러고, 마지막은 key명을 바꾸어 주는 것입니다. 

    keys = list(check_ge['netG_state_dict'].keys())
    for i in range(len(keys)) :
        temp = keys[i].replace('module.','')
        check_ge['netG_state_dict'][temp] = check_ge['netG_state_dict'].pop(keys[i])

     

    그러고 딱 실행을 시키면 ! 

    보이시나요 All keys mathced successfully

    하나도 빠짐없이 모든 키가 쑉쑉 자기 자리 잘 찾아간 것을 확인 가능합니다.

    In [9]: net_generator.load_state_dict(check_ge['netG_state_dict'])
    Out[9]: <All keys matched successfully>

     

    저도 제가 에러가 뜨니까 여기저기 열심히 검색해서 에러를 잡으려고 했는데

    model.load_state_dict(checkpoint, strict=False)

    이러한 방법으로 에러를 잡는것도 보았습니다.

    근데 이것은 조금 조심해야 하는데 이렇게 하면 에러가 뜨지는 않습니다.

    하지만, 키가 맞지 않는 것은 업로드 되지 않고, 키가 맞는 것만 업로드 하겠다 라는 의미입니다

    즉 부분적으로만 가중치를 업로드 하는 것이기 때문에, 내가 맨처음에 학습하고 내가 학습한거 업로드하고자 했던

    그 의미가 약간 깨질 수도 있으니 주의해서 사용해주세요

     

    특히 한두개의 가중치 층이면 모두에게 비밀로 하고 그냥 쉿 하고 하면 되는데

    저 같은 경우에는 모든 것에 다 module. 이라는 것이 들어갔기 때문에

    그냥 strict=False했으면 아무것도 업로드 안될 뻔 했습니다 ㅜ

     

    그러니까 상황에 맞게 잘 사용하시길 바래용 그럼 안뇽

    댓글

Designed by Tistory.