第 4 章 需求1: Top10 热门品类

创建离线子模块

创建子模块: sparkmall-offline

添加依赖

<dependencies>
    <dependency>
        <groupId>com.atguigu</groupId>
        <artifactId>sparkmall-common</artifactId>
        <version>1.0-SNAPSHOT</version>
    </dependency>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-core_2.11</artifactId>=-098
    </dependency>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-sql_2.11</artifactId>
    </dependency>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-hive_2.11</artifactId>
    </dependency>
</dependencies>

mysql 数据库准备

-- ----------------------------
-- create databse sparkmall
-- ----------------------------
drop database if EXISTS sparkmall;
create database sparkmall DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci;


-- ----------------------------
-- create table category_top10
-- ----------------------------
use sparkmall;
drop table if exists category_top10;
create table category_top10 (
    taskId text,
    category_id text,
    click_count bigint(20) default null,
    order_count bigint(20) default null,
    pay_count bigint(20) default null
)ENGINE=InnoDB DEFAULT CHARSET=utf8;

4.1 需求1: 简介

品类是指的产品的的分类, 一些电商品类分多级, 咱们的项目中品类类只有一级.

不同的公司可能对热门的定义不一样.

我们按照每个品类的 点击、下单、支付 的量来统计热门品类.

计算成功之后, 统计出来的数据存储到 Mysql 中.


4.2 需求1: 思路

思路1: 使用 sql 语句

分别统计每个品类点击的次数, 下单的次数和支付的次数.

1. categoryId => click_count
  select count(*) where click not null group by click_category_id

2. ...

缺点: 统计三个指标就需要遍历三次整表, 相当耗时. 不可行.

思路2: 遍历一次, 统计出来三个指标

我们应该遍历一次数据, 就可以统计出来三个指标.

  1. 遍历全部日志表, 根据品类 id 和操作类型分别累加. 需要用到累加器

    • 定义累加器
    • 当碰到订单和支付业务的时候注意拆分字段才能得到品类 id
  2. 遍历完成之后就得到每个每个品类 id 和操作类型的数量.

  3. 按照点击下单支付的顺序来排序

  4. 取出 Top10

  5. 写入到 Mysql 数据库


4.3 需求1: 具体实现

1. 工具类

JDBCUtil.scala

这个类放在sparkmall-common 模块下

package com.atguigu.sparkmall.common.util

import java.util.Properties

import com.alibaba.druid.pool.DruidDataSourceFactory

object JDBCUtil {

    val dataSource = initConnection()

    /**
      * 初始化的连接
      */
    def initConnection() = {
        val properties = new Properties()
        val config = ConfigurationUtil("config.properties")
        properties.setProperty("driverClassName", "com.mysql.jdbc.Driver")
        properties.setProperty("url", config.getString("jdbc.url"))
        properties.setProperty("username", config.getString("jdbc.user"))
        properties.setProperty("password", config.getString("jdbc.password"))
        properties.setProperty("maxActive", config.getString("jdbc.maxActive"))
        DruidDataSourceFactory.createDataSource(properties)
    }
    /**
      * 执行单条语句
      *
      * "insert into xxx values (?,?,?)"
      */
    def executeUpdate(sql: String, args: Array[Any]) = {
        val conn = dataSource.getConnection
        conn.setAutoCommit(false)
        val ps = conn.prepareStatement(sql)
        if (args != null && args.length > 0) {
            (0 until args.length).foreach {
                i => ps.setObject(i + 1, args(i))
            }
        }
        ps.executeUpdate
        conn.commit()
    }

    /**
      * 执行批处理
      */
    def executeBatchUpdate(sql: String, argsList: Iterable[Array[Any]]) = {
        val conn = dataSource.getConnection
        conn.setAutoCommit(false)
        val ps = conn.prepareStatement(sql)
        argsList.foreach {
            case args: Array[Any] => {
                (0 until args.length).foreach {
                    i => ps.setObject(i + 1, args(i))
                }
                ps.addBatch()
            }
        }
        ps.executeBatch()
        conn.commit()
    }
}

