切换语言为:繁体

SpringBoot 集成 TensorFlow : 本地实现图片合规性安全检测

  • 爱糖宝
  • 2024-08-31
  • 2049
  • 0
  • 0

一、简介

你是否还在为如何处理非法图片而感到困惑?在涉及用户文件上传的系统中,图片内容的审核变得至关重要。不当图片不仅影响用户体验,还可能带来法律风险。依赖外部服务进行审核会带来数据隐私问题和速度瓶颈。

为了解决这些问题,我们引入了一个技术亮点:基于 TensorFlow 模型的本地图片内容安全检测,并集成在 SpringBoot 应用中。这一方案可以在本地环境直接审核图片内容,无需外部 API,确保数据私密性,同时提升审核速度和系统稳定性。通过这种本地化的智能审核,系统能够更自主高效地处理非法图片,为用户提供更安全的体验。

二、NSFW 介绍

因为项目中会使用到 NSFW的模型,所以这里简单介绍一下

1. 描述

NSFW(Not Safe For Work)是一类用于检测不适合在公共场所或工作环境中展示内容的模型。这些模型通常用于识别包含成人内容、暴力或其他不宜公开展示的图片或视频。NSFW模型通过深度学习算法进行训练,能够自动检测这些不适合公开的内容,广泛应用于社交媒体平台、内容审核系统,以及其他需要过滤敏感内容的应用场景。

2. 评判指标

NSFW 模型通过以下几个分类指标来判断图片属于哪一类内容。用户可以设定对应的概率阈值来判定图片是否符合某一类:

  • DRAWINGS: 卡通或漫画图片,这类图片通常为手绘或数字绘图风格。

  • HENTAI: 带有情色成分的动画或漫画,包含成人内容但以动画风格呈现。

  • NEUTRAL: 正常的、适合公开展示的图像,无不当内容。

  • PORN: 色情图片,包含明显的成人内容和性行为。

  • SEXY: 暗示性强的图像,虽然不完全属于色情类别,但具有强烈的性暗示。

这些分类指标帮助系统自动化地对图像内容进行分类和过滤,从而有效地防止不适当内容的传播。

三、功能演示

  • 当我们上传一张正常的图片

SpringBoot 集成 TensorFlow : 本地实现图片合规性安全检测

  • 当我们上传一张具有违规行为的图片

SpringBoot 集成 TensorFlow : 本地实现图片合规性安全检测

四、编码实现

1. 引入依赖

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
</dependency>

2. 初始化NSFW模型

NSFW 模型可以通过 GitHub 获取,推荐使用 nsfwjs 项目中的模型。可以使用 Python 将模型转换为 TensorFlow 的 saved_model 格式,然后在 SpringBoot 应用中进行加载。当然,也可以直接找到已转换好的 NSFW saved_model 格式的模型进行加载。

注意事项:

  • 在打包后 new ClassPathResource("").getFile().getAbsolutePath(); 的访问会出现问题(文章末尾描述解决办法)

/**
 * NSFW 模型
 *
 * @author : YiFei
 */
@Getter
@Component
public class NSFWModelService {

    // 提供方法来获取 TensorFlow Session
    private Session session;

