Skip to content

Fine-tuning the Dolphin model encounters a strange issue. #144

@htdung167work0

Description

@htdung167work0

I tried fine-tuning Dolphin with Transformers for Vietnamese document images.

Clearly, the model had some ability with Vietnamese, but after fine-tuning it completely lost that ability and only outputs strange characters. It’s likely there’s something wrong in my code.

For the code, I referred to fine-tuning examples from other models like Donut and TrOCR.

My fine-tuned code:

import os
import torch
from datasets import load_from_disk
from transformers import (
    AutoProcessor,
    VisionEncoderDecoderModel,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback
)

import wandb
from PIL import Image
import cv2
import numpy as np

############### UTILS ###############
def crop_margin(img: Image.Image) -> Image.Image:
    """Crop margins from image"""
    try:
        width, height = img.size
        if width == 0 or height == 0:
            print("Warning: Image has zero width or height")
            return img

        data = np.array(img.convert("L"))
        data = data.astype(np.uint8)
        max_val = data.max()
        min_val = data.min()
        if max_val == min_val:
            return img
        data = (data - min_val) / (max_val - min_val) * 255
        gray = 255 * (data < 200).astype(np.uint8)

        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
        if coords is None:
            return img
        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box

        # Ensure crop coordinates are within image bounds
        a = max(0, a)
        b = max(0, b)
        w = min(w, width - a)
        h = min(h, height - b)

        # Only crop if we have a valid region
        if w > 0 and h > 0:
            return img.crop((a, b, a + w, b + h))
        return img
    except Exception as e:
        print(f"crop_margin error: {str(e)}")
        return img  # Return original image on error

############### LOGGING ###############
run_id = os.getenv("WANDB_RUN_ID", None)
wandb_name = os.getenv("WANDB_NAME", None)
wandb_project = os.getenv("WANDB_PROJECT", None)
wandb_entity = os.getenv("WANDB_ENTITY", None)
wandb.init(
    project=wandb_project,
    name=wandb_name,
    entity=wandb_entity,
    id=run_id,        
    resume="allow"
)


############### HYPER-PARAMS & PATHS ###############
SEQ_LEN = 2048
OUTPUT_DIR = "./trained_models/vi_dolphin_ver1_20250913_0621"
HF_REPO = "htdung167/ViDolphin-v1"
HF_STRATEGY = "every_save"
HF_TOKEN = os.getenv("HF_TOKEN")

PER_DEVICE_TRAIN_BATCH = 8
PER_DEVICE_EVAL_BATCH  = 8
GRAD_ACCUM    = 4
LR            = 2e-5
WARMUP_RATIO  = 0.01
EPOCHS        = 10
LOG_STEPS     = 50
EVAL_STEPS    = 500
SAVE_STEPS    = 500
SAVE_LIMIT    = 2
IGNORE_ID     = -100

############### MODEL & PROCESSOR ###############
processor = AutoProcessor.from_pretrained("ByteDance/Dolphin")
model = VisionEncoderDecoderModel.from_pretrained("ByteDance/Dolphin")

# Configure model for training
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder.max_length = SEQ_LEN
model.generation_config.max_length = SEQ_LEN

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

############### DATA ###############
processed_dataset = load_from_disk("./dataset/md_text_table")
print(processed_dataset)

############### CORRECTED COLLATE FN ###############
def collate_fn(batch):
    # Preprocess images with crop_margin
    processed_images = [crop_margin(ex["image"]) for ex in batch]
    pixel_values = processor(
        images=processed_images,
        return_tensors="pt"
    ).pixel_values

    # Prepare the text part for the decoder
    texts = [ex['text'] for ex in batch]

    # Tokenize the texts to create labels.
    labels = processor.tokenizer(
        texts,
        add_special_tokens=False,
        max_length=SEQ_LEN,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).input_ids

    # The returned dictionary must contain keys that match the model's forward pass arguments
    return {
        "pixel_values": pixel_values,
        "labels": labels
    }


############### TRAINING ARGS ###############
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=0.01,
    lr_scheduler_type="linear",
    fp16=True,
    logging_steps=LOG_STEPS,
    save_strategy="steps",
    save_steps=SAVE_STEPS,
    save_total_limit=SAVE_LIMIT,
    eval_strategy="steps",
    eval_steps=EVAL_STEPS,
    predict_with_generate=True,
    report_to="wandb",
    dataloader_num_workers=4,
    remove_unused_columns=False,

    # push to hub
    push_to_hub=True,
    hub_model_id=HF_REPO,
    hub_strategy=HF_STRATEGY,
    hub_token=HF_TOKEN,

    seed=42,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

############### TRAINER ###############
early_stopper = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.0
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
    data_collator=collate_fn,
    callbacks=[early_stopper]
)

############### TRAINING ###############
last_checkpoint = None
if os.path.isdir(OUTPUT_DIR):
    checkpoints = [os.path.join(OUTPUT_DIR, d) for d in os.listdir(OUTPUT_DIR) if d.startswith("checkpoint-")]
    if len(checkpoints) > 0:
        last_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1]))
        print(f"Resuming from checkpoint: {last_checkpoint}")

trainer.train(resume_from_checkpoint=last_checkpoint)

# Save & push
processor.save_pretrained(OUTPUT_DIR)
trainer.create_model_card()
trainer.push_to_hub()

A sample in the dataset (using load_from_disk):

{'id': '33c599ccc91748dd5fa5_0',
 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1100x1022>,
 'text': '<s>Read text in the image. <Answer/>## Bán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn - ID1661394 - Các mặt hàng khác - Mặt\nBán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn\nKhu vực: Tp. Hồ Chí Minh  Q. 3 - Ngày đăng: 16/05/2016 06:54 - Đã xem: 117\nMã tin : 1661394 Nâng cấp tin qua ĐTDĐ\nNgày hết hạn: 14/08/2016 06:54\nXem từ 1 đến 15 (của 49685 Rao vặt) 1 2 3 4 5 6 7 8 9 10 ...\n\n- Bán Cà Gai LEO, tốt- - Chữa bệnh gan, Thải độc gan, hạ cholesterol, giã rượu (118)\n- Bán Ấm Pha Trà, Chất lượng -Rất Nhiều mẫu mới, hình dáng đẹp giá rẻ (52)\n- Lá Dâm Dương Hoắc-Sản phẩm làm tăng sinh lý mạnh, cho quý ông-giá thật tốt (34)\n- Bán ATISO Đà Lạt- mát gan, hạ cholesterol, giải nhiệt mùa nắng , giá rẻ (159)\n- Bán Loại Trà Dây SAPA-Sử dụng Chữa Dạ dày, tá tràng, ăn tốt và ngủ tốt, rẻ (119)\n- Bột Quế và Mật Ong- nhiều công dụng thiết thực, quý giá -giá tốt (132)\n- Lá Vằng, tốt. nhất- Gúp ăn ngủ tốt, thông huyết, ,bà mẹ sữa tốt, hồi phục nhanh (123)\n- Lá NEEM, chất lượng-ĐểChữa tiểu đường, bớt nhức mỏi và tiêu viêm (152)\n- Nụ Hoa TAM THẤT- Sản phẩm cho giấc ngủ ngon, tăng sức đề kháng, sứ khỏe tốt (98)\n- Trà Hồng Đài- Phòng lão hóa, thanh nhiệt, giảm cholesterol, bảo vệ mắt (69)\n- Bán Sản phẩm Giúp Hỗ trợ điều trị ung thư tốt-Trà Lá DỪA CẠn (105)\n- Bán sản phẩm Giúp Phòng tai biến và đột quỵ tốt, giá rẻ (93)\n- cách phân biệt giá đỗ sạch và giá đỗ ngâm thuốc (40)\n\nTìm liên quan  Bán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn\nĐang xem  Bán Trà CERY-Sản phẩm Chữa tê thấp, chữa bệnh GOUT tốt, giá ổn</s>'}

After finetuning, output like:

Sh s mt�anhoố tiả �ầ khnghá
hiuì ha nmớ dnà ko nhcô ngệ thậ lợ v cynê chp|ă ph�ựếạ bch�ang�|huịụứ,�i�ề qu tr x hngâ doh
ia|ó/ọ�ờ gyn�|ong�|ể tra%
�e�úư1023-ò tin�
in��:uyn�
ấ rng� yut�|ổ|ung�jpg5687ThCChHXMTKVKhGdễ �à sx�|ắ gii�
ha|ín"||||||||||||||||||||||||||||||...|uaí| Thơ t hongở|ữ hngơ hngamho ttnạ BN|oauynhim#kngô phii|||ủ chê vin�_ỉử|ý nmmao|ó tin�||||| Phn binẻ||||||||||||||||||||auho tt kin hin|||th ttn tin|gu Si Gn Dinễ �à KinThY Ka dim||ắ tt|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
   Khmmmu|||||||||||||||||||||||||||||||||||
---|||||||||||||||||||||||||||
###á|á|||á||||||||||||||||||||||||||
Bgy|||||||||||||||||||||||||||||||||||
 Hmy|||||||||||||||||||||||||||||

 Hynnn||||ay||||||||||||||||||||||||||||||||||
==================================================

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions