优秀的编程知识分享平台

网站首页 > 技术文章 正文

TF Lite Model Maker: 构建安卓图片分类器

nanyue 2024-12-28 14:35:47 技术文章 5 ℃

TensorFlow最近才通过livestream(由于COVID-19全球流行病)结束了一年一度的Dev Summit,并发布了许多令人兴奋的公告,其中大多数都集中在将机器学习推向新的高度。从一个强大的新发行版的核心TensorFlow平台(TF2.2)到新的Google Cloud AI平台,使TensorFlow在生产中的使用更加容易,更加广泛。

但这并不是这篇文章的重点。我们将深入探究今年的突破性公告之一:TensorFlow Lite Model Maker。

使用TF Lite Model Maker(它被放入TF Lite support库中)为移动和边缘设备构建模型非常容易。此外,Android Studio 4.1(目前是Canary版本),具有新的针对TF Lite模型的代码生成功能,可以自动生成TF Lite模型的Java包装类,从而简化了移动机器学习开发人员的模型开发和部署过程。

TensorFlow Lite是一个轻量级的跨平台解决方案,用于在移动和嵌入式设备上部署ML模型。如果您想了解今年TF Dev Summit上有关TensorFlow Lite的所有新闻和发布,我肯定建议您查看以下资源:

https://heartbeat.fritz.ai/tensorflow-dev-summit-2020-tensorflow-lite-19dde3153335

TensorFlow Lite Model Maker

TF Lite Model Maker是一个Python API,它使从头构建机器学习模型变得轻而易举。只需要5行代码(不包括imports),如下所示:

是的,我们将使用NSFW数据集。

在上面的演示中,我们加载了一个数据集并将其拆分为训练集和测试集。随后,我们训练,评估,并输出TF Lite model,以及标签(从子文件夹中取出)。

在幕后,Model Maker API使用迁移学习,用不同数据集和类别来重新训练模型。默认情况下,Model Maker API使用EfficientNet-Lite0作为基础模型。

EfficientNet-Lite是最近才发布的,属于能够在边缘设备上实现最新精度的图像分类模型家族。下图显示了EfficientNet-Lite模型的精度与大小比较,并将其与MobileNet和ResNet进行了比较。

来自文档,EfficientNet-Lite针对移动推理进行了优化

Model Maker API还允许我们切换底层模型。例如:

model = image_classifier.create(train_data,
model_spec=mobilenet_v2_spec, validation_data=validation_data)

或者,我们也可以从TensorFlow Hub传递模型,使用自定义的input shape,如下所示:

inception_v3_spec = ImageModelSpec(uri='tfhub_url_goes_here')
inception_v3_spec.input_image_shape = [299, 299]
//pass this spec into model_spec

我们还可以在Model Maker API的create函数中微调训练超参数,如epoch、dropout_rate和batch_size。

model = image_classifier.create(train_data, epochs=10)

现在我们已经很好地了解了modelmakerapi的核心功能,现在让我们收紧运行上述Python脚本所需的依赖关系。

升级TensorFlow

确保您运行的是Python3.6或更高版本,并且在你的macOS上安装了最新的pip版本。TensorFlow 2软件包要求pip版本>19.0。随后,pip安装以下更新TensorFlow:

pip install --user --upgrade tensorflow

注意:建议在virtual环境中安装Python包。使用virtual环境,您可以测试不同版本的库。要了解这个过程,可以查看本教程。

在我们的终端上,我们使用以下命令快速测试是否安装了最新的TensorFlow版本:

python3 -c 'import tensorflow as tf; print(tf.__version__)'

安装Model Maker库

在终端上执行如下命令安装Model Maker库:

pip3 install
git+https://github.com/tensorflow/examples.git#egg=tensorflow-examples[model_maker]

现在一切都准备好了,这意味着是时候训练我们的模型了。只需从macOS终端运行Python脚本。对于这个演示,我们使用了来自Kaggle的一个适当的NSFW数据集。一旦我们的模型准备好了,就可以导入到我们新的Android Studio项目中了。

  • 首先,我们直接从import菜单导入tflite模型,并将其放在ml文件夹中。只需转到File > New > Other > TensorFlow Lite Model

注意:在这样一个场景中,您不像上面那样导入模型,而是直接将其放在assets文件夹中,模型绑定将被禁用,Android Studio的自动代码生成不会为分类器创建Java包装类,除非您将其移到ml文件夹中。


  • 其次,Android Studio现在有一个模型查看器,可以显示元数据摘要-输入和输出张量、这些张量的描述以及示例代码,如下所示:

默认情况下,ModelMaker API只生成由输入和输出形状组成的最小元数据。为了扩展和添加更多的上下文,如作者、版本、许可证以及输入和输出描述,我们可以利用新的扩展元数据特性(目前处于试验阶段)。

打开ML Model binding

尽管将tflite模型放在ml目录中,但model binding也不会自动启用。您需要在您的app的build.gradle 中添加buildFeatures和aaptOptions来打开:

android{
 buildFeatures {
        mlModelBinding true
    }
    aaptOptions {
        noCompress "tflite"
    }
}

我们的模型分类器现在可供我们运行推断。是时候build.gradle文件中添加tensorflow-lite依赖了。

dependencies {

  implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
  implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly'
  implementation 'org.tensorflow:tensorflow-lite-support:0.0.0-nightly'
  implementation 'org.apache.commons:commons-compress:1.19'
}

设置Activity Layout

现在是时候在我们的Activity中放置UI元素了。简单一点,我们的 activity_main.xml文件由一个RecyclerView和一个Button组成:

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context=".MainActivity">

    <androidx.recyclerview.widget.RecyclerView
        android:id="@+id/recyclerView"
        android:layout_width="match_parent"
        android:layout_above="@+id/btnClassifier"
        android:layout_height="match_parent" />


    <Button
        android:id="@+id/btnClassifier"
        android:text="Run Classifier"
        android:layout_alignParentBottom="true"
        android:layout_centerHorizontal="true"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"/>


</RelativeLayout>

为了填充RecyclerView的adapter,我们需要一个模型。下面的Kotlin数据类包含一个图像、预测文本,以及一个布尔标志,用于指示输入图像是否为NSFW。

data class DataModel(var drawableID: Int, 
                     var isNSFW: Boolean,
                     var prediction: String)

以下XML代码表示RecyclerView的每一行的布局:

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="wrap_content"
    android:orientation="horizontal"
    android:padding="10dp">
    
    <ImageView
        android:id="@+id/imageView"
        android:layout_width="100dp"
        android:layout_height="100dp"
        android:scaleType="centerCrop" />
    
    <TextView
        android:id="@+id/tvPrediction"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_gravity="center"
        android:gravity="end|center_vertical"
        android:textSize="20sp" />

</LinearLayout>

现在我们已经创建了数据模型和视图,是时候将它们提供给RecyclerView的适配器了。

设置RecyclerView的Adapter

以下代码创建RecyclerView的适配器类:

import android.content.Context
import android.graphics.Color
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import androidx.core.content.ContextCompat
import androidx.recyclerview.widget.RecyclerView
import kotlinx.android.synthetic.main.item_row.view.*


class RecyclerViewAdapter(val items: ArrayList<DataModel>, val context: Context) :
    RecyclerView.Adapter<ViewHolder>() {

    override fun getItemCount(): Int {
        return items.size
    }

    override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder {
        return ViewHolder(
            LayoutInflater.from(context).inflate(
                R.layout.item_row,
                parent,
                false
            )
        )
    }

    override fun onBindViewHolder(holder: ViewHolder, position: Int) {

        val data = items.get(position)

        holder.tvPrediction?.text = data.prediction

        val image = ContextCompat.getDrawable(context, data.drawableID)
        holder.imageView.setImageDrawable(image)

        if (data.isNSFW) {
            holder.imageView.setColorFilter(Color.BLACK)
        } else {
            holder.imageView.setColorFilter(Color.TRANSPARENT)
        }
    }
}

class ViewHolder(view: View) : RecyclerView.ViewHolder(view) {
    val tvPrediction = view.tvPrediction
    val imageView = view.imageView
}

我们根据NSFW的输出在ImageView上设置了一个颜色过滤器(NSFW的图像由于明显的原因隐藏在黑色中)。

最后,来到了我们的MainActivity.kt,在这里我们初始化上面的适配器,更重要的是,对图像列表运行推断。

运行TF Lite图像分类器

为了运行模型,我们需要预处理输入以满足模型的要求。TensorFlow lite有大把的内置预处理方法。要使用它们,我们首先需要初始化ImageProcessor,然后添加所需的运算符:

预处理输入图像

在下面的代码中,我们将输入图像的大小调整为224×224,即模型要求的输入形状的尺寸:

val imageProcessor = ImageProcessor.Builder()
    .add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
    .build()

var tImage = TensorImage(DataType.FLOAT32)

tImage.load(bitmap)
tImage = imageProcessor.process(tImage)

