5 Star 34 Fork 9

MindSpore Lab / mindcv

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
vit.py 17.42 KB
一键复制 编辑 原始数据 按行查看 历史
import functools
from typing import Callable, Optional
import numpy as np
import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common.initializer import TruncatedNormal, XavierUniform, initializer
from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.drop_path import DropPath
from .layers.mlp import Mlp
from .layers.patch_dropout import PatchDropout
from .layers.patch_embed import PatchEmbed
from .layers.pos_embed import resample_abs_pos_embed
from .registry import register_model
__all__ = [
"vit_l_16_224", # with pretrained weights
"vit_b_32_224", # with pretrained weights
"vit_l_32_224", # with pretrained weights
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"first_conv": "patch_embed.proj",
"classifier": "head",
default_cfgs = {
"vit_b_16_224": _cfg(url=""),
"vit_b_16_384": _cfg(
url="", input_size=(3, 384, 384)
"vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-d2635f8b.ckpt"),
"vit_l_16_384": _cfg(
url="", input_size=(3, 384, 384)
"vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-4a1c9d8e.ckpt"),
"vit_b_32_384": _cfg(
url="", input_size=(3, 384, 384)
"vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-8c8ea164.ckpt"),
# TODO: Flash Attention
class Attention(nn.Cell):
Attention layer implementation, Rearrange Input -> B x N x hidden size.
dim (int): The dimension of input features.
num_heads (int): The number of attention heads. Default: 8.
qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True.
qk_norm (bool): Specifies whether to do normalization to q and k.
attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0.
proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0.
Tensor, output tensor.
>>> ops = Attention(768, 12)
def __init__(
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Cell = nn.LayerNorm,
super(Attention, self).__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = Tensor(self.head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
self.q_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity()
self.k_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity()
self.attn_drop = Dropout(attn_drop)
self.proj = nn.Dense(dim, dim)
self.proj_drop = Dropout(proj_drop)
self.mul = ops.Mul()
self.reshape = ops.Reshape()
self.transpose = ops.Transpose()
self.unstack = ops.Unstack(axis=0)
self.attn_matmul_v = ops.BatchMatMul()
self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
def construct(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = self.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim))
qkv = self.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = self.unstack(qkv)
q, k = self.q_norm(q), self.k_norm(k)
q = self.mul(q, self.scale**0.5)
k = self.mul(k, self.scale**0.5)
attn = self.q_matmul_k(q, k)
attn = ops.softmax(attn.astype(ms.float32), axis=-1).astype(attn.dtype)
attn = self.attn_drop(attn)
out = self.attn_matmul_v(attn, v)
out = self.transpose(out, (0, 2, 1, 3))
out = self.reshape(out, (b, n, c))
out = self.proj(out)
out = self.proj_drop(out)
return out
class LayerScale(nn.Cell):
Layer scale, help ViT improve the training dynamic, allowing for the training
of deeper high-capacity image transformers that benefit from depth
dim (int): The output dimension of attnetion layer or mlp layer.
init_values (float): The scale factor. Default: 1e-5.
Tensor, output tensor.
>>> ops = LayerScale(768, 0.01)
def __init__(
dim: int,
init_values: float = 1e-5
super(LayerScale, self).__init__()
self.gamma = Parameter(initializer(init_values, dim))
def construct(self, x):
return self.gamma * x
class Block(nn.Cell):
Transformer block implementation.
dim (int): The dimension of embedding.
num_heads (int): The number of attention heads.
qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True.
attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0.
proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0.
mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer.
drop_path (float): The drop rate for drop path. Default: 0.0.
act_layer (nn.Cell): Activation function which will be stacked on top of the
normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU.
norm_layer (nn.Cell): Norm layer that will be stacked on top of the convolution
layer. Default: nn.LayerNorm.
Tensor, output tensor.
>>> ops = TransformerEncoder(768, 12, 12, 3072)
def __init__(
dim: int,
num_heads: int = 8,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Cell = nn.GELU,
norm_layer: nn.Cell = nn.LayerNorm,
mlp_layer: Callable = Mlp,
super(Block, self).__init__()
self.norm1 = norm_layer((dim,))
self.attn = Attention(
self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer((dim,))
self.mlp = mlp_layer(
hidden_features=int(dim * mlp_ratio),
self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def construct(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class VisionTransformer(nn.Cell):
ViT encoder, which returns the feature encoded by transformer encoder.
def __init__(
image_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
global_pool: str = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: bool = True,
init_values: Optional[float] = None,
no_embed_class: bool = False,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
act_layer: nn.Cell = nn.GELU,
embed_layer: Callable = PatchEmbed,
norm_layer: nn.Cell = nn.LayerNorm,
mlp_layer: Callable = Mlp,
class_token: bool = True,
block_fn: Callable = Block,
num_classes: int = 1000,
super(VisionTransformer, self).__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
self.global_pool = global_pool
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.dynamic_img_size = dynamic_img_size
self.dynamic_img_pad = dynamic_img_pad
embed_args = {}
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
elif dynamic_img_pad:
self.patch_embed = embed_layer(
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
num_patches = self.patch_embed.num_patches
self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), (1, embed_len, embed_dim)))
self.pos_drop = Dropout(pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer((embed_dim,)) if pre_norm else nn.Identity()
dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)]
self.blocks = nn.CellList([
dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
attn_drop=attn_drop_rate, proj_drop=proj_drop_rate,
mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values,
act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer,
) for i in range(depth)
self.norm = norm_layer((embed_dim,)) if not use_fc_norm else nn.Identity()
self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else nn.Identity()
self.head_drop = Dropout(drop_rate)
self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init:
def get_num_layers(self):
return len(self.blocks)
def _init_weights(self):
w = self.patch_embed.proj.weight
w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:]))
w.set_data(initializer(XavierUniform(), w_shape_flatted, w.dtype).reshape(w.shape))
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Dense):
initializer(XavierUniform(), cell.weight.shape, cell.weight.dtype)
if cell.bias is not None:
initializer('zeros', cell.bias.shape, cell.bias.dtype)
elif isinstance(cell, nn.LayerNorm):
initializer('ones', cell.gamma.shape, cell.gamma.dtype)
initializer('zeros', cell.beta.shape, cell.beta.dtype)
def _pos_embed(self, x):
if self.dynamic_img_size or self.dynamic_img_pad:
# bhwc format
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
x = ops.reshape(x, (B, -1, C))
pos_embed = self.pos_embed
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if self.cls_token is not None:
cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1))
cls_tokens = cls_tokens.astype(x.dtype)
x = ops.concat((cls_tokens, x), axis=1)
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1))
cls_tokens = cls_tokens.astype(x.dtype)
x = ops.concat((cls_tokens, x), axis=1)
x = x + pos_embed
return self.pos_drop(x)
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward_head(self, x):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(axis=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.head_drop(x)
x = self.head(x)
return x
def construct(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def vit_b_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_16_224"]
model = VisionTransformer(
image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
def vit_b_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_16_384"]
model = VisionTransformer(
image_size=384, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
def vit_l_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_l_16_224"]
model = VisionTransformer(
image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
num_classes=num_classes, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
def vit_l_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_l_16_384"]
model = VisionTransformer(
image_size=384, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
num_classes=num_classes, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
def vit_b_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_32_224"]
model = VisionTransformer(
image_size=224, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
def vit_b_32_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_32_384"]
model = VisionTransformer(
image_size=384, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
def vit_l_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_l_32_224"]
model = VisionTransformer(
image_size=224, patch_size=32, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
num_classes=num_classes, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
