Spark DAGScheduler模块源码解析

2015-01-25 by guozhongxin

Spark DAGScheduler的背景知识

Spark Application在遇到action算子时,SparkContext会生成Job,并将构成DAG图将给DAG Scheduler解析成Stage。

Stage

Stage是Spark对DAG的划分,以此作为对作业的进行任务(task)划分和调度的依据。
可以这样理解Stage不需要shuffle是可以随意并发的, 所以stage的边界就是需要shuffle的地方。

下图是一个stage例子。 3

Stage有两种:

  • ShuffleMapStage
    这种Stage是以Shuffle为输出边界,其输入边界可以是从外部获取数据,也可以是另一个ShuffleMapStage的输出,其输出可以。是另一个Stage的开始ShuffleMapStage的最后Task就是ShuffleMapTask。在一个Job里可能有该类型的Stage,也可以能没有该类型Stage。
    上图Stage 1,Stage 2都属于ShuffleMapStage
  • ResultStage
    这种Stage是直接输出结果。其输入边界可以是从外部获取数据,也可以是另一个ShuffleMapStage的输出。ResultStage的最后Task就是ResultTask。在一个Job里必定有该类型Stage。一个Job含有一个或多个Stage,但至少含有一个ResultStage。

DAGScheduler

DAGScheduler主要功能如下:

  • 接收用户提交的job;
  • 将job根据类型划分为不同的stage,记录哪些RDD、Stage被物化,并在每一个stage内产生一系列的task,并封装成TaskSet;
  • 决定每个Task的最佳位置(任务在数据所在的节点上运行),并结合当前的缓存情况;将TaskSet提交给TaskScheduler;
  • 重新提交Shuffle输出丢失的Stage给TaskScheduler;

注:一个Stage内部的错误不是由shuffle输出丢失造成的,DAGScheduler是不管的,由TaskScheduler负责尝试重新提交task执行;

Spark DAGScheduler源码解析

DAGScheduler的创建是在用户定义一个新的SparkContext时进行的。(需要注意的是,在SparkContext中,TaskSchduler是在DAGScheduler之前生成的,即dagScheduler = new DAGScheduler(this)中的this.taskScheduler已经被生成,这个taskScheduler也是dagScheduler的一个成员变量)

@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
 dagScheduler = new DAGScheduler(this)
} catch {
 case e: Exception => throw
   new SparkException("DAGScheduler cannot be initialized due to %s".format(e.getMessage))
}

当执行输出算子的时候,spark会调用sc.runJob()方法,例如RDD.scala中定义的count():

def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum

跟进到SparkContext.scala中的runJob()方法,可以看到:

01 def runJob[T, U: ClassTag](
02     rdd: RDD[T],
03     func: (TaskContext, Iterator[T]) => U,
04     partitions: Seq[Int],
05     allowLocal: Boolean,
06     resultHandler: (Int, U) => Unit) {
07   if (dagScheduler == null) {
08     throw new SparkException("SparkContext has been shutdown")
09   }
10   val callSite = getCallSite
11   val cleanedFunc = clean(func)
12   logInfo("Starting job: " + callSite.shortForm)
13   val start = System.nanoTime
14   dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
15     resultHandler, localProperties.get)
16   logInfo(
17     "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s")
18   rdd.doCheckpoint()
19 }

sc.runJob()是调用的dagScheduler.runJob()方法。跟进到DAGScheduler.runJob()

