5 Star 34 Fork 9

MindSpore Lab / mindcv

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
train_step.py 6.97 KB
一键复制 编辑 原始数据 按行查看 历史
"""Customized TrainOneStepCell.
Supported algorithms are list as follows:
* Exponential Moving Average (EMA)
* Gradient Clipping
* Gradient Accumulation
"""
import mindspore as ms
from mindspore import Parameter, RowTensor, Tensor, boost, nn, ops
from mindspore.boost.grad_accumulation import gradient_accumulation_op, gradient_clear_op
from mindspore.ops import functional as F
__all__ = [
"GradientAccumulation",
"TrainStep",
]
_ema_op = ops.MultitypeFuncGraph("ema_op")
_grad_scale = ops.MultitypeFuncGraph("grad_scale")
reciprocal = ops.Reciprocal()
@_ema_op.register("Tensor", "Tensor", "Tensor")
def ema_ops(factor, ema_weight, weight):
return F.assign(ema_weight, ema_weight * factor + weight * (1 - factor))
@_grad_scale.register("Tensor", "Tensor")
def grad_scale_tensor(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))
@_grad_scale.register("Tensor", "RowTensor")
def grad_scale_row_tensor(scale, grad):
return RowTensor(
grad.indices,
grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape,
)
class GradientAccumulation(boost.GradientAccumulation):
"""
After accumulating the gradients of multiple steps, call to optimize its update.
Note:
The implementation is based on mindspore.boost.GradientAccumulation with the following modifications:
1. The learning rate will be updated at each iteration step.
However, in the original implementation, the learning rate will only be updated for each M iteration steps.
2. For distributed training, at gradient accumulation stage, each device will maintain
its own accumulated gradient; At the parameter update stage, the gradient will be synchronized
across the devices first and then perform the gradient update.
Args:
max_accumulation_step (int): Steps to accumulate gradients.
optimizer (nn.Cell): Optimizer used.
grad_reducer (nn.Cell): Gradient reducer, which synchronize gradients across the devices.
"""
def __init__(self, max_accumulation_step, optimizer, grad_reducer):
super().__init__(max_accumulation_step, optimizer)
self.grad_reducer = grad_reducer
def construct(self, loss, grads):
loss = F.depend(
loss,
self.hyper_map(
F.partial(gradient_accumulation_op, self._max_accumulation_step), self._grad_accumulation, grads
),
)
self._accumulation_step += 1
if self._accumulation_step >= self._max_accumulation_step:
# accumulate the gradient at each device, don't sync them until updating the weight
reduced_grad_accumulation = self.grad_reducer(self._grad_accumulation)
loss = F.depend(loss, self.optimizer(reduced_grad_accumulation))
loss = F.depend(loss, self.hyper_map(F.partial(gradient_clear_op), self._grad_accumulation))
self._accumulation_step = 0
else:
# update the learning rate, do not update the parameter
loss = F.depend(loss, self.optimizer.get_lr())
return loss
class TrainStep(nn.TrainOneStepWithLossScaleCell):
"""Training step with loss scale.
The customized trainOneStepCell also supported following algorithms:
* Exponential Moving Average (EMA)
* Gradient Clipping
* Gradient Accumulation
"""
def __init__(
self,
network,
optimizer,
scale_sense=1.0,
ema=False,
ema_decay=0.9999,
clip_grad=False,
clip_value=15.0,
gradient_accumulation_steps=1,
):
super(TrainStep, self).__init__(network, optimizer, scale_sense)
self.ema = ema
self.ema_decay = ema_decay
self.updates = Parameter(Tensor(0.0, ms.float32))
self.clip_grad = clip_grad
self.clip_value = clip_value
if self.ema:
self.weights_all = ms.ParameterTuple(list(network.get_parameters()))
self.ema_weight = self.weights_all.clone("ema", init="same")
self.accumulate_grad = gradient_accumulation_steps > 1
if self.accumulate_grad:
self.gradient_accumulation = GradientAccumulation(gradient_accumulation_steps, optimizer, self.grad_reducer)
def ema_update(self):
self.updates += 1
# ema factor is corrected by (1 - exp(-t/T)), where `t` means time and `T` means temperature.
ema_decay = self.ema_decay * (1 - F.exp(-self.updates / 2000))
# update trainable parameters
success = self.hyper_map(F.partial(_ema_op, ema_decay), self.ema_weight, self.weights_all)
return success
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
scaling_sens = self.scale_sense
status, scaling_sens = self.start_overflow_check(loss, scaling_sens)
scaling_sens_filled = ops.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss))
grads = self.grad(self.network, weights)(*inputs, scaling_sens_filled)
grads = self.hyper_map(F.partial(_grad_scale, scaling_sens), grads)
# todo: When to clip grad? Do we need to clip grad after grad reduction? What if grad accumulation is needed?
if self.clip_grad:
grads = ops.clip_by_global_norm(grads, clip_norm=self.clip_value)
if self.loss_scaling_manager: # scale_sense = update_cell: Cell --> TrainOneStepWithLossScaleCell.construct
if self.accumulate_grad:
# todo: GradientAccumulation only call grad_reducer at the step where the accumulation is completed.
# So checking the overflow status is after gradient reduction, is this correct?
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = self.gradient_accumulation(loss, grads)
if self.ema:
loss = F.depend(loss, self.ema_update())
else:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
# if there is no overflow, do optimize
if not overflow:
loss = F.depend(loss, self.optimizer(grads))
if self.ema:
loss = F.depend(loss, self.ema_update())
else: # scale_sense = loss_scale: Tensor --> TrainOneStepCell.construct
if self.accumulate_grad:
loss = self.gradient_accumulation(loss, grads)
else:
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
if self.ema:
loss = F.depend(loss, self.ema_update())
return loss
1
https://gitee.com/mindspore-lab/mindcv.git
git@gitee.com:mindspore-lab/mindcv.git
mindspore-lab
mindcv
mindcv
main

搜索帮助