-
프레임워크 : pytorch
import torch import torchvision import torchvision.transforms as transforms
데이터 셋을 로드하는 코드
# Load the dataset transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) trainset = torchvision.datasets.ImageNet(root='./data', train=True, transform=transform, download=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
pretrained 모델을 정의한다.
# Define the model model = torchvision.models.resnet18(pretrained=True)
loss function 과 optimizer를 정의한다. lr 은 0.001 momentum은 0.9
# Define the loss function and optimizer criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
모델을 학습한다.
# Train the model for epoch in range(2): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0
'study' 카테고리의 다른 글
[논문] DETR: End to End Object Detection with Transformers (ECCV 2020) (0) 2023.02.09 [코드] yolov5 구현 (0) 2023.02.09 [study] EfficientNet 모델 구조 (0) 2023.02.09 [논문리뷰] Improving Pixel Embedding Learning through Intermediate Distance Regression Supervision for Instance Segmentation (0) 2023.02.09 json to yolo (2023.01.10) (0) 2023.02.07