Chapter 8 Patter Matching

函数式编程下的模式识别

主要是为了避免 if/else 结构的维护性差,可读性差。通过匹配对象,以及对对象的成员进行解析(Extract)进行匹配,来观察对象的类型或者执行一定的操作。文中主要用 Scala 进行演示,主要通过 case match 关键字来完成模式匹配。

例如有一个需求,Customer 对象的成员有 name, state, domain 都是 String 类型的,要求有一个构造方法,若其中任何一个字段为空字符串,返回 NULL。

最简单直接的方式就是用 if/else

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
if(name.isEmpty) {
println("Name cannot be blank")
null
} else if(state.isEmpty) {
println("State cannot be blank")
null
} else if(domain.isEmpty) {
println("Domain cannot be blank")
null
} else {
new Customer(
0,
name,
state,
domain,
true,
new Contract(Calendar.getInstance, true),
List()
)
}

在通过 case match 进行重构之前,自认为可以简单的通过卫语句来进行重写,可以让逻辑稍微清晰一点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
if(name.isEmpty) {
println("Name cannot be blank")
return null
}

if(state.isEmpty) {
println("State cannot be blank")
return null
}

if(domain.isEmpty) {
println("Domain cannot be blank")
return null
}

new Customer( 0, name, state, domain, true, new Contract(Calendar.getInstance, true), List())

现在使用 case 和 match 来重构,完成模式匹配。主要做法是将三个属性封装在一个对象或者数据结构里,如 Tuple 元组 (name, state, domain), 然后对这元祖进行匹配。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def createCustomer(name:String, state:String, domain:String)
: Customer = {
(name, state, string) match {
case ("", _, _,) => {
println("Name cannot be blank")
null
}

cae (_, "", _,) => {
println("State cannot be blank")
null
}

case (_, _, "") => {
println("Domain cannot be blank")
null
}

case _ => new Customer( 0, name, state, domain, true, new Contract(Calendar.getInstance, true), List())
}
}

继续来重构之前用尾递归写的一个方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def updateCustomerByIdList(initialIds : List[Customer],
ids : List[Integer],
cls : Customer => Customer) : List[Customer] = {
if(ids.size <= 0) {
initialIds
} else if(initialIds.size <= 0) {
initialIds
} else {
val precust = initialIds.find(cust => cust.customer_id == ids(0))
val cust = if(precust.isEmpty) { List() } else { List(cls(precust.get)) }
cust ::: updateCustomerByIdList(
initialIds.filter(cust => cust.customer_id == ids(0)),
ids.tail,
cls
)
}
}

同样使用元祖、case、match 来进行重构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def updateCustomerByIdList(initialIds : List[Customer],
ids : List[Integer],
cls : Customer => Customer) : List[Customer] = {
(initialIds, ids) match {
case (List(), _) => initialIds
case (_, List()) => initialIds
case _ => {
val precust = initialIds.find(cust => cust.customer_id == ids(0))
val cust = if(precust.isEmpty) {List()} else {
List(cls(precust.get))}
cust ::: updateCustomerByIdList(initialIds.filter(cust => cust.customer_id == ids(0)), ids.drop(1), cls)
}
}
}
}

但是,能否进一步减少这个方法的复杂程度呢?通过 Extracting Lists 来实现。

x::y 操作符出现在 case 语句中,告诉 Scala, 匹配的对象应该是一个 list, 这个 list 的首元素将被赋值给 x, 剩下的元素 (也是一个list) tail 将赋值给 y

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def updateCustomerByIdList(initialIds : List[Customer],
ids : List[Integer],
cls : Customer => Customer) : List[Customer] = {
(initialIds, ids) match {
case (List(), _) => initialIds
case (_, List()) => initialIds
case (_, id::tailIds)=> {
val precust = initialIds.find(cust => cust.customer_id == id)
val cust = if(precust.isEmpty) {List()} else {
List(cls(precust.get))}
cust ::: updateCustomerByIdList(initialIds.filter(cust => cust.customer_id == id), tailIds, cls)
}
}
}
}

