Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions bin/cyhy-import
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Options:
-h --help Show this screen.
--version Show version.

-A --allow-restricted-ips Allow restricted IPs to be added to the database.
[default: False]
-i STAGE --init-stage STAGE Override the init-stage specified in file
-f --force Force import of existing request, destroying original
-s SECTION --section=SECTION Configuration section to use.
Expand Down Expand Up @@ -75,7 +77,7 @@ def has_intersections(db, nets, filename, owner):
return True


def has_restricted_ips(nets):
def has_restricted_ips(nets, allow_restricted=False):
geo_loc_db = GeoLocDB()
restricted_dict = defaultdict(lambda: IPSet())
for ip in nets:
Expand All @@ -92,12 +94,15 @@ def has_restricted_ips(nets):
print "%s:" % country
for cidr in cidrs.iter_cidrs():
print " %s" % cidr
print "Cannot continue!\nSome addresses associated with restricted countries."
return True
if allow_restricted:
print "WARNING: Restricted IPs will be added to the database."
else:
print "Cannot continue!\nSome addresses associated with restricted countries."
return True
return False


def import_request(db, request, source, force=False, init_stage=None):
def import_request(db, request, source, force=False, init_stage=None, allow_restricted=False):
owner = request["_id"]
# Check if owner already exists
db_request = db.RequestDoc.get_by_owner(owner)
Expand All @@ -116,7 +121,7 @@ def import_request(db, request, source, force=False, init_stage=None):
request["init_stage"] = init_stage
if has_intersections(db, nets, source, owner):
return False
if has_restricted_ips(nets):
if has_restricted_ips(nets, allow_restricted):
# This is not combined with the previous check so the output
# from each function can be seen so the user can get all messages
# without multiple runs
Expand Down Expand Up @@ -146,7 +151,7 @@ def import_request(db, request, source, force=False, init_stage=None):
return True


def import_file(db, filename, force, init_stage=None):
def import_file(db, filename, force, init_stage=None, allow_restricted=False):
try:
with open(filename, "r") as f:
# For py3, encoding should move from json.load into open statement
Expand All @@ -155,28 +160,28 @@ def import_file(db, filename, force, init_stage=None):
print ("Document contains a non-ASCII character: {}".format(e))
return False

return import_request(db, request, filename, force, init_stage)
return import_request(db, request, filename, force, init_stage, allow_restricted)


def import_stdin(db, force, init_stage=None):
def import_stdin(db, force, init_stage=None, allow_restricted=False):
try:
# This would need to be changed for py3
request = json.load(sys.stdin, encoding="ascii")
except UnicodeDecodeError as e:
print ("Document contains a non-ASCII character: {}".format(e))
return False

return import_request(db, request, "from stdin", force, init_stage)
return import_request(db, request, "from stdin", force, init_stage, allow_restricted)


def main():
args = docopt(__doc__, version="v0.0.2")
db = database.db_from_config(args["--section"])

if args["FILE"] != None:
success = import_file(db, args["FILE"], args["--force"], args["--init-stage"])
success = import_file(db, args["FILE"], args["--force"], args["--init-stage"], args["--allow-restricted-ips"])
else:
success = import_stdin(db, args["--force"], args["--init-stage"])
success = import_stdin(db, args["--force"], args["--init-stage"], args["--allow-restricted-ips"])

if not success:
sys.exit(-1)
Expand Down
13 changes: 9 additions & 4 deletions bin/cyhy-ip
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Options:
-h --help Show this screen.
--version Show version.

-A --allow-restricted-ips Allow restricted IPs to be added to the database.
[default: False]
-f FILENAME --file=FILENAME Read addresses from a file.
-s SECTION --section=SECTION Configuration section to use.

Expand Down Expand Up @@ -165,7 +167,7 @@ def do_list_all(db):
print_intersections(intersections)


def add(db, owner, cidrs):
def add(db, owner, cidrs, allow_restricted=False):
intersections = db.RequestDoc.get_all_intersections(
cidrs
) # intersections from database
Expand All @@ -185,8 +187,11 @@ def add(db, owner, cidrs):
print "%s:" % country
for cidr in cidrs.iter_cidrs():
print " %s" % cidr
print "Cannot continue!\nSome addresses associated with restricted countries."
sys.exit(-1)
if allow_restricted:
print "WARNING: Restricted IPs will be added to the database."
else:
print "Cannot continue!\nSome addresses associated with restricted countries."
sys.exit(-1)

intersections.update(get_special_intersections(cidrs)) # intersections with RFCs
if intersections:
Expand Down Expand Up @@ -481,7 +486,7 @@ def main():
print "# %d" % len(nets)
print_cidrs(nets, indent=0)
elif args["add"]:
add(db, args["OWNER"], nets)
add(db, args["OWNER"], nets, args["--allow-restricted-ips"])
elif args["remove"]:
wrapped_remove(db, args["OWNER"], nets)
elif args["compare"]:
Expand Down