diff --git a/main/xiaozhi-server/config/logger.py b/main/xiaozhi-server/config/logger.py index 8866b1a84..e6f9ae472 100644 --- a/main/xiaozhi-server/config/logger.py +++ b/main/xiaozhi-server/config/logger.py @@ -38,6 +38,11 @@ def formatter(record): return record["message"] +def get_logger(tag): + """获取预配置的日志记录器,避免循环导入""" + # 基本配置,不依赖config + return logger.bind(tag=tag) + def setup_logging(): check_config_file() """从配置文件中读取日志配置,并设置日志输出格式和级别""" diff --git a/main/xiaozhi-server/core/providers/tts/local_cosyvoice.py b/main/xiaozhi-server/core/providers/tts/local_cosyvoice.py new file mode 100644 index 000000000..347ad4a51 --- /dev/null +++ b/main/xiaozhi-server/core/providers/tts/local_cosyvoice.py @@ -0,0 +1,107 @@ +import os +import sys +import uuid +import requests +from config.logger import get_logger +from datetime import datetime +from core.providers.tts.base import TTSProviderBase +import torch +import torchaudio + +TAG = __name__ +logger = get_logger(TAG) + + +class TTSProvider(TTSProviderBase): + + def _initialize_model(self): + # 保存原始 sys.path + original_path = None + + try: + original_path = sys.path.copy() + # 动态修改 sys.path + sys.path.insert(0, self.matcha_tts_path) + sys.path.insert(0, self.cosy_voice_path) + + # 导入必要的模块 + from cosyvoice.cli.cosyvoice import CosyVoice2 + from cosyvoice.utils.file_utils import load_wav + + # 初始化模型 + self.model = CosyVoice2(self.cosy_voice_model_dir, + load_jit=False, load_trt=False, fp16=False) + + # 保存导入的模块供之后使用 + self.CosyVoice2 = CosyVoice2 + self.prompt_speech_16k = load_wav(self.prompt_speech_16k, 16000) + + return True + except ImportError as e: + logger.bind(tag=TAG).error(f"导入 CosyVoice 模块失败: {e}") + raise ImportError(f"导入 CosyVoice 模块失败: {e}") + finally: + # 恢复原始 sys.path + if original_path: + sys.path = original_path + + def __init__(self, config, delete_audio_file): + super().__init__(config, delete_audio_file) + self.cosy_voice_path = config.get("cosyvoice_path") + self.cosy_voice_model_dir = config.get("cosyvoice_model_dir") + self.matcha_tts_path = config.get("matcha_tts_path") if config.get( + "matcha_tts_path") else f"{self.cosy_voice_path}/third_party/Matcha-TTS" + + self.prompt_speech_16k = config.get("prompt_speech_16k") if config.get( + "prompt_speech_16k") else f"{self.cosy_voice_path}/asset/zero_shot_prompt.wav" + # 非必传参数,如果不传,则使用默认的16k采样率的提示音频 + self.prompt_speech_16k_text = config.get("prompt_speech_16k_text") if config.get( + "prompt_speech_16k_text") else None + + self._initialize_model() + + def inference_to_single_file(self, inference_func, output_path, *args, **kwargs): + """ + 执行推理并将结果保存为单个音频文件 + + 参数: + inference_func: 推理函数(如cosyvoice.inference_zero_shot) + output_path: 输出文件路径 + *args, **kwargs: 传递给推理函数的参数 + + 返回: + 合并后的语音张量 + """ + speech_segments = [] + for segment in inference_func(*args, **kwargs): + speech_segments.append(segment['tts_speech']) + if speech_segments: + combined_speech = torch.cat(speech_segments, dim=1) + torchaudio.save(output_path, combined_speech, self.model.sample_rate) + return combined_speech + return None + + def generate_filename(self): + return os.path.join(self.output_file, f"tts-{datetime.now().date()}@{uuid.uuid4().hex}.{self.format}") + + async def text_to_speak(self, text, output_file): + try: + if not self.prompt_speech_16k_text: + self.inference_to_single_file( + self.model.inference_cross_lingual, + output_file, + text, + self.prompt_speech_16k, + stream=False + ) + else: + self.inference_to_single_file( + self.model.inference_zero_shot, + output_file, + text, + self.prompt_speech_16k_text, + self.prompt_speech_16k, + stream=False + ) + except Exception as e: + logger.bind(tag=TAG).exception(f"CosyVoice TTS请求失败: {e}") \ No newline at end of file diff --git a/main/xiaozhi-server/core/utils/tts.py b/main/xiaozhi-server/core/utils/tts.py index eb3836672..24a4b33fa 100644 --- a/main/xiaozhi-server/core/utils/tts.py +++ b/main/xiaozhi-server/core/utils/tts.py @@ -1,10 +1,10 @@ import os import re import sys -from config.logger import setup_logging +from config.logger import get_logger import importlib -logger = setup_logging() +logger = get_logger(__name__) def create_instance(class_name, *args, **kwargs): diff --git a/main/xiaozhi-server/models/SenseVoiceSmall/demo.py b/main/xiaozhi-server/models/SenseVoiceSmall/demo.py index 531e97985..e0c619760 100644 --- a/main/xiaozhi-server/models/SenseVoiceSmall/demo.py +++ b/main/xiaozhi-server/models/SenseVoiceSmall/demo.py @@ -12,9 +12,19 @@ hub="hf", ) +res0 = model.generate( + input=f"{model.model_path}/example/en.mp3", + cache={}, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=True, + batch_size_s=60, + merge_vad=True, # + merge_length_s=15, +) + # en res = model.generate( - input=f"{model.model_path}/example/en.mp3", + input=f"{model.model_path}/example/zh.mp3", cache={}, language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" use_itn=True, @@ -22,6 +32,6 @@ merge_vad=True, # merge_length_s=15, ) -text = rich_transcription_postprocess(res[0]["text"]) -print(text) +print(rich_transcription_postprocess(res[0]["text"])) +print(rich_transcription_postprocess(res0[0]["text"])) diff --git a/main/xiaozhi-server/requirements.txt b/main/xiaozhi-server/requirements.txt index 2b6f1dc9b..c32ec4536 100755 --- a/main/xiaozhi-server/requirements.txt +++ b/main/xiaozhi-server/requirements.txt @@ -1,12 +1,12 @@ pyyml==0.0.2 torch==2.2.2 +torchaudio==2.2.2 silero_vad==5.1.2 websockets==14.2 opuslib_next==1.1.2 numpy==1.26.4 pydub==0.25.1 funasr==1.2.3 -torchaudio==2.2.2 openai==1.61.0 google-generativeai==0.8.4 edge_tts==7.0.0 diff --git a/main/xiaozhi-server/test/unit/tts/cosyvoice_requirements.txt b/main/xiaozhi-server/test/unit/tts/cosyvoice_requirements.txt new file mode 100644 index 000000000..f09d9bca1 --- /dev/null +++ b/main/xiaozhi-server/test/unit/tts/cosyvoice_requirements.txt @@ -0,0 +1,14 @@ +HyperPyYAML==1.2.2 +openai-whisper==20231117 +inflect==7.3.1 +transformers==4.40.1 +conformer==0.3.2 +diffusers==0.29.0 +lightning==2.2.4 +rich==13.7.1 +gdown==5.1.0 +matplotlib==3.7.5 +wget==3.2 +pyarrow==19.0.1 +pyworld==0.3.4 +onnxruntime-gpu==1.21.0 \ No newline at end of file diff --git a/main/xiaozhi-server/test/unit/tts/local_cosyvoice_real.py b/main/xiaozhi-server/test/unit/tts/local_cosyvoice_real.py new file mode 100644 index 000000000..f8987e056 --- /dev/null +++ b/main/xiaozhi-server/test/unit/tts/local_cosyvoice_real.py @@ -0,0 +1,81 @@ +import os +import unittest +import time +from datetime import datetime +import uuid +import wave +import torchaudio +from core.providers.tts.local_cosyvoice import TTSProvider + + +class TestRealTTSGeneration(unittest.TestCase): + """真实语音生成集成测试,需要实际的 CosyVoice 环境""" + + def setUp(self): + # 使用真实配置 - 请确保这些路径在您的环境中有效 + self.config = { + "output_dir": "/tmp", + "cosyvoice_path": "/home/shangjun/xt_workspace/python_workspace/CosyVoice", + "cosyvoice_model_dir": "/home/shangjun/xt_workspace/python_workspace/CosyVoice/pretrained_models/CosyVoice2-0.5B", + "prompt_speech_16k": "/home/shangjun/xt_workspace/python_workspace/CosyVoice/asset/zero_shot_prompt.wav", + "prompt_speech_16k_text": "希望你以后能够做的比我还好呦。" + } + # 创建输出目录 + os.makedirs(self.config["output_dir"], exist_ok=True) + + def test_real_tts_generation(self): + """使用真实模型生成语音文件""" + # 初始化提供者,不删除生成的文件 + provider = TTSProvider(self.config, False) + provider.format = "wav" + + # 生成测试文本 + test_text = "王骀受了刖刑,被砍去了一只脚。孔子有个弟子叫常季,他见老师时提出了自己的疑问。他说:老师你看,王骀被砍去了一只脚,可是他的学识和品行好像都超过了先生您,至于跟平常人相比,好像水平就更高了。像他这样的人,运用心智是怎样的与众不同呢?孔子的学生觉得很是奇怪,这个人一只脚被砍掉了,但是他的名声却很大,很多人都喜欢跟他学习,这个学生感到很不理解,一见到老师就向老师提出自己心中的疑问。文中庄子又是借孔子之口,表达了自己这样的观点:说死和生都是人生中的大事,可是死和生都不能使王骀这样的人随之变化,你说王骀是个什么样的人呢?即使天翻过来地坠下去,他也不会因此而被毁灭,他通晓无所依凭的道理,当然也就不随物变迁,而是听任事物的变化而信守自己的宗本。孔子的这段话把常季给说晕了,他忍不住再问:老师您这些话是什么意思啊?孔子怎么回答的呢?这段话很重要,来看一下完整的译文:孔子说:“从事物千差万别的一面去看,邻近的肝胆虽处于一体之中,也像是楚国和越国那样相距甚远;如果从事物相同的一面来看,万事万物又都是同一的,没有差别的。像王骀这样的人,耳朵和眼睛最适宜何种声音和色彩这样的事,已经不在他考虑范围之内了。他让自己的心思自由自在地遨游在忘形、忘情的浑同境域之中,就把这些东西的差别都忘掉了。所以他看待自己丧失了一只脚这件事,就像是看待失落的土块一样。”学了前面《庄子》的几篇文章,这段话的观点我们已不陌生。另外,有没有觉得这段话的句式很熟悉?中学时我们就学过苏东坡的《前赤壁赋》,其中就有这样的句式:“自其变者而观之,则天地曾不能以一瞬;自其不变者而观之,则物与我皆无尽也。”可以说,东坡不仅化用了庄子的句式,而且思想也和庄子是一样的。" + + # 执行文本到语音转换 + start_time = time.time() + result_file = provider.to_tts(test_text) + end_time = time.time() + + # 输出生成信息 + print(f"语音生成耗时: {end_time - start_time:.2f}秒") + print(f"生成的文件路径: {result_file}") + + # 验证文件是否存在 + self.assertTrue(os.path.exists(result_file), "语音文件未成功生成") + + # 验证文件格式 + self.assertTrue(result_file.endswith(".wav"), "生成的不是WAV文件") + + # 验证文件内容 + try: + # 检查音频文件属性 + audio_info = torchaudio.info(result_file) + print(f"采样率: {audio_info.sample_rate}Hz") + print(f"声道数: {audio_info.num_channels}") + print(f"音频长度: {audio_info.num_frames / audio_info.sample_rate:.2f}秒") + + # 加载音频文件 + waveform, sample_rate = torchaudio.load(result_file) + + # 验证音频基本特性 + self.assertEqual(sample_rate, 24000, "采样率应为24kHz") + self.assertTrue(waveform.size(0) > 0, "音频数据不应为空") + self.assertTrue(waveform.size(1) > 0, "音频长度不应为0") + + print(f"音频形状: {waveform.shape}") + print(f"最大值: {waveform.max().item():.4f}, 最小值: {waveform.min().item():.4f}") + + except Exception as e: + self.fail(f"验证音频文件失败: {e}") + + # 如果需要,可以在这里播放音频进行人工验证 + # import IPython.display as ipd + # ipd.Audio(result_file) + + def tearDown(self): + # 清理临时文件(可选) + # 注意:如果想保留文件以便检查,可以注释掉下面的代码 + # import shutil + # shutil.rmtree(self.config["output_dir"]) + pass \ No newline at end of file diff --git a/main/xiaozhi-server/test/unit/tts/local_cosyvoice_test.py b/main/xiaozhi-server/test/unit/tts/local_cosyvoice_test.py new file mode 100644 index 000000000..c5a88502d --- /dev/null +++ b/main/xiaozhi-server/test/unit/tts/local_cosyvoice_test.py @@ -0,0 +1,244 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import os +import torch +import uuid +from datetime import datetime +import asyncio +import numpy as np +from core.providers.tts.base import TTSProviderBase +from core.providers.tts.local_cosyvoice import TTSProvider + + +class MockTTSProvider(TTSProviderBase): + def generate_filename(self): + return f"/tmp/mock_tts_{str(uuid.uuid4())}.wav" + + async def text_to_speak(self, text, output_file): + pass + + +class TestTTSProviderBase(unittest.TestCase): + + def setUp(self): + self.config = {"output_dir": "/tmp/audio"} + self.provider = MockTTSProvider(self.config, True) + + @patch('core.providers.tts.base.MarkdownCleaner.clean_markdown') + @patch('asyncio.run') + @patch('os.path.exists') + def test_successful_tts_generation(self, mock_exists, mock_run, mock_clean): + mock_clean.return_value = "清理后的文本" + mock_exists.side_effect = [False, True, True] + + result = self.provider.to_tts("你好世界") + + self.assertIsNotNone(result) + mock_run.assert_called_once() + mock_clean.assert_called_once_with("你好世界") + + @patch('core.providers.tts.base.MarkdownCleaner.clean_markdown') + @patch('asyncio.run') + @patch('os.path.exists') + def test_tts_with_retries(self, mock_exists, mock_run, mock_clean): + mock_clean.return_value = "清理后的文本" + mock_exists.side_effect = [False, False, False, False, True] + + result = self.provider.to_tts("你好世界") + + self.assertIsNotNone(result) + self.assertEqual(mock_run.call_count, 2) + + @patch('core.providers.tts.base.MarkdownCleaner.clean_markdown') + @patch('asyncio.run') + @patch('os.path.exists') + def test_tts_max_retries_exceeded(self, mock_exists, mock_run, mock_clean): + mock_clean.return_value = "清理后的文本" + mock_exists.return_value = False + + result = self.provider.to_tts("你好世界") + + self.assertEqual(mock_run.call_count, 5) + + @patch('core.providers.tts.base.MarkdownCleaner.clean_markdown') + @patch('asyncio.run') + def test_tts_with_exception(self, mock_run, mock_clean): + mock_clean.return_value = "清理后的文本" + mock_run.side_effect = Exception("TTS 错误") + + result = self.provider.to_tts("你好世界") + + self.assertIsNone(result) + + @patch('pydub.AudioSegment.from_file') + @patch('opuslib_next.Encoder') + def test_audio_to_opus_conversion(self, mock_encoder_class, mock_from_file): + mock_audio = MagicMock() + mock_audio.set_channels.return_value = mock_audio + mock_audio.set_frame_rate.return_value = mock_audio + mock_audio.set_sample_width.return_value = mock_audio + mock_audio.__len__.return_value = 3000 + mock_audio.raw_data = b'\x00\x01' * 48000 + mock_from_file.return_value = mock_audio + + mock_encoder = MagicMock() + mock_encoder.encode.return_value = b'encoded_data' + mock_encoder_class.return_value = mock_encoder + + result, duration = self.provider.audio_to_opus_data("test.wav") + + self.assertEqual(duration, 3.0) + self.assertTrue(len(result) > 0) + self.assertTrue(all(item == b'encoded_data' for item in result)) + + +class TestCosyVoiceTTSProvider(unittest.TestCase): + + def setUp(self): + self.config = { + "output_dir": "/tmp/audio", + "cosyvoice_path": "/path/to/cosyvoice", + "cosyvoice_model_dir": "/path/to/models" + } + self.mock_cosyvoice = MagicMock() + self.mock_torch = MagicMock() + self.mock_torchaudio = MagicMock() + + self.patches = [ + patch('sys.path', new_callable=list), + patch.dict('sys.modules', { + 'cosyvoice.cli.cosyvoice': MagicMock(), + 'cosyvoice.utils.file_utils': MagicMock(), + 'torch': self.mock_torch, + 'torchaudio': self.mock_torchaudio + }) + ] + for p in self.patches: + p.start() + + def tearDown(self): + for p in self.patches: + p.stop() + + @patch('core.providers.tts.local_cosyvoice.CosyVoice2') + def test_initialization_with_defaults(self, mock_cosyvoice_class): + provider = TTSProvider(self.config, True) + + self.assertEqual(provider.cosy_voice_path, "/path/to/cosyvoice") + self.assertEqual(provider.cosy_voice_model_dir, "/path/to/models") + self.assertEqual(provider.matcha_tts_path, "/path/to/cosyvoice/third_party/Matcha-TTS") + self.assertTrue(provider.prompt_speech_16k.endswith("/asset/zero_shot_prompt.wav")) + self.assertIsNone(provider.prompt_speech_16k_text) + + @patch('core.providers.tts.local_cosyvoice.CosyVoice2') + def test_initialization_with_custom_values(self, mock_cosyvoice_class): + custom_config = self.config.copy() + custom_config.update({ + "matcha_tts_path": "/custom/matcha", + "prompt_speech_16k": "/custom/prompt.wav", + "prompt_speech_16k_text": "你好" + }) + + provider = TTSProvider(custom_config, True) + + self.assertEqual(provider.matcha_tts_path, "/custom/matcha") + self.assertEqual(provider.prompt_speech_16k, "/custom/prompt.wav") + self.assertEqual(provider.prompt_speech_16k_text, "你好") + + @patch.object(TTSProvider, '_initialize_model') + @patch('uuid.uuid4') + @patch('core.providers.tts.local_cosyvoice.datetime') # 修改这一行 + def test_generate_filename(self, mock_datetime, mock_uuid, _): + mock_datetime.now.return_value.date.return_value = "2023-01-01" + mock_uuid.return_value.hex = "abcd1234" + + provider = TTSProvider(self.config, True) + provider.format = "wav" + filename = provider.generate_filename() + + self.assertEqual(filename, "/tmp/audio/tts-2023-01-01@abcd1234.wav") + + @patch('core.providers.tts.local_cosyvoice.CosyVoice2') + def test_inference_to_single_file(self, _): + provider = TTSProvider(self.config, True) + cosyvoice = MagicMock() + cosyvoice.sample_rate = 16000 + + mock_speech1 = torch.tensor([0.1, 0.2]) + mock_speech2 = torch.tensor([0.3, 0.4]) + mock_inference_func = MagicMock() + mock_inference_func.return_value = [ + {"tts_speech": mock_speech1}, + {"tts_speech": mock_speech2} + ] + + with patch.object(torch, 'cat', return_value="combined_speech"): + result = provider.inference_to_single_file( + mock_inference_func, "/tmp/output.wav", "测试文本" + ) + + mock_inference_func.assert_called_once_with("测试文本") + self.assertEqual(result, "combined_speech") + + @patch('core.providers.tts.local_cosyvoice.CosyVoice2') + def test_inference_to_single_file_empty_result(self, _): + provider = TTSProvider(self.config, True) + + mock_inference_func = MagicMock() + mock_inference_func.return_value = [] + + result = provider.inference_to_single_file( + mock_inference_func, "/tmp/output.wav", "测试文本" + ) + + self.assertIsNone(result) + + @patch('core.providers.tts.local_cosyvoice.CosyVoice2') + @patch('core.providers.tts.local_cosyvoice.cosyvoice', new_callable=MagicMock) + @patch('core.providers.tts.local_cosyvoice.torch', new_callable=MagicMock) + @patch('core.providers.tts.local_cosyvoice.torchaudio', new_callable=MagicMock) + async def test_text_to_speak_with_prompt_text(self, _, __, mock_cosyvoice, ___): + provider = TTSProvider(self.config, True) + provider.prompt_speech_16k_text = "提示文本" + + with patch.object(provider, 'inference_to_single_file') as mock_inference: + await provider.text_to_speak("你好世界", "/tmp/output.wav") + + mock_inference.assert_called_once_with( + mock_cosyvoice.inference_cross_lingual, + "/tmp/output.wav", + "你好世界", + provider.prompt_speech_16k_text, + provider.prompt_speech_16k, + stream=False + ) + + @patch('core.providers.tts.local_cosyvoice.CosyVoice2') + @patch('core.providers.tts.local_cosyvoice.cosyvoice', new_callable=MagicMock) + @patch('core.providers.tts.local_cosyvoice.torch', new_callable=MagicMock) + @patch('core.providers.tts.local_cosyvoice.torchaudio', new_callable=MagicMock) + async def test_text_to_speak_without_prompt_text(self, _, __, mock_cosyvoice, ___): + provider = TTSProvider(self.config, True) + provider.prompt_speech_16k_text = None + + with patch.object(provider, 'inference_to_single_file') as mock_inference: + await provider.text_to_speak("你好世界", "/tmp/output.wav") + + mock_inference.assert_called_once_with( + mock_cosyvoice.inference_zero_shot, + "/tmp/output.wav", + "你好世界", + provider.prompt_speech_16k, + stream=False + ) + + @patch('core.providers.tts.local_cosyvoice.CosyVoice2') + async def test_text_to_speak_handles_exception(self, _): + provider = TTSProvider(self.config, True) + + with patch.object(provider, 'inference_to_single_file', side_effect=Exception("TTS错误")): + try: + await provider.text_to_speak("你好世界", "/tmp/output.wav") + # 如果不抛出异常则测试通过 + except Exception: + self.fail("text_to_speak方法没有正确处理异常")