ShopTRAINING/server/models/optimized_kan_forecaster.py

390 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class OptimizedKANLinear(nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(OptimizedKANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
计算给定输入张量的B样条基函数 - 内存优化版本
"""
# 确保输入是2D张量
if x.dim() > 2:
original_shape = x.shape
x = x.reshape(-1, self.in_features)
else:
original_shape = None
assert x.size(-1) == self.in_features
grid: torch.Tensor = self.grid
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
计算插值给定点的曲线系数 - 内存优化版本
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
# 使用内存优化的方式计算最小二乘解
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
"""
前向传播 - 内存优化版本
"""
# 保存原始形状并将输入重塑为2D
if x.dim() > 2:
original_shape = x.shape
x = x.reshape(-1, self.in_features)
else:
original_shape = None
assert x.size(-1) == self.in_features
# 计算基础激活和B样条基函数只计算一次
base_activated = self.base_activation(x)
splines = self.b_splines(x)
# 计算基础输出
base_output = F.linear(base_activated, self.base_weight)
# 计算样条输出 - 使用内存优化的方式
spline_output = F.linear(
splines.view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
# 合并输出
output = base_output + spline_output
# 如果需要,恢复原始形状
if original_shape is not None:
output = output.view(*original_shape[:-1], self.out_features)
return output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
"""
更新网格 - 内存优化版本
"""
# 保存原始形状并将输入重塑为2D
if x.dim() > 2:
original_shape = x.shape
x = x.reshape(-1, self.in_features)
else:
original_shape = None
assert x.size(-1) == self.in_features
batch = x.size(0)
# 使用内存优化的方式计算样条输出
splines = self.b_splines(x)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
# 使用批量矩阵乘法计算未减少的样条输出
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# 对每个通道单独排序以收集数据分布
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
# 使用torch.cat而非concatenate以提高兼容性
grid = torch.cat(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
计算正则化损失 - 内存优化版本
"""
# 使用更高效的方式计算L1正则化
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
# 计算熵正则化
p = l1_fake / (regularization_loss_activation + 1e-8)
regularization_loss_entropy = -torch.sum(p * torch.log(p + 1e-8))
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
class OptimizedKAN(nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(OptimizedKAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.layers = nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
OptimizedKANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)
def forward(self, x: torch.Tensor, update_grid=False):
# 保存原始形状以便后续处理
original_shape = x.shape if x.dim() > 2 else None
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)
class OptimizedKANForecaster(nn.Module):
def __init__(
self,
input_features,
hidden_sizes=[64, 128, 64],
output_size=1,
grid_size=5,
spline_order=3,
dropout_rate=0.1,
output_sequence_length=1,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(OptimizedKANForecaster, self).__init__()
# 输入投影层
self.input_projection = nn.Linear(input_features, hidden_sizes[0])
# 优化的KAN层
layers_hidden = [hidden_sizes[0]] + hidden_sizes + [hidden_sizes[-1]]
self.kan = OptimizedKAN(
layers_hidden=layers_hidden,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
# 输出层
self.output_layer = nn.Linear(hidden_sizes[-1], output_sequence_length)
self.dropout = nn.Dropout(dropout_rate)
self.output_sequence_length = output_sequence_length
def forward(self, x, update_grid=False):
"""
前向传播 - 内存优化版本
x形状: [batch_size, seq_length, features]
"""
batch_size, seq_len, features = x.shape
# 经过输入投影
x_reshaped = x.reshape(-1, features)
x = self.input_projection(x_reshaped)
x = F.relu(x)
x = self.dropout(x)
# 重塑为3D张量以保留批次和时间步信息
x = x.view(batch_size, seq_len, -1)
# 通过KAN网络 - 内部会处理维度转换
x = self.kan(x, update_grid)
# 聚合时间维度(取最后一个时间步)
x = x[:, -1, :]
# 输出层
x = self.output_layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
计算KAN网络的正则化损失
"""
return self.kan.regularization_loss(
regularize_activation=regularize_activation,
regularize_entropy=regularize_entropy
)