llm_twin / train.py.amltmp
Ankita Maroti Kadam
Add llm_twin model files
7916f08
# =========================
# 1. Imports
# =========================
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
# =========================
# 2. Environment settings
# =========================
os.environ["HF_DISABLE_MLFLOW"] = "1" # Disable MLflow for AzureML
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
# =========================
# 3. Load tokenizer and model
# =========================
model_name = "tiiuae/falcon-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Falcon tokenizer may not have pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Memory-efficient model loading
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # Auto assign layers to GPU/CPU
torch_dtype="auto",
low_cpu_mem_usage=True # Prevent meta tensor errors
)
# =========================
# 4. LoRA configuration
# =========================
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["query_key_value"], # Falcon-specific
lora_dropout=0.1,
bias="none"
)
model = get_peft_model(model, lora_config)
# =========================
# 5. Load dataset
# =========================
dataset = load_dataset("json", data_files="output_medium.jsonl", split="train")
print("Dataset columns:", dataset.column_names)
# =========================
# 6. Tokenization function
# =========================
def tokenize(batch):
prompts = []
for i in range(len(batch['title'])):
title = batch['title'][i] if batch['title'][i] else ""
subtitle = batch.get('subtitle', [""]*len(batch['title']))[i] or ""
content = batch.get('content', [""]*len(batch['title']))[i] or ""
full_text = (title + " " + subtitle).strip() + "\n" + content.strip()
prompts.append(full_text)
encodings = tokenizer(prompts, truncation=True, padding="max_length", max_length=128)
encodings["labels"] = encodings["input_ids"].copy() # Key fix for causal LM
return encodings
dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
# =========================
# 7. Training arguments
# =========================
training_args = TrainingArguments(
output_dir="./outputs",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
learning_rate=2e-4,
logging_steps=10,
save_steps=500,
save_total_limit=2,
fp16=True,
report_to=[], # Disable MLflow / WandB
)
# =========================
# 8. Trainer
# =========================
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
# =========================
# 9. Train
# =========================
trainer.train()
# =========================
# 10. Save model & tokenizer
# =========================
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")