@@ -98,22 +98,24 @@ def wrapper(*args, **kwargs):
9898 return decorator
9999
100100
101- def skip_if_no_xpu (message = None ):
102- """Decorator to skip tests on ROCm platform with custom message.
101+ def skip_if_no_xpu ():
102+ try :
103+ import pytest
103104
104- Args:
105- message (str, optional): Additional information about why the test is skipped.
106- """
107- import unittest
105+ has_pytest = True
106+ except ImportError :
107+ has_pytest = False
108+ import unittest
108109
109110 def decorator (func ):
110111 @functools .wraps (func )
111112 def wrapper (* args , ** kwargs ):
112113 if not torch .xpu .is_available ():
113- skip_message = "Skipping the test in XPU"
114- if message :
115- skip_message += f": { message } "
116- unittest .skip (skip_message )
114+ skip_message = "No XPU available"
115+ if has_pytest :
116+ pytest .skip (skip_message )
117+ else :
118+ unittest .skip (skip_message )
117119 return func (* args , ** kwargs )
118120
119121 return wrapper
@@ -123,19 +125,39 @@ def wrapper(*args, **kwargs):
123125
124126def skip_if_xpu (message = None ):
125127 """
126- Decorator to skip tests if XPU is available .
128+ Decorator to skip tests on XPU platform with custom message .
127129
128130 Args:
129131 message (str, optional): Additional information about why the test is skipped.
130132 """
133+ try :
134+ import pytest
135+
136+ has_pytest = True
137+ except ImportError :
138+ has_pytest = False
139+ import unittest
131140
132141 def decorator (func ):
133- reason = "Skipping the test on XPU"
134- if message :
135- reason += f": { message } "
142+ @functools .wraps (func )
143+ def wrapper (* args , ** kwargs ):
144+ if torch .xpu .is_available ():
145+ skip_message = "Skipping the test in XPU"
146+ if message :
147+ skip_message += f": { message } "
148+ if has_pytest :
149+ pytest .skip (skip_message )
150+ else :
151+ unittest .skip (skip_message )
152+ return func (* args , ** kwargs )
136153
137- return unittest . skipIf ( torch . xpu . is_available (), reason )( func )
154+ return wrapper
138155
156+ # Handle both @skip_if_xpu and @skip_if_xpu() syntax
157+ if callable (message ):
158+ func = message
159+ message = None
160+ return decorator (func )
139161 return decorator
140162
141163
0 commit comments