feat: add QLoRA PDF pipeline and Gitea CI workflow
- qlora_pdf_pipeline.py: 6-step QLoRA fine-tuning pipeline (S3 PDFs → prepare data → train → evaluate → push to Gitea → MLflow) - .gitea/workflows/compile-upload.yaml: auto-compile and upload all pipelines to Kubeflow on push, with ntfy notifications
This commit is contained in:
705
qlora_pdf_pipeline.py
Normal file
705
qlora_pdf_pipeline.py
Normal file
@@ -0,0 +1,705 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
QLoRA Fine-Tuning Pipeline – Kubeflow Pipelines SDK
|
||||
|
||||
Fetches PDFs from a Quobjects S3 bucket, extracts instruction-tuning
|
||||
data, trains a QLoRA adapter on the Llama 3.1 70B base model using
|
||||
the Strix Halo's 128 GB unified memory, evaluates it, and pushes the
|
||||
adapter weights to a Gitea repository.
|
||||
|
||||
Usage:
|
||||
pip install kfp==2.12.1
|
||||
python qlora_pdf_pipeline.py
|
||||
# Upload qlora_pdf_pipeline.yaml to Kubeflow Pipelines UI
|
||||
|
||||
Prerequisites in-cluster:
|
||||
- Secret mlpipeline-minio-artifact (namespace kubeflow) for S3 creds
|
||||
- Secret gitea-admin-secret (namespace gitea) for Gitea push
|
||||
- Node khelben with amd.com/gpu and the ROCm PyTorch image
|
||||
"""
|
||||
|
||||
from kfp import compiler, dsl
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# 1. Fetch PDFs from Quobjects S3
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["boto3"],
|
||||
)
|
||||
def fetch_pdfs_from_s3(
|
||||
s3_endpoint: str,
|
||||
s3_bucket: str,
|
||||
s3_prefix: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
) -> NamedTuple("PDFOutput", [("pdf_dir", str), ("num_files", int)]):
|
||||
"""Download all PDFs from a Quobjects S3 bucket."""
|
||||
import os
|
||||
import boto3
|
||||
from botocore.client import Config
|
||||
|
||||
out_dir = "/tmp/pdfs"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
client = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=f"http://{s3_endpoint}",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name="us-east-1",
|
||||
config=Config(signature_version="s3v4"),
|
||||
)
|
||||
|
||||
paginator = client.get_paginator("list_objects_v2")
|
||||
count = 0
|
||||
for page in paginator.paginate(Bucket=s3_bucket, Prefix=s3_prefix):
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if key.lower().endswith(".pdf"):
|
||||
local_path = os.path.join(out_dir, os.path.basename(key))
|
||||
print(f"Downloading: {key} → {local_path}")
|
||||
client.download_file(s3_bucket, key, local_path)
|
||||
count += 1
|
||||
|
||||
print(f"Downloaded {count} PDFs to {out_dir}")
|
||||
from collections import namedtuple
|
||||
|
||||
return namedtuple("PDFOutput", ["pdf_dir", "num_files"])(
|
||||
pdf_dir=out_dir, num_files=count
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# 2. Extract text from PDFs → instruction-tuning dataset
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["pymupdf"],
|
||||
)
|
||||
def prepare_training_data(
|
||||
pdf_dir: str,
|
||||
max_seq_length: int = 2048,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 64,
|
||||
) -> NamedTuple("DataOutput", [("dataset_path", str), ("num_train", int), ("num_val", int)]):
|
||||
"""Extract text from PDFs, chunk it, and format as instruction-tuning pairs."""
|
||||
import json
|
||||
import os
|
||||
import fitz # PyMuPDF
|
||||
|
||||
out_dir = "/tmp/training_data"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 1. Extract text from all PDFs
|
||||
all_chunks: list[dict] = []
|
||||
for fname in sorted(os.listdir(pdf_dir)):
|
||||
if not fname.lower().endswith(".pdf"):
|
||||
continue
|
||||
path = os.path.join(pdf_dir, fname)
|
||||
print(f"Extracting: {fname}")
|
||||
try:
|
||||
doc = fitz.open(path)
|
||||
full_text = ""
|
||||
for page in doc:
|
||||
full_text += page.get_text() + "\n"
|
||||
doc.close()
|
||||
except Exception as e:
|
||||
print(f" SKIP ({e})")
|
||||
continue
|
||||
|
||||
# 2. Chunk text with overlap
|
||||
words = full_text.split()
|
||||
for i in range(0, len(words), chunk_size - chunk_overlap):
|
||||
chunk_words = words[i : i + chunk_size]
|
||||
if len(chunk_words) < 50:
|
||||
continue # skip tiny trailing chunks
|
||||
chunk_text = " ".join(chunk_words)
|
||||
all_chunks.append({"text": chunk_text, "source": fname})
|
||||
|
||||
print(f"Total chunks: {len(all_chunks)}")
|
||||
if not all_chunks:
|
||||
raise ValueError("No text extracted from PDFs — check your bucket")
|
||||
|
||||
# 3. Format as Llama 3 chat training pairs
|
||||
# We create self-supervised pairs: model learns to continue/explain the content
|
||||
samples = []
|
||||
for chunk in all_chunks:
|
||||
text = chunk["text"]
|
||||
source = chunk["source"]
|
||||
# Split chunk roughly in half for input/output
|
||||
words = text.split()
|
||||
mid = len(words) // 2
|
||||
context = " ".join(words[:mid])
|
||||
continuation = " ".join(words[mid:])
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a knowledgeable assistant. "
|
||||
"Continue the information accurately and coherently."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Continue the following passage from {source}:\n\n{context}",
|
||||
},
|
||||
{"role": "assistant", "content": continuation},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# 4. Train/val split (90/10)
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
random.shuffle(samples)
|
||||
split = int(len(samples) * 0.9)
|
||||
train = samples[:split]
|
||||
val = samples[split:]
|
||||
|
||||
train_path = os.path.join(out_dir, "train.json")
|
||||
val_path = os.path.join(out_dir, "val.json")
|
||||
with open(train_path, "w") as f:
|
||||
json.dump(train, f)
|
||||
with open(val_path, "w") as f:
|
||||
json.dump(val, f)
|
||||
|
||||
print(f"Train: {len(train)} samples, Val: {len(val)} samples")
|
||||
from collections import namedtuple
|
||||
|
||||
return namedtuple("DataOutput", ["dataset_path", "num_train", "num_val"])(
|
||||
dataset_path=out_dir, num_train=len(train), num_val=len(val)
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# 3. QLoRA training on Strix Halo (ROCm, 128 GB unified)
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
@dsl.component(
|
||||
# Use a ROCm base image with PyTorch + PEFT pre-installed.
|
||||
# Falls back to pip-installing if not present.
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=[
|
||||
"torch",
|
||||
"transformers",
|
||||
"peft",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"bitsandbytes",
|
||||
"scipy",
|
||||
"trl",
|
||||
],
|
||||
)
|
||||
def train_qlora(
|
||||
dataset_path: str,
|
||||
base_model: str,
|
||||
learning_rate: float = 2e-4,
|
||||
num_epochs: int = 3,
|
||||
batch_size: int = 2,
|
||||
gradient_accumulation_steps: int = 8,
|
||||
max_seq_length: int = 2048,
|
||||
lora_r: int = 64,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
) -> NamedTuple(
|
||||
"TrainOutput",
|
||||
[("adapter_path", str), ("train_loss", float), ("eval_loss", float)],
|
||||
):
|
||||
"""QLoRA fine-tune Llama 3.1 70B with 4-bit NF4 quantization."""
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
TrainingArguments,
|
||||
)
|
||||
from trl import SFTTrainer
|
||||
|
||||
output_dir = "/tmp/qlora_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# ── Load data ───────────────────────────────────────────
|
||||
with open(os.path.join(dataset_path, "train.json")) as f:
|
||||
train_data = json.load(f)
|
||||
with open(os.path.join(dataset_path, "val.json")) as f:
|
||||
val_data = json.load(f)
|
||||
|
||||
print(f"Loaded {len(train_data)} train / {len(val_data)} val samples")
|
||||
|
||||
# ── Tokenizer ───────────────────────────────────────────
|
||||
print(f"Loading tokenizer: {base_model}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right"
|
||||
|
||||
# ── Format with chat template ───────────────────────────
|
||||
def format_chat(sample):
|
||||
return {"text": tokenizer.apply_chat_template(
|
||||
sample["messages"], tokenize=False, add_generation_prompt=False
|
||||
)}
|
||||
|
||||
train_ds = Dataset.from_list(train_data).map(format_chat)
|
||||
val_ds = Dataset.from_list(val_data).map(format_chat)
|
||||
|
||||
# ── 4-bit quantisation ──────────────────────────────────
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
print(f"Loading model: {base_model} (4-bit NF4)")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# ── LoRA config ─────────────────────────────────────────
|
||||
lora_config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",
|
||||
],
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# ── Training args ───────────────────────────────────────
|
||||
training_args = TrainingArguments(
|
||||
output_dir=os.path.join(output_dir, "checkpoints"),
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
learning_rate=learning_rate,
|
||||
bf16=True,
|
||||
logging_steps=5,
|
||||
eval_strategy="steps",
|
||||
eval_steps=50,
|
||||
save_strategy="steps",
|
||||
save_steps=100,
|
||||
save_total_limit=2,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
report_to="none",
|
||||
warmup_ratio=0.03,
|
||||
lr_scheduler_type="cosine",
|
||||
optim="paged_adamw_8bit",
|
||||
max_grad_norm=0.3,
|
||||
group_by_length=True,
|
||||
)
|
||||
|
||||
# ── SFTTrainer ──────────────────────────────────────────
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=val_ds,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=max_seq_length,
|
||||
dataset_text_field="text",
|
||||
packing=True, # pack short samples for efficiency
|
||||
)
|
||||
|
||||
print("Starting QLoRA training …")
|
||||
result = trainer.train()
|
||||
train_loss = result.training_loss
|
||||
|
||||
eval_result = trainer.evaluate()
|
||||
eval_loss = eval_result.get("eval_loss", 0.0)
|
||||
|
||||
print(f"Train loss: {train_loss:.4f}, Eval loss: {eval_loss:.4f}")
|
||||
|
||||
# ── Save adapter ────────────────────────────────────────
|
||||
adapter_path = os.path.join(output_dir, "adapter")
|
||||
model.save_pretrained(adapter_path)
|
||||
tokenizer.save_pretrained(adapter_path)
|
||||
|
||||
metadata = {
|
||||
"base_model": base_model,
|
||||
"lora_r": lora_r,
|
||||
"lora_alpha": lora_alpha,
|
||||
"lora_dropout": lora_dropout,
|
||||
"learning_rate": learning_rate,
|
||||
"num_epochs": num_epochs,
|
||||
"batch_size": batch_size,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"max_seq_length": max_seq_length,
|
||||
"train_samples": len(train_data),
|
||||
"val_samples": len(val_data),
|
||||
"train_loss": train_loss,
|
||||
"eval_loss": eval_loss,
|
||||
}
|
||||
with open(os.path.join(adapter_path, "training_metadata.json"), "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
print(f"Adapter saved to {adapter_path}")
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
return namedtuple("TrainOutput", ["adapter_path", "train_loss", "eval_loss"])(
|
||||
adapter_path=adapter_path,
|
||||
train_loss=train_loss,
|
||||
eval_loss=eval_loss,
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# 4. Quick sanity evaluation
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=[
|
||||
"torch", "transformers", "peft", "bitsandbytes", "accelerate", "scipy",
|
||||
],
|
||||
)
|
||||
def evaluate_adapter(
|
||||
adapter_path: str,
|
||||
base_model: str,
|
||||
) -> NamedTuple("EvalOutput", [("report", str), ("passed", bool)]):
|
||||
"""Load the QLoRA adapter and run a few sanity-check prompts."""
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from peft import PeftModel
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
print(f"Loading base model {base_model} …")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
||||
|
||||
print(f"Loading adapter from {adapter_path} …")
|
||||
model = PeftModel.from_pretrained(model, adapter_path)
|
||||
model.eval()
|
||||
|
||||
test_prompts = [
|
||||
"Summarise the key points from the training material.",
|
||||
"What are the main topics covered in the source documents?",
|
||||
"Explain the most important concept from the training data.",
|
||||
]
|
||||
|
||||
lines = []
|
||||
for prompt in test_prompts:
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
input_text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inputs, max_new_tokens=128, temperature=0.7, do_sample=True)
|
||||
response = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
||||
lines.append(f"Q: {prompt}\nA: {response}\n")
|
||||
print(lines[-1])
|
||||
|
||||
report = "\n".join(lines)
|
||||
# Simple heuristic: did the model produce non-empty responses?
|
||||
passed = all(len(l.split("A:")[1].strip()) > 10 for l in lines)
|
||||
print(f"Evaluation passed: {passed}")
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
return namedtuple("EvalOutput", ["report", "passed"])(
|
||||
report=report, passed=passed
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# 5. Push adapter to Gitea repo
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["requests"],
|
||||
)
|
||||
def push_adapter_to_gitea(
|
||||
adapter_path: str,
|
||||
gitea_url: str,
|
||||
gitea_owner: str,
|
||||
gitea_repo: str,
|
||||
gitea_username: str,
|
||||
gitea_password: str,
|
||||
branch: str = "main",
|
||||
commit_message: str = "feat: add QLoRA adapter from PDF training pipeline",
|
||||
) -> NamedTuple("PushOutput", [("repo_url", str), ("files_pushed", int)]):
|
||||
"""Push the QLoRA adapter files to a Gitea repository via the API."""
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
|
||||
api_base = f"{gitea_url}/api/v1"
|
||||
auth = (gitea_username, gitea_password)
|
||||
repo_api = f"{api_base}/repos/{gitea_owner}/{gitea_repo}"
|
||||
|
||||
# Check if repo exists, create if not
|
||||
resp = requests.get(repo_api, auth=auth, timeout=30)
|
||||
if resp.status_code == 404:
|
||||
print(f"Creating repo {gitea_owner}/{gitea_repo} …")
|
||||
create_resp = requests.post(
|
||||
f"{api_base}/orgs/{gitea_owner}/repos"
|
||||
if gitea_owner != gitea_username
|
||||
else f"{api_base}/user/repos",
|
||||
auth=auth,
|
||||
json={
|
||||
"name": gitea_repo,
|
||||
"description": "QLoRA adapters trained from PDF documents",
|
||||
"private": False,
|
||||
"auto_init": True,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
create_resp.raise_for_status()
|
||||
print(f"Created: {create_resp.json().get('html_url')}")
|
||||
|
||||
# Collect all adapter files
|
||||
files_to_push = []
|
||||
for root, dirs, files in os.walk(adapter_path):
|
||||
for fname in files:
|
||||
fpath = os.path.join(root, fname)
|
||||
rel_path = os.path.relpath(fpath, adapter_path)
|
||||
with open(fpath, "rb") as f:
|
||||
content = base64.b64encode(f.read()).decode("utf-8")
|
||||
files_to_push.append({"path": rel_path, "content": content})
|
||||
|
||||
print(f"Pushing {len(files_to_push)} files to {gitea_owner}/{gitea_repo}")
|
||||
|
||||
# Push each file via Gitea contents API
|
||||
pushed = 0
|
||||
for item in files_to_push:
|
||||
file_api = f"{repo_api}/contents/{item['path']}"
|
||||
|
||||
# Check if file already exists (need SHA for update)
|
||||
existing = requests.get(file_api, auth=auth, params={"ref": branch}, timeout=30)
|
||||
payload = {
|
||||
"message": commit_message,
|
||||
"content": item["content"],
|
||||
"branch": branch,
|
||||
}
|
||||
if existing.status_code == 200:
|
||||
payload["sha"] = existing.json()["sha"]
|
||||
resp = requests.put(file_api, auth=auth, json=payload, timeout=60)
|
||||
else:
|
||||
resp = requests.post(file_api, auth=auth, json=payload, timeout=60)
|
||||
|
||||
if resp.status_code in (200, 201):
|
||||
pushed += 1
|
||||
print(f" ✓ {item['path']}")
|
||||
else:
|
||||
print(f" ✗ {item['path']}: {resp.status_code} {resp.text[:200]}")
|
||||
|
||||
repo_url = f"{gitea_url}/{gitea_owner}/{gitea_repo}"
|
||||
print(f"Pushed {pushed}/{len(files_to_push)} files to {repo_url}")
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
return namedtuple("PushOutput", ["repo_url", "files_pushed"])(
|
||||
repo_url=repo_url, files_pushed=pushed
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# 6. Log metrics to MLflow
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
@dsl.component(
|
||||
base_image="python:3.13-slim",
|
||||
packages_to_install=["mlflow==2.22.0"],
|
||||
)
|
||||
def log_training_metrics(
|
||||
base_model: str,
|
||||
train_loss: float,
|
||||
eval_loss: float,
|
||||
num_train: int,
|
||||
num_val: int,
|
||||
num_pdfs: int,
|
||||
lora_r: int,
|
||||
lora_alpha: int,
|
||||
learning_rate: float,
|
||||
num_epochs: int,
|
||||
repo_url: str,
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
experiment_name: str = "qlora-pdf-training",
|
||||
):
|
||||
"""Log the full training run to MLflow."""
|
||||
import mlflow
|
||||
|
||||
mlflow.set_tracking_uri(mlflow_tracking_uri)
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
with mlflow.start_run(run_name=f"qlora-{base_model.split('/')[-1]}"):
|
||||
mlflow.log_params(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"lora_r": lora_r,
|
||||
"lora_alpha": lora_alpha,
|
||||
"learning_rate": learning_rate,
|
||||
"num_epochs": num_epochs,
|
||||
"num_pdfs": num_pdfs,
|
||||
"data_source": "quobjects/training-data",
|
||||
}
|
||||
)
|
||||
mlflow.log_metrics(
|
||||
{
|
||||
"train_loss": train_loss,
|
||||
"eval_loss": eval_loss,
|
||||
"train_samples": float(num_train),
|
||||
"val_samples": float(num_val),
|
||||
}
|
||||
)
|
||||
mlflow.set_tag("adapter_repo", repo_url)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Pipeline definition
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
@dsl.pipeline(
|
||||
name="QLoRA PDF Fine-Tuning",
|
||||
description=(
|
||||
"Fine-tune Llama 3.1 70B via QLoRA on PDFs from the Quobjects "
|
||||
"training-data bucket. Pushes the adapter to Gitea and logs "
|
||||
"metrics to MLflow."
|
||||
),
|
||||
)
|
||||
def qlora_pdf_pipeline(
|
||||
# ── S3 / Quobjects ──
|
||||
s3_endpoint: str = "candlekeep.lab.daviestechlabs.io",
|
||||
s3_bucket: str = "training-data",
|
||||
s3_prefix: str = "",
|
||||
aws_access_key_id: str = "",
|
||||
aws_secret_access_key: str = "",
|
||||
# ── Model ──
|
||||
base_model: str = "meta-llama/Llama-3.1-70B-Instruct",
|
||||
# ── Training hyper-params ──
|
||||
learning_rate: float = 2e-4,
|
||||
num_epochs: int = 3,
|
||||
batch_size: int = 2,
|
||||
gradient_accumulation_steps: int = 8,
|
||||
max_seq_length: int = 2048,
|
||||
lora_r: int = 64,
|
||||
lora_alpha: int = 16,
|
||||
lora_dropout: float = 0.05,
|
||||
# ── Data prep ──
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 64,
|
||||
# ── Gitea ──
|
||||
gitea_url: str = "http://gitea-http.gitea.svc.cluster.local:3000",
|
||||
gitea_owner: str = "daviestechlabs",
|
||||
gitea_repo: str = "qlora-adapters",
|
||||
gitea_username: str = "",
|
||||
gitea_password: str = "",
|
||||
# ── MLflow ──
|
||||
mlflow_tracking_uri: str = "http://mlflow.mlflow.svc.cluster.local:80",
|
||||
):
|
||||
# Step 1 — Fetch PDFs from S3
|
||||
pdfs = fetch_pdfs_from_s3(
|
||||
s3_endpoint=s3_endpoint,
|
||||
s3_bucket=s3_bucket,
|
||||
s3_prefix=s3_prefix,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
# Step 2 — Extract text and build training dataset
|
||||
data = prepare_training_data(
|
||||
pdf_dir=pdfs.outputs["pdf_dir"],
|
||||
max_seq_length=max_seq_length,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
# Step 3 — QLoRA training (GPU-heavy)
|
||||
trained = train_qlora(
|
||||
dataset_path=data.outputs["dataset_path"],
|
||||
base_model=base_model,
|
||||
learning_rate=learning_rate,
|
||||
num_epochs=num_epochs,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
max_seq_length=max_seq_length,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
)
|
||||
# Ask for a GPU on khelben
|
||||
trained.set_accelerator_type("gpu")
|
||||
trained.set_gpu_limit(1)
|
||||
|
||||
# Step 4 — Quick evaluation
|
||||
evaluated = evaluate_adapter(
|
||||
adapter_path=trained.outputs["adapter_path"],
|
||||
base_model=base_model,
|
||||
)
|
||||
evaluated.set_accelerator_type("gpu")
|
||||
evaluated.set_gpu_limit(1)
|
||||
|
||||
# Step 5 — Push adapter to Gitea
|
||||
pushed = push_adapter_to_gitea(
|
||||
adapter_path=trained.outputs["adapter_path"],
|
||||
gitea_url=gitea_url,
|
||||
gitea_owner=gitea_owner,
|
||||
gitea_repo=gitea_repo,
|
||||
gitea_username=gitea_username,
|
||||
gitea_password=gitea_password,
|
||||
)
|
||||
|
||||
# Step 6 — Log to MLflow
|
||||
log_training_metrics(
|
||||
base_model=base_model,
|
||||
train_loss=trained.outputs["train_loss"],
|
||||
eval_loss=trained.outputs["eval_loss"],
|
||||
num_train=data.outputs["num_train"],
|
||||
num_val=data.outputs["num_val"],
|
||||
num_pdfs=pdfs.outputs["num_files"],
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
learning_rate=learning_rate,
|
||||
num_epochs=num_epochs,
|
||||
repo_url=pushed.outputs["repo_url"],
|
||||
mlflow_tracking_uri=mlflow_tracking_uri,
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# Compile
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
if __name__ == "__main__":
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=qlora_pdf_pipeline,
|
||||
package_path="qlora_pdf_pipeline.yaml",
|
||||
)
|
||||
print("Compiled: qlora_pdf_pipeline.yaml")
|
||||
Reference in New Issue
Block a user