Skip to content

Кастомные middleware в Axum

Middleware в Axum позволяют выполнять действия до и после обработки запроса обработчиками. В этом разделе мы рассмотрим, как создавать собственные middleware для разных сценариев использования.

Содержание

Основы middleware в Axum

В Axum middleware реализуются с помощью комбинации слоёв Tower и обработчиков. Есть два основных подхода к созданию middleware:

  1. Функциональные middleware — простые функции, которые обрабатывают запросы
  2. Tower middleware — более мощные, на основе трейта Service

Функциональные middleware

Самый простой способ создать middleware в Axum — использовать функции:

rust
use axum::{
    middleware::{self, Next},
    response::Response,
    routing::get,
    Router,
    body::Body,
    http::Request,
};
use std::time::Instant;

// Middleware для замера времени обработки запроса
async fn timing_middleware(
    req: Request<Body>,
    next: Next<Body>,
) -> Response {
    let start = Instant::now();
    
    // Передача запроса следующему обработчику в цепочке
    let response = next.run(req).await;
    
    // Измерение после выполнения запроса
    let latency = start.elapsed();
    println!("Запрос обработан за: {:?}", latency);
    
    // Возврат полученного ответа
    response
}

// Применение middleware к маршрутизатору
let app = Router::new()
    .route("/", get(|| async { "Привет, мир!" }))
    .layer(middleware::from_fn(timing_middleware));

Middleware на основе Tower Service

Для более сложных сценариев можно создать middleware, реализующий трейт Service:

rust
use axum::{
    body::Body,
    http::{Request, StatusCode},
    response::{IntoResponse, Response},
    Router,
    routing::get,
};
use futures::future::BoxFuture;
use std::{
    task::{Context, Poll},
    sync::{Arc, Mutex},
};
use tower::{Layer, Service};

// Структура для хранения статистики
#[derive(Default, Clone)]
struct RequestCounter {
    count: Arc<Mutex<usize>>,
}

impl RequestCounter {
    fn new() -> Self {
        Self {
            count: Arc::new(Mutex::new(0)),
        }
    }
    
    fn increment(&self) -> usize {
        let mut count = self.count.lock().unwrap();
        *count += 1;
        *count
    }
    
    fn get_count(&self) -> usize {
        *self.count.lock().unwrap()
    }
}

// Структура middleware
struct CounterMiddleware<S> {
    inner: S,
    counter: RequestCounter,
}

// Реализация Service для нашего middleware
impl<S, ReqBody> Service<Request<ReqBody>> for CounterMiddleware<S>
where
    S: Service<Request<ReqBody>, Response = Response> + Clone + Send + 'static,
    S::Future: Send + 'static,
    ReqBody: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
        // Увеличиваем счетчик
        let count = self.counter.increment();
        println!("Запрос #{}: {}", count, request.uri());
        
        // Клонируем внутренний сервис для использования в future
        let mut inner = self.inner.clone();
        
        Box::pin(async move {
            // Вызываем внутренний сервис
            let response = inner.call(request).await?;
            Ok(response)
        })
    }
}

// Layer для создания нашего middleware
#[derive(Clone)]
struct CounterLayer {
    counter: RequestCounter,
}

impl CounterLayer {
    fn new() -> Self {
        Self {
            counter: RequestCounter::new(),
        }
    }
}

impl<S> Layer<S> for CounterLayer {
    type Service = CounterMiddleware<S>;

    fn layer(&self, service: S) -> Self::Service {
        CounterMiddleware {
            inner: service,
            counter: self.counter.clone(),
        }
    }
}

// Обработчики для демонстрации
async fn hello() -> &'static str {
    "Привет, мир!"
}

async fn get_count(Extension(counter): Extension<RequestCounter>) -> String {
    format!("Всего запросов: {}", counter.get_count())
}

// Создание приложения с middleware
let counter = RequestCounter::new();
let counter_layer = CounterLayer {
    counter: counter.clone(),
};

let app = Router::new()
    .route("/", get(hello))
    .route("/count", get(get_count))
    .layer(counter_layer)
    .layer(Extension(counter));

Совместное использование данных

Для передачи данных между middleware и обработчиками можно использовать Extension:

rust
use axum::{
    extract::Extension,
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::get,
    Router,
    body::Body,
    http::Request,
};

// Структура для передачи данных
#[derive(Clone, Default)]
struct RequestMetadata {
    user_agent: Option<String>,
    request_id: String,
    path: String,
}

