Įvadas
„TinyBERT“ yra kompaktiška BERT (transformatorių dvikrypčių kodavimo priemonių) versija, sukurta panašiai našumui su žymiai mažesniu modelio dydžiu. Šioje pamokoje parodysime, kaip naudoti TinyBERT_General_4L_312D
teksto klasifikavimui.
Būtinos sąlygos
- Python 3.6 arba naujesnė versija
- PyTorch
- „Transformerių“ biblioteka, sukurta Hugging Face
- Duomenų rinkiniai mokymams ir testavimui
1 veiksmas: įdiekite reikalingas bibliotekas
Pirmiausia įdiegkime reikiamas bibliotekas:
pip install torch transformers datasets
2 veiksmas: įkelkite TinyBERT modelį ir tokenizatorių
Turime įkelti TinyBERT modelį ir atitinkamą žetonų įtaisą iš Hugging Face Transformers bibliotekos.
from transformers import BertTokenizer, BertForSequenceClassification
# Load TinyBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('huawei-noah/TinyBERT_General_4L_312D')
model = BertForSequenceClassification.from_pretrained('huawei-noah/TinyBERT_General_4L_312D')
3 veiksmas: paruoškite duomenų rinkinį
Mes naudosime datasets
biblioteką duomenų rinkiniui įkelti ir iš anksto apdoroti. Šiame pavyzdyje dvejetainiam nuotaikų klasifikavimui naudosime IMDB duomenų rinkinį.
from datasets import load_dataset
# Load the IMDB dataset
dataset = load_dataset('imdb')
# Split the dataset into train and test sets
train_dataset = dataset('train')
test_dataset = dataset('test')
4 veiksmas: suaktyvinkite duomenų rinkinį
Turime sutvirtinti tekstinius duomenis, kad juos būtų galima įvesti į TinyBERT modelį.
def tokenize_function(examples):
return tokenizer(examples('text'), padding='max_length', truncation=True, max_length=128)
# Tokenize the dataset
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
5 veiksmas: paruoškite duomenų įkėlimo įrenginius
„PyTorch“ reikalauja, kad duomenys būtų įkeliami paketais. Mes naudosime DataLoader
klasė šiam tikslui.
from torch.utils.data import DataLoader
# Define data collator
data_collator = lambda data: {
'input_ids': torch.tensor((f('input_ids') for f in data), dtype=torch.long),
'attention_mask': torch.tensor((f('attention_mask') for f in data), dtype=torch.long),
'labels': torch.tensor((f('label') for f in data), dtype=torch.long)
}
# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=data_collator)
6 veiksmas: išmokykite modelį
Dabar nustatysime mokymo kilpą, kad galėtume tiksliai suderinti TinyBERT modelį mūsų duomenų rinkinyje.
import torch
from torch.optim import AdamW
from tqdm import tqdm
# Set device (GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)
# Training loop
model.train()
for epoch in range(3): # Train for 3 epochs
loop = tqdm(train_dataloader, leave=True)
for batch in loop:
optimizer.zero_grad()
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
outputs = model(**batch)
loss = outputs.loss
# Backward pass
loss.backward()
optimizer.step()
# Update progress bar
loop.set_description(f'Epoch {epoch}')
loop.set_postfix(loss=loss.item())
7 veiksmas: įvertinkite modelį
Po mokymo turime įvertinti modelio našumą bandymo duomenų rinkinyje.
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_dataloader:
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
outputs = model(**batch)
logits = outputs.logits
# Calculate accuracy
predictions = torch.argmax(logits, dim=-1)
correct += (predictions == batch('labels')).sum().item()
total += len(batch('labels'))
accuracy = correct / total
print(f'Test Accuracy: {accuracy:.4f}')
Išvada
Šioje pamokoje parodėme, kaip naudoti TinyBERT_General_4L_312D
teksto klasifikavimui. Įkėlėme modelį ir prieigos raktą, paruošėme duomenų rinkinį, apmokėme modelį ir įvertinome jo veikimą. „TinyBERT“ siūlo lengvą, bet veiksmingą alternatyvą originaliam BERT modeliui, todėl jis tinkamas naudoti ribotoje aplinkoje.
Jei tekste radote klaidą, siųskite pranešimą autoriui pažymėdami klaidą ir paspausdami Ctrl-Enter.