From 177c96a4097519c73e46654bf6364ca9d7da940e Mon Sep 17 00:00:00 2001 From: Wennie396 Date: Tue, 23 Sep 2025 20:54:39 +0800 Subject: [PATCH 1/3] =?UTF-8?q?hack=20offload=20optimizer=E5=87=8F?= =?UTF-8?q?=E5=B0=91=E4=B8=80=E6=AC=A1master=20weight=E7=9A=84offload&relo?= =?UTF-8?q?ad?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddlenlp/trainer/utils/offload_optimizer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/utils/offload_optimizer.py b/paddlenlp/trainer/utils/offload_optimizer.py index f20066f1e29b..bb74c0984f4b 100644 --- a/paddlenlp/trainer/utils/offload_optimizer.py +++ b/paddlenlp/trainer/utils/offload_optimizer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import paddle from paddle import _C_ops @@ -37,6 +38,10 @@ def reload(tensor): assert new_tensor is tensor, "to_device must be inplace operation" +def is_offload_opt_cache_master_weight(): + return os.getenv("FLAGS_offload_opt_master_weight_cache", "0").lower() in ["true", "1"] + + def hack_offload_optimizer(): # Step 1: mock _add_accumulator origin_add_accumulator = getattr(Optimizer, "_add_accumulator") @@ -60,9 +65,10 @@ def new_opt_op(*args): ret = origin_op(*args) is_offload_opt = getattr(args[0], "is_offload_opt", False) for i, arg in enumerate(args): - if ( - i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt - ): # do not offload parameter and gradient + need_offload_arg = i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt + if is_offload_opt_cache_master_weight(): + need_offload_arg = need_offload_arg and i != 8 + if need_offload_arg: # do not offload parameter and gradient offload(arg) return ret From 74878bdfce228a0f3da7d62a88da9805aa01b5b9 Mon Sep 17 00:00:00 2001 From: Wennie396 Date: Wed, 24 Sep 2025 17:15:29 +0800 Subject: [PATCH 2/3] revert change --- paddlenlp/trainer/utils/offload_optimizer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/paddlenlp/trainer/utils/offload_optimizer.py b/paddlenlp/trainer/utils/offload_optimizer.py index bb74c0984f4b..a38a8d81e093 100644 --- a/paddlenlp/trainer/utils/offload_optimizer.py +++ b/paddlenlp/trainer/utils/offload_optimizer.py @@ -65,10 +65,13 @@ def new_opt_op(*args): ret = origin_op(*args) is_offload_opt = getattr(args[0], "is_offload_opt", False) for i, arg in enumerate(args): - need_offload_arg = i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt - if is_offload_opt_cache_master_weight(): - need_offload_arg = need_offload_arg and i != 8 - if need_offload_arg: # do not offload parameter and gradient + # need_offload_arg = i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt + # if is_offload_opt_cache_master_weight(): + # need_offload_arg = need_offload_arg and i != 8 + # if need_offload_arg: # do not offload parameter and gradient + if ( + i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt + ): # do not offload parameter and gradient offload(arg) return ret From 3a3d93fe5134cddfc18212c32c39a1f7e0ea1908 Mon Sep 17 00:00:00 2001 From: Wennie396 Date: Thu, 9 Oct 2025 13:23:48 +0800 Subject: [PATCH 3/3] add fix --- paddlenlp/trainer/utils/offload_optimizer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/paddlenlp/trainer/utils/offload_optimizer.py b/paddlenlp/trainer/utils/offload_optimizer.py index a38a8d81e093..bb74c0984f4b 100644 --- a/paddlenlp/trainer/utils/offload_optimizer.py +++ b/paddlenlp/trainer/utils/offload_optimizer.py @@ -65,13 +65,10 @@ def new_opt_op(*args): ret = origin_op(*args) is_offload_opt = getattr(args[0], "is_offload_opt", False) for i, arg in enumerate(args): - # need_offload_arg = i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt - # if is_offload_opt_cache_master_weight(): - # need_offload_arg = need_offload_arg and i != 8 - # if need_offload_arg: # do not offload parameter and gradient - if ( - i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt - ): # do not offload parameter and gradient + need_offload_arg = i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt + if is_offload_opt_cache_master_weight(): + need_offload_arg = need_offload_arg and i != 8 + if need_offload_arg: # do not offload parameter and gradient offload(arg) return ret