from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
def tokenize(batch):
return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=256)
train_dataset = Dataset.from_pandas(train_df[['text', 'label']]).map(tokenize, batched=True)
valid_dataset = Dataset.from_pandas(valid_df[['text', 'label']]).map(tokenize, batched=True)
args = TrainingArguments(
output_dir='artifacts/topic_classifier',
evaluation_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
num_train_epochs=3,
load_best_model_at_end=True,
)
trainer = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=valid_dataset)
trainer.train()