본문 바로가기
  • Staying curious, growing through questions
Study Log

[pytorch] 모델 파라미터 출력하기

by Evergreen Mind 2023. 10. 20.

Inversion 실습을 하던 중에, 미션을 받았다.

분류기 파라미터를 받아서, 해당 분류기로 INVERSION을 시킨후에 training dataset을 찾아내는 미션이었다.

 

그런데 분류기 파라미터를 받아서, 파라미터들을 확인하려고 다음처럼 state_dict()를 사용했더니 오류가 나왔다. 

 for params in checkpoint.state_dict():
        print(params)

에러 문구

AttributeError: collections.OrderedDict' object has no attribute 'state_dict'

 

 


state_dict를 사용하지 않고 다음 코드처럼 모델을 바로  출력하면, layer 이름 뿐 아니라 파라미터까지 텐서값으로 같이 출력이 되어 layer들을 잘 확인할 수가 없다. 

checkpoint = torch.load(mypath)
	print(checkpoint)

'fc.bias' 뿐 아니라 tensor 값으로도 출력되는 모습


그럼 모델이 어떤 레이어로 구성되어있는지 확인하기 위해서 웨이트 값 없이 'layer이름'만 출력하려면 어떻게 해야할까?

checkpoint = checkpoint.keys()
    print(checkpoint)

'checkpoint'가 dictionary이므로 key값만 출력하면 된다!

 

그런데 checkpoint.state_dict()가 왜 안되는지는 아직 잘 모르겠다.