390 lines
13 KiB
Python
390 lines
13 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import math
|
||
|
||
class OptimizedKANLinear(nn.Module):
|
||
def __init__(
|
||
self,
|
||
in_features,
|
||
out_features,
|
||
grid_size=5,
|
||
spline_order=3,
|
||
scale_noise=0.1,
|
||
scale_base=1.0,
|
||
scale_spline=1.0,
|
||
enable_standalone_scale_spline=True,
|
||
base_activation=nn.SiLU,
|
||
grid_eps=0.02,
|
||
grid_range=[-1, 1],
|
||
):
|
||
super(OptimizedKANLinear, self).__init__()
|
||
self.in_features = in_features
|
||
self.out_features = out_features
|
||
self.grid_size = grid_size
|
||
self.spline_order = spline_order
|
||
|
||
h = (grid_range[1] - grid_range[0]) / grid_size
|
||
grid = (
|
||
(
|
||
torch.arange(-spline_order, grid_size + spline_order + 1) * h
|
||
+ grid_range[0]
|
||
)
|
||
.expand(in_features, -1)
|
||
.contiguous()
|
||
)
|
||
self.register_buffer("grid", grid)
|
||
|
||
self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
||
self.spline_weight = nn.Parameter(
|
||
torch.Tensor(out_features, in_features, grid_size + spline_order)
|
||
)
|
||
if enable_standalone_scale_spline:
|
||
self.spline_scaler = nn.Parameter(
|
||
torch.Tensor(out_features, in_features)
|
||
)
|
||
|
||
self.scale_noise = scale_noise
|
||
self.scale_base = scale_base
|
||
self.scale_spline = scale_spline
|
||
self.enable_standalone_scale_spline = enable_standalone_scale_spline
|
||
self.base_activation = base_activation()
|
||
self.grid_eps = grid_eps
|
||
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self):
|
||
nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
|
||
with torch.no_grad():
|
||
noise = (
|
||
(
|
||
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
|
||
- 1 / 2
|
||
)
|
||
* self.scale_noise
|
||
/ self.grid_size
|
||
)
|
||
self.spline_weight.data.copy_(
|
||
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
|
||
* self.curve2coeff(
|
||
self.grid.T[self.spline_order : -self.spline_order],
|
||
noise,
|
||
)
|
||
)
|
||
if self.enable_standalone_scale_spline:
|
||
nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
|
||
|
||
def b_splines(self, x: torch.Tensor):
|
||
"""
|
||
计算给定输入张量的B样条基函数 - 内存优化版本
|
||
"""
|
||
# 确保输入是2D张量
|
||
if x.dim() > 2:
|
||
original_shape = x.shape
|
||
x = x.reshape(-1, self.in_features)
|
||
else:
|
||
original_shape = None
|
||
|
||
assert x.size(-1) == self.in_features
|
||
|
||
grid: torch.Tensor = self.grid
|
||
x = x.unsqueeze(-1)
|
||
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
|
||
for k in range(1, self.spline_order + 1):
|
||
bases = (
|
||
(x - grid[:, : -(k + 1)])
|
||
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
|
||
* bases[:, :, :-1]
|
||
) + (
|
||
(grid[:, k + 1 :] - x)
|
||
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
|
||
* bases[:, :, 1:]
|
||
)
|
||
|
||
assert bases.size() == (
|
||
x.size(0),
|
||
self.in_features,
|
||
self.grid_size + self.spline_order,
|
||
)
|
||
return bases.contiguous()
|
||
|
||
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
|
||
"""
|
||
计算插值给定点的曲线系数 - 内存优化版本
|
||
"""
|
||
assert x.dim() == 2 and x.size(1) == self.in_features
|
||
assert y.size() == (x.size(0), self.in_features, self.out_features)
|
||
|
||
A = self.b_splines(x).transpose(
|
||
0, 1
|
||
) # (in_features, batch_size, grid_size + spline_order)
|
||
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
|
||
|
||
# 使用内存优化的方式计算最小二乘解
|
||
solution = torch.linalg.lstsq(
|
||
A, B
|
||
).solution # (in_features, grid_size + spline_order, out_features)
|
||
result = solution.permute(
|
||
2, 0, 1
|
||
) # (out_features, in_features, grid_size + spline_order)
|
||
|
||
assert result.size() == (
|
||
self.out_features,
|
||
self.in_features,
|
||
self.grid_size + self.spline_order,
|
||
)
|
||
return result.contiguous()
|
||
|
||
@property
|
||
def scaled_spline_weight(self):
|
||
return self.spline_weight * (
|
||
self.spline_scaler.unsqueeze(-1)
|
||
if self.enable_standalone_scale_spline
|
||
else 1.0
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
"""
|
||
前向传播 - 内存优化版本
|
||
"""
|
||
# 保存原始形状并将输入重塑为2D
|
||
if x.dim() > 2:
|
||
original_shape = x.shape
|
||
x = x.reshape(-1, self.in_features)
|
||
else:
|
||
original_shape = None
|
||
|
||
assert x.size(-1) == self.in_features
|
||
|
||
# 计算基础激活和B样条基函数(只计算一次)
|
||
base_activated = self.base_activation(x)
|
||
splines = self.b_splines(x)
|
||
|
||
# 计算基础输出
|
||
base_output = F.linear(base_activated, self.base_weight)
|
||
|
||
# 计算样条输出 - 使用内存优化的方式
|
||
spline_output = F.linear(
|
||
splines.view(x.size(0), -1),
|
||
self.scaled_spline_weight.view(self.out_features, -1),
|
||
)
|
||
|
||
# 合并输出
|
||
output = base_output + spline_output
|
||
|
||
# 如果需要,恢复原始形状
|
||
if original_shape is not None:
|
||
output = output.view(*original_shape[:-1], self.out_features)
|
||
|
||
return output
|
||
|
||
@torch.no_grad()
|
||
def update_grid(self, x: torch.Tensor, margin=0.01):
|
||
"""
|
||
更新网格 - 内存优化版本
|
||
"""
|
||
# 保存原始形状并将输入重塑为2D
|
||
if x.dim() > 2:
|
||
original_shape = x.shape
|
||
x = x.reshape(-1, self.in_features)
|
||
else:
|
||
original_shape = None
|
||
|
||
assert x.size(-1) == self.in_features
|
||
batch = x.size(0)
|
||
|
||
# 使用内存优化的方式计算样条输出
|
||
splines = self.b_splines(x)
|
||
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
|
||
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
|
||
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
|
||
|
||
# 使用批量矩阵乘法计算未减少的样条输出
|
||
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
|
||
unreduced_spline_output = unreduced_spline_output.permute(
|
||
1, 0, 2
|
||
) # (batch, in, out)
|
||
|
||
# 对每个通道单独排序以收集数据分布
|
||
x_sorted = torch.sort(x, dim=0)[0]
|
||
grid_adaptive = x_sorted[
|
||
torch.linspace(
|
||
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
|
||
)
|
||
]
|
||
|
||
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
|
||
grid_uniform = (
|
||
torch.arange(
|
||
self.grid_size + 1, dtype=torch.float32, device=x.device
|
||
).unsqueeze(1)
|
||
* uniform_step
|
||
+ x_sorted[0]
|
||
- margin
|
||
)
|
||
|
||
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
|
||
|
||
# 使用torch.cat而非concatenate以提高兼容性
|
||
grid = torch.cat(
|
||
[
|
||
grid[:1]
|
||
- uniform_step
|
||
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
|
||
grid,
|
||
grid[-1:]
|
||
+ uniform_step
|
||
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
|
||
],
|
||
dim=0,
|
||
)
|
||
|
||
self.grid.copy_(grid.T)
|
||
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
|
||
|
||
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
||
"""
|
||
计算正则化损失 - 内存优化版本
|
||
"""
|
||
# 使用更高效的方式计算L1正则化
|
||
l1_fake = self.spline_weight.abs().mean(-1)
|
||
regularization_loss_activation = l1_fake.sum()
|
||
|
||
# 计算熵正则化
|
||
p = l1_fake / (regularization_loss_activation + 1e-8)
|
||
regularization_loss_entropy = -torch.sum(p * torch.log(p + 1e-8))
|
||
|
||
return (
|
||
regularize_activation * regularization_loss_activation
|
||
+ regularize_entropy * regularization_loss_entropy
|
||
)
|
||
|
||
|
||
class OptimizedKAN(nn.Module):
|
||
def __init__(
|
||
self,
|
||
layers_hidden,
|
||
grid_size=5,
|
||
spline_order=3,
|
||
scale_noise=0.1,
|
||
scale_base=1.0,
|
||
scale_spline=1.0,
|
||
base_activation=nn.SiLU,
|
||
grid_eps=0.02,
|
||
grid_range=[-1, 1],
|
||
):
|
||
super(OptimizedKAN, self).__init__()
|
||
self.grid_size = grid_size
|
||
self.spline_order = spline_order
|
||
|
||
self.layers = nn.ModuleList()
|
||
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
|
||
self.layers.append(
|
||
OptimizedKANLinear(
|
||
in_features,
|
||
out_features,
|
||
grid_size=grid_size,
|
||
spline_order=spline_order,
|
||
scale_noise=scale_noise,
|
||
scale_base=scale_base,
|
||
scale_spline=scale_spline,
|
||
base_activation=base_activation,
|
||
grid_eps=grid_eps,
|
||
grid_range=grid_range,
|
||
)
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor, update_grid=False):
|
||
# 保存原始形状以便后续处理
|
||
original_shape = x.shape if x.dim() > 2 else None
|
||
|
||
for layer in self.layers:
|
||
if update_grid:
|
||
layer.update_grid(x)
|
||
x = layer(x)
|
||
return x
|
||
|
||
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
||
return sum(
|
||
layer.regularization_loss(regularize_activation, regularize_entropy)
|
||
for layer in self.layers
|
||
)
|
||
|
||
|
||
class OptimizedKANForecaster(nn.Module):
|
||
def __init__(
|
||
self,
|
||
input_features,
|
||
hidden_sizes=[64, 128, 64],
|
||
output_size=1,
|
||
grid_size=5,
|
||
spline_order=3,
|
||
dropout_rate=0.1,
|
||
output_sequence_length=1,
|
||
scale_noise=0.1,
|
||
scale_base=1.0,
|
||
scale_spline=1.0,
|
||
base_activation=nn.SiLU,
|
||
grid_eps=0.02,
|
||
grid_range=[-1, 1],
|
||
):
|
||
super(OptimizedKANForecaster, self).__init__()
|
||
|
||
# 输入投影层
|
||
self.input_projection = nn.Linear(input_features, hidden_sizes[0])
|
||
|
||
# 优化的KAN层
|
||
layers_hidden = [hidden_sizes[0]] + hidden_sizes + [hidden_sizes[-1]]
|
||
self.kan = OptimizedKAN(
|
||
layers_hidden=layers_hidden,
|
||
grid_size=grid_size,
|
||
spline_order=spline_order,
|
||
scale_noise=scale_noise,
|
||
scale_base=scale_base,
|
||
scale_spline=scale_spline,
|
||
base_activation=base_activation,
|
||
grid_eps=grid_eps,
|
||
grid_range=grid_range,
|
||
)
|
||
|
||
# 输出层
|
||
self.output_layer = nn.Linear(hidden_sizes[-1], output_sequence_length)
|
||
|
||
self.dropout = nn.Dropout(dropout_rate)
|
||
self.output_sequence_length = output_sequence_length
|
||
|
||
def forward(self, x, update_grid=False):
|
||
"""
|
||
前向传播 - 内存优化版本
|
||
x形状: [batch_size, seq_length, features]
|
||
"""
|
||
batch_size, seq_len, features = x.shape
|
||
|
||
# 经过输入投影
|
||
x_reshaped = x.reshape(-1, features)
|
||
x = self.input_projection(x_reshaped)
|
||
x = F.relu(x)
|
||
x = self.dropout(x)
|
||
|
||
# 重塑为3D张量以保留批次和时间步信息
|
||
x = x.view(batch_size, seq_len, -1)
|
||
|
||
# 通过KAN网络 - 内部会处理维度转换
|
||
x = self.kan(x, update_grid)
|
||
|
||
# 聚合时间维度(取最后一个时间步)
|
||
x = x[:, -1, :]
|
||
|
||
# 输出层
|
||
x = self.output_layer(x)
|
||
|
||
return x
|
||
|
||
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
|
||
"""
|
||
计算KAN网络的正则化损失
|
||
"""
|
||
return self.kan.regularization_loss(
|
||
regularize_activation=regularize_activation,
|
||
regularize_entropy=regularize_entropy
|
||
) |