// Middleware, добавляющий метаданные в запрос
async fn metadata_middleware(
    mut req: Request<Body>,
    next: Next<Body>,
) -> Response {
    let path = req.uri().path().to_string();
    
    let user_agent = req
        .headers()
        .get(header::USER_AGENT)
        .and_then(|v| v.to_str().ok())
        .map(|s| s.to_string());
    
    // Генерируем уникальный ID запроса
    let request_id = uuid::Uuid::new_v4().to_string();
    
    // Создаем структуру метаданных
    let metadata = RequestMetadata {
        user_agent,
        request_id,
        path,
    };
    
    // Добавляем в расширения запроса
    req.extensions_mut().insert(metadata);
    
    // Продолжаем выполнение запроса
    next.run(req).await
}

// Обработчик, использующий метаданные
async fn handler(
    Extension(metadata): Extension<RequestMetadata>,
) -> impl IntoResponse {
    format!(
        "Path: {}\nRequest ID: {}\nUser-Agent: {}",
        metadata.path,
        metadata.request_id,
        metadata.user_agent.unwrap_or_else(|| "Unknown".to_string())
    )
}

// Создание приложения
let app = Router::new()
    .route("/", get(handler))
    .layer(middleware::from_fn(metadata_middleware));

Примеры практичных middleware

1. Middleware для логирования

rust
use axum::{
    body::{Body, Bytes},
    http::{Request, Response, StatusCode},
    middleware::Next,
};
use futures::StreamExt;
use std::time::Instant;
use tower::ServiceBuilder;
use tracing::{info, error};

// Middleware для логирования
async fn request_logging_middleware(
    req: Request<Body>,
    next: Next<Body>,
) -> Result<Response, StatusCode> {
    let method = req.method().clone();
    let uri = req.uri().clone();
    
    // Получаем IP-адрес клиента из заголовков
    let forwarded_for = req
        .headers()
        .get("X-Forwarded-For")
        .and_then(|v| v.to_str().ok())
        .unwrap_or("-");
    
    let start_time = Instant::now();
    info!(
        ">> {} {} - от {}",
        method, uri, forwarded_for
    );
    
    // Передаем управление следующему обработчику
    let response = next.run(req).await;
    
    // Логируем информацию о результате запроса
    let status = response.status();
    let latency = start_time.elapsed();
    
    if status.is_success() {
        info!(
            "<< {} {} - статус {} (за {:?})",
            method, uri, status, latency
        );
    } else {
        error!(
            "<< {} {} - статус {} (за {:?})",
            method, uri, status, latency
        );
    }
    
    Ok(response)
}

// Применение с помощью ServiceBuilder
let app = Router::new()
    // маршруты...
    .layer(
        ServiceBuilder::new()
            .layer(middleware::from_fn(request_logging_middleware))
            // другие middleware...
    );

2. Middleware для CORS

rust
use axum::{
    http::{header, Method, HeaderValue, Response, StatusCode},
    middleware::Next,
};

// Middleware для настройки CORS
async fn cors_middleware(
    req: Request<Body>,
    next: Next<Body>,
) -> Result<Response<Body>, StatusCode> {
    // Проверяем, является ли запрос preflight
    let is_preflight = req.method() == Method::OPTIONS;
    
    // Выполняем обычный запрос, если это не preflight
    let mut response = if is_preflight {
        Response::builder()
            .status(StatusCode::NO_CONTENT)
            .body(Body::empty())
            .unwrap()
    } else {
        next.run(req).await
    };
    
    // Добавляем CORS заголовки к ответу
    let headers = response.headers_mut();
    
    headers.insert(
        header::ACCESS_CONTROL_ALLOW_ORIGIN,
        HeaderValue::from_static("*"),
    );
    
    headers.insert(
        header::ACCESS_CONTROL_ALLOW_METHODS,
        HeaderValue::from_static("GET, POST, PUT, DELETE, OPTIONS"),
    );
    
    headers.insert(
        header::ACCESS_CONTROL_ALLOW_HEADERS,
        HeaderValue::from_static("Content-Type, Authorization"),
    );
    
    Ok(response)
}

3. Middleware для аутентификации

rust
use axum::{
    extract::TypedHeader,
    headers::{authorization::Bearer, Authorization},
    http::{Request, StatusCode},
    middleware::Next,
    response::{Response, IntoResponse},
    RequestPartsExt,
};
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use serde::{Deserialize, Serialize};

// Структура данных JWT
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
    sub: String,
    exp: usize,
    role: String,
}