上面的代码中,find 方法返回的是一个Option 对象,可以将它转换为 list, 集训进行 Extracting Lists

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def updateCustomerByIdList(initialIds : List[Customer],
ids : List[Integer],
cls : Customer => Customer) : List[Customer] = {
(initialIds, ids) match {
case (List(), _) => initialIds
case (_, List()) => initialIds
case (_, id::tailIds)=> {
val precust = initialIds.find(cust => cust.customer_id == id).toList
precust match {
case List() => updateCustomerByIdList(initialIds, tailIds, cls)
case cust ::: updateCustomerByIdList(initialIds.filter(cust => cust.customer_id == id), tailIds, cls)
}
}
}
}

None 和 Some 是 Option 接口的两个实现类,None 对象不包含任何对象,Some 包含了实际存在的对象。因此可以对 find 返回的 Option 对象直接进行模式匹配,包含两种模式 case None case Some

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def updateCustomerByIdList(initialIds : List[Customer],
ids : List[Integer],
cls : Customer => Customer) : List[Customer] = {
(initialIds, ids) match {
case (List(), _) => initialIds
case (_, List()) => initialIds
case (_, id::tailIds)=> {
val precust = initialIds.find(cust => cust.customer_id == id)
precust match {
case None => updateCustomerByIdList(initialIds, tailIds, cls)
case Some(cust) ::: updateCustomerByIdList(initialIds.filter(cust => cust.customer_id == id), tailIds, cls)
}
}
}
}

可以让之前的 createCustomer 直接返回一个 Option 对象,更加函数化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def createCustomer(name : String,
state : String,
domain : String) : Option[Customer] = {
def error(message : String) : Option[Customer] = {
println(message)
None
}
(name, state, domain) match {
case ("", _, _) => error("Name cannot be blank")
case (_, "", _) => error("State cannot be blank")
case (_, _, "") => error("Domain cannot be blank")
case _ => new Some(new Customer(
0,
name,
state,
domain,
true,
new Contract(Calendar.getInstance, true),
List()
)
)
}
}

用 case 关键字修饰类

1
2
3
4
5
6
7
8
case class Customer(val customer_id : Integer,
val name : String,
val state : String,
val domain : String,
val enabled : Boolean,
val contract : Contract,
val contacts : List[Contact]) {
}
1
2
3
4
5
6
7
8
9
10
11
12
def countEnabledCustomersWithNoEnabledContacts(customers : List[Customer], sum : Integer) : Integer = {
customers match {
case List() => sum

case cust :: custs => {
if (cust.enabled && cust.contacts.exists({ contact => contact.enabled}))
countEnabledCustomersWithNoEnabledContacts(custs, sum + 1)
else
countEnabledCustomersWithNoEnabledContacts(custs, sum)
}
}
}
1
2
3
4
5
6
7
8
9
10
11
def countEnabledCustomersWithNoEnabledContacts(customers : List[Customer], sum : Integer) : Integer = {
customers match {
case List() => sum

case Customer(_,_,_,_,true,_,cont) :: custs
if cont.exists({ contact => contact.enabled}) =>
countEnabledCustomersWithNoEnabledContacts(custs, sum + 1)

case cust :: custs => countEnabledCustomersWithNoEnabledContacts(custs, sum)
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def countEnabledCustomersWithNoEnabledContacts(customers : List[Customer], sum : Integer) : Integer = {
customers match {
case List() => sum

case Customer(_,_,_,_,true,_,List()) :: custs =>
countEnabledCustomersWithNoEnabledContacts(custs, sum)

case Customer(_,_,_,_,true,_,cont) :: custs
if cont.exists({ contact => contact.enabled}) =>
countEnabledCustomersWithNoEnabledContacts(custs, sum + 1)

case cust :: custs => countEnabledCustomersWithNoEnabledContacts(custs, sum)
}
}