Skip to content

Commit f4eed65

Browse files
committed
refine the xpu skip func
1 parent b4933ae commit f4eed65

File tree

1 file changed

+37
-15
lines changed

1 file changed

+37
-15
lines changed

torchao/testing/utils.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,24 @@ def wrapper(*args, **kwargs):
9898
return decorator
9999

100100

101-
def skip_if_no_xpu(message=None):
102-
"""Decorator to skip tests on ROCm platform with custom message.
101+
def skip_if_no_xpu():
102+
try:
103+
import pytest
103104

104-
Args:
105-
message (str, optional): Additional information about why the test is skipped.
106-
"""
107-
import unittest
105+
has_pytest = True
106+
except ImportError:
107+
has_pytest = False
108+
import unittest
108109

109110
def decorator(func):
110111
@functools.wraps(func)
111112
def wrapper(*args, **kwargs):
112113
if not torch.xpu.is_available():
113-
skip_message = "Skipping the test in XPU"
114-
if message:
115-
skip_message += f": {message}"
116-
unittest.skip(skip_message)
114+
skip_message = "No XPU available"
115+
if has_pytest:
116+
pytest.skip(skip_message)
117+
else:
118+
unittest.skip(skip_message)
117119
return func(*args, **kwargs)
118120

119121
return wrapper
@@ -123,19 +125,39 @@ def wrapper(*args, **kwargs):
123125

124126
def skip_if_xpu(message=None):
125127
"""
126-
Decorator to skip tests if XPU is available.
128+
Decorator to skip tests on XPU platform with custom message.
127129
128130
Args:
129131
message (str, optional): Additional information about why the test is skipped.
130132
"""
133+
try:
134+
import pytest
135+
136+
has_pytest = True
137+
except ImportError:
138+
has_pytest = False
139+
import unittest
131140

132141
def decorator(func):
133-
reason = "Skipping the test on XPU"
134-
if message:
135-
reason += f": {message}"
142+
@functools.wraps(func)
143+
def wrapper(*args, **kwargs):
144+
if torch.xpu.is_available():
145+
skip_message = "Skipping the test in XPU"
146+
if message:
147+
skip_message += f": {message}"
148+
if has_pytest:
149+
pytest.skip(skip_message)
150+
else:
151+
unittest.skip(skip_message)
152+
return func(*args, **kwargs)
136153

137-
return unittest.skipIf(torch.xpu.is_available(), reason)(func)
154+
return wrapper
138155

156+
# Handle both @skip_if_xpu and @skip_if_xpu() syntax
157+
if callable(message):
158+
func = message
159+
message = None
160+
return decorator(func)
139161
return decorator
140162

141163

0 commit comments

Comments
 (0)