Skip to content

Commit 40be6cc

Browse files
vrkhareJonathan Esterhazy
authored and
Jonathan Esterhazy
committed
add Traveling Salesman Problem RL example
1 parent e443287 commit 40be6cc

33 files changed

+4596
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Traveling Salesman and Vehicle Routing with Amazon SageMaker RL
2+
3+
The travelling salesman problem (TSP) is a classic algorithmic problem in the field of computer science and operations research. Given a list of cities and the distances between each pair of cities, the problem is to find the shortest possible route that visits each city and returns to the origin city.
4+
5+
The problem is NP-complete as the number of combinations of cities grows larger as we add more cities.
6+
7+
In the classic version of the problem, the salesman picks a city to start, travels through remaining cities and returns to the original city.
8+
9+
In this version, we have slightly modified the problem, presenting it as a restaurant order delivery problem on a 2D gridworld. The agent (driver) starts at the restaurant, a fixed point on the grid. Then, delivery orders appear elsewhere on the grid. The driver needs to visit the orders, and return to the restaurant, to obtain rewards. Rewards are proportional to the time taken to do this (equivalent to the distance, as each timestep moves one square on the grid).
10+
11+
Vehicle Routing is a similar problem where the algorithm optimizes the movement of a fleet of vehicles. In our formulation, we have a delivery driver who accepts orders from customers, picks up food from a restaurant and delivers it to the customer. The driver optimizes to increase the number of successful deliveries within a time limit.
12+
13+
## Contents
14+
15+
* `rl_traveling_salesman_vehicle_routing_coach`: notebook used for training traveling salesman and vehicle routing policies.
16+
* `src/`
17+
* `TSP_env.py`: traveling salesman problem is defined here.
18+
* `TSP_view_2D.py`: visualizer for the traveling salesman problem.
19+
* `TSP_baseline.py`: baseline implementation of traveling salesman.
20+
* `TSP_baseline_util.py`: helper file for baseline implmentation.
21+
* `VRP_env.py`: vehicle routing problem is defined here.
22+
* `VRP_abstract_env.py`: defines an easier version of vehicle routing problem where the driver knows the path to go from one place to another.
23+
* `VRP_view_2D.py`: visualizer for the vehicle routing problem.
24+
* `VRP_baseline.py`: baseline implementation of vehicle routing.
25+
* `VRP_baseline_util.py`: helper file for baseline implmentation.
26+
* `train-coach.py`: launcher for coach training.
27+
* `evaluate-coach.py`: launcher for coach evaluation.
28+
* `preset-tsp-easy.py`: coach preset for Clipped PPO for the easy version of TSP.
29+
* `preset-tsp-medium.py`: coach preset for Clipped PPO for the medium version of TSP.
30+
* `preset-vrp-easy.py`: coach preset for Clipped PPO for the easy version of VRP.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
{
3+
"default-runtime": "nvidia",
4+
"runtimes": {
5+
"nvidia": {
6+
"path": "/usr/bin/nvidia-container-runtime",
7+
"runtimeArgs": []
8+
}
9+
}
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
from __future__ import absolute_import
15+
16+
import base64
17+
import contextlib
18+
import os
19+
import time
20+
import shlex
21+
import shutil
22+
import subprocess
23+
import sys
24+
import tempfile
25+
26+
import boto3
27+
import json
28+
29+
IMAGE_TEMPLATE = "{account}.dkr.ecr.{region}.amazonaws.com/{image_name}:{version}"
30+
31+
32+
def build_and_push_docker_image(repository_name, dockerfile='Dockerfile', build_args={}):
33+
"""Builds a docker image from the specified dockerfile, and pushes it to
34+
ECR. Handles things like ECR login, creating the repository.
35+
36+
Returns the name of the created docker image in ECR
37+
"""
38+
base_image = _find_base_image_in_dockerfile(dockerfile)
39+
_ecr_login_if_needed(base_image)
40+
_build_from_dockerfile(repository_name, dockerfile, build_args)
41+
ecr_tag = push(repository_name)
42+
return ecr_tag
43+
44+
45+
def _build_from_dockerfile(repository_name, dockerfile='Dockerfile', build_args={}):
46+
build_cmd = ['docker', 'build', '-t', repository_name, '-f', dockerfile, '.']
47+
for k,v in build_args.items():
48+
build_cmd += ['--build-arg', '%s=%s' % (k,v)]
49+
50+
print("Building docker image %s from %s" % (repository_name, dockerfile))
51+
_execute(build_cmd)
52+
print("Done building docker image %s" % repository_name)
53+
54+
55+
def _find_base_image_in_dockerfile(dockerfile):
56+
dockerfile_lines = open(dockerfile).readlines()
57+
from_line = list(filter(lambda line: line.startswith("FROM "), dockerfile_lines))[0].rstrip()
58+
base_image = from_line[5:]
59+
return base_image
60+
61+
62+
def push(tag, aws_account=None, aws_region=None):
63+
"""
64+
Push the builded tag to ECR.
65+
66+
Args:
67+
tag (string): tag which you named your algo
68+
aws_account (string): aws account of the ECR repo
69+
aws_region (string): aws region where the repo is located
70+
71+
Returns:
72+
(string): ECR repo image that was pushed
73+
"""
74+
session = boto3.Session()
75+
aws_account = aws_account or session.client("sts").get_caller_identity()['Account']
76+
aws_region = aws_region or session.region_name
77+
try:
78+
repository_name, version = tag.split(':')
79+
except ValueError: # split failed because no :
80+
repository_name = tag
81+
version = "latest"
82+
ecr_client = session.client('ecr', region_name=aws_region)
83+
84+
_create_ecr_repo(ecr_client, repository_name)
85+
_ecr_login(ecr_client, aws_account)
86+
ecr_tag = _push(aws_account, aws_region, tag)
87+
88+
return ecr_tag
89+
90+
91+
def _push(aws_account, aws_region, tag):
92+
ecr_repo = '%s.dkr.ecr.%s.amazonaws.com' % (aws_account, aws_region)
93+
ecr_tag = '%s/%s' % (ecr_repo, tag)
94+
_execute(['docker', 'tag', tag, ecr_tag])
95+
print("Pushing docker image to ECR repository %s/%s\n" % (ecr_repo, tag))
96+
_execute(['docker', 'push', ecr_tag])
97+
print("Done pushing %s" % ecr_tag)
98+
return ecr_tag
99+
100+
101+
def _create_ecr_repo(ecr_client, repository_name):
102+
"""
103+
Create the repo if it doesn't already exist.
104+
"""
105+
try:
106+
ecr_client.create_repository(repositoryName=repository_name)
107+
print("Created new ECR repository: %s" % repository_name)
108+
except ecr_client.exceptions.RepositoryAlreadyExistsException:
109+
print("ECR repository already exists: %s" % repository_name)
110+
111+
112+
def _ecr_login(ecr_client, aws_account):
113+
auth = ecr_client.get_authorization_token(registryIds=[aws_account])
114+
authorization_data = auth['authorizationData'][0]
115+
116+
raw_token = base64.b64decode(authorization_data['authorizationToken'])
117+
token = raw_token.decode('utf-8').strip('AWS:')
118+
ecr_url = auth['authorizationData'][0]['proxyEndpoint']
119+
120+
cmd = ['docker', 'login', '-u', 'AWS', '-p', token, ecr_url]
121+
_execute(cmd, quiet=True)
122+
print("Logged into ECR")
123+
124+
125+
def _ecr_login_if_needed(image):
126+
ecr_client = boto3.client('ecr')
127+
128+
# Only ECR images need login
129+
if not ('dkr.ecr' in image and 'amazonaws.com' in image):
130+
return
131+
132+
# do we have the image?
133+
if _check_output('docker images -q %s' % image).strip():
134+
return
135+
136+
aws_account = image.split('.')[0]
137+
_ecr_login(ecr_client, aws_account)
138+
139+
140+
@contextlib.contextmanager
141+
def _tmpdir(suffix='', prefix='tmp', dir=None): # type: (str, str, str) -> None
142+
"""Create a temporary directory with a context manager. The file is deleted when the context exits.
143+
144+
The prefix, suffix, and dir arguments are the same as for mkstemp().
145+
146+
Args:
147+
suffix (str): If suffix is specified, the file name will end with that suffix, otherwise there will be no
148+
suffix.
149+
prefix (str): If prefix is specified, the file name will begin with that prefix; otherwise,
150+
a default prefix is used.
151+
dir (str): If dir is specified, the file will be created in that directory; otherwise, a default directory is
152+
used.
153+
Returns:
154+
str: path to the directory
155+
"""
156+
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
157+
yield tmp
158+
shutil.rmtree(tmp)
159+
160+
161+
def _execute(command, quiet=False):
162+
if not quiet:
163+
print("$ %s" % ' '.join(command))
164+
process = subprocess.Popen(command,
165+
stdout=subprocess.PIPE,
166+
stderr=subprocess.STDOUT)
167+
try:
168+
_stream_output(process)
169+
except RuntimeError as e:
170+
# _stream_output() doesn't have the command line. We will handle the exception
171+
# which contains the exit code and append the command line to it.
172+
msg = "Failed to run: %s, %s" % (command, str(e))
173+
raise RuntimeError(msg)
174+
175+
176+
def _stream_output(process):
177+
"""Stream the output of a process to stdout
178+
179+
This function takes an existing process that will be polled for output. Only stdout
180+
will be polled and sent to sys.stdout.
181+
182+
Args:
183+
process(subprocess.Popen): a process that has been started with
184+
stdout=PIPE and stderr=STDOUT
185+
186+
Returns (int): process exit code
187+
"""
188+
exit_code = None
189+
190+
while exit_code is None:
191+
stdout = process.stdout.readline().decode("utf-8")
192+
sys.stdout.write(stdout)
193+
exit_code = process.poll()
194+
195+
if exit_code != 0:
196+
raise RuntimeError("Process exited with code: %s" % exit_code)
197+
198+
199+
def _check_output(cmd, *popenargs, **kwargs):
200+
if isinstance(cmd, str):
201+
cmd = shlex.split(cmd)
202+
203+
success = True
204+
try:
205+
output = subprocess.check_output(cmd, *popenargs, **kwargs)
206+
except subprocess.CalledProcessError as e:
207+
output = e.output
208+
success = False
209+
210+
output = output.decode("utf-8")
211+
if not success:
212+
print("Command output: %s" % output)
213+
raise Exception("Failed to run %s" % ",".join(cmd))
214+
215+
return output

0 commit comments

Comments
 (0)