@@ -75,7 +75,7 @@ def result_ready(self):
7575 async def run(self, config=None):
7676 print("Worker started with config:", config)
7777 # Wait until stop is requested
78- await self._tsignal_stopping.wait ()
78+ await self.wait_for_stop ()
7979 print("Worker finishing...")
8080
8181 async def do_work(self, data):
@@ -108,9 +108,7 @@ def __init__(self):
108108 All operations that access or modify worker's lifecycle state must be
109109 performed while holding this lock.
110110 """
111- self ._tsignal_lifecycle_lock = (
112- threading .RLock ()
113- ) # Renamed lock for loop and thread
111+ self ._tsignal_lifecycle_lock = threading .RLock ()
114112 self ._tsignal_stopping = asyncio .Event ()
115113 self ._tsignal_affinity = object ()
116114 self ._tsignal_process_queue_task = None
@@ -120,8 +118,10 @@ def __init__(self):
120118 @property
121119 def event_loop (self ) -> asyncio .AbstractEventLoop :
122120 """Returns the worker's event loop"""
121+
123122 if not self ._tsignal_loop :
124123 raise RuntimeError ("Worker not started" )
124+
125125 return self ._tsignal_loop
126126
127127 @t_signal
@@ -134,6 +134,7 @@ def stopped(self):
134134
135135 async def run (self , * args , ** kwargs ):
136136 """Run the worker."""
137+
137138 logger .debug ("[WorkerClass][run] calling super" )
138139
139140 super_run = getattr (super (), _WorkerConstants .RUN , None )
@@ -161,8 +162,10 @@ async def run(self, *args, **kwargs):
161162
162163 async def _process_queue (self ):
163164 """Process the task queue."""
165+
164166 while not self ._tsignal_stopping .is_set ():
165167 coro = await self ._tsignal_task_queue .get ()
168+
166169 try :
167170 await coro
168171 except Exception as e :
@@ -176,12 +179,14 @@ async def _process_queue(self):
176179
177180 async def start_queue (self ):
178181 """Start the task queue processing. Returns the queue task."""
182+
179183 self ._tsignal_process_queue_task = asyncio .create_task (
180184 self ._process_queue ()
181185 )
182186
183187 def queue_task (self , coro ):
184188 """Method to add a task to the queue"""
189+
185190 if not asyncio .iscoroutine (coro ):
186191 logger .error (
187192 "[WorkerClass][queue_task] Task must be a coroutine object: %s" ,
@@ -196,6 +201,7 @@ def queue_task(self, coro):
196201
197202 def start (self , * args , ** kwargs ):
198203 """Start the worker thread."""
204+
199205 run_coro = kwargs .pop (_WorkerConstants .RUN_CORO , None )
200206
201207 if run_coro is not None and not asyncio .iscoroutine (run_coro ):
@@ -207,6 +213,7 @@ def start(self, *args, **kwargs):
207213
208214 def thread_main ():
209215 """Thread main function."""
216+
210217 self ._tsignal_task_queue = asyncio .Queue ()
211218
212219 with self ._tsignal_lifecycle_lock :
@@ -215,6 +222,7 @@ def thread_main():
215222
216223 async def runner ():
217224 """Runner function."""
225+
218226 self .started .emit ()
219227
220228 if run_coro is not None :
@@ -324,4 +332,9 @@ def move_to_thread(self, target):
324332 self ._tsignal_affinity ,
325333 )
326334
335+ async def wait_for_stop (self ):
336+ """Wait for the worker to stop."""
337+
338+ await self ._tsignal_stopping .wait ()
339+
327340 return WorkerClass
0 commit comments