@@ -1711,6 +1711,11 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
17111711 ):
17121712 dataset_class = MTBenchDataset
17131713 args .hf_split = "train"
1714+ elif (
1715+ args .dataset_path in MultiModalConversationDataset .SUPPORTED_DATASET_PATHS
1716+ or args .hf_name in MultiModalConversationDataset .SUPPORTED_DATASET_PATHS
1717+ ):
1718+ dataset_class = MultiModalConversationDataset
17141719 elif (
17151720 args .dataset_path in ConversationDataset .SUPPORTED_DATASET_PATHS
17161721 or args .hf_name in ConversationDataset .SUPPORTED_DATASET_PATHS
@@ -2272,12 +2277,71 @@ def load_data(self) -> None:
22722277
22732278
22742279class ConversationDataset (HuggingFaceDataset ):
2275- """Dataset for conversation data with multimodal support ."""
2280+ """Dataset for text-only conversation data."""
22762281
22772282 SUPPORTED_DATASET_PATHS = {
2278- "lmms-lab/LLaVA-OneVision-Data" ,
22792283 "Aeala/ShareGPT_Vicuna_unfiltered" ,
22802284 }
2285+ IS_MULTIMODAL = False
2286+
2287+ def sample (
2288+ self ,
2289+ tokenizer : PreTrainedTokenizerBase ,
2290+ num_requests : int ,
2291+ output_len : int | None = None ,
2292+ enable_multimodal_chat : bool = False ,
2293+ request_id_prefix : str = "" ,
2294+ no_oversample : bool = False ,
2295+ ** kwargs ,
2296+ ) -> list :
2297+ # Filter examples with at least 2 conversations
2298+ filtered_data = self .data .filter (lambda x : len (x ["conversations" ]) >= 2 )
2299+ sampled_requests = []
2300+ ind = 0
2301+ dynamic_output = output_len is None
2302+
2303+ for item in filtered_data :
2304+ if len (sampled_requests ) >= num_requests :
2305+ break
2306+ conv = item ["conversations" ]
2307+ prompt , completion = conv [0 ]["value" ], conv [1 ]["value" ]
2308+
2309+ prompt_ids = tokenizer (prompt ).input_ids
2310+ completion_ids = tokenizer (completion ).input_ids
2311+ prompt_len = len (prompt_ids )
2312+ completion_len = len (completion_ids )
2313+ output_len = completion_len if dynamic_output else output_len
2314+ assert isinstance (output_len , int ) and output_len > 0
2315+ if dynamic_output and not is_valid_sequence (prompt_len , completion_len ):
2316+ continue
2317+ mm_content = process_image (item ["image" ]) if "image" in item else None
2318+ if enable_multimodal_chat :
2319+ # Note: when chat is enabled the request prompt_len is no longer
2320+ # accurate and we will be using request output to count the
2321+ # actual prompt len and output len
2322+ prompt = self .apply_multimodal_chat_transformation (prompt , mm_content )
2323+ sampled_requests .append (
2324+ SampleRequest (
2325+ prompt = prompt ,
2326+ prompt_len = prompt_len ,
2327+ expected_output_len = output_len ,
2328+ multi_modal_data = mm_content ,
2329+ request_id = request_id_prefix + str (ind ),
2330+ )
2331+ )
2332+ ind += 1
2333+ self .maybe_oversample_requests (
2334+ sampled_requests , num_requests , request_id_prefix , no_oversample
2335+ )
2336+ return sampled_requests
2337+
2338+
2339+ class MultiModalConversationDataset (HuggingFaceDataset ):
2340+ """Dataset for multimodal conversation data."""
2341+
2342+ SUPPORTED_DATASET_PATHS = {
2343+ "lmms-lab/LLaVA-OneVision-Data" ,
2344+ }
22812345 IS_MULTIMODAL = True
22822346
22832347 def sample (
0 commit comments