TensorImage是喂给我们的TensorFlow Lite模型的输入。但是,在我们运行推断之前,让我们创建一个后处理器来规范化输出概率。

设置后处理器

后处理器基本上是一个容器,它将对结果进行反量化:

val probabilityProcessor =
    TensorProcessor.Builder().add(NormalizeOp(0f, 255f)).build()

运行推断

下面几行代码实例化从模型自动生成的分类器,传递输入张量图像,并在outputBuffer中获得结果:

val model = NsfwClassifier.newInstance(this@MainActivity)
val outputs =
    model.process(probabilityProcessor.process(tImage.tensorBuffer))
val outputBuffer = outputs.outputFeature0AsTensorBuffer
val tensorLabel = TensorLabel(labelsList, outputBuffer)

TensorLabel用于将相关概率映射到标签。在我们的模型中,只有两个标签:”NSFW”和”SFW”。我们已经把它们设置给labelsList。在不同的场景中,您可以解析labels.txt文件来获取所有类别的信息,比如这里。

最后,使用mapWithFloatValue函数,我们可以获得NSFW和SFW类别的概率。

MainActivity.kt的完整代码如下所示。它在每个图像上运行上述图像分类器,并使用相应的数据更改更新RecyclerView适配器:

class MainActivity : AppCompatActivity() {

    val dataArray: ArrayList<DataModel> = ArrayList()
    val labelsList = arrayListOf("NSFW", "SFW")

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        btnClassifier.setOnClickListener {
            val iterate = dataArray.listIterator()
            while (iterate.hasNext()) {
                val oldValue = iterate.next()
                runImageClassifier(oldValue)
            }
            recyclerView.adapter?.notifyDataSetChanged()
        }

        populateData()
        recyclerView.layoutManager = LinearLayoutManager(this)
        recyclerView.adapter = RecyclerViewAdapter(dataArray, this)

    }

    fun populateData()
    {
        dataArray.add(DataModel(R.drawable.sfw_1,true,""))
        dataArray.add(DataModel(R.drawable.nsfw,true,""))
        dataArray.add(DataModel(R.drawable.nsfw2,true,""))
        dataArray.add(DataModel(R.drawable.sfw,true,""))
    }

    fun runImageClassifier(data: DataModel)
    {

        val bitmap =
            BitmapFactory.decodeResource(applicationContext.resources, data.drawableID)

        try {

            val probabilityProcessor =
                TensorProcessor.Builder().add(NormalizeOp(0f, 255f)).build()

            val imageProcessor = ImageProcessor.Builder()
                .add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
                .build()
            
            var tImage = TensorImage(DataType.FLOAT32)

            tImage.load(bitmap)
            tImage = imageProcessor.process(tImage)
            val model = NsfwClassifier.newInstance(this@MainActivity)
            val outputs =
                model.process(probabilityProcessor.process(tImage.tensorBuffer))
            val outputBuffer = outputs.outputFeature0AsTensorBuffer
            val tensorLabel = TensorLabel(labelsList, outputBuffer)

            val nsfwProbability = tensorLabel.mapWithFloatValue.get("NSFW")
            if (nsfwProbability?.compareTo(0.5)!! < 0){
                data.isNSFW = false
            }
            data.prediction =  "NSFW : "+ tensorLabel.mapWithFloatValue.get("NSFW")
            

        } catch (e: Exception) {
            Log.d("TAG", "Exception is " + e.localizedMessage)
        }
    }
}

下面是上述应用程序的实际输出:

只有SFW图像被显示

结束语

可以说,Model Maker Python库将继续存在,并将被希望在设备上快速部署ML模型的移动开发人员广泛使用。

对于那些了解苹果机器学习技术的人来说,TF Lite Model Maker与Create ML类似,至少在理论上是这样。目前,Model Maker API只支持图像和文本分类用例,目标检测和QR阅读器预计很快就会推出。

Android Studio对ML模型绑定和自动代码生成的支持消除了与ByteBuffer交互的需要,正如我们在之前的TensorFlow Lite Android教程中所做的那样。

扩展元数据(在编写时处于实验阶段)还允许我们生成定制的、特定于平台的包装器代码,从而进一步减少了我们需要编写的例行代码的数量。在将来的教程中,我们将研究自定义代码生成以及更多内容。

以上教程的完整源代码可以在这个GitHub仓库 https://github.com/anupamchugh/AndroidTFLiteModelMaker (或者码云克隆 https://gitee.com/freehug/anupamchugh-android-tflite-model-maker )中找到。

感谢阅读。

原文链接

Tags:

最近发表
标签列表