包对象 定义了判断字符串是否为空的方法

package com.atguigu.sparkmall

package object offline {
    def isNotEmpty(text: String): Boolean = text != null && text.length == 0

    def isEmpty(text: String): Boolean = !isNotEmpty(text)
}

2. bean 类

Condition用来封装从 hive 中读数据时的过滤条件

package com.atguigu.sparkmall.offline.bean

case class Condition(var startDate: String,
                     var endDate: String,
                     var startAge: Int,
                     var endAge: Int,
                     var professionals: String,
                     var city: String,
                     var gender: String,
                     var keywords: String,
                     var categoryIds: String,
                     var targetPageFlow: String)

CategoryCountInfo用来封装写入 Msyql 的数据.

package com.atguigu.sparkmall.offline.bean

case class CategoryCountInfo(taskId: String,
                             categoryId: String,
                             clickCount: Long,
                             orderCount: Long,
                             payCount: Long)

业务实现

1. 整个离线模块的入口类

OfflineApp.scala

package com.atguigu.sparkmall.offline

import java.util.UUID

import com.alibaba.fastjson.JSON
import com.atguigu.sparkmall.common.bean.UserVisitAction
import com.atguigu.sparkmall.common.util.ConfigurationUtil
import com.atguigu.sparkmall.offline.app.CategoryTop10App
import com.atguigu.sparkmall.offline.bean.Condition
import org.apache.spark.sql.SparkSession

object OfflineApp {
    def main(args: Array[String]): Unit = {


        val spark: SparkSession = SparkSession
            .builder()
            .master("local[*]")
            .appName("OfflineApp")
            .enableHiveSupport()
            .config("spark.sql.warehouse.dir", "hdfs://hadoop201:9000/user/hive/warehouse")
            .getOrCreate()
        val taskId = UUID.randomUUID().toString
        // 根据条件过滤取出需要的 RDD, 过滤条件定义在配置文件中
        val userVisitActionRDD = readUserVisitActionRDD(spark, readConditions)


        println("任务1: 开始")
        CategoryTop10App.statCategoryTop10(spark, userVisitActionRDD, taskId)
        println("任务1: 结束")

    }

    /**
      * 读取指定条件的 UserVisitActionRDD
      *
      * @param spark
      * @param condition
      */
    def readUserVisitActionRDD(spark: SparkSession, condition: Condition) = {
        var sql = s"select v.* from user_visit_action v join user_info u on v.user_id=u.user_id where 1=1"
        if (isNotEmpty(condition.startDate)) {
            sql += s" and v.date>='${condition.startDate}'"
        }
        if (isNotEmpty(condition.endDate)) {
            sql += s" and v.date<='${condition.endDate}'"
        }

        if (condition.startAge != 0) {
            sql += s" and u.age>=${condition.startAge}"
        }
        if (condition.endAge != 0) {
            sql += s" and u.age<=${condition.endAge}"
        }
        import spark.implicits._
        spark.sql("use sparkmall")
        spark.sql(sql).as[UserVisitAction].rdd
    }

    /**
      * 读取过滤条件
      *
      * @return
      */
    def readConditions: Condition = {
        // 读取配置文件
        val config = ConfigurationUtil("conditions.properties")
        // 读取到其中的 JSON 字符串
        val conditionString = config.getString("condition.params.json")
        // 解析成 Condition 对象
        JSON.parseObject(conditionString, classOf[Condition])

    }
}

2. 用到的 Map 累加器

package com.atguigu.sparkmall.offline.acc

import org.apache.spark.util.AccumulatorV2

import scala.collection.mutable

class MapAccumulator extends AccumulatorV2[(String, String), mutable.Map[(String, String), Long]] {
    val map = mutable.Map[(String, String), Long]()

    override def isZero: Boolean = map.isEmpty

    override def copy(): AccumulatorV2[(String, String), mutable.Map[(String, String), Long]] = {
        val newAcc = new MapAccumulator
        map.synchronized {
            newAcc.map ++= map
        }
        newAcc
    }

    override def reset(): Unit = map.clear


