반응형
LlamaIndex와 PyTorch를 활용한 텍스트 분류 모델 구축
LlamaIndex는 대규모 언어 모델과의 통합을 통해 데이터를 효율적으로 처리할 수 있도록 돕는 라이브러리입니다. 이를 활용하면 PyTorch와 같은 딥러닝 프레임워크에서 사용할 수 있는 고품질의 학습 데이터를 준비할 수 있습니다. 이번 글에서는 LlamaIndex로 데이터 준비를 하고 PyTorch를 활용해 텍스트 분류 모델을 구축하는 과정을 소개합니다.
1. 프로젝트 개요
목표
- LlamaIndex를 활용해 텍스트 데이터를 인덱싱하고 전처리합니다.
- PyTorch를 사용해 텍스트 분류 모델을 설계, 학습 및 평가합니다.
주요 기능
- 텍스트 데이터를 효율적으로 인덱싱 및 전처리
- PyTorch 기반 딥러닝 모델 설계 및 학습
- 분류 성능 평가
2. 개발 환경 준비
2.1 필수 라이브러리 설치
다음 명령어를 실행하여 필요한 라이브러리를 설치합니다:
pip install llama-index torch torchvision transformers
참고:
transformers
는 사전 학습된 언어 모델(BERT, GPT 등)을 사용하기 위해 설치합니다.
2.2 기본 설정
Python 3.8 이상의 환경을 권장하며, 가상 환경(venv, Conda)을 사용하면 의존성 관리를 간편하게 할 수 있습니다.
3. 데이터 준비 및 전처리
3.1 샘플 데이터 준비
아래와 같은 텍스트 분류 데이터를 준비합니다. 이 데이터는 긍정(1)과 부정(0)으로 분류된 영화 리뷰 데이터입니다.
data.json:
[
{"text": "The movie was fantastic and full of great moments.", "label": 1},
{"text": "I hated the movie, it was boring and slow.", "label": 0},
{"text": "Absolutely loved the characters and the story.", "label": 1},
{"text": "The plot was weak and uninteresting.", "label": 0}
]
3.2 LlamaIndex로 데이터 전처리
LlamaIndex를 활용하여 텍스트 데이터를 인덱싱합니다.
import os
import json
from llama_index import GPTSimpleVectorIndex, Document
# OpenAI API 키 설정
os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"
# 데이터 로드 및 인덱싱
def load_and_index_data(file_path):
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
documents = [Document(f"Text: {item['text']}\nLabel: {item['label']}") for item in data]
index = GPTSimpleVectorIndex.from_documents(documents)
return index, data
index, raw_data = load_and_index_data("data.json")
print("데이터 인덱싱 완료!")
3.3 텍스트와 라벨 분리
LlamaIndex로부터 데이터를 가져와 PyTorch 모델에 적합한 형태로 변환합니다.
texts = [item['text'] for item in raw_data]
labels = [item['label'] for item in raw_data]
4. PyTorch를 사용한 텍스트 분류 모델 구축
4.1 데이터 토크나이제이션
transformers
라이브러리의 토크나이저를 사용하여 데이터를 벡터화합니다.
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def tokenize_data(texts, labels):
encodings = tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors="pt")
return encodings, labels
encodings, labels = tokenize_data(texts, labels)
4.2 데이터 로더 생성
PyTorch의 DataLoader
를 사용하여 학습 데이터를 준비합니다.
import torch
from torch.utils.data import DataLoader, TensorDataset
# 텐서 데이터셋 생성
dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'], torch.tensor(labels))
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
4.3 모델 정의
사전 학습된 BERT 모델을 사용하여 텍스트 분류 모델을 정의합니다.
from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
4.4 모델 학습
모델을 학습시키는 간단한 루프를 작성합니다.
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for epoch in range(3):
for batch in dataloader:
input_ids, attention_mask, labels = [b.to(device) for b in batch]
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
5. 모델 평가
5.1 테스트 데이터 준비
테스트 데이터에 대해 예측을 수행하고 정확도를 평가합니다.
from sklearn.metrics import accuracy_score
# 테스트 데이터 예측
model.eval()
predictions = []
with torch.no_grad():
for batch in dataloader:
input_ids, attention_mask, labels = [b.to(device) for b in batch]
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy())
accuracy = accuracy_score(labels.cpu().numpy(), predictions)
print(f"Accuracy: {accuracy * 100:.2f}%")
참고 자료
반응형
'LlamaIndex' 카테고리의 다른 글
LlamaIndex와 OpenAI API 연동 (0) | 2025.01.22 |
---|---|
LlamaIndex와 Streamlit을 활용한 데이터 검색 웹 애플리케이션 (0) | 2025.01.22 |
LlamaIndex와 Pandas를 이용한 데이터 분석 및 검색 (0) | 2025.01.22 |
LlamaIndex와 FastAPI를 결합한 검색 API 개발 (0) | 2025.01.22 |
LlamaIndex를 활용한 웹 크롤러 개발 (0) | 2025.01.22 |