В нашей предыдущей статье мы рассказали, что исследователи запустили Mamba, новую архитектуру, которая бросает вызов Transformer.
Их исследования показывают, что Mamba — это модель пространства состояний (SSM), которая демонстрирует превосходную производительность в различных модальностях, таких как язык, аудио и временные ряды. Чтобы проиллюстрировать эту точку зрения, исследователи провели эксперименты по моделированию языка с использованием модели Mamba-3B. Модель превосходит другие модели, основанные на Трансформаторах того же размера, а во время предварительного обучения и последующей оценки она работает так же хорошо, как модели Трансформеров, в два раза превышающие ее размер.
Mamba уникальна своими возможностями быстрой обработки, выборочным слоем SSM и удобным для аппаратного обеспечения дизайном, вдохновленным FlashAttention. Эти особенности делают Мамбу превосходящей Трансформера (Трансформер не имеет традиционного внимания и блоков MLP).
Многие люди хотят протестировать эффект Mamba самостоятельно, поэтому в этой статье собран код, который может полностью запускать Mamba на Colab. Код также использует официальную модель 3B Mamba для реального тестирования работы.
Сначала устанавливаем зависимости, которые представлены на официальном сайте:
!pip install causal-conv1d==1.0.0
!pip install mamba-ssm==1.0.1
Затем напрямую используйте библиотеку трансформеров, чтобы прочитать предварительно обученную Мамбу-3Б.
import torch
import os
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
model = MambaLMHeadModel.from_pretrained(os.path.expanduser("state-spaces/mamba-2.8b"), device="cuda", dtype=torch.bfloat16)
Как видите, модель 3b имеет 11G.
Затем идет тестовый контент.
tokens = tokenizer("What is the meaning of life", return_tensors="pt")
input_ids = tokens.input_ids.to(device="cuda")
max_length = input_ids.shape[1] + 80
fn = lambda: model.generate(
input_ids=input_ids, max_length=max_length, cg=True,
return_dict_in_generate=True, output_scores=True,
enable_timing=False, temperature=0.1, top_k=10, top_p=0.1,)
out = fn()
print(tokenizer.decode(out[0][0]))
Вот еще один пример чата
import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template
model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)
messages = []
user_message = """
What is the date for announcement
On August 10 said that its arm JSW Neo Energy has agreed to buy a portfolio of 1753 mega watt renewable energy generation capacity from Mytrah Energy India Pvt Ltd for Rs 10,530 crore.
"""
messages.append(dict(role="user",content=user_message))
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(out)
messages.append(dict(role="assistant",content=decoded[0].split("<|assistant|>\n")[-1]))
print("Model:", decoded[0].split("<|assistant|>\n")[-1])
Здесь я организовал весь код в блокноте Colab. Желающие могут использовать его напрямую:
https://colab.research.google.com/drive/1JyZpvncfSvtFZNOr3TU17Ff0BW5Nd_my?usp=sharing