Weighted image
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
# Assuming class_weights is a dictionary containing class frequencies or weights
class_weights = {0: 0.1, 1: 0.9} # Example class weights
# Assuming dataset is your custom dataset
dataset = CustomDataset(...)
# Calculate weights for each image based on their class labels
weights = [class_weights[label] for _, label in dataset]
# Create a WeightedRandomSampler to sample images based on their weights
sampler = WeightedRandomSampler(weights, len(weights))
# Create a DataLoader with the weighted sampler
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
# Use the dataloader for training
for inputs, labels in dataloader:
...Last updated