01 def runJob[T, U: ClassTag](
02     rdd: RDD[T],
03     func: (TaskContext, Iterator[T]) => U,
04     partitions: Seq[Int],
05     callSite: CallSite,
06     allowLocal: Boolean,
07     resultHandler: (Int, U) => Unit,
08     properties: Properties = null)
09 {
10   val start = System.nanoTime
11   val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
12   waiter.awaitResult() match {
13     case JobSucceeded => {
14       logInfo("Job %d finished: %s, took %f s".format
15         (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
16     }
17     case JobFailed(exception: Exception) =>
18       logInfo("Job %d failed: %s, took %f s".format
19         (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
20       throw exception
21   }
22 }

当job被正常提交时,submitJob()返回一个JobWaiter的类,并产生一个JobSubmittedevent(事件)

val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessActor ! JobSubmitted(
 jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
waiter

DAGScheduler是一个生产者-消费者模型。在DAGScheduler的实例dagScheduler在SparkContext中被创建时,dagScheduler初始化了一个守候进程,用来对DAGScheduler中的各种事件进行相应。

private def initializeEventProcessActor() {
 // blocking the thread until supervisor is started, which ensures eventProcessActor is
 // not null before any job is submitted
 implicit val timeout = Timeout(30 seconds)
 val initEventActorReply =
   dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
 eventProcessActor = Await.result(initEventActorReply, timeout.duration).
   asInstanceOf[ActorRef]
}

DAGSchedulerEventProcessActor这个class在DAGScheduler.scala中被定义,用来接受并处理DAGScheduler工作时产生的各种事件event,处理的方法是调用传入的dagScheduler中的方法。DAGSchedulerEventProcessActor处理的事件有:

  • JobSubmitted
  • StageCancelled
  • JobCancelled
  • JobGroupCancelled
  • AllJobsCancelled
  • ExecutorAdded
  • ExecutorLost
  • BeginEvent
  • GettingResultEvent
  • CompletionEvent
  • ResubmitFailedStages

JobSubmitted事件为例:

case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
 dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
   listener, properties)

dagScheduler.handleJobSubmitted将接收到finalRDD的依赖关系解析出来,生成stages,即整个DAG的结构,再调用函数将stage内的tasks打包成TaskSet,交给taskScheduler处理。跟着这个方法,handleJobSubmitted,就可以了解DAGScheduler的主要功能和实现原理。

01 private[scheduler] def handleJobSubmitted(jobId: Int,
02     finalRDD: RDD[_],
03     func: (TaskContext, Iterator[_]) => _,
04     partitions: Array[Int],
05     allowLocal: Boolean,
06     callSite: CallSite,
07     listener: JobListener,
08     properties: Properties = null)
09 {
10   var finalStage: Stage = null
11   try {
12     // New stage creation may throw an exception if, for example, jobs are run on a
13     // HadoopRDD whose underlying HDFS files have been deleted.
14     finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite)
15   } catch {
16     case e: Exception =>
17       logWarning("Creating new stage failed due to exception - job: " + jobId, e)
18       listener.jobFailed(e)
19       return
20   }
21   if (finalStage != null) {
22     val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
23     clearCacheLocs()
24     logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format(
25       job.jobId, callSite.shortForm, partitions.length, allowLocal))
26     logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
27     logInfo("Parents of final stage: " + finalStage.parents)
28     logInfo("Missing parents: " + getMissingParentStages(finalStage))
29     val shouldRunLocally =
30       localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
31     if (shouldRunLocally) {
32       // Compute very short actions like first() or take() with no parent stages locally.
33       listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties))
34       runLocally(job)
35     } else {
36       jobIdToActiveJob(jobId) = job
37       activeJobs += job
38       finalStage.resultOfJob = Some(job)
39       val stageIds = jobIdToStageIds(jobId).toArray
40       val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
41       listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties))
42       submitStage(finalStage)
43     }
44   }
45   submitWaitingStages()
46 }

可以看出,DAGScheduler生成stage,是通过最后一个RDD推算出来的,(这个RDD通过sc.runJob() -> dagScheduler.runJob() -> dagScheduler.submitJob() -> JobSubmitted() -> dagScheduler.handleJobSubmitted() 层层调用传进来的)

这一行代码,

finalStage = newStage(finalRDD, partitions.size, None, jobId, callSite)

通过调用newStage()方法,生成了finalStage。实际上,newStage()中调用了getParentStages()方法,由finalRDD向前追溯,生成了parentStages。

01 private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = {
02   val parents = new HashSet[Stage]
03   val visited = new HashSet[RDD[_]]
04   // We are manually maintaining a stack here to prevent StackOverflowError
05   // caused by recursively visiting
06   val waitingForVisit = new Stack[RDD[_]]
07   def visit(r: RDD[_]) {
08     if (!visited(r)) {
09       visited += r
10       // Kind of ugly: need to register RDDs with the cache here since
11       // we can't do it in its constructor because # of partitions is unknown
12       for (dep <- r.dependencies) {
13         dep match {
14           case shufDep: ShuffleDependency[_, _, _] =>
15             parents += getShuffleMapStage(shufDep, jobId)
16           case _ =>
17             waitingForVisit.push(dep.rdd)
18         }
19       }
20     }
21   }
22   waitingForVisit.push(rdd)
23   while (!waitingForVisit.isEmpty) {
24     visit(waitingForVisit.pop())
25   }
26   parents.toList
27 }

回到handleJobSubmitted(),看到27、28两行,一个是"Parents of final stage: ",这个是由getParentStages()方法获取的,而"Missing parents: ",是由getMissingParentStages获取的,在这里(handleJobSubmitted()),两者没有什么不同。但是在其他地方,调用两个函数还是会有不同效果。

