A 20-Year-Old Algorithm Can Help Us Understand Transformer Embeddings
Как 20-летний алгоритм помогает понять эмбеддинги трансформеров
Чтобы понять, о чём думает LLM, когда она слышит «Java», нужно разложить внутренние векторы на понятные человеку концепции. Это формулируется как задача dictionary learning: эмбеддинг представляется как разреженная сумма базовых векторов-концептов. В 2023 г. Bricken и др. предложили учить словарь через sparse autoencoder (SAE), отказавшись от классических методов из-за масштабируемости и опасения «слишком сильного» восстановления признаков.
Мы показали, что 20-летний алгоритм KSVD, с минимальными доработками, справляется с миллионами примеров и тысячами измерений. Наивная реализация требовала бы 30 дней; наша версия DB-KSVD ускорена в 10 000 раз и работает 8 минут. DB-KSVD обобщает k-means, но позволяет приписывать объект сразу нескольким «кластерам» (концептам).
Библиотека KSVD.jl доступна из Python:
import torch, juliacall; jl = juliacall.Main
jl.seval("using KSVD")
Y = torch.rand(128, 5000, dtype=torch.float32)
res = jl.ksvd(Y.numpy(), 256, 3) # словарь 256, sparsity 3
На бенчмарке SAEBench DB-KSVD и расширение MatryoshkaDB-KSVD показывают результаты, сравнимые с SAE, по шести метрикам: восстановление эмбеддингов, разделение концептов, их интерпретируемость и др.