Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,56 +133,50 @@ private[spark] object PythonWorkerUtils extends Logging {
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val addedBids = newBids.diff(oldBids)
val cnt = toRemove.size + addedBids.size
val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty
dataOut.writeBoolean(needsDecryptionServer)
dataOut.writeInt(cnt)
def sendBidsToRemove(): Unit = {
for (bid <- toRemove) {
// remove the broadcast from worker
dataOut.writeLong(-bid - 1) // bid >= 0
oldBids.remove(bid)
}
}
var connInfo: Option[Any] = None
var secret: Option[String] = None

if (needsDecryptionServer) {
// if there is encryption, we setup a server which reads the encrypted files, and sends
// the decrypted data to python
val idsAndFiles = broadcastVars.flatMap { broadcast =>
if (!oldBids.contains(broadcast.id)) {
oldBids.add(broadcast.id)
Some((broadcast.id, broadcast.value.path))
} else {
None
}
val idsAndFiles = broadcastVars.filter(b => !oldBids.contains(b.id)).map { broadcast =>
(broadcast.id, broadcast.value.path)
}
val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
server.connInfo match {
case portNum: Int =>
dataOut.writeInt(portNum)
writeUTF(server.secret, dataOut)
connInfo = Some(portNum)
secret = Some(server.secret)
case sockPath: String =>
dataOut.writeInt(-1)
writeUTF(sockPath, dataOut)
connInfo = Some(sockPath)
}
logTrace(s"broadcast decryption server setup on ${server.connInfo}")
sendBidsToRemove()
idsAndFiles.foreach { case (id, _) =>
// send new broadcast
dataOut.writeLong(id)
}
dataOut.flush()
} else {
sendBidsToRemove()
for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
}
dataOut.flush()

val json = Serialization.write(Map(
"broadcast_decryption_server" -> needsDecryptionServer,
"conn_info" -> connInfo.orNull,
"auth_secret" -> secret.orNull,
"broadcast_variables" -> (
broadcastVars.filter(b => !oldBids.contains(b.id)).map { broadcast =>
Map(
"bid" -> broadcast.id,
"path" -> broadcast.value.path
)
} ++ toRemove.map { bid =>
Map(
"bid" -> (-bid - 1),
"path" -> null
)
}
)
))

oldBids.clear()
oldBids ++= newBids

writeUTF(json, dataOut)
}

/**
Expand Down
39 changes: 21 additions & 18 deletions python/pyspark/worker_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
import importlib
from inspect import currentframe, getframeinfo
import json
import os
import sys
from typing import Any, IO, Optional
Expand All @@ -42,7 +43,6 @@
from pyspark.errors import PySparkRuntimeError
from pyspark.util import local_connect_and_auth
from pyspark.serializers import (
read_bool,
read_int,
read_long,
write_int,
Expand Down Expand Up @@ -155,39 +155,42 @@ def setup_spark_files(infile: IO) -> None:
def setup_broadcasts(infile: IO) -> None:
"""
Set up broadcasted variables.
{
"conn_info": int | str | None,
"auth_secret": str | None,
"broadcast_variables": [
{
"bid": int,
"path": str | None,
}
]
}
"""
if not is_remote_only():
from pyspark.core.broadcast import Broadcast, _broadcastRegistry

# fetch names and values of broadcast variables
needs_broadcast_decryption_server = read_bool(infile)
num_broadcast_variables = read_int(infile)
if needs_broadcast_decryption_server:
data = json.loads(utf8_deserializer.loads(infile))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we use this info for debugging? I feel like you can just log it here instead of JSON ser/de

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not quite get the question. We are not using it for debugging. This is part of the protocol. We need this from JVM.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind describing how we're going to use this in the PR description?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry use what? This is part of the JVM <-> Python worker protocol. We are not adding any new features. JVM used to send broadcast variable information integer by integer (with some strings in the middle) to Python. Now instead of that raw fragile protocol, we send all of the broadcast variable information in a JSON string.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to understand why we need this. Is this to purely make the protocol more stable?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case, I am not super supportive of this change. This could impact jobs like Structured Streaming (with micro batches) or ML jobs that disable spark.python.worker.reuse (which happen often in practice to work around any problem by having long living daemon worker). Considering the overhead vs benefit, I would prefer to just leave it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance impact is a big red herring. This change introduced two kinds of "overhead":

  • CPU time to encode/decode json
  • Extra bytes through the JVM/worker network (it's on the same machine)

Decoding a small json string takes about 1us. It's probably on the same range on scala side. Local network runs at least 10Gbps, an extra 100 bytes takes about 0.1us.

That's the overhead we introduce for every UDF run.

Currently, without reuse-worker, each worker takes about a few hundred ms to spawn. I made an optimization a few weeks ago that eliminated 100-200ms per spawn for reused worker and no one even notice it.

1us is 0.01% of 100ms. That's literally nothing. If we care about 1us, we have serious issues with our current UDF. I can get a lot of 1us from our current code if that's what we need to make our protocol more stable.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that the benefit is not super compelling. If we plan to refactor the whole protocol or sth, yeah probably we should go ahead. But doing this alone doesn't look worthwhile to me.

If we do want to refactor the whole protocol, we should better have a bigger picture.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of the effort to refactor the whole protocol. I'm doing it piece by piece so eventually we can have a structured message from JVM. Ideally a multi-phased message. All the initialization (probably the UDF definition) should be sent in a single message. The message should be relatively resistant to new changes. For example it won't stuck if we decide to add something new. Or it should report a clear error when the message is not following the protocol.

json is good in a sense that, if we decide to add something else to broadcast variable protocol, it's easy. If we did it wrong, we can find it quickly too. Otherwise we had to be really careful about where to insert the read_long and if we did not do it correctly, the worker could stuck at an arbitrary point.

This refactor actually eliminated some unnecessary fields. Our old protocol is too fragile that no one is willing to touch it. I want to gradually convert it to a more structured way. A single switch is a bit too dangerous.


broadcast_sock_file = None
if data["broadcast_decryption_server"]:
# read the decrypted data from a server in the jvm
conn_info = read_int(infile)
auth_secret = None
if conn_info == -1:
conn_info = utf8_deserializer.loads(infile)
else:
auth_secret = utf8_deserializer.loads(infile)
(broadcast_sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
(broadcast_sock_file, _) = local_connect_and_auth(data["conn_info"], data["auth_secret"])

for _ in range(num_broadcast_variables):
bid = read_long(infile)
for broadcast_variable in data["broadcast_variables"]:
bid = broadcast_variable["bid"]
if bid >= 0:
if needs_broadcast_decryption_server:
if broadcast_sock_file is not None:
read_bid = read_long(broadcast_sock_file)
assert read_bid == bid
_broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file)
else:
path = utf8_deserializer.loads(infile)
_broadcastRegistry[bid] = Broadcast(path=path)
_broadcastRegistry[bid] = Broadcast(path=broadcast_variable["path"])

else:
bid = -bid - 1
_broadcastRegistry.pop(bid)

if needs_broadcast_decryption_server:
if broadcast_sock_file is not None:
broadcast_sock_file.write(b"1")
broadcast_sock_file.close()

Expand Down