-
Notifications
You must be signed in to change notification settings - Fork 630
Open
Description
I tried fine-tuning Dolphin with Transformers for Vietnamese document images.
Clearly, the model had some ability with Vietnamese, but after fine-tuning it completely lost that ability and only outputs strange characters. It’s likely there’s something wrong in my code.
For the code, I referred to fine-tuning examples from other models like Donut and TrOCR.
My fine-tuned code:
import os
import torch
from datasets import load_from_disk
from transformers import (
AutoProcessor,
VisionEncoderDecoderModel,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
EarlyStoppingCallback
)
import wandb
from PIL import Image
import cv2
import numpy as np
############### UTILS ###############
def crop_margin(img: Image.Image) -> Image.Image:
"""Crop margins from image"""
try:
width, height = img.size
if width == 0 or height == 0:
print("Warning: Image has zero width or height")
return img
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
if coords is None:
return img
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
# Ensure crop coordinates are within image bounds
a = max(0, a)
b = max(0, b)
w = min(w, width - a)
h = min(h, height - b)
# Only crop if we have a valid region
if w > 0 and h > 0:
return img.crop((a, b, a + w, b + h))
return img
except Exception as e:
print(f"crop_margin error: {str(e)}")
return img # Return original image on error
############### LOGGING ###############
run_id = os.getenv("WANDB_RUN_ID", None)
wandb_name = os.getenv("WANDB_NAME", None)
wandb_project = os.getenv("WANDB_PROJECT", None)
wandb_entity = os.getenv("WANDB_ENTITY", None)
wandb.init(
project=wandb_project,
name=wandb_name,
entity=wandb_entity,
id=run_id,
resume="allow"
)
############### HYPER-PARAMS & PATHS ###############
SEQ_LEN = 2048
OUTPUT_DIR = "./trained_models/vi_dolphin_ver1_20250913_0621"
HF_REPO = "htdung167/ViDolphin-v1"
HF_STRATEGY = "every_save"
HF_TOKEN = os.getenv("HF_TOKEN")
PER_DEVICE_TRAIN_BATCH = 8
PER_DEVICE_EVAL_BATCH = 8
GRAD_ACCUM = 4
LR = 2e-5
WARMUP_RATIO = 0.01
EPOCHS = 10
LOG_STEPS = 50
EVAL_STEPS = 500
SAVE_STEPS = 500
SAVE_LIMIT = 2
IGNORE_ID = -100
############### MODEL & PROCESSOR ###############
processor = AutoProcessor.from_pretrained("ByteDance/Dolphin")
model = VisionEncoderDecoderModel.from_pretrained("ByteDance/Dolphin")
# Configure model for training
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder.max_length = SEQ_LEN
model.generation_config.max_length = SEQ_LEN
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
############### DATA ###############
processed_dataset = load_from_disk("./dataset/md_text_table")
print(processed_dataset)
############### CORRECTED COLLATE FN ###############
def collate_fn(batch):
# Preprocess images with crop_margin
processed_images = [crop_margin(ex["image"]) for ex in batch]
pixel_values = processor(
images=processed_images,
return_tensors="pt"
).pixel_values
# Prepare the text part for the decoder
texts = [ex['text'] for ex in batch]
# Tokenize the texts to create labels.
labels = processor.tokenizer(
texts,
add_special_tokens=False,
max_length=SEQ_LEN,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
# The returned dictionary must contain keys that match the model's forward pass arguments
return {
"pixel_values": pixel_values,
"labels": labels
}
############### TRAINING ARGS ###############
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=EPOCHS,
learning_rate=LR,
per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH,
per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
gradient_accumulation_steps=GRAD_ACCUM,
warmup_ratio=WARMUP_RATIO,
weight_decay=0.01,
lr_scheduler_type="linear",
fp16=True,
logging_steps=LOG_STEPS,
save_strategy="steps",
save_steps=SAVE_STEPS,
save_total_limit=SAVE_LIMIT,
eval_strategy="steps",
eval_steps=EVAL_STEPS,
predict_with_generate=True,
report_to="wandb",
dataloader_num_workers=4,
remove_unused_columns=False,
# push to hub
push_to_hub=True,
hub_model_id=HF_REPO,
hub_strategy=HF_STRATEGY,
hub_token=HF_TOKEN,
seed=42,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
)
############### TRAINER ###############
early_stopper = EarlyStoppingCallback(
early_stopping_patience=3,
early_stopping_threshold=0.0
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=processed_dataset["train"],
eval_dataset=processed_dataset["validation"],
data_collator=collate_fn,
callbacks=[early_stopper]
)
############### TRAINING ###############
last_checkpoint = None
if os.path.isdir(OUTPUT_DIR):
checkpoints = [os.path.join(OUTPUT_DIR, d) for d in os.listdir(OUTPUT_DIR) if d.startswith("checkpoint-")]
if len(checkpoints) > 0:
last_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1]))
print(f"Resuming from checkpoint: {last_checkpoint}")
trainer.train(resume_from_checkpoint=last_checkpoint)
# Save & push
processor.save_pretrained(OUTPUT_DIR)
trainer.create_model_card()
trainer.push_to_hub()A sample in the dataset (using load_from_disk):
{'id': '33c599ccc91748dd5fa5_0',
'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1100x1022>,
'text': '<s>Read text in the image. <Answer/>## Bán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn - ID1661394 - Các mặt hàng khác - Mặt\nBán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn\nKhu vực: Tp. Hồ Chí Minh Q. 3 - Ngày đăng: 16/05/2016 06:54 - Đã xem: 117\nMã tin : 1661394 Nâng cấp tin qua ĐTDĐ\nNgày hết hạn: 14/08/2016 06:54\nXem từ 1 đến 15 (của 49685 Rao vặt) 1 2 3 4 5 6 7 8 9 10 ...\n\n- Bán Cà Gai LEO, tốt- - Chữa bệnh gan, Thải độc gan, hạ cholesterol, giã rượu (118)\n- Bán Ấm Pha Trà, Chất lượng -Rất Nhiều mẫu mới, hình dáng đẹp giá rẻ (52)\n- Lá Dâm Dương Hoắc-Sản phẩm làm tăng sinh lý mạnh, cho quý ông-giá thật tốt (34)\n- Bán ATISO Đà Lạt- mát gan, hạ cholesterol, giải nhiệt mùa nắng , giá rẻ (159)\n- Bán Loại Trà Dây SAPA-Sử dụng Chữa Dạ dày, tá tràng, ăn tốt và ngủ tốt, rẻ (119)\n- Bột Quế và Mật Ong- nhiều công dụng thiết thực, quý giá -giá tốt (132)\n- Lá Vằng, tốt. nhất- Gúp ăn ngủ tốt, thông huyết, ,bà mẹ sữa tốt, hồi phục nhanh (123)\n- Lá NEEM, chất lượng-ĐểChữa tiểu đường, bớt nhức mỏi và tiêu viêm (152)\n- Nụ Hoa TAM THẤT- Sản phẩm cho giấc ngủ ngon, tăng sức đề kháng, sứ khỏe tốt (98)\n- Trà Hồng Đài- Phòng lão hóa, thanh nhiệt, giảm cholesterol, bảo vệ mắt (69)\n- Bán Sản phẩm Giúp Hỗ trợ điều trị ung thư tốt-Trà Lá DỪA CẠn (105)\n- Bán sản phẩm Giúp Phòng tai biến và đột quỵ tốt, giá rẻ (93)\n- cách phân biệt giá đỗ sạch và giá đỗ ngâm thuốc (40)\n\nTìm liên quan Bán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn\nĐang xem Bán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn</s>'}After finetuning, output like:
Sh s mt�anhoố tiả �ầ khnghá
hiuì ha nmớ dnà ko nhcô ngệ thậ lợ v cynê chp|ă ph�ựếạ bch�ang�|huịụứ,�i�ề qu tr x hngâ doh
ia|ó/ọ�ờ gyn�|ong�|ể tra%
�e�úư1023-ò tin�
in��:uyn�
ấ rng� yut�|ổ|ung�jpg5687ThCChHXMTKVKhGdễ �à sx�|ắ gii�
ha|ín"||||||||||||||||||||||||||||||...|uaí| Thơ t hongở|ữ hngơ hngamho ttnạ BN|oauynhim#kngô phii|||ủ chê vin�_ỉử|ý nmmao|ó tin�||||| Phn binẻ||||||||||||||||||||auho tt kin hin|||th ttn tin|gu Si Gn Dinễ �à KinThY Ka dim||ắ tt|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Khmmmu|||||||||||||||||||||||||||||||||||
---|||||||||||||||||||||||||||
###á|á|||á||||||||||||||||||||||||||
Bgy|||||||||||||||||||||||||||||||||||
Hmy|||||||||||||||||||||||||||||
Hynnn||||ay||||||||||||||||||||||||||||||||||
==================================================
htdung167work0
Metadata
Metadata
Assignees
Labels
No labels