Отсечение градиента — это метод оптимизации, который предотвращает взрыв или исчезновение градиента. Он может масштабировать или усекать градиент в процессе обратного распространения ошибки, чтобы поддерживать его в разумных пределах. Существует два распространенных метода обрезки градиента:
В PyTorch вы можете использовать torch.nn.utils.clip_grad_value_ и torch.nn.utils.clip_grad_norm_ Эти две функции используются для реализации отсечения градиента. Они вызываются после завершения расчета градиента и до обновления весов.
torch.nn.utils.clip_grad_value_ — это функция, которая обрезает градиент параметра, чтобы он не превышал заданное значение. Это может предотвратить проблему взрыва или исчезновения градиента и улучшить эффект обучения модели.
import torch
import torch.nn as nn
# Определите простую линейную модель
model = nn.Linear(2, 1)
# Определить оптимизацию
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# Определите функцию потерь
criterion = nn.MSELoss()
# Создайте случайный ввод и цель
x = torch.randn(4, 2)
y = torch.randn(4, 1)
# прямое распространение
output = model(x)
# Рассчитать потери
loss = criterion(output, y)
# Обратное распространение ошибки
loss.backward()
# Прежде чем обновлять веса, обрежьте градиент так, чтобы он не превышал 0,5.
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
# Обновить веса
optimizer.step()
В этом коде мы используем torch.nn.utils.clip_grad_value_ Функция, принимающая два параметра: один — параметр модели, другой — обрезанное значение. Он обрезает градиент каждого параметра так, чтобы он был В диапазоне [-0,5, 0,5]. Это может предотвратить слишком большой или слишком маленький градиент, влияющий на сходимость модели.
import torch
import torch.nn as nn
import torch.optim as optim
# Предположим, у нас есть простая полностью подключенная сеть.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# Создание сети, оптимизация и функция потерь
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
# Предположим, у нас есть случайные входные данные и цель.
data = torch.randn(5, 10)
target = torch.randn(5, 1)
# Этапы обучения
outputs = model(data) # прямое распространение
loss = loss_fn(outputs, target) # Рассчитать потери
optimizer.zero_grad() # Четкий градиент
loss.backward() # Обратное распространение ошибки,Рассчитать градиент
# Перед этапом оптимизации мы используем обрезку градиента.
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step() # Обновить параметры модели
В PyTorch,nn.utils.clip_grad_norm_ Функция используется для реализации обрезки градиента. Эта функция сначала вычисляет норму градиента, а затем ограничивает ее максимальным значением. Это может предотвратить проблемы численной нестабильности, вызванные чрезмерными градиентами во время обратного распространения ошибки.
Параметры этой функции следующие:
Рабочий процесс этого кода выглядит следующим образом:
# Для каждого параметра Модели рассчитайте норму L2 ее градиента.
for param in model.parameters():
grad_norm = torch.norm(param.grad, p=2)
print(grad_norm)
В этом коде мы используем torch.norm Функция, принимающая два параметра: один — тензор для вычисления нормы, а другой — тип нормы. Указанный тип нормы — 2, что означает расчет нормы L2. Таким образом можно получить норму L2 градиента каждого параметра.
Отсечение градиента в основном используется для решения проблемы взрыва градиента при обучении нейронных сетей. Вот несколько ситуаций, в которых вы можете использовать обрезку градиента:
(1) Глубокая нейронная сеть:глубинанейронная сеть,Особенно РНН,Проблема градиентного взрыва часто возникает в процессе обучения.。Это потому, что в Обратное распространение ошибкив процессе,Градиент увеличивается экспоненциально с увеличением количества слоев.
(2) Обучение нестабильно:нравиться果你在训练в процессе观察到Модель Убыток внезапно становится очень большим или становитсяNaN,Это может быть вызвано взрывом градиентов. в этом случае,Использование градиентного отсечения может помочь стабилизировать тренировку.
(3) Длинная последовательность тренировок:Обработка данных длинной последовательности(нравиться机устройство翻译或语音识别)час,За счет увеличения длины последовательности,Градиент может быть в Обратное распространение ошибкив процессе累加并导致爆炸。Градиентное отсечение предотвращает это.。
Важно отметить, что хотя отсечение градиента может помочь предотвратить взрыв градиента, оно не решает проблему исчезновения градиентов. Для решения проблем исчезновения градиента может потребоваться использовать другие методы, такие как сети с вентилируемыми рекуррентными единицами (GRU) или сети с длинной краткосрочной памятью (LSTM), или использовать такие методы, как остаточные соединения.
Хотя отсечение градиента является эффективным методом предотвращения взрыва градиента, оно также имеет некоторые потенциальные недостатки:
(1) Выберите подходящий порог обрезки:Выбор подходящего порога ограничения градиента может оказаться затруднительным.。нравиться果阈值设置的太大,Тогда отсечение градиента может оказаться не в состоянии предотвратить взрыв градиента, если порог установлен слишком мал;,Тогда это может ограничить способность Модели к обучению. в целом,Этот порог необходимо определять экспериментально.
(2) Проблема исчезающего градиента не может быть решена:Отсечение градиента предотвращает только взрыв градиента.,Но это не может решить проблему исчезновения градиента. Подробно нейронная сеть,Исчезающий градиент также является распространенной проблемой.,Это затрудняет обучение глубоких частей сети.
(3) Может повлиять на производительность оптимизатора:некоторыйоптимизацияустройство,Такие как АдамиRMSProp,Включены механизмы предотвращения взрыва градиентов. Использование ограничения градиента в этих оптимизациях может повлиять на их внутренние рабочие механизмы.,Тем самым влияя на эффективность тренировок.
(4) Могут быть введены дополнительные вычислительные затраты:вычислитьи应用梯度裁剪需要额外的вычислить资源,Особенно в Модели, где количество параметров очень велико.
Ссылка: Глубокое обучение на графах и LLM для больших моделей.