mindsporeliteexamplestrain_lenetmodeltrain_utils.py注解
** "mindsporeliteexamplestrain_lenetmodeltrain_utils.py"**
一、代码用处
这段代码块主要使用于 训练数据的模型
二、代码注释
"""train_utils."""
import mindspore.nn as nn#导入mindspore包
from mindspore.common.parameter
import ParameterTuple
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):#定义一个包装函数
"""
TrainWrap
"""
if loss_fn is None:#判断是否有损失
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction=‘mean‘, sparse=True)#调用方法使用 Logits 的软最大交叉熵
loss_net = nn.WithLossCell(net, loss_fn)
loss_net.set_train()
if weights is None:
weights = ParameterTuple(net.trainable_params())
if optimizer is None:#优化器
optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)#进行优化
train_net = nn.TrainOneStepCell(loss_net, optimizer)
return train_net#返回训练数据
mindsporeliteexamplestrain_lenetmodeltrain_utils.py注解
[db:回答]
THE END
二维码