// Middleware для проверки JWT
async fn auth_middleware(
    mut req: Request<Body>,
    next: Next<Body>,
) -> Result<Response, impl IntoResponse> {
    // Извлекаем заголовок Authorization
    let auth_header = req
        .headers()
        .get(header::AUTHORIZATION)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| {
            if value.starts_with("Bearer ") {
                Some(value[7..].to_string())
            } else {
                None
            }
        });
    
    // Проверяем наличие и валидность токена
    let token = match auth_header {
        Some(token) => token,
        None => return Err((
            StatusCode::UNAUTHORIZED,
            "Отсутствует токен авторизации".to_string(),
        )),
    };
    
    // Проверяем JWT
    let claims = match verify_token(&token) {
        Ok(claims) => claims,
        Err(e) => return Err((
            StatusCode::UNAUTHORIZED,
            format!("Недействительный токен: {}", e),
        )),
    };
    
    // Добавляем информацию о пользователе в расширения запроса
    req.extensions_mut().insert(claims);
    
    // Продолжаем выполнение запроса
    Ok(next.run(req).await)
}

// Функция проверки JWT токена
fn verify_token(token: &str) -> Result<Claims, String> {
    // В реальном приложении ключ должен храниться в секретном месте
    let secret = "your_jwt_secret";
    
    let token_data = decode::<Claims>(
        token,
        &DecodingKey::from_secret(secret.as_bytes()),
        &Validation::new(Algorithm::HS256),
    )
    .map_err(|e| e.to_string())?;
    
    Ok(token_data.claims)
}

// Защищенный обработчик, использующий информацию из JWT
async fn protected_handler(
    Extension(claims): Extension<Claims>,
) -> impl IntoResponse {
    format!("Привет, {}! Ваша роль: {}", claims.sub, claims.role)
}

// Применение middleware только к определенным маршрутам
let app = Router::new()
    .route("/public", get(public_handler))
    .route(
        "/protected",
        get(protected_handler)
            .layer(middleware::from_fn(auth_middleware))
    );

Лучшие практики

1. Разделение ответственности

Каждый middleware должен отвечать за одну конкретную задачу:

rust
// Правильно: отдельные middleware для разных задач
let app = Router::new()
    .route("/", get(handler))
    .layer(middleware::from_fn(error_handling_middleware))
    .layer(middleware::from_fn(logging_middleware))
    .layer(middleware::from_fn(auth_middleware));

// Неправильно: один middleware делает всё
async fn do_everything_middleware(
    req: Request<Body>,
    next: Next<Body>,
) -> Response {
    // Логирование
    // Аутентификация
    // Обработка ошибок
    // ...
}

2. Порядок middleware имеет значение

Middleware выполняются в порядке "снаружи внутрь" - первыми обрабатываются те, что указаны последними:

rust
let app = Router::new()
    .route("/", get(handler))
    .layer(middleware::from_fn(first_middleware))   // Выполнится третьим
    .layer(middleware::from_fn(second_middleware))  // Выполнится вторым
    .layer(middleware::from_fn(third_middleware));  // Выполнится первым

3. Используйте асинхронные операции эффективно

rust
// Плохо: блокирующая операция в middleware
async fn bad_middleware(
    req: Request<Body>,
    next: Next<Body>,
) -> Response {
    std::thread::sleep(std::time::Duration::from_millis(100));
    next.run(req).await
}

// Хорошо: асинхронная операция
async fn good_middleware(
    req: Request<Body>,
    next: Next<Body>,
) -> Response {
    tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
    next.run(req).await
}

4. Обработка ошибок в middleware

rust
async fn error_handling_middleware(
    req: Request<Body>,
    next: Next<Body>,
) -> Response {
    // Оборачиваем выполнение следующего middleware в try-catch
    let response = match next.run(req).await {
        Ok(response) => response,
        Err(error) => {
            // Преобразуем ошибку в HTTP-ответ
            error_to_response(error)
        }
    };
    
    response
}

5. Комбинирование middleware

Используйте tower::ServiceBuilder для комбинирования нескольких middleware:

rust
use tower::ServiceBuilder;

let app = Router::new()
    .route("/", get(handler))
    .layer(
        ServiceBuilder::new()
            .layer(middleware::from_fn(first_middleware))
            .layer(middleware::from_fn(second_middleware))
            .layer(middleware::from_fn(third_middleware))
            .into_inner(),
    );

Создание собственных middleware в Axum позволяет структурировать логику приложения, отделяя сквозную функциональность (логирование, аутентификацию, обработку ошибок) от бизнес-логики в обработчиках. Используя возможности Tower, вы можете создавать модульные, тестируемые и поддерживаемые веб-приложения на Axum.