    override def add(v: (String, String)): Unit = {
        map(v) = map.getOrElseUpdate(v, 0) + 1
    }

    // otherMap: (1, click) -> 20  this: (1, click) -> 10         thisMap: (1,2) -> 30
    // otherMap: (1, order) -> 5                                  thisMap: (1,3) -> 5
    override def merge(other: AccumulatorV2[(String, String), mutable.Map[(String, String), Long]]): Unit = {
        val otherMap: mutable.Map[(String, String), Long] = other.value
        otherMap.foreach {
            kv => map.put(kv._1, map.getOrElse(kv._1, 0L) + kv._2)
        }
    }
    override def value: mutable.Map[(String, String), Long] = map
}

3. 需求1具体实现

package com.atguigu.sparkmall.offline.app

import com.atguigu.sparkmall.common.bean.UserVisitAction
import com.atguigu.sparkmall.common.util.JDBCUtil
import com.atguigu.sparkmall.offline.acc.MapAccumulator
import com.atguigu.sparkmall.offline.bean.CategoryCountInfo
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

object CategoryTop10App {

    // 统计热门品 Top10
    def statCategoryTop10(spark: SparkSession, userVisitActionRDD: RDD[UserVisitAction], taskId: String) = {
        // 1. 注册累加器
        val acc = new MapAccumulator
        spark.sparkContext.register(acc, "CategoryActionAcc")

        // 2. 遍历日志
        userVisitActionRDD.foreach {
            visitAction => {
                if (visitAction.click_category_id != -1) {
                    acc.add(visitAction.click_category_id.toString, "click")
                } else if (visitAction.order_category_ids != null) {
                    visitAction.order_category_ids.split(",").foreach {
                        oid => acc.add(oid, "order")
                    }
                } else if (visitAction.pay_category_ids != null) {
                    visitAction.pay_category_ids.split(",").foreach {
                        pid => acc.add(pid, "pay")
                    }
                }
            }
        }
        // 3. 遍历完成之后就得到每个每个品类 id 和操作类型的数量.  然后按照 CategoryId 进行进行分组
        val actionCountByCategoryIdMap = acc.value.groupBy(_._1._1)

        // 4. 聚合成 CategoryCountInfo 类型的集合
        val categoryCountInfoList = actionCountByCategoryIdMap.map {
            case (cid, actionMap) => CategoryCountInfo(
                taskId,
                cid,
                actionMap.getOrElse((cid, "click"), 0),
                actionMap.getOrElse((cid, "order"), 0),
                actionMap.getOrElse((cid, "pay"), 0)
            )
        }.toList

        // 5. 按照 点击 下单 支付 的顺序降序来排序
        val sortedCategoryInfoList = categoryCountInfoList.sortBy(info => (info.clickCount, info.orderCount, info.payCount))(Ordering.Tuple3(Ordering.Long.reverse, Ordering.Long.reverse, Ordering.Long.reverse))

        // 6. 截取前 10
        val top10 = sortedCategoryInfoList.take(10)

        // 7. 插入数据库
        val argsList = top10.map(info => Array(info.taskId, info.categoryId, info.clickCount, info.orderCount, info.payCount))
        JDBCUtil.executeBatchUpdate("insert into category_top10 values(?, ?, ?, ?, ?)", argsList)
    }

}

/*
1. 遍历全部日志表, 根据品类 id 和操作类型分别累加. 需要用到累加器

    - 定义累加器
        累加器用什么来保存? map
         key:  (categoryId, action)    元组来表示
         value: count
    - 当碰到订单和支付业务的时候注意拆分字段才能得到品类 id

2. 遍历完成之后就得到每个每个品类 id 和操作类型的数量. 然后按照 cid 进行聚合, 聚合成 CategoryCountInfo 类型

3. 按照 点击 下单 支付 的顺序来排序

4. 取出 Top10

5. 写入到 Mysql 数据库
*/
Copyright © 尚硅谷大数据 2019 all right reserved,powered by Gitbook
该文件最后修订时间: 2019-06-15 20:23:47

results matching ""

    No results matching ""