Įvadas

„TinyBERT“ yra kompaktiška BERT (transformatorių dvikrypčių kodavimo priemonių) versija, sukurta panašiai našumui su žymiai mažesniu modelio dydžiu. Šioje pamokoje parodysime, kaip naudoti TinyBERT_General_4L_312D teksto klasifikavimui.

Būtinos sąlygos

  • Python 3.6 arba naujesnė versija
  • PyTorch
  • „Transformerių“ biblioteka, sukurta Hugging Face
  • Duomenų rinkiniai mokymams ir testavimui

1 veiksmas: įdiekite reikalingas bibliotekas

Pirmiausia įdiegkime reikiamas bibliotekas:

pip install torch transformers datasets

2 veiksmas: įkelkite TinyBERT modelį ir tokenizatorių

Turime įkelti TinyBERT modelį ir atitinkamą žetonų įtaisą iš Hugging Face Transformers bibliotekos.

from transformers import BertTokenizer, BertForSequenceClassification

# Load TinyBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('huawei-noah/TinyBERT_General_4L_312D')
model = BertForSequenceClassification.from_pretrained('huawei-noah/TinyBERT_General_4L_312D')

3 veiksmas: paruoškite duomenų rinkinį

Mes naudosime datasets biblioteką duomenų rinkiniui įkelti ir iš anksto apdoroti. Šiame pavyzdyje dvejetainiam nuotaikų klasifikavimui naudosime IMDB duomenų rinkinį.

from datasets import load_dataset

# Load the IMDB dataset
dataset = load_dataset('imdb')

# Split the dataset into train and test sets
train_dataset = dataset('train')
test_dataset = dataset('test')

4 veiksmas: suaktyvinkite duomenų rinkinį

Turime sutvirtinti tekstinius duomenis, kad juos būtų galima įvesti į TinyBERT modelį.

def tokenize_function(examples):
    return tokenizer(examples('text'), padding='max_length', truncation=True, max_length=128)

# Tokenize the dataset
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

5 veiksmas: paruoškite duomenų įkėlimo įrenginius

„PyTorch“ reikalauja, kad duomenys būtų įkeliami paketais. Mes naudosime DataLoader klasė šiam tikslui.

from torch.utils.data import DataLoader

# Define data collator
data_collator = lambda data: {
    'input_ids': torch.tensor((f('input_ids') for f in data), dtype=torch.long),
    'attention_mask': torch.tensor((f('attention_mask') for f in data), dtype=torch.long),
    'labels': torch.tensor((f('label') for f in data), dtype=torch.long)
}

# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=data_collator)

6 veiksmas: išmokykite modelį

Dabar nustatysime mokymo kilpą, kad galėtume tiksliai suderinti TinyBERT modelį mūsų duomenų rinkinyje.

import torch
from torch.optim import AdamW
from tqdm import tqdm

# Set device (GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
model.train()
for epoch in range(3):  # Train for 3 epochs
    loop = tqdm(train_dataloader, leave=True)
    for batch in loop:
        optimizer.zero_grad()
        
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Update progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

7 veiksmas: įvertinkite modelį

Po mokymo turime įvertinti modelio našumą bandymo duomenų rinkinyje.

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_dataloader:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        logits = outputs.logits
        
        # Calculate accuracy
        predictions = torch.argmax(logits, dim=-1)
        correct += (predictions == batch('labels')).sum().item()
        total += len(batch('labels'))

accuracy = correct / total
print(f'Test Accuracy: {accuracy:.4f}')

Išvada

Šioje pamokoje parodėme, kaip naudoti TinyBERT_General_4L_312D teksto klasifikavimui. Įkėlėme modelį ir prieigos raktą, paruošėme duomenų rinkinį, apmokėme modelį ir įvertinome jo veikimą. „TinyBERT“ siūlo lengvą, bet veiksmingą alternatyvą originaliam BERT modeliui, todėl jis tinkamas naudoti ribotoje aplinkoje.

Jei tekste radote klaidą, siųskite pranešimą autoriui pažymėdami klaidą ir paspausdami Ctrl-Enter.



Source link

By admin

Draugai: - Marketingo paslaugos - Teisinės konsultacijos - Skaidrių skenavimas - Fotofilmų kūrimas - Karščiausios naujienos - Ultragarsinis tyrimas - Saulius Narbutas - Įvaizdžio kūrimas - Veidoskaita - Nuotekų valymo įrenginiai -  Padelio treniruotės - Pranešimai spaudai -