    @PostConstruct
    public void init() {
        // 加载 TensorFlow 模型
        try {
            String modelAbsolutePath = new ClassPathResource("tensorflow/saved_model/nsfw").getFile().getAbsolutePath();
            SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");
            this.session = model.session();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 在销毁 Bean 时关闭 TensorFlow Session
     */
    @PreDestroy
    public void closeSession() {
        this.session.close();
    }
}

如果需要输出模型的详细信息,可以加上以下代码在 SavedModelBundle model = SavedModelBundle.load(modelPath, "serve");

//            以下是获取模型 Inputs 数据格式 、输入张量名  , output 数据格式 、输出张量名
//            MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(model.metaGraphDef());
//            Map<String, SignatureDef> signatureDefMap = metaGraphDef.getSignatureDefMap();
//
//            for (Map.Entry<String, SignatureDef> entry : signatureDefMap.entrySet()) {
//                System.out.println("SignatureDef key: " + entry.getKey());
//
//                SignatureDef signatureDef = entry.getValue();
//                System.out.println("Inputs:");
//                for (Map.Entry<String, TensorInfo> inputEntry : signatureDef.getInputsMap().entrySet()) {
//                    String inputKey = inputEntry.getKey();
//                    TensorInfo inputTensorInfo = inputEntry.getValue();
//
//                    // 打印输入张量的名称
//                    System.out.println("  Key: " + inputKey);
//                    System.out.println("  Name: " + inputTensorInfo.getName());
//
//                    // 打印输入张量的形状
//                    if (inputTensorInfo.hasTensorShape()) {
//                        TensorShapeProto tensorShape = inputTensorInfo.getTensorShape();
//                        System.out.println("  Shape: " + tensorShape);
//                    }
//
//                    // 打印输入张量的数据类型
//                    System.out.println("  Data Type: " + inputTensorInfo.getDtype());
//                }
//
//                System.out.println("Outputs:");
//                for (Map.Entry<String, TensorInfo> outputEntry : signatureDef.getOutputsMap().entrySet()) {
//                    String outputKey = outputEntry.getKey();
//                    TensorInfo outputTensorInfo = outputEntry.getValue();
//
//                    // 打印输出张量的名称
//                    System.out.println("  Key: " + outputKey);
//                    System.out.println("  Name: " + outputTensorInfo.getName());
//
//                    // 打印输出张量的形状
//                    if (outputTensorInfo.hasTensorShape()) {
//                        TensorShapeProto tensorShape = outputTensorInfo.getTensorShape();
//                        System.out.println("  Shape: " + tensorShape.toString());
//                    }
//
//                    // 打印输出张量的数据类型
//                    System.out.println("  Data Type: " + outputTensorInfo.getDtype());
//                }
//            }

3. 编写工具类

尽管 TensorFlow API 已经相对简洁,但在实际使用中仍可能显得繁琐。为了解决这一问题,我们可以封装一个工具类,使接口更加友好,让开发者只需一行代码即可完成图片内容安全的校验,而无需编写大量冗余的代码。通过这个工具类,您可以更轻松地集成 NSFW 模型,并提高项目的开发效率。

工具类的具体实现读者无需深入理解,只需了解其输入输出信息以及如何使用即可。由于在之前的文章中工具类代码过多,影响了阅读体验,因此在这里我贴出了使用到的工具类源码以及大致介绍。

  • NSFWAnalyzerUtils.java

    • Map<String, String> getNsfwPredictions(MultipartFile file):该方法接收一个 MultipartFile 对象,先通过 Image.read() 将图片转换为 BufferedImage,然后提取图片的 RGB 值作为模型输入张量,最后解析模型输出张量并返回各个 NSFW 指标的概率。

    • boolean isNsfwFile(MultipartFile file):该方法与 getNsfwPredictions(MultipartFile file) 的处理过程类似,但增加了开发者设定的阈值判断。当检测到的非法指标概率达到或超过该阈值时,图片将被判定为非法。

  • TensorflowUtil.java

    • static String getModelPath(String classPathResource) 将Resource下的save_model文件转换到临时文件中,返回临时文件绝对路径。

4. 编写 RESTful 接口

@RestController
@RequestMapping("nsfw")
@RequiredArgsConstructor
public class NsfwController {

    private final NSFWAnalyzerUtils nsfwAnalyzerUtils;

    @Operation(summary = "图片检测")
    @PreventDuplicateSubmit
    @PostMapping("/check")
    public Result<Map<String, String>> nsfwCheck(MultipartFile file) {
        try {
            return Result.success(nsfwAnalyzerUtils.getNsfwPredictions(file));
        } catch (Exception e) {
            throw new ServiceException(ResultCode.FILE_ANALYZER_ERROR);
        }
    }

}

五、解决模型加载时部署问题

在项目被打成 Jar 包后,使用 new ClassPathResource("XXX").getFile().getAbsolutePath(); 是无法访问到资源的绝对路径的。为了解决这个问题,我们通过将 resource 中 save_model 目录下的所有文件转存到临时目录,再将临时目录的路径返回给模型进行加载。

0条评论

您的电子邮件等信息不会被公开,以下所有项均必填

OK! You can skip this field.