01 private def getMissingParentStages(stage: Stage): List[Stage] = {
02  val missing = new HashSet[Stage]
03  val visited = new HashSet[RDD[_]]
04  // We are manually maintaining a stack here to prevent StackOverflowError
05  // caused by recursively visiting
06  val waitingForVisit = new Stack[RDD[_]]
07  def visit(rdd: RDD[_]) {
08    if (!visited(rdd)) {
09      visited += rdd
10      if (getCacheLocs(rdd).contains(Nil)) {
11        for (dep <- rdd.dependencies) {
12          dep match {
13            case shufDep: ShuffleDependency[_, _, _] =>
14              val mapStage = getShuffleMapStage(shufDep, stage.jobId)
15              if (!mapStage.isAvailable) {
16                missing += mapStage
17              }
18            case narrowDep: NarrowDependency[_] =>
19              waitingForVisit.push(narrowDep.rdd)
20          }
21        }
22      }
23    }
24  }
25  waitingForVisit.push(stage.rdd)
26  while (!waitingForVisit.isEmpty) {
27    visit(waitingForVisit.pop())
28  }
29  missing.toList
30 }

由以上的代码可以看出,getMissingParentStages()getParentStages()在第15、16行。

回到handleJobSubmitted()41、42行,DAGScheduler向监听总线发生一个JobStart的事件,之后,调用submitStage()将生成的Stage提交

01 /** Submits stage, but first recursively submits any missing parents. */
02 private def submitStage(stage: Stage) {
03   val jobId = activeJobForStage(stage)
04   if (jobId.isDefined) {
05     logDebug("submitStage(" + stage + ")")
06     if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
07       val missing = getMissingParentStages(stage).sortBy(_.id)
08       logDebug("missing: " + missing)
09       if (missing == Nil) {
10         logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
11         submitMissingTasks(stage, jobId.get)
12       } else {
13         for (parent <- missing) {
14           submitStage(parent)
15         }
16         waitingStages += stage
17       }
18     }
19   } else {
20     abortStage(stage, "No active job for stage " + stage.id)
21   }
22 }

submitMissingTasks()中,DAGScheduler将stage中的tasks进行拆分,并将tasks打包成TaskSet,交给TaskScheduler处理。

01 /* Called when stage's parents are available and we can now do its task. /
02 private def submitMissingTasks(stage: Stage, jobId: Int) {
03   logDebug("submitMissingTasks(" + stage + ")")
04   // Get our pending tasks and remember them in our pendingTasks entry
05   stage.pendingTasks.clear()
06 
07   ····
08 
09   val tasks: Seq[Task[_]] = if (stage.isShuffleMap) {
10     partitionsToCompute.map { id =>
11       val locs = getPreferredLocs(stage.rdd, id)
12       val part = stage.rdd.partitions(id)
13       new ShuffleMapTask(stage.id, taskBinary, part, locs)
14     }
15   } else {
16     val job = stage.resultOfJob.get
17     partitionsToCompute.map { id =>
18       val p: Int = job.partitions(id)
19       val part = stage.rdd.partitions(p)
20       val locs = getPreferredLocs(stage.rdd, p)
21       new ResultTask(stage.id, taskBinary, part, locs, id)
22     }
23   }
24 
25   if (tasks.size > 0) {
26     // Preemptively serialize a task to make sure it can be serialized.
27     try {
28       closureSerializer.serialize(tasks.head)
29     } catch {
30       case e: NotSerializableException =>
31         abortStage(stage, "Task not serializable: " + e.toString)
32         runningStages -= stage
33         return
34       case NonFatal(e) => // Other exceptions, such as IllegalArgumentException from Kryo.
35         abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
36         runningStages -= stage
37         return
38     }
39 
40     logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
41     stage.pendingTasks ++= tasks
42     logDebug("New pending tasks: " + stage.pendingTasks)
43     taskScheduler.submitTasks(
44       new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
45     stage.latestInfo.submissionTime = Some(clock.getTime())
46   } else {
47     // Because we posted SparkListenerStageSubmitted earlier, we should post
48     // SparkListenerStageCompleted here in case there are no tasks to run.
49     listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
50     logDebug("Stage " + stage + " is actually done; %b %d %d".format(
51       stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
52     runningStages -= stage
53   }
54 }


接下来的工作,就交给TaskScheduler解决了。

有时间再整理一下吧


Comments

Fork me on GitHub