從 Ray 2.3.0 之後,開始支援 Apache Spark Cluster,也讓原本使用 Spark 進行分散式機器學習訓練的工作,可以透過 Ray 來取代 Spark。
%pip install ray[default]>=2.3.0
ray.util.spark.setup_ray_cluster
這個 API 將會在 Spark 上建立 Ray cluster。
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster
setup_ray_cluster(
num_worker_nodes=2,
num_cpus_per_node=4,
collect_log_to_path="/dbfs/path/to/ray_collected_logs"
)
import ray
import random
import time
from fractions import Fraction
ray.init()
@ray.remote
def pi4_sample(sample_count):
"""pi4_sample runs sample_count experiments, and returns the
fraction of time it was inside the circle.
"""
in_count = 0
for i in range(sample_count):
x = random.random()
y = random.random()
if x*x + y*y <= 1:
in_count += 1
return Fraction(in_count, sample_count)
SAMPLE_COUNT = 1000 * 1000
start = time.time()
future = pi4_sample.remote(sample_count=SAMPLE_COUNT)
pi4 = ray.get(future)
end = time.time()
dur = end - start
print(f'Running {SAMPLE_COUNT} tests took {dur} seconds')
pi = pi4 * 4
print(float(pi))
透過 Spark DataFrame 來讀取資料,並將資料轉換成 Ray dataset。
import ray
import os
from urllib.parse import urlparse
def create_ray_dataset_from_spark_dataframe(spark_dataframe, dbfs_tmp_path):
spark_dataframe.write.mode('overwrite').parquet(dbfs_tmp_path)
fuse_path = "/dbfs" + urlparse(dbfs_tmp_path).path
return ray.data.read_parquet(fuse_path)
# For example, read a Delta Table as a Spark DataFrame
spark_df = spark.read.table("diviner_demo.diviner_pedestrians_data_500")
# Provide a dbfs location to write the table to
data_location_2 = (
"dbfs:/home/example.user@databricks.com/data/ray_test/test_data_2"
)
# Convert the Spark DataFrame to a Ray dataset
ray_dataset = create_ray_dataset_from_spark_dataframe(
spark_dataframe=spark_df,
dbfs_tmp_path=data_location_2
)
ray.utils.spark.shutdown_ray_cluster
API.Reference: